@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d):
except :
return d # Not a dataclass field
@dataclass
class ModelArguments :
"""
Arguments pertaining to which model / config / tokenizer we are going to fine - tune , or train from scratch .
"""
freeze_vision_model : bool = field (
default = False , metadata = { " help " : " Whether to freeze the vision model parameters or not. " }
)
freeze_text_model : bool = field (
default = False , metadata = { " help " : " Whether to freeze the text model parameters or not. " }
)
cache_dir : Optional [ str ] = field (
default = None , metadata = { " help " : " Where do you want to store the pretrained models downloaded from s3 " }
)
@dataclass
class DataTrainingArguments :
@ -103,21 +117,12 @@ class DataTrainingArguments:
default = None ,
metadata = { " help " : " The number of processes to use for the preprocessing. " } ,
)
cache_dir : Optional [ str ] = field (
default = None , metadata = { " help " : " Where do you want to store the pretrained models downloaded from s3 " }
)
image_mean : Optional [ str ] = field (
default = None , metadata = { " help " : " image preprocessing mean " }
)
image_std : Optional [ str ] = field (
default = None , metadata = { " help " : " image preprocessing std " }
)
freeze_vision_model : bool = field (
default = False , metadata = { " help " : " Whether to freeze the vision model parameters or not. " }
)
freeze_text_model : bool = field (
default = False , metadata = { " help " : " Whether to freeze the text model parameters or not. " }
)
def __post_init__ ( self ) :
@ -163,7 +168,7 @@ def collate_fn(examples):
}
def train_with_hf_trainer ( model , tokenizer , data_args , training_args , * * kwargs ) :
def train_with_hf_trainer ( model , tokenizer , data_args , training_args , model_args , * * kwargs ) :
import evaluate
import datasets
@ -179,6 +184,11 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
print ( ' **** TrainingArguments **** ' )
get_dataclasses_help ( TrainingArguments )
training_args = dataclass_from_dict ( TrainingArguments , training_args )
print ( ' **** ModelArguments **** ' )
get_dataclasses_help ( ModelArguments )
model_args = dataclass_from_dict ( ModelArguments , model_args )
# Setup logging
logging . basicConfig (
@ -195,7 +205,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
transformers . utils . logging . enable_default_handler ( )
transformers . utils . logging . enable_explicit_format ( )
temp_cache_dir = data _args. cache_dir
temp_cache_dir = model _args. cache_dir
logger . warning (
f " Process rank: { training_args . local_rank } , device: { training_args . device } , n_gpu: { training_args . n_gpu } "
+ f " distributed training: { bool ( training_args . local_rank != - 1 ) } , 16-bits training: { training_args . fp16 } "
@ -254,8 +264,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
config = model . config
freeze_vision_model = data _args. freeze_vision_model
freeze_text_model = data _args. freeze_text_model
freeze_vision_model = model _args. freeze_vision_model
freeze_text_model = model _args. freeze_text_model
def _freeze_params ( module ) :
for param in module . parameters ( ) :