From 3af58627d63c52fe1b396930d67598c97dfcc8e7 Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 18 Jan 2023 15:59:24 +0800 Subject: [PATCH] add default value for device. Signed-off-by: wxywb --- clip.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/clip.py b/clip.py index e3de597..6e78f36 100644 --- a/clip.py +++ b/clip.py @@ -22,6 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor +#from towhee.dc2 import accelerate #@accelerate class CLIPModelVision(nn.Module): @@ -49,10 +50,10 @@ class Clip(NNOperator): """ CLIP multi-modal embedding operator """ - def __init__(self, model_name: str, modality: str, device, checkpoint_path): + def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None): self.model_name = model_name self.modality = modality - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device cfg = self._configs()[model_name] try: clip_model = CLIPModel.from_pretrained(cfg) @@ -71,6 +72,7 @@ class Clip(NNOperator): self.model = CLIPModelText(clip_model) else: raise ValueError("modality[{}] not implemented.".format(self.modality)) + self.model.to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained(cfg) self.processor = CLIPProcessor.from_pretrained(cfg) @@ -99,14 +101,14 @@ class Clip(NNOperator): def _inference_from_text(self, text): tokens = self.tokenizer([text], padding=True, return_tensors="pt") - text_features = self.model(tokens['input_ids'],tokens['attention_mask']) + text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device)) return text_features @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = to_pil(img) inputs = self.processor(images=img, return_tensors="pt") - image_features = self.model(inputs['pixel_values']) + image_features = self.model(inputs['pixel_values'].to(self.device)) return image_features def train(self, **kwargs):