logo
Browse Source

add training example to blip.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
5e65fbaa33
  1. 69
      README.md
  2. 5
      blip.py
  3. 36
      train_blip_with_hf_trainer.py
  4. 412
      train_clip_with_hf_trainer.py

69
README.md

@ -125,6 +125,75 @@ op.save_model('onnx', 'test.onnx')
​ The data embedding extracted by model.
***supported_model_names(format=None)***
Get a list of all supported model names or supported model names for specified model format.
**Parameters:**
***format***: *str*
​ The model format such as 'pytorch', 'torchscript'.
```python
from towhee import ops
op = ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op()
full_list = op.supported_model_names()
onnx_list = op.supported_model_names(format='onnx')
print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
```
<br />
## Fine-tune
### Requirement
If you want to train this operator, besides dependency in requirements.txt, you need install these dependencies.
There is also an [example](https://github.com/towhee-io/examples/blob/main/image/text_image_search/2_deep_dive_text_image_search.ipynb) to show how to finetune it on a custom dataset.
```python
! python -m pip install datasets
```
### Get start
```python
import towhee
blip_op = towhee.ops.image_text_embedding.blip(model_name='blip_itm_base_coco', modality='image').get_op()
data_args = {
'dataset_name': 'ydshieh/coco_dataset_script',
'dataset_config_name': '2017',
'max_seq_length': 77,
'data_dir': path_to_your_coco_dataset,
'image_mean': [0.48145466, 0.4578275, 0.40821073],
'image_std': [0.26862954, 0.26130258, 0.27577711]
}
training_args = {
'num_train_epochs': 3, # you can add epoch number to get a better metric.
'per_device_train_batch_size': 8,
'per_device_eval_batch_size': 8,
'do_train': True,
'do_eval': True,
'remove_unused_columns': False,
'output_dir': './tmp/test-blip',
'overwrite_output_dir': True,
}
model_args = {
'freeze_vision_model': False,
'freeze_text_model': False,
'cache_dir': './cache'
}
blip_op.train(data_args=data_args, training_args=training_args, model_args=model_args)
```
### Dive deep and customize your training
You can change the [training script](https://towhee.io/image-text-embedding/blip/src/branch/main/train_blip_with_hf_trainer.py) in your customer way.
Or your can refer to the original [hugging face transformers training examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text).

5
blip.py

@ -234,14 +234,15 @@ class Blip(NNOperator):
import pathlib
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
from train_clip_with_hf_trainer import train_with_hf_trainer
from train_blip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
model_args = kwargs.pop('model_args', None)
model_finetune = self._model.backbone
model_finetune.forward = MethodType(_forward, model_finetune)
model_finetune.logit_scale = torch.nn.Parameter(torch.ones([]) * model_finetune.config.logit_scale_init_value)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args)
train_with_hf_trainer(model_finetune, self.processor.tokenizer, data_args, training_args, model_args)
@property

36
train_blip_with_hf_trainer.py

@ -40,6 +40,20 @@ def dataclass_from_dict(klass, d):
except:
return d # Not a dataclass field
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
freeze_vision_model: bool = field(
default=False, metadata={"help": "Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataTrainingArguments:
@ -103,22 +117,12 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
image_mean: Optional[str] = field(
default=None, metadata={"help": "image preprocessing mean"}
)
image_std: Optional[str] = field(
default=None, metadata={"help": "image preprocessing std"}
)
freeze_vision_model: bool = field(
default=False, metadata={"help":"Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@ -163,7 +167,7 @@ def collate_fn(examples):
}
def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
def train_with_hf_trainer(model, tokenizer, data_args, training_args, model_args, **kwargs):
import evaluate
import datasets
@ -180,6 +184,10 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
print('**** ModelArguments ****')
get_dataclasses_help(ModelArguments)
model_args = dataclass_from_dict(ModelArguments, model_args)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -195,7 +203,7 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
temp_cache_dir = data_args.cache_dir
temp_cache_dir = model_args.cache_dir
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
@ -254,8 +262,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
config = model.config
freeze_vision_model = data_args.freeze_vision_model
freeze_text_model = data_args.freeze_text_model
freeze_vision_model = model_args.freeze_vision_model
freeze_text_model = model_args.freeze_text_model
def _freeze_params(module):
for param in module.parameters():

412
train_clip_with_hf_trainer.py

@ -1,412 +0,0 @@
import logging
import os
import sys
import transformers
import dataclasses
from dataclasses import dataclass, field
from typing import Optional, List
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
# We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module,
# so we jit it to be faster.
logger = logging.getLogger(__name__)
dataset_name_mapping = {
"image_caption_dataset.py": ("image_path", "caption"),
}
def dataclass_from_dict(klass, d):
try:
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
except:
return d # Not a dataclass field
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
image_column: Optional[str] = field(
default="image_path",
metadata={"help": "The name of the column in the datasets containing the full image file paths."},
)
caption_column: Optional[str] = field(
default="caption",
metadata={"help": "The name of the column in the datasets containing the image captions."},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=77,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
image_mean: Optional[str] = field(
default=None, metadata={"help": "image preprocessing mean"}
)
image_std: Optional[str] = field(
default=None, metadata={"help": "image preprocessing std"}
)
freeze_vision_model: bool = field(
default=False, metadata={"help":"Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
class Transform(torch.nn.Module):
def __init__(self, image_size, mean, std):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(mean, std),
)
def forward(self, x) -> torch.Tensor:
"""`x` should be an instance of `PIL.Image.Image`"""
with torch.no_grad():
x = self.transforms(x)
return x
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"return_loss": True,
}
def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
import evaluate
import datasets
from transformers import Trainer
from datasets import load_dataset
from towhee.trainer.training_config import get_dataclasses_help
print('**** DataTrainingArguments ****')
get_dataclasses_help(DataTrainingArguments)
data_args = dataclass_from_dict(DataTrainingArguments, data_args)
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Setup logging
#+ training_args
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
temp_cache_dir = data_args.cache_dir
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint and eventualy continue from last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Load dataset
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files this script will use the first column for the full image path and the second column for the
# captions (unless you specify column names for this with the `image_column` and `caption_column` arguments).
#
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=temp_cache_dir,
keep_in_memory=False,
data_dir=data_args.data_dir,
# use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
dataset = load_dataset(
extension,
data_files=data_files,
cache_dir=temp_cache_dir,
# use_auth_token=True if model_args.use_auth_token else None,
)
config = model.config
freeze_vision_model = data_args.freeze_vision_model
freeze_text_model = data_args.freeze_text_model
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
if model_args.freeze_vision_model:
_freeze_params(model.vision_model)
if model_args.freeze_text_model:
_freeze_params(model.text_model)
if freeze_vision_model is True:
_freeze_params(model.vision_model)
if freeze_text_model is True:
_freeze_params(model.text_model)
set_seed(training_args.seed)
if training_args.do_train:
column_names = dataset["train"].column_names
elif training_args.do_eval:
column_names = dataset["validation"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval`.")
return
dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None)
if data_args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = data_args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if data_args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = data_args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
image_mean, image_std = data_args.image_mean, data_args.image_std
# Preprocessing the datasets.
# Initialize torchvision transforms and jit it for faster processing.
image_transformations = Transform(
config.vision_config.image_size, image_mean, image_std
)
image_transformations = torch.jit.script(image_transformations)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
#data_args
def tokenize_captions(examples):
captions = [caption for caption in examples[caption_column]]
text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
examples["input_ids"] = text_inputs.input_ids
examples["attention_mask"] = text_inputs.attention_mask
return examples
def transform_images(examples):
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]]
examples["pixel_values"] = [image_transformations(image) for image in images]
return examples
def filter_corrupt_images(examples):
"""remove problematic images"""
valid_images = []
for image_file in examples[image_column]:
try:
Image.open(image_file)
valid_images.append(True)
except Exception:
valid_images.append(False)
return valid_images
if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
train_dataset = train_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
train_dataset = train_dataset.map(
function=tokenize_captions,
batched=True,
remove_columns=[col for col in column_names if col != image_column],
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
# Transform images on the fly as doing it on the whole dataset takes too much time.
train_dataset.set_transform(transform_images)
if training_args.do_eval:
if "validation" not in dataset:
raise ValueError("--do_eval requires a train validation")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_dataset = eval_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
eval_dataset = eval_dataset.map(
function=tokenize_captions,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[col for col in column_names if col != image_column],
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
# Transform images on the fly as doing it on the whole dataset takes too much time.
eval_dataset.set_transform(transform_images)
# Initalize our trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
data_collator=collate_fn,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
Loading…
Cancel
Save