logo
Browse Source

separate model arguments.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
4e25e34aef
  1. 11
      README.md
  2. 3
      clip.py
  3. 36
      train_clip_with_hf_trainer.py

11
README.md

@ -145,6 +145,7 @@ print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
## Fine-tune ## Fine-tune
### Requirement ### Requirement
If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies. If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies.
There is also an [example](https://github.com/towhee-io/examples/blob/main/image/text_image_search/2_deep_dive_text_image_search.ipynb) to show how to finetune it on a custom dataset.
```python ```python
! python -m pip install datasets evaluate ! python -m pip install datasets evaluate
``` ```
@ -158,11 +159,10 @@ clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16
data_args = { data_args = {
'dataset_name': 'ydshieh/coco_dataset_script', 'dataset_name': 'ydshieh/coco_dataset_script',
'dataset_config_name': '2017', 'dataset_config_name': '2017',
'cache_dir': './cache',
'max_seq_length': 77, 'max_seq_length': 77,
'data_dir': path_to_your_coco_dataset, 'data_dir': path_to_your_coco_dataset,
'image_mean': [0.48145466, 0.4578275, 0.40821073], 'image_mean': [0.48145466, 0.4578275, 0.40821073],
"image_std": [0.26862954, 0.26130258, 0.27577711]
'image_std': [0.26862954, 0.26130258, 0.27577711]
} }
training_args = { training_args = {
'num_train_epochs': 3, # you can add epoch number to get a better metric. 'num_train_epochs': 3, # you can add epoch number to get a better metric.
@ -174,8 +174,13 @@ training_args = {
'output_dir': './tmp/test-clip', 'output_dir': './tmp/test-clip',
'overwrite_output_dir': True, 'overwrite_output_dir': True,
} }
model_args = {
'freeze_vision_model': False,
'freeze_text_model': False,
'cache_dir': './cache'
}
clip_op.train(data_args=data_args, training_args=training_args)
clip_op.train(data_args=data_args, training_args=training_args, model_args=model_args)
``` ```
### Dive deep and customize your training ### Dive deep and customize your training

3
clip.py

@ -147,7 +147,8 @@ class Clip(NNOperator):
from train_clip_with_hf_trainer import train_with_hf_trainer from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args)
model_args = kwargs.pop('model_args', None)
train_with_hf_trainer(self._model.backbone, self.tokenizer, data_args, training_args, model_args)
def _configs(self): def _configs(self):
config = {} config = {}

36
train_clip_with_hf_trainer.py

@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d):
except: except:
return d # Not a dataclass field return d # Not a dataclass field
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
freeze_vision_model: bool = field(
default=False, metadata={"help": "Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
@ -103,21 +117,12 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
) )
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
image_mean: Optional[str] = field( image_mean: Optional[str] = field(
default=None, metadata={"help": "image preprocessing mean"} default=None, metadata={"help": "image preprocessing mean"}
) )
image_std: Optional[str] = field( image_std: Optional[str] = field(
default=None, metadata={"help": "image preprocessing std"} default=None, metadata={"help": "image preprocessing std"}
) )
freeze_vision_model: bool = field(
default=False, metadata={"help":"Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
def __post_init__(self): def __post_init__(self):
@ -163,7 +168,7 @@ def collate_fn(examples):
} }
def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
def train_with_hf_trainer(model, tokenizer, data_args, training_args, model_args, **kwargs):
import evaluate import evaluate
import datasets import datasets
@ -179,6 +184,11 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
print('**** TrainingArguments ****') print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments) get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args) training_args = dataclass_from_dict(TrainingArguments, training_args)
print('**** ModelArguments ****')
get_dataclasses_help(ModelArguments)
model_args = dataclass_from_dict(ModelArguments, model_args)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@ -195,7 +205,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
temp_cache_dir = data_args.cache_dir
temp_cache_dir = model_args.cache_dir
logger.warning( logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
@ -254,8 +264,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
config = model.config config = model.config
freeze_vision_model = data_args.freeze_vision_model
freeze_text_model = data_args.freeze_text_model
freeze_vision_model = model_args.freeze_vision_model
freeze_text_model = model_args.freeze_text_model
def _freeze_params(module): def _freeze_params(module):
for param in module.parameters(): for param in module.parameters():

Loading…
Cancel
Save