isc
copied
ChengZi
2 years ago
2 changed files with 481 additions and 6 deletions
@ -0,0 +1,468 @@ |
|||
try: |
|||
import os |
|||
import pickle |
|||
import random |
|||
from pathlib import Path |
|||
from typing import Any, Dict, List, Optional |
|||
|
|||
import timm |
|||
import torch |
|||
import torch.backends.cudnn as cudnn |
|||
import torch.distributed as dist |
|||
import torch.multiprocessing as mp |
|||
import torch.nn.parallel |
|||
import torch.optim |
|||
import torch.utils.data |
|||
import torch.utils.data.distributed |
|||
import torchvision.transforms as transforms |
|||
from augly.image import (EncodingQuality, OneOf, |
|||
RandomBlur, RandomEmojiOverlay, RandomPixelization, |
|||
RandomRotation, ShufflePixels) |
|||
from augly.image.functional import overlay_emoji, overlay_image, overlay_text |
|||
from augly.image.transforms import BaseTransform |
|||
from augly.utils import pathmgr |
|||
from augly.utils.base_paths import MODULE_BASE_DIR |
|||
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR |
|||
from PIL import Image, ImageFilter |
|||
from pytorch_metric_learning import losses |
|||
from pytorch_metric_learning.utils import distributed as pml_dist |
|||
from tqdm import tqdm |
|||
import dataclasses |
|||
from dataclasses import dataclass, field |
|||
from towhee.trainer.training_config import get_dataclasses_help |
|||
except: |
|||
pass |
|||
|
|||
|
|||
def dataclass_from_dict(klass, d): |
|||
try: |
|||
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} |
|||
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) |
|||
except: |
|||
return d # Not a dataclass field |
|||
|
|||
|
|||
@dataclass |
|||
class TrainingArguments: |
|||
output_dir: Optional[str] = field( |
|||
default='./output', metadata={"help": "output checkpoint saving dir."} |
|||
) |
|||
distributed: Optional[bool] = field( |
|||
default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."} |
|||
) |
|||
gpu: Optional[int] = field( |
|||
default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."} |
|||
) |
|||
start_epoch: Optional[int] = field( |
|||
default=0, metadata={"help": "Start epoch number."} |
|||
) |
|||
epochs: Optional[int] = field( |
|||
default=6, metadata={"help": "End epoch number."} |
|||
) |
|||
batch_size: Optional[int] = field( |
|||
default=128, metadata={"help": "Total batch size in all gpu."} |
|||
) |
|||
init_lr: Optional[float] = field( |
|||
default=0.1, metadata={"help": "init learning rate in SGD."} |
|||
) |
|||
train_data_dir: Optional[str] = field(default=None, |
|||
metadata={"help": "The dir containing all training data image files."}) |
|||
|
|||
|
|||
def train_isc(model, training_args): |
|||
print('**** TrainingArguments ****') |
|||
get_dataclasses_help(TrainingArguments) |
|||
training_args = dataclass_from_dict(TrainingArguments, training_args) |
|||
if training_args.distributed is True: |
|||
ngpus_per_node = torch.cuda.device_count() |
|||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args)) |
|||
else: |
|||
main_worker(training_args.gpu, 1, model, training_args) |
|||
|
|||
|
|||
def main_worker(gpu, ngpus_per_node, model, training_args): |
|||
rank = gpu |
|||
world_size = ngpus_per_node |
|||
distributed = training_args.distributed |
|||
if distributed: |
|||
dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001', |
|||
world_size=world_size, rank=rank) |
|||
torch.distributed.barrier(device_ids=[rank]) |
|||
|
|||
# infer learning rate before changing batch size |
|||
init_lr = training_args.init_lr |
|||
|
|||
if distributed: |
|||
# Apply SyncBN |
|||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|||
torch.cuda.set_device(gpu) |
|||
model.cuda(gpu) |
|||
batch_size = training_args.batch_size |
|||
batch_size = int(batch_size / ngpus_per_node) |
|||
workers = 8 |
|||
if distributed: |
|||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) |
|||
|
|||
memory_size = 20000 |
|||
pos_margin = 0.0 |
|||
neg_margin = 1.0 |
|||
loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin) |
|||
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size) |
|||
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank]) |
|||
|
|||
decay = [] |
|||
no_decay = [] |
|||
for name, param in model.named_parameters(): |
|||
if not param.requires_grad: |
|||
continue # frozen weights |
|||
if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name: |
|||
no_decay.append(param) |
|||
else: |
|||
decay.append(param) |
|||
|
|||
optim_params = [ |
|||
{'params': no_decay, 'weight_decay': 0.}, |
|||
{'params': decay, 'weight_decay': 1e-6} |
|||
] |
|||
momentum = 0.9 |
|||
|
|||
optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum) |
|||
scaler = torch.cuda.amp.GradScaler() |
|||
|
|||
cudnn.benchmark = True |
|||
|
|||
train_dataset = get_dataset(training_args.train_data_dir) |
|||
if distributed: |
|||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
|||
else: |
|||
train_sampler = None |
|||
|
|||
train_loader = torch.utils.data.DataLoader( |
|||
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), |
|||
num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True) |
|||
start_epoch = training_args.start_epoch |
|||
epochs = training_args.epochs |
|||
for epoch in range(start_epoch, epochs): |
|||
if distributed: |
|||
train_sampler.set_epoch(epoch) |
|||
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) |
|||
|
|||
if not distributed or (distributed and rank == 0): |
|||
torch.save({ |
|||
'epoch': epoch + 1, |
|||
'state_dict': model.state_dict(), |
|||
'optimizer': optimizer.state_dict(), |
|||
'training_args': training_args, |
|||
}, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar') |
|||
|
|||
|
|||
def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank): |
|||
losses = AverageMeter('Loss', ':.4f') |
|||
progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader)) |
|||
|
|||
model.train() |
|||
|
|||
for labels, images in progress: |
|||
optimizer.zero_grad() |
|||
|
|||
labels = labels.cuda(rank, non_blocking=True) |
|||
images = torch.cat([ |
|||
image for image in images |
|||
], dim=0).cuda(rank, non_blocking=True) |
|||
labels = torch.tile(labels, dims=(2,)) |
|||
|
|||
with torch.cuda.amp.autocast(): |
|||
embeddings = model(images) |
|||
loss = loss_fn(embeddings, labels) |
|||
|
|||
losses.update(loss.item(), images.size(0)) |
|||
|
|||
scaler.scale(loss).backward() |
|||
scaler.step(optimizer) |
|||
scaler.update() |
|||
|
|||
progress.set_postfix(loss=losses.avg) |
|||
|
|||
|
|||
class AverageMeter(object): |
|||
"""Computes and stores the average and current value""" |
|||
|
|||
def __init__(self, name, fmt=':f'): |
|||
self.name = name |
|||
self.fmt = fmt |
|||
self.reset() |
|||
|
|||
def reset(self): |
|||
self.val = 0 |
|||
self.avg = 0 |
|||
self.sum = 0 |
|||
self.count = 0 |
|||
|
|||
def update(self, val, n=1): |
|||
self.val = val |
|||
self.sum += val * n |
|||
self.count += n |
|||
self.avg = self.sum / self.count |
|||
|
|||
def __str__(self): |
|||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|||
return fmtstr.format(**self.__dict__) |
|||
|
|||
|
|||
class ISCDataset(torch.utils.data.Dataset): |
|||
def __init__( |
|||
self, |
|||
paths, |
|||
transforms, |
|||
): |
|||
self.paths = paths |
|||
self.transforms = transforms |
|||
|
|||
def __len__(self): |
|||
return len(self.paths) |
|||
|
|||
def __getitem__(self, i): |
|||
image = Image.open(self.paths[i]) |
|||
image = self.transforms(image) |
|||
return i, image |
|||
|
|||
|
|||
class NCropsTransform: |
|||
"""Take n random crops of one image as the query and key.""" |
|||
|
|||
def __init__(self, aug_moderate, aug_hard, ncrops=2): |
|||
self.aug_moderate = aug_moderate |
|||
self.aug_hard = aug_hard |
|||
self.ncrops = ncrops |
|||
|
|||
def __call__(self, x): |
|||
return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)] |
|||
|
|||
|
|||
class GaussianBlur(object): |
|||
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" |
|||
|
|||
def __init__(self, sigma=[.1, 2.]): |
|||
self.sigma = sigma |
|||
|
|||
def __call__(self, x): |
|||
sigma = random.uniform(self.sigma[0], self.sigma[1]) |
|||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) |
|||
return x |
|||
|
|||
|
|||
class RandomOverlayText(BaseTransform): |
|||
def __init__( |
|||
self, |
|||
opacity: float = 1.0, |
|||
p: float = 1.0, |
|||
): |
|||
super().__init__(p) |
|||
self.opacity = opacity |
|||
|
|||
with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f: |
|||
font_list = [s.strip() for s in f.readlines()] |
|||
blacklist = [ |
|||
'TypeMyMusic', |
|||
'PainttheSky-Regular', |
|||
] |
|||
self.font_list = [ |
|||
f for f in font_list |
|||
if all(_ not in f for _ in blacklist) |
|||
] |
|||
|
|||
self.font_lens = [] |
|||
for ff in self.font_list: |
|||
font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl') |
|||
with open(font_file, 'rb') as f: |
|||
self.font_lens.append(len(pickle.load(f))) |
|||
|
|||
def apply_transform( |
|||
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
|||
) -> Image.Image: |
|||
i = random.randrange(0, len(self.font_list)) |
|||
kwargs = dict( |
|||
font_file=Path(MODULE_BASE_DIR) / self.font_list[i], |
|||
font_size=random.uniform(0.1, 0.3), |
|||
color=[random.randrange(0, 256) for _ in range(3)], |
|||
x_pos=random.uniform(0.0, 0.5), |
|||
metadata=metadata, |
|||
opacity=self.opacity, |
|||
) |
|||
try: |
|||
for j in range(random.randrange(1, 3)): |
|||
if j == 0: |
|||
y_pos = random.uniform(0.0, 0.5) |
|||
else: |
|||
y_pos += kwargs['font_size'] |
|||
image = overlay_text( |
|||
image, |
|||
text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))], |
|||
y_pos=y_pos, |
|||
**kwargs, |
|||
) |
|||
return image |
|||
except OSError: |
|||
return image |
|||
|
|||
|
|||
class RandomOverlayImageAndResizedCrop(BaseTransform): |
|||
def __init__( |
|||
self, |
|||
img_paths: List[Path], |
|||
opacity_lower: float = 0.5, |
|||
size_lower: float = 0.4, |
|||
size_upper: float = 0.6, |
|||
input_size: int = 224, |
|||
moderate_scale_lower: float = 0.7, |
|||
hard_scale_lower: float = 0.15, |
|||
overlay_p: float = 0.05, |
|||
p: float = 1.0, |
|||
): |
|||
super().__init__(p) |
|||
self.img_paths = img_paths |
|||
self.opacity_lower = opacity_lower |
|||
self.size_lower = size_lower |
|||
self.size_upper = size_upper |
|||
self.input_size = input_size |
|||
self.moderate_scale_lower = moderate_scale_lower |
|||
self.hard_scale_lower = hard_scale_lower |
|||
self.overlay_p = overlay_p |
|||
|
|||
def apply_transform( |
|||
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
|||
) -> Image.Image: |
|||
|
|||
if random.uniform(0.0, 1.0) < self.overlay_p: |
|||
if random.uniform(0.0, 1.0) > 0.5: |
|||
background = Image.open(random.choice(self.img_paths)) |
|||
overlay = image |
|||
else: |
|||
background = image |
|||
overlay = Image.open(random.choice(self.img_paths)) |
|||
|
|||
overlay_size = random.uniform(self.size_lower, self.size_upper) |
|||
image = overlay_image( |
|||
background, |
|||
overlay=overlay, |
|||
opacity=random.uniform(self.opacity_lower, 1.0), |
|||
overlay_size=overlay_size, |
|||
x_pos=random.uniform(0.0, 1.0 - overlay_size), |
|||
y_pos=random.uniform(0.0, 1.0 - overlay_size), |
|||
metadata=metadata, |
|||
) |
|||
return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image) |
|||
else: |
|||
return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image) |
|||
|
|||
|
|||
class RandomEmojiOverlay(BaseTransform): |
|||
def __init__( |
|||
self, |
|||
emoji_directory: str = SMILEY_EMOJI_DIR, |
|||
opacity: float = 1.0, |
|||
p: float = 1.0, |
|||
): |
|||
super().__init__(p) |
|||
self.emoji_directory = emoji_directory |
|||
self.emoji_paths = pathmgr.ls(emoji_directory) |
|||
self.opacity = opacity |
|||
|
|||
def apply_transform( |
|||
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
|||
) -> Image.Image: |
|||
emoji_path = random.choice(self.emoji_paths) |
|||
return overlay_emoji( |
|||
image, |
|||
emoji_path=os.path.join(self.emoji_directory, emoji_path), |
|||
opacity=self.opacity, |
|||
emoji_size=random.uniform(0.1, 0.3), |
|||
x_pos=random.uniform(0.0, 1.0), |
|||
y_pos=random.uniform(0.0, 1.0), |
|||
metadata=metadata, |
|||
) |
|||
|
|||
|
|||
class RandomEdgeEnhance(BaseTransform): |
|||
def __init__( |
|||
self, |
|||
mode=ImageFilter.EDGE_ENHANCE, |
|||
p: float = 1.0, |
|||
): |
|||
super().__init__(p) |
|||
self.mode = mode |
|||
|
|||
def apply_transform(self, image: Image.Image, *args) -> Image.Image: |
|||
return image.filter(self.mode) |
|||
|
|||
|
|||
class ShuffledAug: |
|||
|
|||
def __init__(self, aug_list): |
|||
self.aug_list = aug_list |
|||
|
|||
def __call__(self, x): |
|||
# without replacement |
|||
shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list)) |
|||
for op in shuffled_aug_list: |
|||
x = op(x) |
|||
return x |
|||
|
|||
|
|||
def convert2rgb(x): |
|||
return x.convert('RGB') |
|||
|
|||
|
|||
def get_dataset(train_data_dir): |
|||
input_size = 256 |
|||
ncrops = 2 |
|||
|
|||
backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True) |
|||
|
|||
train_paths = list(Path(train_data_dir).glob('**/*.jpg')) |
|||
|
|||
aug_moderate = [ |
|||
transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)), |
|||
transforms.RandomHorizontalFlip(), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']) |
|||
] |
|||
|
|||
aug_list = [ |
|||
transforms.ColorJitter(0.7, 0.7, 0.7, 0.2), |
|||
RandomPixelization(p=0.25), |
|||
ShufflePixels(factor=0.1, p=0.25), |
|||
OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25), |
|||
transforms.RandomGrayscale(p=0.25), |
|||
RandomBlur(p=0.25), |
|||
transforms.RandomPerspective(p=0.25), |
|||
transforms.RandomHorizontalFlip(p=0.5), |
|||
transforms.RandomVerticalFlip(p=0.1), |
|||
RandomOverlayText(p=0.25), |
|||
RandomEmojiOverlay(p=0.25), |
|||
OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE), |
|||
RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)], |
|||
p=0.25), |
|||
] |
|||
aug_hard = [ |
|||
RandomRotation(p=0.25), |
|||
RandomOverlayImageAndResizedCrop( |
|||
train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6, |
|||
input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0, |
|||
), |
|||
ShuffledAug(aug_list), |
|||
convert2rgb, |
|||
transforms.ToTensor(), |
|||
transforms.RandomErasing(value='random', p=0.25), |
|||
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']), |
|||
] |
|||
|
|||
train_dataset = ISCDataset( |
|||
train_paths, |
|||
NCropsTransform( |
|||
transforms.Compose(aug_moderate), |
|||
transforms.Compose(aug_hard), |
|||
ncrops, |
|||
), |
|||
) |
|||
return train_dataset |
Loading…
Reference in new issue