diff --git a/README.md b/README.md index 0313b5a..94e4010 100644 --- a/README.md +++ b/README.md @@ -73,16 +73,10 @@ Create the operator via the following factory method ​ ***model_name:*** *str* ​ The model name of CLIP. Supported model names: -- clip_resnet_r50 -- clip_resnet_r101 -- clip_vit_b32 -- clip_vit_b16 -- clip_resnet_r50x4 -- clip_resnet_r50x16 -- clip_resnet_r50x64 -- clip_vit_l14 -- clip_vit_l14@336px - +- clip_vit_base_patch16 +- clip_vit_base_patch32 +- clip_vit_large_patch14 +- clip_vit_large_patch14_336 ​ ***modality:*** *str* @@ -90,12 +84,40 @@ Create the operator via the following factory method
+***checkpoint_path***: *str* + +The path to local checkpoint, defaults to None. +If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. ## Interface An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray. +***save_model(format='pytorch', path='default')*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format of saved model, defaults to 'pytorch'. + +***path***: *str* + +​ The path where model is saved to. By default, it will save model to the operator directory. + + +```python +from towhee import ops + +op = ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op() +op.save_model('onnx', 'test.onnx') +``` +
+ + **Parameters:** @@ -109,6 +131,66 @@ An image-text embedding operator takes a [towhee image](link/to/towhee/image/api ​ The data embedding extracted by model. +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', 'torchscript'. + +```python +from towhee import ops + + +op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') +``` + +
+ +## Fine-tune +### Requirement +If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies. +```python +! python -m pip install datasets evaluate +``` +### Get start + +```python +import towhee + +clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op() + +data_args = { + 'dataset_name': 'ydshieh/coco_dataset_script', + 'dataset_config_name': '2017', + 'cache_dir': './cache', + 'max_seq_length': 77, + 'data_dir': path_to_your_coco_dataset, + 'image_mean': [0.48145466, 0.4578275, 0.40821073], + "image_std": [0.26862954, 0.26130258, 0.27577711] +} +training_args = { + 'num_train_epochs': 3, # you can add epoch number to get a better metric. + 'per_device_train_batch_size': 8, + 'per_device_eval_batch_size': 8, + 'do_train': True, + 'do_eval': True, + 'remove_unused_columns': False, + 'output_dir': './tmp/test-clip', + 'overwrite_output_dir': True, +} + +``` + +### Dive deep and customize your training +You can change the [training script](https://towhee.io/image-text-embedding/clip/src/branch/main/train_clip_with_hf_trainer.py) in your customer way. +Or your can refer to the original [hugging face transformers training examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text). diff --git a/clip.py b/clip.py index 2041191..ec9f1fc 100644 --- a/clip.py +++ b/clip.py @@ -54,7 +54,17 @@ class Clip(NNOperator): self.modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" cfg = self._configs()[model_name] - clip_model = CLIPModel.from_pretrained(cfg) + try: + clip_model = CLIPModel.from_pretrained(cfg) + except Exception as e: + log.error(f"Fail to load model by name: {self.model_name}") + raise e + if checkpoint_path: + try: + state_dict = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(state_dict) + except Exception as e: + log.error(f"Fail to load state dict from {checkpoint_path}: {e}") if self.modality == 'image': self.model = CLIPModelVision(clip_model) elif self.modality == 'text': @@ -112,10 +122,10 @@ class Clip(NNOperator): def _configs(self): config = {} - config['clip_vit_base_32'] = 'openai/clip-vit-base-patch16' - config['clip_vit_base_16'] = 'openai/clip-vit-base-patch32' - config['clip_vit_large_14'] = 'openai/clip-vit-large-patch14' - config['clip_vit_large_14_336'] ='openai/clip-vit-large-patch14-336' + config['clip_vit_base_patch16'] = 'openai/clip-vit-base-patch16' + config['clip_vit_base_patch32'] = 'openai/clip-vit-base-patch32' + config['clip_vit_large_patch14'] = 'openai/clip-vit-large-patch14' + config['clip_vit_large_patch14_336'] ='openai/clip-vit-large-patch14-336' return config @property @@ -130,10 +140,10 @@ class Clip(NNOperator): def supported_model_names(format: str = None): if format == 'pytorch' or format == 'torchscript' or format == 'onnx': model_list = [ - 'openai/clip-vit-base-patch16', - 'openai/clip-vit-base-patch32', - 'openai/clip-vit-large-patch14', - 'openai/clip-vit-large-patch14-336' + 'clip_vit_base_patch16', + 'clip_vit_base_patch32', + 'clip_vit_large_patch14', + 'clip_vit_large_patch14_336' ] else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')