diff --git a/README.md b/README.md
index 0313b5a..94e4010 100644
--- a/README.md
+++ b/README.md
@@ -73,16 +73,10 @@ Create the operator via the following factory method
***model_name:*** *str*
The model name of CLIP. Supported model names:
-- clip_resnet_r50
-- clip_resnet_r101
-- clip_vit_b32
-- clip_vit_b16
-- clip_resnet_r50x4
-- clip_resnet_r50x16
-- clip_resnet_r50x64
-- clip_vit_l14
-- clip_vit_l14@336px
-
+- clip_vit_base_patch16
+- clip_vit_base_patch32
+- clip_vit_large_patch14
+- clip_vit_large_patch14_336
***modality:*** *str*
@@ -90,12 +84,40 @@ Create the operator via the following factory method
+***checkpoint_path***: *str*
+
+The path to local checkpoint, defaults to None.
+If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
+***save_model(format='pytorch', path='default')***
+
+Save model to local with specified format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The format of saved model, defaults to 'pytorch'.
+
+***path***: *str*
+
+ The path where model is saved to. By default, it will save model to the operator directory.
+
+
+```python
+from towhee import ops
+
+op = ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
+op.save_model('onnx', 'test.onnx')
+```
+
+
+
**Parameters:**
@@ -109,6 +131,66 @@ An image-text embedding operator takes a [towhee image](link/to/towhee/image/api
The data embedding extracted by model.
+***supported_model_names(format=None)***
+
+Get a list of all supported model names or supported model names for specified model format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The model format such as 'pytorch', 'torchscript'.
+
+```python
+from towhee import ops
+
+
+op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
+full_list = op.supported_model_names()
+onnx_list = op.supported_model_names(format='onnx')
+print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
+```
+
+
+
+## Fine-tune
+### Requirement
+If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies.
+```python
+! python -m pip install datasets evaluate
+```
+### Get start
+
+```python
+import towhee
+
+clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
+
+data_args = {
+ 'dataset_name': 'ydshieh/coco_dataset_script',
+ 'dataset_config_name': '2017',
+ 'cache_dir': './cache',
+ 'max_seq_length': 77,
+ 'data_dir': path_to_your_coco_dataset,
+ 'image_mean': [0.48145466, 0.4578275, 0.40821073],
+ "image_std": [0.26862954, 0.26130258, 0.27577711]
+}
+training_args = {
+ 'num_train_epochs': 3, # you can add epoch number to get a better metric.
+ 'per_device_train_batch_size': 8,
+ 'per_device_eval_batch_size': 8,
+ 'do_train': True,
+ 'do_eval': True,
+ 'remove_unused_columns': False,
+ 'output_dir': './tmp/test-clip',
+ 'overwrite_output_dir': True,
+}
+
+```
+
+### Dive deep and customize your training
+You can change the [training script](https://towhee.io/image-text-embedding/clip/src/branch/main/train_clip_with_hf_trainer.py) in your customer way.
+Or your can refer to the original [hugging face transformers training examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text).
diff --git a/clip.py b/clip.py
index 2041191..ec9f1fc 100644
--- a/clip.py
+++ b/clip.py
@@ -54,7 +54,17 @@ class Clip(NNOperator):
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = self._configs()[model_name]
- clip_model = CLIPModel.from_pretrained(cfg)
+ try:
+ clip_model = CLIPModel.from_pretrained(cfg)
+ except Exception as e:
+ log.error(f"Fail to load model by name: {self.model_name}")
+ raise e
+ if checkpoint_path:
+ try:
+ state_dict = torch.load(checkpoint_path, map_location=self.device)
+ self.model.load_state_dict(state_dict)
+ except Exception as e:
+ log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
if self.modality == 'image':
self.model = CLIPModelVision(clip_model)
elif self.modality == 'text':
@@ -112,10 +122,10 @@ class Clip(NNOperator):
def _configs(self):
config = {}
- config['clip_vit_base_32'] = 'openai/clip-vit-base-patch16'
- config['clip_vit_base_16'] = 'openai/clip-vit-base-patch32'
- config['clip_vit_large_14'] = 'openai/clip-vit-large-patch14'
- config['clip_vit_large_14_336'] ='openai/clip-vit-large-patch14-336'
+ config['clip_vit_base_patch16'] = 'openai/clip-vit-base-patch16'
+ config['clip_vit_base_patch32'] = 'openai/clip-vit-base-patch32'
+ config['clip_vit_large_patch14'] = 'openai/clip-vit-large-patch14'
+ config['clip_vit_large_patch14_336'] ='openai/clip-vit-large-patch14-336'
return config
@property
@@ -130,10 +140,10 @@ class Clip(NNOperator):
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
- 'openai/clip-vit-base-patch16',
- 'openai/clip-vit-base-patch32',
- 'openai/clip-vit-large-patch14',
- 'openai/clip-vit-large-patch14-336'
+ 'clip_vit_base_patch16',
+ 'clip_vit_base_patch32',
+ 'clip_vit_large_patch14',
+ 'clip_vit_large_patch14_336'
]
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')