import dataclasses from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import os try: import pickle import random from pathlib import Path from PIL import Image, ImageFilter import timm from tqdm import tqdm import torch import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from augly.image import (EncodingQuality, OneOf, RandomBlur, RandomEmojiOverlay, RandomPixelization, RandomRotation, ShufflePixels) from augly.image.functional import overlay_emoji, overlay_image, overlay_text from augly.image.transforms import BaseTransform from augly.utils import pathmgr from augly.utils.base_paths import MODULE_BASE_DIR from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR from pytorch_metric_learning import losses from pytorch_metric_learning.utils import distributed as pml_dist except: pass def dataclass_from_dict(klass, d): try: fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) except: return d # Not a dataclass field @dataclass class TrainingArguments: output_dir: Optional[str] = field( default='./output', metadata={"help": "output checkpoint saving dir."} ) distributed: Optional[bool] = field( default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."} ) gpu: Optional[int] = field( default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."} ) start_epoch: Optional[int] = field( default=0, metadata={"help": "Start epoch number."} ) epochs: Optional[int] = field( default=6, metadata={"help": "End epoch number."} ) batch_size: Optional[int] = field( default=128, metadata={"help": "Total batch size in all gpu."} ) init_lr: Optional[float] = field( default=0.1, metadata={"help": "init learning rate in SGD."} ) train_data_dir: Optional[str] = field(default=None, metadata={"help": "The dir containing all training data image files."}) def train_isc(model, training_args): from towhee.trainer.training_config import get_dataclasses_help print('**** TrainingArguments ****') get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) if training_args.distributed is True: ngpus_per_node = torch.cuda.device_count() mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args)) else: main_worker(training_args.gpu, 1, model, training_args) def main_worker(gpu, ngpus_per_node, model, training_args): rank = gpu world_size = ngpus_per_node distributed = training_args.distributed if distributed: dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001', world_size=world_size, rank=rank) torch.distributed.barrier(device_ids=[rank]) # infer learning rate before changing batch size init_lr = training_args.init_lr if distributed: # Apply SyncBN model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) torch.cuda.set_device(gpu) model.cuda(gpu) batch_size = training_args.batch_size batch_size = int(batch_size / ngpus_per_node) workers = 8 if distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) memory_size = 20000 pos_margin = 0.0 neg_margin = 1.0 loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin) loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size) loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank]) decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name: no_decay.append(param) else: decay.append(param) optim_params = [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': 1e-6} ] momentum = 0.9 optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum) scaler = torch.cuda.amp.GradScaler() cudnn.benchmark = True train_dataset = get_dataset(training_args.train_data_dir) if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True) start_epoch = training_args.start_epoch epochs = training_args.epochs for epoch in range(start_epoch, epochs): if distributed: train_sampler.set_epoch(epoch) train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) if not distributed or (distributed and rank == 0): if not os.path.exists(training_args.output_dir): os.mkdir(training_args.output_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'training_args': training_args, }, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar') def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank): losses = AverageMeter('Loss', ':.4f') progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader)) model.train() for labels, images in progress: optimizer.zero_grad() labels = labels.cuda(rank, non_blocking=True) images = torch.cat([ image for image in images ], dim=0).cuda(rank, non_blocking=True) labels = torch.tile(labels, dims=(2,)) with torch.cuda.amp.autocast(): embeddings = model(images) loss = loss_fn(embeddings, labels) losses.update(loss.item(), images.size(0)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() progress.set_postfix(loss=losses.avg) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ISCDataset(torch.utils.data.Dataset): def __init__( self, paths, transforms, ): self.paths = paths self.transforms = transforms def __len__(self): return len(self.paths) def __getitem__(self, i): image = Image.open(self.paths[i]) image = self.transforms(image) return i, image class NCropsTransform: """Take n random crops of one image as the query and key.""" def __init__(self, aug_moderate, aug_hard, ncrops=2): self.aug_moderate = aug_moderate self.aug_hard = aug_hard self.ncrops = ncrops def __call__(self, x): return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)] class GaussianBlur(object): """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" def __init__(self, sigma=[.1, 2.]): self.sigma = sigma def __call__(self, x): sigma = random.uniform(self.sigma[0], self.sigma[1]) x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) return x class RandomOverlayText(BaseTransform): def __init__( self, opacity: float = 1.0, p: float = 1.0, ): super().__init__(p) self.opacity = opacity with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f: font_list = [s.strip() for s in f.readlines()] blacklist = [ 'TypeMyMusic', 'PainttheSky-Regular', ] self.font_list = [ f for f in font_list if all(_ not in f for _ in blacklist) ] self.font_lens = [] for ff in self.font_list: font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl') with open(font_file, 'rb') as f: self.font_lens.append(len(pickle.load(f))) def apply_transform( self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None ) -> Image.Image: i = random.randrange(0, len(self.font_list)) kwargs = dict( font_file=Path(MODULE_BASE_DIR) / self.font_list[i], font_size=random.uniform(0.1, 0.3), color=[random.randrange(0, 256) for _ in range(3)], x_pos=random.uniform(0.0, 0.5), metadata=metadata, opacity=self.opacity, ) try: for j in range(random.randrange(1, 3)): if j == 0: y_pos = random.uniform(0.0, 0.5) else: y_pos += kwargs['font_size'] image = overlay_text( image, text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))], y_pos=y_pos, **kwargs, ) return image except OSError: return image class RandomOverlayImageAndResizedCrop(BaseTransform): def __init__( self, img_paths: List[Path], opacity_lower: float = 0.5, size_lower: float = 0.4, size_upper: float = 0.6, input_size: int = 224, moderate_scale_lower: float = 0.7, hard_scale_lower: float = 0.15, overlay_p: float = 0.05, p: float = 1.0, ): super().__init__(p) self.img_paths = img_paths self.opacity_lower = opacity_lower self.size_lower = size_lower self.size_upper = size_upper self.input_size = input_size self.moderate_scale_lower = moderate_scale_lower self.hard_scale_lower = hard_scale_lower self.overlay_p = overlay_p def apply_transform( self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None ) -> Image.Image: if random.uniform(0.0, 1.0) < self.overlay_p: if random.uniform(0.0, 1.0) > 0.5: background = Image.open(random.choice(self.img_paths)) overlay = image else: background = image overlay = Image.open(random.choice(self.img_paths)) overlay_size = random.uniform(self.size_lower, self.size_upper) image = overlay_image( background, overlay=overlay, opacity=random.uniform(self.opacity_lower, 1.0), overlay_size=overlay_size, x_pos=random.uniform(0.0, 1.0 - overlay_size), y_pos=random.uniform(0.0, 1.0 - overlay_size), metadata=metadata, ) return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image) else: return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image) class RandomEmojiOverlay(BaseTransform): def __init__( self, emoji_directory: str = SMILEY_EMOJI_DIR, opacity: float = 1.0, p: float = 1.0, ): super().__init__(p) self.emoji_directory = emoji_directory self.emoji_paths = pathmgr.ls(emoji_directory) self.opacity = opacity def apply_transform( self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None ) -> Image.Image: emoji_path = random.choice(self.emoji_paths) return overlay_emoji( image, emoji_path=os.path.join(self.emoji_directory, emoji_path), opacity=self.opacity, emoji_size=random.uniform(0.1, 0.3), x_pos=random.uniform(0.0, 1.0), y_pos=random.uniform(0.0, 1.0), metadata=metadata, ) class RandomEdgeEnhance(BaseTransform): def __init__( self, mode=ImageFilter.EDGE_ENHANCE, p: float = 1.0, ): super().__init__(p) self.mode = mode def apply_transform(self, image: Image.Image, *args) -> Image.Image: return image.filter(self.mode) class ShuffledAug: def __init__(self, aug_list): self.aug_list = aug_list def __call__(self, x): # without replacement shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list)) for op in shuffled_aug_list: x = op(x) return x def convert2rgb(x): return x.convert('RGB') def get_dataset(train_data_dir): input_size = 256 ncrops = 2 backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True) train_paths = list(Path(train_data_dir).glob('**/*.jpg')) aug_moderate = [ transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']) ] aug_list = [ transforms.ColorJitter(0.7, 0.7, 0.7, 0.2), RandomPixelization(p=0.25), ShufflePixels(factor=0.1, p=0.25), OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25), transforms.RandomGrayscale(p=0.25), RandomBlur(p=0.25), transforms.RandomPerspective(p=0.25), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.1), RandomOverlayText(p=0.25), RandomEmojiOverlay(p=0.25), OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE), RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)], p=0.25), ] aug_hard = [ RandomRotation(p=0.25), RandomOverlayImageAndResizedCrop( train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6, input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0, ), ShuffledAug(aug_list), convert2rgb, transforms.ToTensor(), transforms.RandomErasing(value='random', p=0.25), transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']), ] train_dataset = ISCDataset( train_paths, NCropsTransform( transforms.Compose(aug_moderate), transforms.Compose(aug_hard), ncrops, ), ) return train_dataset