From 5e65fbaa33c8ae26546e4b3dddf4138e573745e6 Mon Sep 17 00:00:00 2001 From: jinlingxu06 Date: Wed, 1 Mar 2023 15:38:28 +0800 Subject: [PATCH] add training example to blip. Signed-off-by: jinlingxu06 --- README.md | 69 ++++++ blip.py | 5 +- train_blip_with_hf_trainer.py | 36 +-- train_clip_with_hf_trainer.py | 412 ---------------------------------- 4 files changed, 94 insertions(+), 428 deletions(-) delete mode 100644 train_clip_with_hf_trainer.py diff --git a/README.md b/README.md index 69e3735..67d55f6 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,75 @@ op.save_model('onnx', 'test.onnx') ​ The data embedding extracted by model. +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', 'torchscript'. + +```python +from towhee import ops + + +op = ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') +``` + +
+ +## Fine-tune +### Requirement +If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies. +There is also an [example](https://github.com/towhee-io/examples/blob/main/image/text_image_search/2_deep_dive_text_image_search.ipynb) to show how to finetune it on a custom dataset. +```python +! python -m pip install datasets +``` +### Get start + +```python +import towhee + +blip_op = towhee.ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op() + +data_args = { + 'dataset_name': 'ydshieh/coco_dataset_script', + 'dataset_config_name': '2017', + 'max_seq_length': 77, + 'data_dir': path_to_your_coco_dataset, + 'image_mean': [0.48145466, 0.4578275, 0.40821073], + 'image_std': [0.26862954, 0.26130258, 0.27577711] +} +training_args = { + 'num_train_epochs': 3, # you can add epoch number to get a better metric. + 'per_device_train_batch_size': 8, + 'per_device_eval_batch_size': 8, + 'do_train': True, + 'do_eval': True, + 'remove_unused_columns': False, + 'output_dir': './tmp/test-blip', + 'overwrite_output_dir': True, +} +model_args = { + 'freeze_vision_model': False, + 'freeze_text_model': False, + 'cache_dir': './cache' +} + +blip_op.train(data_args=data_args, training_args=training_args, model_args=model_args) +``` + +### Dive deep and customize your training +You can change the [training script](https://towhee.io/image-text-embedding/blip/src/branch/main/train_blip_with_hf_trainer.py) in your customer way. +Or your can refer to the original [hugging face transformers training examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text). + + + diff --git a/blip.py b/blip.py index 8005669..21e0240 100644 --- a/blip.py +++ b/blip.py @@ -234,14 +234,15 @@ class Blip(NNOperator): import pathlib path = str(pathlib.Path(__file__).parent) sys.path.append(path) - from train_clip_with_hf_trainer import train_with_hf_trainer + from train_blip_with_hf_trainer import train_with_hf_trainer data_args = kwargs.pop('data_args', None) training_args = kwargs.pop('training_args', None) + model_args = kwargs.pop('model_args', None) model_finetune = self._model.backbone model_finetune.forward = MethodType(_forward, model_finetune) model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value) - train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args) + train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args, model_args) @property diff --git a/train_blip_with_hf_trainer.py b/train_blip_with_hf_trainer.py index f1c26d7..ca2a89e 100644 --- a/train_blip_with_hf_trainer.py +++ b/train_blip_with_hf_trainer.py @@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d): except: return d # Not a dataclass field +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + freeze_vision_model: bool = field( + default=False, metadata={"help": "Whether to freeze the vision model parameters or not."} + ) + freeze_text_model: bool = field( + default=False, metadata={"help": "Whether to freeze the text model parameters or not."} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) @dataclass class DataTrainingArguments: @@ -103,22 +117,12 @@ class DataTrainingArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) image_mean: Optional[str] = field( default=None, metadata={"help": "image preprocessing mean"} ) image_std: Optional[str] = field( default=None, metadata={"help": "image preprocessing std"} ) - freeze_vision_model: bool = field( - default=False, metadata={"help":"Whether to freeze the vision model parameters or not."} - ) - freeze_text_model: bool = field( - default=False, metadata={"help": "Whether to freeze the text model parameters or not."} - ) - def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -163,7 +167,7 @@ def collate_fn(examples): } -def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): +def train_with_hf_trainer(model, tokenizer, data_args, training_args, model_args, **kwargs): import evaluate import datasets @@ -180,6 +184,10 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) + print('**** ModelArguments ****') + get_dataclasses_help(ModelArguments) + model_args = dataclass_from_dict(ModelArguments, model_args) + # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -195,7 +203,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - temp_cache_dir = data_args.cache_dir + temp_cache_dir = model_args.cache_dir logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" @@ -254,8 +262,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): config = model.config - freeze_vision_model = data_args.freeze_vision_model - freeze_text_model = data_args.freeze_text_model + freeze_vision_model = model_args.freeze_vision_model + freeze_text_model = model_args.freeze_text_model def _freeze_params(module): for param in module.parameters(): diff --git a/train_clip_with_hf_trainer.py b/train_clip_with_hf_trainer.py deleted file mode 100644 index f1c26d7..0000000 --- a/train_clip_with_hf_trainer.py +++ /dev/null @@ -1,412 +0,0 @@ -import logging -import os -import sys -import transformers -import dataclasses -from dataclasses import dataclass, field -from typing import Optional, List - -import torch -from datasets import load_dataset -from PIL import Image -from torchvision.io import ImageReadMode, read_image -from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize -from torchvision.transforms.functional import InterpolationMode - -from transformers import ( - MODEL_FOR_CAUSAL_LM_MAPPING, - TrainingArguments, - default_data_collator, - is_torch_tpu_available, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint - -# We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module, -# so we jit it to be faster. - - -logger = logging.getLogger(__name__) - -dataset_name_mapping = { - "image_caption_dataset.py": ("image_path", "caption"), -} - - -def dataclass_from_dict(klass, d): - try: - fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} - return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) - except: - return d # Not a dataclass field - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."}) - image_column: Optional[str] = field( - default="image_path", - metadata={"help": "The name of the column in the datasets containing the full image file paths."}, - ) - caption_column: Optional[str] = field( - default="caption", - metadata={"help": "The name of the column in the datasets containing the image captions."}, - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a jsonlines file)."} - ) - validation_file: Optional[str] = field( - default=None, - metadata={"help": "An optional input evaluation data file (a jsonlines file)."}, - ) - max_seq_length: Optional[int] = field( - default=77, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - image_mean: Optional[str] = field( - default=None, metadata={"help": "image preprocessing mean"} - ) - image_std: Optional[str] = field( - default=None, metadata={"help": "image preprocessing std"} - ) - freeze_vision_model: bool = field( - default=False, metadata={"help":"Whether to freeze the vision model parameters or not."} - ) - freeze_text_model: bool = field( - default=False, metadata={"help": "Whether to freeze the text model parameters or not."} - ) - - - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension == "json", "`validation_file` should be a json file." - - -class Transform(torch.nn.Module): - def __init__(self, image_size, mean, std): - super().__init__() - self.transforms = torch.nn.Sequential( - Resize([image_size], interpolation=InterpolationMode.BICUBIC), - CenterCrop(image_size), - ConvertImageDtype(torch.float), - Normalize(mean, std), - ) - - def forward(self, x) -> torch.Tensor: - """`x` should be an instance of `PIL.Image.Image`""" - with torch.no_grad(): - x = self.transforms(x) - return x - -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long) - attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long) - return { - "pixel_values": pixel_values, - "input_ids": input_ids, - "attention_mask": attention_mask, - "return_loss": True, - } - - -def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): - - import evaluate - import datasets - - from transformers import Trainer - from datasets import load_dataset - from towhee.trainer.training_config import get_dataclasses_help - - print('**** DataTrainingArguments ****') - get_dataclasses_help(DataTrainingArguments) - data_args = dataclass_from_dict(DataTrainingArguments, data_args) - - print('**** TrainingArguments ****') - get_dataclasses_help(TrainingArguments) - training_args = dataclass_from_dict(TrainingArguments, training_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - # Setup logging - #+ training_args - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - temp_cache_dir = data_args.cache_dir - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint and eventualy continue from last checkpoint - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - - # Load dataset - # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files this script will use the first column for the full image path and the second column for the - # captions (unless you specify column names for this with the `image_column` and `caption_column` arguments). - # - - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=temp_cache_dir, - keep_in_memory=False, - data_dir=data_args.data_dir, - # use_auth_token=True if model_args.use_auth_token else None, - ) - else: - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - extension = data_args.train_file.split(".")[-1] - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.validation_file.split(".")[-1] - dataset = load_dataset( - extension, - data_files=data_files, - cache_dir=temp_cache_dir, - # use_auth_token=True if model_args.use_auth_token else None, - ) - - config = model.config - - freeze_vision_model = data_args.freeze_vision_model - freeze_text_model = data_args.freeze_text_model - - def _freeze_params(module): - for param in module.parameters(): - param.requires_grad = False - - if model_args.freeze_vision_model: - _freeze_params(model.vision_model) - - if model_args.freeze_text_model: - _freeze_params(model.text_model) - - if freeze_vision_model is True: - _freeze_params(model.vision_model) - - if freeze_text_model is True: - _freeze_params(model.text_model) - - set_seed(training_args.seed) - - if training_args.do_train: - column_names = dataset["train"].column_names - elif training_args.do_eval: - column_names = dataset["validation"].column_names - else: - logger.info("There is nothing to do. Please pass `do_train`, `do_eval`.") - return - - dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None) - if data_args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] - else: - image_column = data_args.image_column - if image_column not in column_names: - raise ValueError( - f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}" - ) - if data_args.caption_column is None: - caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] - else: - caption_column = data_args.caption_column - if caption_column not in column_names: - raise ValueError( - f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}" - ) - - - image_mean, image_std = data_args.image_mean, data_args.image_std - # Preprocessing the datasets. - # Initialize torchvision transforms and jit it for faster processing. - image_transformations = Transform( - config.vision_config.image_size, image_mean, image_std - ) - image_transformations = torch.jit.script(image_transformations) - - # Preprocessing the datasets. - # We need to tokenize input captions and transform the images. - #data_args - - def tokenize_captions(examples): - captions = [caption for caption in examples[caption_column]] - text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True) - examples["input_ids"] = text_inputs.input_ids - examples["attention_mask"] = text_inputs.attention_mask - return examples - - def transform_images(examples): - images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] - examples["pixel_values"] = [image_transformations(image) for image in images] - return examples - - def filter_corrupt_images(examples): - """remove problematic images""" - valid_images = [] - for image_file in examples[image_column]: - try: - Image.open(image_file) - valid_images.append(True) - except Exception: - valid_images.append(False) - return valid_images - - if training_args.do_train: - if "train" not in dataset: - raise ValueError("--do_train requires a train dataset") - train_dataset = dataset["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - train_dataset = train_dataset.filter( - filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers - ) - train_dataset = train_dataset.map( - function=tokenize_captions, - batched=True, - remove_columns=[col for col in column_names if col != image_column], - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on train dataset", - ) - - # Transform images on the fly as doing it on the whole dataset takes too much time. - train_dataset.set_transform(transform_images) - - if training_args.do_eval: - if "validation" not in dataset: - raise ValueError("--do_eval requires a train validation") - eval_dataset = dataset["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - eval_dataset = eval_dataset.filter( - filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers - ) - eval_dataset = eval_dataset.map( - function=tokenize_captions, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[col for col in column_names if col != image_column], - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on validation dataset", - ) - - # Transform images on the fly as doing it on the whole dataset takes too much time. - eval_dataset.set_transform(transform_images) - - # Initalize our trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - data_collator=collate_fn, - ) - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - trainer.save_model() - trainer.log_metrics("train", train_result.metrics) - trainer.save_metrics("train", train_result.metrics) - trainer.save_state() - # Evaluation - if training_args.do_eval: - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - -