|
@ -22,7 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
#from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_model(model_name, modality, checkpoint_path, device): |
|
|
def create_model(model_name, modality, checkpoint_path, device): |
|
@ -42,8 +42,7 @@ def create_model(model_name, modality, checkpoint_path, device): |
|
|
clip = CLIPModelText(hf_clip_model) |
|
|
clip = CLIPModelText(hf_clip_model) |
|
|
else: |
|
|
else: |
|
|
raise ValueError("modality[{}] not implemented.".format(modality)) |
|
|
raise ValueError("modality[{}] not implemented.".format(modality)) |
|
|
model = Model(clip) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
return clip |
|
|
|
|
|
|
|
|
class CLIPModelVision(nn.Module): |
|
|
class CLIPModelVision(nn.Module): |
|
|
def __init__(self, model): |
|
|
def __init__(self, model): |
|
@ -63,17 +62,25 @@ class CLIPModelText(nn.Module): |
|
|
text_embeds = self.backbone.get_text_features(input_ids, attention_mask) |
|
|
text_embeds = self.backbone.get_text_features(input_ids, attention_mask) |
|
|
return text_embeds |
|
|
return text_embeds |
|
|
|
|
|
|
|
|
#@accelerate |
|
|
|
|
|
|
|
|
# @accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model): |
|
|
|
|
|
self.model = model |
|
|
|
|
|
|
|
|
def __init__(self, model_name, modality, checkpoint_path, device): |
|
|
|
|
|
self.model = create_model(model_name, modality, checkpoint_path, device) |
|
|
|
|
|
self.device = device |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
outs = self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
new_args = [] |
|
|
|
|
|
for item in args: |
|
|
|
|
|
new_args.append(item.to(self.device)) |
|
|
|
|
|
new_kwargs = {} |
|
|
|
|
|
for k, value in kwargs.items(): |
|
|
|
|
|
new_kwargs[k] = value.to(self.device) |
|
|
|
|
|
outs = self.model(*new_args, **new_kwargs) |
|
|
return outs |
|
|
return outs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
|
class Clip(NNOperator): |
|
|
|
|
|
|
|
|
class Clip(NNOperator): |
|
|
""" |
|
|
""" |
|
|
CLIP multi-modal embedding operator |
|
|
CLIP multi-modal embedding operator |
|
|
""" |
|
|
""" |
|
@ -82,11 +89,11 @@ class Clip(NNOperator): |
|
|
self.modality = modality |
|
|
self.modality = modality |
|
|
self.device = device |
|
|
self.device = device |
|
|
self.checkpoint_path = checkpoint_path |
|
|
self.checkpoint_path = checkpoint_path |
|
|
cfg = self._configs()[model_name] |
|
|
|
|
|
|
|
|
real_name = self._configs()[model_name] |
|
|
|
|
|
|
|
|
self.model = create_model(cfg, modality, checkpoint_path, device) |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) |
|
|
|
|
|
self.processor = CLIPProcessor.from_pretrained(cfg) |
|
|
|
|
|
|
|
|
self.model = Model(real_name, modality, checkpoint_path, device) |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(real_name) |
|
|
|
|
|
self.processor = CLIPProcessor.from_pretrained(real_name) |
|
|
|
|
|
|
|
|
def inference_single_data(self, data): |
|
|
def inference_single_data(self, data): |
|
|
if self.modality == 'image': |
|
|
if self.modality == 'image': |
|
@ -113,14 +120,14 @@ class Clip(NNOperator): |
|
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
def _inference_from_text(self, text): |
|
|
tokens = self.tokenizer([text], padding=True, return_tensors="pt") |
|
|
tokens = self.tokenizer([text], padding=True, return_tensors="pt") |
|
|
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device)) |
|
|
|
|
|
|
|
|
text_features = self.model(tokens['input_ids'], tokens['attention_mask']) |
|
|
return text_features |
|
|
return text_features |
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def _inference_from_image(self, img): |
|
|
def _inference_from_image(self, img): |
|
|
img = to_pil(img) |
|
|
img = to_pil(img) |
|
|
inputs = self.processor(images=img, return_tensors="pt") |
|
|
inputs = self.processor(images=img, return_tensors="pt") |
|
|
image_features = self.model(inputs['pixel_values'].to(self.device)) |
|
|
|
|
|
|
|
|
image_features = self.model(inputs['pixel_values']) |
|
|
return image_features |
|
|
return image_features |
|
|
|
|
|
|
|
|
def train(self, **kwargs): |
|
|
def train(self, **kwargs): |
|
|