From 7eeccf4ebb741b593b7c9d8152ab6acba8001fb4 Mon Sep 17 00:00:00 2001 From: wxywb Date: Mon, 26 Sep 2022 14:47:46 +0800 Subject: [PATCH] fix the gpu usage. Signed-off-by: wxywb --- taiyi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/taiyi.py b/taiyi.py index 51084e0..27123a5 100644 --- a/taiyi.py +++ b/taiyi.py @@ -41,6 +41,9 @@ class Taiyi(NNOperator): self.clip_model = CLIPModel.from_pretrained(config['clip_model']) self.processor = CLIPProcessor.from_pretrained(config['processor']) + self.text_encoder.to(self.device) + self.clip_model.to(self.device) + def inference_single_data(self, data): if self.modality == 'image': vec = self._inference_from_image(data) @@ -72,7 +75,7 @@ class Taiyi(NNOperator): @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): image = to_pil(img) - image = self.processor(images=image, return_tensors="pt") + image = self.processor(images=image, return_tensors="pt").to(self.device) image_features = self.clip_model.get_image_features(**image) return image_features