|
|
@ -22,6 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee import register |
|
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
|
#from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
class CLIPModelVision(nn.Module): |
|
|
@ -49,10 +50,10 @@ class Clip(NNOperator): |
|
|
|
""" |
|
|
|
CLIP multi-modal embedding operator |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str, modality: str, device, checkpoint_path): |
|
|
|
def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None): |
|
|
|
self.model_name = model_name |
|
|
|
self.modality = modality |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.device = device |
|
|
|
cfg = self._configs()[model_name] |
|
|
|
try: |
|
|
|
clip_model = CLIPModel.from_pretrained(cfg) |
|
|
@ -71,6 +72,7 @@ class Clip(NNOperator): |
|
|
|
self.model = CLIPModelText(clip_model) |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
|
self.model.to(self.device) |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) |
|
|
|
self.processor = CLIPProcessor.from_pretrained(cfg) |
|
|
|
|
|
|
@ -99,14 +101,14 @@ class Clip(NNOperator): |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
tokens = self.tokenizer([text], padding=True, return_tensors="pt") |
|
|
|
text_features = self.model(tokens['input_ids'],tokens['attention_mask']) |
|
|
|
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device)) |
|
|
|
return text_features |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
img = to_pil(img) |
|
|
|
inputs = self.processor(images=img, return_tensors="pt") |
|
|
|
image_features = self.model(inputs['pixel_values']) |
|
|
|
image_features = self.model(inputs['pixel_values'].to(self.device)) |
|
|
|
return image_features |
|
|
|
|
|
|
|
def train(self, **kwargs): |
|
|
|