diff --git a/taiyi.py b/taiyi.py index 51084e0..27123a5 100644 --- a/taiyi.py +++ b/taiyi.py @@ -41,6 +41,9 @@ class Taiyi(NNOperator): self.clip_model = CLIPModel.from_pretrained(config['clip_model']) self.processor = CLIPProcessor.from_pretrained(config['processor']) + self.text_encoder.to(self.device) + self.clip_model.to(self.device) + def inference_single_data(self, data): if self.modality == 'image': vec = self._inference_from_image(data) @@ -72,7 +75,7 @@ class Taiyi(NNOperator): @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): image = to_pil(img) - image = self.processor(images=image, return_tensors="pt") + image = self.processor(images=image, return_tensors="pt").to(self.device) image_features = self.clip_model.get_image_features(**image) return image_features