diff --git a/README.md b/README.md index 868a0ad..28cf3c4 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') ## Fine-tune ### Requirement If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies. +There is also an [example](https://github.com/towhee-io/examples/blob/main/image/text_image_search/2_deep_dive_text_image_search.ipynb) to show how to finetune it on a custom dataset. ```python ! python -m pip install datasets evaluate ``` @@ -158,11 +159,10 @@ clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16 data_args = { 'dataset_name': 'ydshieh/coco_dataset_script', 'dataset_config_name': '2017', - 'cache_dir': './cache', 'max_seq_length': 77, 'data_dir': path_to_your_coco_dataset, 'image_mean': [0.48145466, 0.4578275, 0.40821073], - "image_std": [0.26862954, 0.26130258, 0.27577711] + 'image_std': [0.26862954, 0.26130258, 0.27577711] } training_args = { 'num_train_epochs': 3, # you can add epoch number to get a better metric. @@ -174,8 +174,13 @@ training_args = { 'output_dir': './tmp/test-clip', 'overwrite_output_dir': True, } +model_args = { + 'freeze_vision_model': False, + 'freeze_text_model': False, + 'cache_dir': './cache' +} -clip_op.train(data_args=data_args, training_args=training_args) +clip_op.train(data_args=data_args, training_args=training_args, model_args=model_args) ``` ### Dive deep and customize your training diff --git a/clip.py b/clip.py index b3892fe..0ebd8ea 100644 --- a/clip.py +++ b/clip.py @@ -147,7 +147,8 @@ class Clip(NNOperator): from train_clip_with_hf_trainer import train_with_hf_trainer data_args = kwargs.pop('data_args', None) training_args = kwargs.pop('training_args', None) - train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args) + model_args = kwargs.pop('model_args', None) + train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args, model_args) def _configs(self): config = {} diff --git a/train_clip_with_hf_trainer.py b/train_clip_with_hf_trainer.py index 2f53662..99365c4 100644 --- a/train_clip_with_hf_trainer.py +++ b/train_clip_with_hf_trainer.py @@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d): except: return d # Not a dataclass field +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + freeze_vision_model: bool = field( + default=False, metadata={"help": "Whether to freeze the vision model parameters or not."} + ) + freeze_text_model: bool = field( + default=False, metadata={"help": "Whether to freeze the text model parameters or not."} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) @dataclass class DataTrainingArguments: @@ -103,21 +117,12 @@ class DataTrainingArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) image_mean: Optional[str] = field( default=None, metadata={"help": "image preprocessing mean"} ) image_std: Optional[str] = field( default=None, metadata={"help": "image preprocessing std"} ) - freeze_vision_model: bool = field( - default=False, metadata={"help":"Whether to freeze the vision model parameters or not."} - ) - freeze_text_model: bool = field( - default=False, metadata={"help": "Whether to freeze the text model parameters or not."} - ) def __post_init__(self): @@ -163,7 +168,7 @@ def collate_fn(examples): } -def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): +def train_with_hf_trainer(model, tokenizer, data_args, training_args, model_args, **kwargs): import evaluate import datasets @@ -179,6 +184,11 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): print('**** TrainingArguments ****') get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) + + print('**** ModelArguments ****') + get_dataclasses_help(ModelArguments) + model_args = dataclass_from_dict(ModelArguments, model_args) + # Setup logging logging.basicConfig( @@ -195,7 +205,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - temp_cache_dir = data_args.cache_dir + temp_cache_dir = model_args.cache_dir logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" @@ -254,8 +264,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): config = model.config - freeze_vision_model = data_args.freeze_vision_model - freeze_text_model = data_args.freeze_text_model + freeze_vision_model = model_args.freeze_vision_model + freeze_text_model = model_args.freeze_text_model def _freeze_params(module): for param in module.parameters():