diff --git a/clip.py b/clip.py index ec05a33..beebab4 100644 --- a/clip.py +++ b/clip.py @@ -22,7 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor -#from towhee.dc2 import accelerate +# from towhee.dc2 import accelerate def create_model(model_name, modality, checkpoint_path, device): @@ -42,8 +42,7 @@ def create_model(model_name, modality, checkpoint_path, device): clip = CLIPModelText(hf_clip_model) else: raise ValueError("modality[{}] not implemented.".format(modality)) - model = Model(clip) - return model + return clip class CLIPModelVision(nn.Module): def __init__(self, model): @@ -63,17 +62,25 @@ class CLIPModelText(nn.Module): text_embeds = self.backbone.get_text_features(input_ids, attention_mask) return text_embeds -#@accelerate +# @accelerate class Model: - def __init__(self, model): - self.model = model + def __init__(self, model_name, modality, checkpoint_path, device): + self.model = create_model(model_name, modality, checkpoint_path, device) + self.device = device def __call__(self, *args, **kwargs): - outs = self.model(*args, **kwargs) + new_args = [] + for item in args: + new_args.append(item.to(self.device)) + new_kwargs = {} + for k, value in kwargs.items(): + new_kwargs[k] = value.to(self.device) + outs = self.model(*new_args, **new_kwargs) return outs + @register(output_schema=['vec']) -class Clip(NNOperator): +class Clip(NNOperator): """ CLIP multi-modal embedding operator """ @@ -82,11 +89,11 @@ class Clip(NNOperator): self.modality = modality self.device = device self.checkpoint_path = checkpoint_path - cfg = self._configs()[model_name] + real_name = self._configs()[model_name] - self.model = create_model(cfg, modality, checkpoint_path, device) - self.tokenizer = CLIPTokenizer.from_pretrained(cfg) - self.processor = CLIPProcessor.from_pretrained(cfg) + self.model = Model(real_name, modality, checkpoint_path, device) + self.tokenizer = CLIPTokenizer.from_pretrained(real_name) + self.processor = CLIPProcessor.from_pretrained(real_name) def inference_single_data(self, data): if self.modality == 'image': @@ -113,14 +120,14 @@ class Clip(NNOperator): def _inference_from_text(self, text): tokens = self.tokenizer([text], padding=True, return_tensors="pt") - text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device)) + text_features = self.model(tokens['input_ids'], tokens['attention_mask']) return text_features @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = to_pil(img) inputs = self.processor(images=img, return_tensors="pt") - image_features = self.model(inputs['pixel_values'].to(self.device)) + image_features = self.model(inputs['pixel_values']) return image_features def train(self, **kwargs):