logo
Browse Source

update the model with wrapper.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
efa84801fe
  1. 71
      clip.py

71
clip.py

@ -24,26 +24,53 @@ from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
#from towhee.dc2 import accelerate #from towhee.dc2 import accelerate
#@accelerate
def create_model(model_name, modality, checkpoint_path, device):
hf_clip_model = CLIPModel.from_pretrained(model_name)
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
hf_clip_model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
hf_clip_model.to(device)
hf_clip_model.eval()
if modality == 'image':
clip = CLIPModelVision(hf_clip_model)
elif modality == 'text':
clip = CLIPModelText(hf_clip_model)
else:
raise ValueError("modality[{}] not implemented.".format(modality))
model = Model(clip)
return model
class CLIPModelVision(nn.Module): class CLIPModelVision(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.model = model
self.backbone = model
def forward(self, pixel_values): def forward(self, pixel_values):
image_embeds = self.model.get_image_features(pixel_values)
image_embeds = self.backbone.get_image_features(pixel_values)
return image_embeds return image_embeds
#@accelerate
class CLIPModelText(nn.Module): class CLIPModelText(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.model = model
self.backbone = model
def forward(self, input_ids, attention_mask): def forward(self, input_ids, attention_mask):
text_embeds = self.model.get_text_features(input_ids, attention_mask)
text_embeds = self.backbone.get_text_features(input_ids, attention_mask)
return text_embeds return text_embeds
#@accelerate
class Model:
def __init__(self, model):
self.model = model
def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs)
return outs
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Clip(NNOperator): class Clip(NNOperator):
@ -56,18 +83,8 @@ class Clip(NNOperator):
self.device = device self.device = device
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
cfg = self._configs()[model_name] cfg = self._configs()[model_name]
try:
clip_model = CLIPModel.from_pretrained(cfg)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if self.modality == 'image':
self.model = CLIPModelVision(self._model)
elif self.modality == 'text':
self.model = CLIPModelText(self._model)
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
self.model = create_model(cfg, modality, checkpoint_path, device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg) self.processor = CLIPProcessor.from_pretrained(cfg)
@ -115,7 +132,7 @@ class Clip(NNOperator):
from train_clip_with_hf_trainer import train_with_hf_trainer from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)
train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args)
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args)
def _configs(self): def _configs(self):
config = {} config = {}
@ -148,21 +165,7 @@ class Clip(NNOperator):
@property @property
def _model(self): def _model(self):
cfg = self._configs()[self.model_name]
try:
hf_clip_model = CLIPModel.from_pretrained(cfg)
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
hf_clip_model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
hf_clip_model.to(self.device)
hf_clip_model.eval()
return hf_clip_model
return self.model.model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
import os import os
@ -222,7 +225,7 @@ class Clip(NNOperator):
else: else:
raise ValueError("modality[{}] not implemented.".format(self.modality)) raise ValueError("modality[{}] not implemented.".format(self.modality))
onnx_export(self.model,
onnx_export(self._model,
(dict(inputs),), (dict(inputs),),
f=Path(output_file), f=Path(output_file),
input_names= input_names, input_names= input_names,

Loading…
Cancel
Save