logo
Browse Source

fix in triton

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
be6635fae0
  1. 29
      clip.py

29
clip.py

@ -42,8 +42,7 @@ def create_model(model_name, modality, checkpoint_path, device):
clip = CLIPModelText(hf_clip_model) clip = CLIPModelText(hf_clip_model)
else: else:
raise ValueError("modality[{}] not implemented.".format(modality)) raise ValueError("modality[{}] not implemented.".format(modality))
model = Model(clip)
return model
return clip
class CLIPModelVision(nn.Module): class CLIPModelVision(nn.Module):
def __init__(self, model): def __init__(self, model):
@ -65,13 +64,21 @@ class CLIPModelText(nn.Module):
# @accelerate # @accelerate
class Model: class Model:
def __init__(self, model):
self.model = model
def __init__(self, model_name, modality, checkpoint_path, device):
self.model = create_model(model_name, modality, checkpoint_path, device)
self.device = device
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs)
new_args = []
for item in args:
new_args.append(item.to(self.device))
new_kwargs = {}
for k, value in kwargs.items():
new_kwargs[k] = value.to(self.device)
outs = self.model(*new_args, **new_kwargs)
return outs return outs
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Clip(NNOperator): class Clip(NNOperator):
""" """
@ -82,11 +89,11 @@ class Clip(NNOperator):
self.modality = modality self.modality = modality
self.device = device self.device = device
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
cfg = self._configs()[model_name]
real_name = self._configs()[model_name]
self.model = create_model(cfg, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg)
self.model = Model(real_name, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(real_name)
self.processor = CLIPProcessor.from_pretrained(real_name)
def inference_single_data(self, data): def inference_single_data(self, data):
if self.modality == 'image': if self.modality == 'image':
@ -113,14 +120,14 @@ class Clip(NNOperator):
def _inference_from_text(self, text): def _inference_from_text(self, text):
tokens = self.tokenizer([text], padding=True, return_tensors="pt") tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device))
text_features = self.model(tokens['input_ids'], tokens['attention_mask'])
return text_features return text_features
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = to_pil(img) img = to_pil(img)
inputs = self.processor(images=img, return_tensors="pt") inputs = self.processor(images=img, return_tensors="pt")
image_features = self.model(inputs['pixel_values'].to(self.device))
image_features = self.model(inputs['pixel_values'])
return image_features return image_features
def train(self, **kwargs): def train(self, **kwargs):

Loading…
Cancel
Save