|
|
@ -24,26 +24,53 @@ from towhee import register |
|
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
|
#from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
|
|
|
|
def create_model(model_name, modality, checkpoint_path, device): |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(model_name) |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
hf_clip_model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
hf_clip_model.to(device) |
|
|
|
hf_clip_model.eval() |
|
|
|
|
|
|
|
if modality == 'image': |
|
|
|
clip = CLIPModelVision(hf_clip_model) |
|
|
|
elif modality == 'text': |
|
|
|
clip = CLIPModelText(hf_clip_model) |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(modality)) |
|
|
|
model = Model(clip) |
|
|
|
return model |
|
|
|
|
|
|
|
class CLIPModelVision(nn.Module): |
|
|
|
def __init__(self, model): |
|
|
|
super().__init__() |
|
|
|
self.model = model |
|
|
|
self.backbone = model |
|
|
|
|
|
|
|
def forward(self, pixel_values): |
|
|
|
image_embeds = self.model.get_image_features(pixel_values) |
|
|
|
image_embeds = self.backbone.get_image_features(pixel_values) |
|
|
|
return image_embeds |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
class CLIPModelText(nn.Module): |
|
|
|
def __init__(self, model): |
|
|
|
super().__init__() |
|
|
|
self.model = model |
|
|
|
self.backbone = model |
|
|
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
text_embeds = self.model.get_text_features(input_ids, attention_mask) |
|
|
|
text_embeds = self.backbone.get_text_features(input_ids, attention_mask) |
|
|
|
return text_embeds |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model): |
|
|
|
self.model = model |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
outs = self.model(*args, **kwargs) |
|
|
|
return outs |
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
|
class Clip(NNOperator): |
|
|
@ -56,18 +83,8 @@ class Clip(NNOperator): |
|
|
|
self.device = device |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
cfg = self._configs()[model_name] |
|
|
|
try: |
|
|
|
clip_model = CLIPModel.from_pretrained(cfg) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
|
|
|
|
if self.modality == 'image': |
|
|
|
self.model = CLIPModelVision(self._model) |
|
|
|
elif self.modality == 'text': |
|
|
|
self.model = CLIPModelText(self._model) |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
|
self.model = create_model(cfg, modality, checkpoint_path, device) |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) |
|
|
|
self.processor = CLIPProcessor.from_pretrained(cfg) |
|
|
|
|
|
|
@ -115,7 +132,7 @@ class Clip(NNOperator): |
|
|
|
from train_clip_with_hf_trainer import train_with_hf_trainer |
|
|
|
data_args = kwargs.pop('data_args', None) |
|
|
|
training_args = kwargs.pop('training_args', None) |
|
|
|
train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args) |
|
|
|
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args) |
|
|
|
|
|
|
|
def _configs(self): |
|
|
|
config = {} |
|
|
@ -148,21 +165,7 @@ class Clip(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
cfg = self._configs()[self.model_name] |
|
|
|
try: |
|
|
|
hf_clip_model = CLIPModel.from_pretrained(cfg) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if self.checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(self.checkpoint_path, map_location=self.device) |
|
|
|
hf_clip_model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
hf_clip_model.to(self.device) |
|
|
|
hf_clip_model.eval() |
|
|
|
return hf_clip_model |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
import os |
|
|
@ -222,7 +225,7 @@ class Clip(NNOperator): |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
|
|
|
|
|
onnx_export(self.model, |
|
|
|
onnx_export(self._model, |
|
|
|
(dict(inputs),), |
|
|
|
f=Path(output_file), |
|
|
|
input_names= input_names, |
|
|
|