logo
Browse Source

separate model arguments.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
4e25e34aef
  1. 11
      README.md
  2. 3
      clip.py
  3. 36
      train_clip_with_hf_trainer.py

11
README.md

@ -145,6 +145,7 @@ print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
## Fine-tune
### Requirement
If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies.
There is also an [example](https://github.com/towhee-io/examples/blob/main/image/text_image_search/2_deep_dive_text_image_search.ipynb) to show how to finetune it on a custom dataset.
```python
! python -m pip install datasets evaluate
```
@ -158,11 +159,10 @@ clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16
data_args = {
'dataset_name': 'ydshieh/coco_dataset_script',
'dataset_config_name': '2017',
'cache_dir': './cache',
'max_seq_length': 77,
'data_dir': path_to_your_coco_dataset,
'image_mean': [0.48145466, 0.4578275, 0.40821073],
"image_std": [0.26862954, 0.26130258, 0.27577711]
'image_std': [0.26862954, 0.26130258, 0.27577711]
}
training_args = {
'num_train_epochs': 3, # you can add epoch number to get a better metric.
@ -174,8 +174,13 @@ training_args = {
'output_dir': './tmp/test-clip',
'overwrite_output_dir': True,
}
model_args = {
'freeze_vision_model': False,
'freeze_text_model': False,
'cache_dir': './cache'
}
clip_op.train(data_args=data_args, training_args=training_args)
clip_op.train(data_args=data_args, training_args=training_args, model_args=model_args)
```
### Dive deep and customize your training

3
clip.py

@ -147,7 +147,8 @@ class Clip(NNOperator):
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args)
model_args = kwargs.pop('model_args', None)
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args, model_args)
def _configs(self):
config = {}

36
train_clip_with_hf_trainer.py

@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d):
except:
return d # Not a dataclass field
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
freeze_vision_model: bool = field(
default=False, metadata={"help": "Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataTrainingArguments:
@ -103,21 +117,12 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
image_mean: Optional[str] = field(
default=None, metadata={"help": "image preprocessing mean"}
)
image_std: Optional[str] = field(
default=None, metadata={"help": "image preprocessing std"}
)
freeze_vision_model: bool = field(
default=False, metadata={"help":"Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
def __post_init__(self):
@ -163,7 +168,7 @@ def collate_fn(examples):
}
def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
def train_with_hf_trainer(model, tokenizer, data_args, training_args, model_args, **kwargs):
import evaluate
import datasets
@ -179,6 +184,11 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
print('**** ModelArguments ****')
get_dataclasses_help(ModelArguments)
model_args = dataclass_from_dict(ModelArguments, model_args)
# Setup logging
logging.basicConfig(
@ -195,7 +205,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
temp_cache_dir = data_args.cache_dir
temp_cache_dir = model_args.cache_dir
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
@ -254,8 +264,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
config = model.config
freeze_vision_model = data_args.freeze_vision_model
freeze_text_model = data_args.freeze_text_model
freeze_vision_model = model_args.freeze_vision_model
freeze_text_model = model_args.freeze_text_model
def _freeze_params(module):
for param in module.parameters():

Loading…
Cancel
Save