logo
Browse Source

fix in triton

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
be6635fae0
  1. 33
      clip.py

33
clip.py

@ -22,7 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
#from towhee.dc2 import accelerate
# from towhee.dc2 import accelerate
def create_model(model_name, modality, checkpoint_path, device):
@ -42,8 +42,7 @@ def create_model(model_name, modality, checkpoint_path, device):
clip = CLIPModelText(hf_clip_model)
else:
raise ValueError("modality[{}] not implemented.".format(modality))
model = Model(clip)
return model
return clip
class CLIPModelVision(nn.Module):
def __init__(self, model):
@ -63,15 +62,23 @@ class CLIPModelText(nn.Module):
text_embeds = self.backbone.get_text_features(input_ids, attention_mask)
return text_embeds
#@accelerate
# @accelerate
class Model:
def __init__(self, model):
self.model = model
def __init__(self, model_name, modality, checkpoint_path, device):
self.model = create_model(model_name, modality, checkpoint_path, device)
self.device = device
def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs)
new_args = []
for item in args:
new_args.append(item.to(self.device))
new_kwargs = {}
for k, value in kwargs.items():
new_kwargs[k] = value.to(self.device)
outs = self.model(*new_args, **new_kwargs)
return outs
@register(output_schema=['vec'])
class Clip(NNOperator):
"""
@ -82,11 +89,11 @@ class Clip(NNOperator):
self.modality = modality
self.device = device
self.checkpoint_path = checkpoint_path
cfg = self._configs()[model_name]
real_name = self._configs()[model_name]
self.model = create_model(cfg, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg)
self.model = Model(real_name, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(real_name)
self.processor = CLIPProcessor.from_pretrained(real_name)
def inference_single_data(self, data):
if self.modality == 'image':
@ -113,14 +120,14 @@ class Clip(NNOperator):
def _inference_from_text(self, text):
tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device))
text_features = self.model(tokens['input_ids'], tokens['attention_mask'])
return text_features
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = to_pil(img)
inputs = self.processor(images=img, return_tensors="pt")
image_features = self.model(inputs['pixel_values'].to(self.device))
image_features = self.model(inputs['pixel_values'])
return image_features
def train(self, **kwargs):

Loading…
Cancel
Save