logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
25c68be2ac
  1. 41
      clip.py
  2. 410
      train_clip_with_hf_trainer.py

41
clip.py

@ -21,7 +21,8 @@ from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from towhee.models import clip
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
from train_clip_with_hf_trainer import train_with_hf_trainer
@register(output_schema=['vec'])
@ -32,15 +33,10 @@ class Clip(NNOperator):
def __init__(self, model_name: str, modality: str):
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True)
self.tokenize = clip.tokenize
self.tfms = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
cfg = self._configs()[model_name]
self.model = CLIPModel.from_pretrained(cfg['name'])
self.tokenizer = CLIPTokenizer.from_pretrained(cfg['name'])
self.processor = CLIPProcessor.from_pretrained(cfg['name'])
def inference_single_data(self, data):
if self.modality == 'image':
@ -66,13 +62,30 @@ class Clip(NNOperator):
return results
def _inference_from_text(self, text):
text = self.tokenize(text).to(self.device)
text_features = self.model.encode_text(text)
tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model.get_text_features(**tokens)
return text_features
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = to_pil(img)
image = self.tfms(img).unsqueeze(0).to(self.device)
image_features = self.model.encode_image(image)
inputs = processor(images=img, return_tensors="pt")
image_features = self.model.get_image_features(**inputs)
return image_features
def train(self, **kwargs):
data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None)
train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args)
def _configs(self):
config = {}
config['clip_vit_base_32'] = {}
config['clip_vit_base_32']['name'] = 'openai/clip-vit-base-patch16'
config['clip_vit_base_16'] = {}
config['clip_vit_base_16']['name'] = 'openai/clip-vit-base-patch32'
config['clip_vit_large_14'] = {}
config['clip_vit_large_14'] = 'openai/clip-vit-large-patch14'
config['clip_vit_large_14_336'] = {}
config['clip_vit_large_14_336']['name'] ='openai/clip-vit-large-patch14-336'
return config

410
train_clip_with_hf_trainer.py

@ -0,0 +1,410 @@
import logging
import os
import sys
import transformers
import dataclasses
import ipdb
from dataclasses import dataclass, field
from typing import Optional, List
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
set_seed,
)
# We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module,
# so we jit it to be faster.
logger = logging.getLogger(__name__)
dataset_name_mapping = {
"image_caption_dataset.py": ("image_path", "caption"),
}
def dataclass_from_dict(klass, d):
try:
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
except:
return d # Not a dataclass field
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
image_column: Optional[str] = field(
default="image_path",
metadata={"help": "The name of the column in the datasets containing the full image file paths."},
)
caption_column: Optional[str] = field(
default="caption",
metadata={"help": "The name of the column in the datasets containing the image captions."},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=77,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
image_mean: Optional[str] = field(
default=None, metadata={"help": "image preprocessing mean"}
)
image_std: Optional[str] = field(
default=None, metadata={"help": "image preprocessing std"}
)
freeze_vision_model: bool = field(
default=False, metadata={"help":"Whether to freeze the vision model parameters or not."}
)
freeze_text_model: bool = field(
default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
class Transform(torch.nn.Module):
def __init__(self, image_size, mean, std):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(mean, std),
)
def forward(self, x) -> torch.Tensor:
"""`x` should be an instance of `PIL.Image.Image`"""
with torch.no_grad():
x = self.transforms(x)
return x
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"return_loss": True,
}
def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
import evaluate
import datasets
from transformers import Trainer
from datasets import load_dataset
from towhee.trainer.training_config import get_dataclasses_help
print('**** DataTrainingArguments ****')
get_dataclasses_help(DataTrainingArguments)
data_args = dataclass_from_dict(DataTrainingArguments, data_args)
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
# 2. Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# 2. Setup logging
#+ training_args
#log_level = training_args.get_process_log_level()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
temp_cache_dir = data_args.cache_dir
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint and eventualy continue from last checkpoint
### place holder ###
### place holder ###
### place holder ###
# Load dataset
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files this script will use the first column for the full image path and the second column for the
# captions (unless you specify column names for this with the `image_column` and `caption_column` arguments).
#
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=temp_cache_dir,
keep_in_memory=False,
data_dir=data_args.data_dir,
# use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
dataset = load_dataset(
extension,
data_files=data_files,
cache_dir=temp_cache_dir,
# use_auth_token=True if model_args.use_auth_token else None,
)
config = model.config
freeze_vision_model = data_args.freeze_vision_model
freeze_text_model = data_args.freeze_text_model
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
if model_args.freeze_vision_model:
_freeze_params(model.vision_model)
if model_args.freeze_text_model:
_freeze_params(model.text_model)
if freeze_vision_model is True:
_freeze_params(model.vision_model)
if freeze_text_model is True:
_freeze_params(model.text_model)
set_seed(training_args.seed)
if training_args.do_train:
column_names = dataset["train"].column_names
elif training_args.do_eval:
column_names = dataset["validation"].column_names
elif training_args.do_predict:
column_names = dataset["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None)
if data_args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = data_args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if data_args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = data_args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
image_mean, image_std = data_args.image_mean, data_args.image_std
# Preprocessing the datasets.
# Initialize torchvision transforms and jit it for faster processing.
image_transformations = Transform(
config.vision_config.image_size, image_mean, image_std
)
image_transformations = torch.jit.script(image_transformations)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
#data_args
def tokenize_captions(examples):
captions = [caption for caption in examples[caption_column]]
text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
examples["input_ids"] = text_inputs.input_ids
examples["attention_mask"] = text_inputs.attention_mask
return examples
def transform_images(examples):
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]]
examples["pixel_values"] = [image_transformations(image) for image in images]
return examples
def filter_corrupt_images(examples):
"""remove problematic images"""
valid_images = []
for image_file in examples[image_column]:
try:
Image.open(image_file)
valid_images.append(True)
except Exception:
valid_images.append(False)
return valid_images
if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
train_dataset = train_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
train_dataset = train_dataset.map(
function=tokenize_captions,
batched=True,
remove_columns=[col for col in column_names if col != image_column],
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
# Transform images on the fly as doing it on the whole dataset takes too much time.
train_dataset.set_transform(transform_images)
if training_args.do_eval:
if "validation" not in dataset:
raise ValueError("--do_eval requires a train validation")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_dataset = eval_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
eval_dataset = eval_dataset.map(
function=tokenize_captions,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[col for col in column_names if col != image_column],
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
# Transform images on the fly as doing it on the whole dataset takes too much time.
eval_dataset.set_transform(transform_images)
# Initalize our trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
data_collator=collate_fn,
)
# Training
last_checkpoint = None
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
Loading…
Cancel
Save