diff --git a/isc.py b/isc.py index f5b999f..4fbaf83 100644 --- a/isc.py +++ b/isc.py @@ -14,7 +14,6 @@ import logging import os -import numpy from typing import Union, List from pathlib import Path @@ -33,6 +32,8 @@ import timm import warnings +from .train_isc import train_isc + warnings.filterwarnings('ignore') log = logging.getLogger('isc_op') @@ -43,7 +44,7 @@ class Model: self.device = device self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, - backbone=self.backbone, p=3.0, eval_p=1.0) + backbone=self.backbone, p=1.0, eval_p=1.0) self.model.eval() def __call__(self, x): @@ -80,10 +81,10 @@ class Isc(NNOperator): self.model = Model(self.timm_backbone, checkpoint_path, self.device) self.tfms = transforms.Compose([ - transforms.Resize((img_size, img_size)), - transforms.ToTensor(), - transforms.Normalize(mean=self.backbone.default_cfg['mean'], - std=self.backbone.default_cfg['std']) + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=self.backbone.default_cfg['mean'], + std=self.backbone.default_cfg['std']) ]) def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): @@ -174,6 +175,12 @@ class Isc(NNOperator): def supported_formats(self): return ['onnx'] + def train(self, training_config=None, + train_dataset=None, + eval_dataset=None, + resume_checkpoint_path=None, **kwargs): + training_args = kwargs.pop('training_args', None) + train_isc(self._model, training_args) # if __name__ == '__main__': # from towhee import ops diff --git a/train_isc.py b/train_isc.py new file mode 100644 index 0000000..d3b972c --- /dev/null +++ b/train_isc.py @@ -0,0 +1,468 @@ +try: + import os + import pickle + import random + from pathlib import Path + from typing import Any, Dict, List, Optional + + import timm + import torch + import torch.backends.cudnn as cudnn + import torch.distributed as dist + import torch.multiprocessing as mp + import torch.nn.parallel + import torch.optim + import torch.utils.data + import torch.utils.data.distributed + import torchvision.transforms as transforms + from augly.image import (EncodingQuality, OneOf, + RandomBlur, RandomEmojiOverlay, RandomPixelization, + RandomRotation, ShufflePixels) + from augly.image.functional import overlay_emoji, overlay_image, overlay_text + from augly.image.transforms import BaseTransform + from augly.utils import pathmgr + from augly.utils.base_paths import MODULE_BASE_DIR + from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR + from PIL import Image, ImageFilter + from pytorch_metric_learning import losses + from pytorch_metric_learning.utils import distributed as pml_dist + from tqdm import tqdm + import dataclasses + from dataclasses import dataclass, field + from towhee.trainer.training_config import get_dataclasses_help +except: + pass + + +def dataclass_from_dict(klass, d): + try: + fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} + return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) + except: + return d # Not a dataclass field + + +@dataclass +class TrainingArguments: + output_dir: Optional[str] = field( + default='./output', metadata={"help": "output checkpoint saving dir."} + ) + distributed: Optional[bool] = field( + default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."} + ) + gpu: Optional[int] = field( + default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."} + ) + start_epoch: Optional[int] = field( + default=0, metadata={"help": "Start epoch number."} + ) + epochs: Optional[int] = field( + default=6, metadata={"help": "End epoch number."} + ) + batch_size: Optional[int] = field( + default=128, metadata={"help": "Total batch size in all gpu."} + ) + init_lr: Optional[float] = field( + default=0.1, metadata={"help": "init learning rate in SGD."} + ) + train_data_dir: Optional[str] = field(default=None, + metadata={"help": "The dir containing all training data image files."}) + + +def train_isc(model, training_args): + print('**** TrainingArguments ****') + get_dataclasses_help(TrainingArguments) + training_args = dataclass_from_dict(TrainingArguments, training_args) + if training_args.distributed is True: + ngpus_per_node = torch.cuda.device_count() + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args)) + else: + main_worker(training_args.gpu, 1, model, training_args) + + +def main_worker(gpu, ngpus_per_node, model, training_args): + rank = gpu + world_size = ngpus_per_node + distributed = training_args.distributed + if distributed: + dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001', + world_size=world_size, rank=rank) + torch.distributed.barrier(device_ids=[rank]) + + # infer learning rate before changing batch size + init_lr = training_args.init_lr + + if distributed: + # Apply SyncBN + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + torch.cuda.set_device(gpu) + model.cuda(gpu) + batch_size = training_args.batch_size + batch_size = int(batch_size / ngpus_per_node) + workers = 8 + if distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) + + memory_size = 20000 + pos_margin = 0.0 + neg_margin = 1.0 + loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin) + loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size) + loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank]) + + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name: + no_decay.append(param) + else: + decay.append(param) + + optim_params = [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': 1e-6} + ] + momentum = 0.9 + + optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum) + scaler = torch.cuda.amp.GradScaler() + + cudnn.benchmark = True + + train_dataset = get_dataset(training_args.train_data_dir) + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), + num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True) + start_epoch = training_args.start_epoch + epochs = training_args.epochs + for epoch in range(start_epoch, epochs): + if distributed: + train_sampler.set_epoch(epoch) + train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) + + if not distributed or (distributed and rank == 0): + torch.save({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'training_args': training_args, + }, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar') + + +def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank): + losses = AverageMeter('Loss', ':.4f') + progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader)) + + model.train() + + for labels, images in progress: + optimizer.zero_grad() + + labels = labels.cuda(rank, non_blocking=True) + images = torch.cat([ + image for image in images + ], dim=0).cuda(rank, non_blocking=True) + labels = torch.tile(labels, dims=(2,)) + + with torch.cuda.amp.autocast(): + embeddings = model(images) + loss = loss_fn(embeddings, labels) + + losses.update(loss.item(), images.size(0)) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + progress.set_postfix(loss=losses.avg) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ISCDataset(torch.utils.data.Dataset): + def __init__( + self, + paths, + transforms, + ): + self.paths = paths + self.transforms = transforms + + def __len__(self): + return len(self.paths) + + def __getitem__(self, i): + image = Image.open(self.paths[i]) + image = self.transforms(image) + return i, image + + +class NCropsTransform: + """Take n random crops of one image as the query and key.""" + + def __init__(self, aug_moderate, aug_hard, ncrops=2): + self.aug_moderate = aug_moderate + self.aug_hard = aug_hard + self.ncrops = ncrops + + def __call__(self, x): + return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)] + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +class RandomOverlayText(BaseTransform): + def __init__( + self, + opacity: float = 1.0, + p: float = 1.0, + ): + super().__init__(p) + self.opacity = opacity + + with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f: + font_list = [s.strip() for s in f.readlines()] + blacklist = [ + 'TypeMyMusic', + 'PainttheSky-Regular', + ] + self.font_list = [ + f for f in font_list + if all(_ not in f for _ in blacklist) + ] + + self.font_lens = [] + for ff in self.font_list: + font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl') + with open(font_file, 'rb') as f: + self.font_lens.append(len(pickle.load(f))) + + def apply_transform( + self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None + ) -> Image.Image: + i = random.randrange(0, len(self.font_list)) + kwargs = dict( + font_file=Path(MODULE_BASE_DIR) / self.font_list[i], + font_size=random.uniform(0.1, 0.3), + color=[random.randrange(0, 256) for _ in range(3)], + x_pos=random.uniform(0.0, 0.5), + metadata=metadata, + opacity=self.opacity, + ) + try: + for j in range(random.randrange(1, 3)): + if j == 0: + y_pos = random.uniform(0.0, 0.5) + else: + y_pos += kwargs['font_size'] + image = overlay_text( + image, + text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))], + y_pos=y_pos, + **kwargs, + ) + return image + except OSError: + return image + + +class RandomOverlayImageAndResizedCrop(BaseTransform): + def __init__( + self, + img_paths: List[Path], + opacity_lower: float = 0.5, + size_lower: float = 0.4, + size_upper: float = 0.6, + input_size: int = 224, + moderate_scale_lower: float = 0.7, + hard_scale_lower: float = 0.15, + overlay_p: float = 0.05, + p: float = 1.0, + ): + super().__init__(p) + self.img_paths = img_paths + self.opacity_lower = opacity_lower + self.size_lower = size_lower + self.size_upper = size_upper + self.input_size = input_size + self.moderate_scale_lower = moderate_scale_lower + self.hard_scale_lower = hard_scale_lower + self.overlay_p = overlay_p + + def apply_transform( + self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None + ) -> Image.Image: + + if random.uniform(0.0, 1.0) < self.overlay_p: + if random.uniform(0.0, 1.0) > 0.5: + background = Image.open(random.choice(self.img_paths)) + overlay = image + else: + background = image + overlay = Image.open(random.choice(self.img_paths)) + + overlay_size = random.uniform(self.size_lower, self.size_upper) + image = overlay_image( + background, + overlay=overlay, + opacity=random.uniform(self.opacity_lower, 1.0), + overlay_size=overlay_size, + x_pos=random.uniform(0.0, 1.0 - overlay_size), + y_pos=random.uniform(0.0, 1.0 - overlay_size), + metadata=metadata, + ) + return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image) + else: + return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image) + + +class RandomEmojiOverlay(BaseTransform): + def __init__( + self, + emoji_directory: str = SMILEY_EMOJI_DIR, + opacity: float = 1.0, + p: float = 1.0, + ): + super().__init__(p) + self.emoji_directory = emoji_directory + self.emoji_paths = pathmgr.ls(emoji_directory) + self.opacity = opacity + + def apply_transform( + self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None + ) -> Image.Image: + emoji_path = random.choice(self.emoji_paths) + return overlay_emoji( + image, + emoji_path=os.path.join(self.emoji_directory, emoji_path), + opacity=self.opacity, + emoji_size=random.uniform(0.1, 0.3), + x_pos=random.uniform(0.0, 1.0), + y_pos=random.uniform(0.0, 1.0), + metadata=metadata, + ) + + +class RandomEdgeEnhance(BaseTransform): + def __init__( + self, + mode=ImageFilter.EDGE_ENHANCE, + p: float = 1.0, + ): + super().__init__(p) + self.mode = mode + + def apply_transform(self, image: Image.Image, *args) -> Image.Image: + return image.filter(self.mode) + + +class ShuffledAug: + + def __init__(self, aug_list): + self.aug_list = aug_list + + def __call__(self, x): + # without replacement + shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list)) + for op in shuffled_aug_list: + x = op(x) + return x + + +def convert2rgb(x): + return x.convert('RGB') + + +def get_dataset(train_data_dir): + input_size = 256 + ncrops = 2 + + backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True) + + train_paths = list(Path(train_data_dir).glob('**/*.jpg')) + + aug_moderate = [ + transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']) + ] + + aug_list = [ + transforms.ColorJitter(0.7, 0.7, 0.7, 0.2), + RandomPixelization(p=0.25), + ShufflePixels(factor=0.1, p=0.25), + OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25), + transforms.RandomGrayscale(p=0.25), + RandomBlur(p=0.25), + transforms.RandomPerspective(p=0.25), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.1), + RandomOverlayText(p=0.25), + RandomEmojiOverlay(p=0.25), + OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE), + RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)], + p=0.25), + ] + aug_hard = [ + RandomRotation(p=0.25), + RandomOverlayImageAndResizedCrop( + train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6, + input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0, + ), + ShuffledAug(aug_list), + convert2rgb, + transforms.ToTensor(), + transforms.RandomErasing(value='random', p=0.25), + transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']), + ] + + train_dataset = ISCDataset( + train_paths, + NCropsTransform( + transforms.Compose(aug_moderate), + transforms.Compose(aug_hard), + ncrops, + ), + ) + return train_dataset