From efa84801fe41b56d12a89395555d704790feabf0 Mon Sep 17 00:00:00 2001 From: wxywb Date: Mon, 6 Feb 2023 11:35:11 +0000 Subject: [PATCH] update the model with wrapper. Signed-off-by: wxywb --- clip.py | 71 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/clip.py b/clip.py index 14e23ea..5ffae49 100644 --- a/clip.py +++ b/clip.py @@ -24,26 +24,53 @@ from towhee import register from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor #from towhee.dc2 import accelerate -#@accelerate + +def create_model(model_name, modality, checkpoint_path, device): + hf_clip_model = CLIPModel.from_pretrained(model_name) + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=device) + hf_clip_model.load_state_dict(state_dict) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") + hf_clip_model.to(device) + hf_clip_model.eval() + + if modality == 'image': + clip = CLIPModelVision(hf_clip_model) + elif modality == 'text': + clip = CLIPModelText(hf_clip_model) + else: + raise ValueError("modality[{}] not implemented.".format(modality)) + model = Model(clip) + return model + class CLIPModelVision(nn.Module): def __init__(self, model): super().__init__() - self.model = model + self.backbone = model def forward(self, pixel_values): - image_embeds = self.model.get_image_features(pixel_values) + image_embeds = self.backbone.get_image_features(pixel_values) return image_embeds -#@accelerate class CLIPModelText(nn.Module): def __init__(self, model): super().__init__() - self.model = model + self.backbone = model def forward(self, input_ids, attention_mask): - text_embeds = self.model.get_text_features(input_ids, attention_mask) + text_embeds = self.backbone.get_text_features(input_ids, attention_mask) return text_embeds +#@accelerate +class Model: + def __init__(self, model): + self.model = model + + def __call__(self, *args, **kwargs): + outs = self.model(*args, **kwargs) + return outs @register(output_schema=['vec']) class Clip(NNOperator): @@ -56,18 +83,8 @@ class Clip(NNOperator): self.device = device self.checkpoint_path = checkpoint_path cfg = self._configs()[model_name] - try: - clip_model = CLIPModel.from_pretrained(cfg) - except Exception as e: - log.error(f"Fail to load model by name: {self.model_name}") - raise e - if self.modality == 'image': - self.model = CLIPModelVision(self._model) - elif self.modality == 'text': - self.model = CLIPModelText(self._model) - else: - raise ValueError("modality[{}] not implemented.".format(self.modality)) + self.model = create_model(cfg, modality, checkpoint_path, device) self.tokenizer = CLIPTokenizer.from_pretrained(cfg) self.processor = CLIPProcessor.from_pretrained(cfg) @@ -115,7 +132,7 @@ class Clip(NNOperator): from train_clip_with_hf_trainer import train_with_hf_trainer data_args = kwargs.pop('data_args', None) training_args = kwargs.pop('training_args', None) - train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args) + train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args) def _configs(self): config = {} @@ -148,21 +165,7 @@ class Clip(NNOperator): @property def _model(self): - cfg = self._configs()[self.model_name] - try: - hf_clip_model = CLIPModel.from_pretrained(cfg) - except Exception as e: - log.error(f"Fail to load model by name: {self.model_name}") - raise e - if self.checkpoint_path: - try: - state_dict = torch.load(self.checkpoint_path, map_location=self.device) - hf_clip_model.load_state_dict(state_dict) - except Exception as e: - log.error(f"Fail to load state dict from {checkpoint_path}: {e}") - hf_clip_model.to(self.device) - hf_clip_model.eval() - return hf_clip_model + return self.model.model def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): import os @@ -222,7 +225,7 @@ class Clip(NNOperator): else: raise ValueError("modality[{}] not implemented.".format(self.modality)) - onnx_export(self.model, + onnx_export(self._model, (dict(inputs),), f=Path(output_file), input_names= input_names,