isc
copied
ChengZi
2 years ago
2 changed files with 481 additions and 6 deletions
@ -0,0 +1,468 @@ |
|||||
|
try: |
||||
|
import os |
||||
|
import pickle |
||||
|
import random |
||||
|
from pathlib import Path |
||||
|
from typing import Any, Dict, List, Optional |
||||
|
|
||||
|
import timm |
||||
|
import torch |
||||
|
import torch.backends.cudnn as cudnn |
||||
|
import torch.distributed as dist |
||||
|
import torch.multiprocessing as mp |
||||
|
import torch.nn.parallel |
||||
|
import torch.optim |
||||
|
import torch.utils.data |
||||
|
import torch.utils.data.distributed |
||||
|
import torchvision.transforms as transforms |
||||
|
from augly.image import (EncodingQuality, OneOf, |
||||
|
RandomBlur, RandomEmojiOverlay, RandomPixelization, |
||||
|
RandomRotation, ShufflePixels) |
||||
|
from augly.image.functional import overlay_emoji, overlay_image, overlay_text |
||||
|
from augly.image.transforms import BaseTransform |
||||
|
from augly.utils import pathmgr |
||||
|
from augly.utils.base_paths import MODULE_BASE_DIR |
||||
|
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR |
||||
|
from PIL import Image, ImageFilter |
||||
|
from pytorch_metric_learning import losses |
||||
|
from pytorch_metric_learning.utils import distributed as pml_dist |
||||
|
from tqdm import tqdm |
||||
|
import dataclasses |
||||
|
from dataclasses import dataclass, field |
||||
|
from towhee.trainer.training_config import get_dataclasses_help |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def dataclass_from_dict(klass, d): |
||||
|
try: |
||||
|
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)} |
||||
|
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) |
||||
|
except: |
||||
|
return d # Not a dataclass field |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class TrainingArguments: |
||||
|
output_dir: Optional[str] = field( |
||||
|
default='./output', metadata={"help": "output checkpoint saving dir."} |
||||
|
) |
||||
|
distributed: Optional[bool] = field( |
||||
|
default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."} |
||||
|
) |
||||
|
gpu: Optional[int] = field( |
||||
|
default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."} |
||||
|
) |
||||
|
start_epoch: Optional[int] = field( |
||||
|
default=0, metadata={"help": "Start epoch number."} |
||||
|
) |
||||
|
epochs: Optional[int] = field( |
||||
|
default=6, metadata={"help": "End epoch number."} |
||||
|
) |
||||
|
batch_size: Optional[int] = field( |
||||
|
default=128, metadata={"help": "Total batch size in all gpu."} |
||||
|
) |
||||
|
init_lr: Optional[float] = field( |
||||
|
default=0.1, metadata={"help": "init learning rate in SGD."} |
||||
|
) |
||||
|
train_data_dir: Optional[str] = field(default=None, |
||||
|
metadata={"help": "The dir containing all training data image files."}) |
||||
|
|
||||
|
|
||||
|
def train_isc(model, training_args): |
||||
|
print('**** TrainingArguments ****') |
||||
|
get_dataclasses_help(TrainingArguments) |
||||
|
training_args = dataclass_from_dict(TrainingArguments, training_args) |
||||
|
if training_args.distributed is True: |
||||
|
ngpus_per_node = torch.cuda.device_count() |
||||
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args)) |
||||
|
else: |
||||
|
main_worker(training_args.gpu, 1, model, training_args) |
||||
|
|
||||
|
|
||||
|
def main_worker(gpu, ngpus_per_node, model, training_args): |
||||
|
rank = gpu |
||||
|
world_size = ngpus_per_node |
||||
|
distributed = training_args.distributed |
||||
|
if distributed: |
||||
|
dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001', |
||||
|
world_size=world_size, rank=rank) |
||||
|
torch.distributed.barrier(device_ids=[rank]) |
||||
|
|
||||
|
# infer learning rate before changing batch size |
||||
|
init_lr = training_args.init_lr |
||||
|
|
||||
|
if distributed: |
||||
|
# Apply SyncBN |
||||
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
||||
|
torch.cuda.set_device(gpu) |
||||
|
model.cuda(gpu) |
||||
|
batch_size = training_args.batch_size |
||||
|
batch_size = int(batch_size / ngpus_per_node) |
||||
|
workers = 8 |
||||
|
if distributed: |
||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) |
||||
|
|
||||
|
memory_size = 20000 |
||||
|
pos_margin = 0.0 |
||||
|
neg_margin = 1.0 |
||||
|
loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin) |
||||
|
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size) |
||||
|
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank]) |
||||
|
|
||||
|
decay = [] |
||||
|
no_decay = [] |
||||
|
for name, param in model.named_parameters(): |
||||
|
if not param.requires_grad: |
||||
|
continue # frozen weights |
||||
|
if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name: |
||||
|
no_decay.append(param) |
||||
|
else: |
||||
|
decay.append(param) |
||||
|
|
||||
|
optim_params = [ |
||||
|
{'params': no_decay, 'weight_decay': 0.}, |
||||
|
{'params': decay, 'weight_decay': 1e-6} |
||||
|
] |
||||
|
momentum = 0.9 |
||||
|
|
||||
|
optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum) |
||||
|
scaler = torch.cuda.amp.GradScaler() |
||||
|
|
||||
|
cudnn.benchmark = True |
||||
|
|
||||
|
train_dataset = get_dataset(training_args.train_data_dir) |
||||
|
if distributed: |
||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
||||
|
else: |
||||
|
train_sampler = None |
||||
|
|
||||
|
train_loader = torch.utils.data.DataLoader( |
||||
|
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), |
||||
|
num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True) |
||||
|
start_epoch = training_args.start_epoch |
||||
|
epochs = training_args.epochs |
||||
|
for epoch in range(start_epoch, epochs): |
||||
|
if distributed: |
||||
|
train_sampler.set_epoch(epoch) |
||||
|
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) |
||||
|
|
||||
|
if not distributed or (distributed and rank == 0): |
||||
|
torch.save({ |
||||
|
'epoch': epoch + 1, |
||||
|
'state_dict': model.state_dict(), |
||||
|
'optimizer': optimizer.state_dict(), |
||||
|
'training_args': training_args, |
||||
|
}, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar') |
||||
|
|
||||
|
|
||||
|
def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank): |
||||
|
losses = AverageMeter('Loss', ':.4f') |
||||
|
progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader)) |
||||
|
|
||||
|
model.train() |
||||
|
|
||||
|
for labels, images in progress: |
||||
|
optimizer.zero_grad() |
||||
|
|
||||
|
labels = labels.cuda(rank, non_blocking=True) |
||||
|
images = torch.cat([ |
||||
|
image for image in images |
||||
|
], dim=0).cuda(rank, non_blocking=True) |
||||
|
labels = torch.tile(labels, dims=(2,)) |
||||
|
|
||||
|
with torch.cuda.amp.autocast(): |
||||
|
embeddings = model(images) |
||||
|
loss = loss_fn(embeddings, labels) |
||||
|
|
||||
|
losses.update(loss.item(), images.size(0)) |
||||
|
|
||||
|
scaler.scale(loss).backward() |
||||
|
scaler.step(optimizer) |
||||
|
scaler.update() |
||||
|
|
||||
|
progress.set_postfix(loss=losses.avg) |
||||
|
|
||||
|
|
||||
|
class AverageMeter(object): |
||||
|
"""Computes and stores the average and current value""" |
||||
|
|
||||
|
def __init__(self, name, fmt=':f'): |
||||
|
self.name = name |
||||
|
self.fmt = fmt |
||||
|
self.reset() |
||||
|
|
||||
|
def reset(self): |
||||
|
self.val = 0 |
||||
|
self.avg = 0 |
||||
|
self.sum = 0 |
||||
|
self.count = 0 |
||||
|
|
||||
|
def update(self, val, n=1): |
||||
|
self.val = val |
||||
|
self.sum += val * n |
||||
|
self.count += n |
||||
|
self.avg = self.sum / self.count |
||||
|
|
||||
|
def __str__(self): |
||||
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
||||
|
return fmtstr.format(**self.__dict__) |
||||
|
|
||||
|
|
||||
|
class ISCDataset(torch.utils.data.Dataset): |
||||
|
def __init__( |
||||
|
self, |
||||
|
paths, |
||||
|
transforms, |
||||
|
): |
||||
|
self.paths = paths |
||||
|
self.transforms = transforms |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.paths) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
image = Image.open(self.paths[i]) |
||||
|
image = self.transforms(image) |
||||
|
return i, image |
||||
|
|
||||
|
|
||||
|
class NCropsTransform: |
||||
|
"""Take n random crops of one image as the query and key.""" |
||||
|
|
||||
|
def __init__(self, aug_moderate, aug_hard, ncrops=2): |
||||
|
self.aug_moderate = aug_moderate |
||||
|
self.aug_hard = aug_hard |
||||
|
self.ncrops = ncrops |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)] |
||||
|
|
||||
|
|
||||
|
class GaussianBlur(object): |
||||
|
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" |
||||
|
|
||||
|
def __init__(self, sigma=[.1, 2.]): |
||||
|
self.sigma = sigma |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
sigma = random.uniform(self.sigma[0], self.sigma[1]) |
||||
|
x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class RandomOverlayText(BaseTransform): |
||||
|
def __init__( |
||||
|
self, |
||||
|
opacity: float = 1.0, |
||||
|
p: float = 1.0, |
||||
|
): |
||||
|
super().__init__(p) |
||||
|
self.opacity = opacity |
||||
|
|
||||
|
with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f: |
||||
|
font_list = [s.strip() for s in f.readlines()] |
||||
|
blacklist = [ |
||||
|
'TypeMyMusic', |
||||
|
'PainttheSky-Regular', |
||||
|
] |
||||
|
self.font_list = [ |
||||
|
f for f in font_list |
||||
|
if all(_ not in f for _ in blacklist) |
||||
|
] |
||||
|
|
||||
|
self.font_lens = [] |
||||
|
for ff in self.font_list: |
||||
|
font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl') |
||||
|
with open(font_file, 'rb') as f: |
||||
|
self.font_lens.append(len(pickle.load(f))) |
||||
|
|
||||
|
def apply_transform( |
||||
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
||||
|
) -> Image.Image: |
||||
|
i = random.randrange(0, len(self.font_list)) |
||||
|
kwargs = dict( |
||||
|
font_file=Path(MODULE_BASE_DIR) / self.font_list[i], |
||||
|
font_size=random.uniform(0.1, 0.3), |
||||
|
color=[random.randrange(0, 256) for _ in range(3)], |
||||
|
x_pos=random.uniform(0.0, 0.5), |
||||
|
metadata=metadata, |
||||
|
opacity=self.opacity, |
||||
|
) |
||||
|
try: |
||||
|
for j in range(random.randrange(1, 3)): |
||||
|
if j == 0: |
||||
|
y_pos = random.uniform(0.0, 0.5) |
||||
|
else: |
||||
|
y_pos += kwargs['font_size'] |
||||
|
image = overlay_text( |
||||
|
image, |
||||
|
text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))], |
||||
|
y_pos=y_pos, |
||||
|
**kwargs, |
||||
|
) |
||||
|
return image |
||||
|
except OSError: |
||||
|
return image |
||||
|
|
||||
|
|
||||
|
class RandomOverlayImageAndResizedCrop(BaseTransform): |
||||
|
def __init__( |
||||
|
self, |
||||
|
img_paths: List[Path], |
||||
|
opacity_lower: float = 0.5, |
||||
|
size_lower: float = 0.4, |
||||
|
size_upper: float = 0.6, |
||||
|
input_size: int = 224, |
||||
|
moderate_scale_lower: float = 0.7, |
||||
|
hard_scale_lower: float = 0.15, |
||||
|
overlay_p: float = 0.05, |
||||
|
p: float = 1.0, |
||||
|
): |
||||
|
super().__init__(p) |
||||
|
self.img_paths = img_paths |
||||
|
self.opacity_lower = opacity_lower |
||||
|
self.size_lower = size_lower |
||||
|
self.size_upper = size_upper |
||||
|
self.input_size = input_size |
||||
|
self.moderate_scale_lower = moderate_scale_lower |
||||
|
self.hard_scale_lower = hard_scale_lower |
||||
|
self.overlay_p = overlay_p |
||||
|
|
||||
|
def apply_transform( |
||||
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
||||
|
) -> Image.Image: |
||||
|
|
||||
|
if random.uniform(0.0, 1.0) < self.overlay_p: |
||||
|
if random.uniform(0.0, 1.0) > 0.5: |
||||
|
background = Image.open(random.choice(self.img_paths)) |
||||
|
overlay = image |
||||
|
else: |
||||
|
background = image |
||||
|
overlay = Image.open(random.choice(self.img_paths)) |
||||
|
|
||||
|
overlay_size = random.uniform(self.size_lower, self.size_upper) |
||||
|
image = overlay_image( |
||||
|
background, |
||||
|
overlay=overlay, |
||||
|
opacity=random.uniform(self.opacity_lower, 1.0), |
||||
|
overlay_size=overlay_size, |
||||
|
x_pos=random.uniform(0.0, 1.0 - overlay_size), |
||||
|
y_pos=random.uniform(0.0, 1.0 - overlay_size), |
||||
|
metadata=metadata, |
||||
|
) |
||||
|
return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image) |
||||
|
else: |
||||
|
return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image) |
||||
|
|
||||
|
|
||||
|
class RandomEmojiOverlay(BaseTransform): |
||||
|
def __init__( |
||||
|
self, |
||||
|
emoji_directory: str = SMILEY_EMOJI_DIR, |
||||
|
opacity: float = 1.0, |
||||
|
p: float = 1.0, |
||||
|
): |
||||
|
super().__init__(p) |
||||
|
self.emoji_directory = emoji_directory |
||||
|
self.emoji_paths = pathmgr.ls(emoji_directory) |
||||
|
self.opacity = opacity |
||||
|
|
||||
|
def apply_transform( |
||||
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None |
||||
|
) -> Image.Image: |
||||
|
emoji_path = random.choice(self.emoji_paths) |
||||
|
return overlay_emoji( |
||||
|
image, |
||||
|
emoji_path=os.path.join(self.emoji_directory, emoji_path), |
||||
|
opacity=self.opacity, |
||||
|
emoji_size=random.uniform(0.1, 0.3), |
||||
|
x_pos=random.uniform(0.0, 1.0), |
||||
|
y_pos=random.uniform(0.0, 1.0), |
||||
|
metadata=metadata, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class RandomEdgeEnhance(BaseTransform): |
||||
|
def __init__( |
||||
|
self, |
||||
|
mode=ImageFilter.EDGE_ENHANCE, |
||||
|
p: float = 1.0, |
||||
|
): |
||||
|
super().__init__(p) |
||||
|
self.mode = mode |
||||
|
|
||||
|
def apply_transform(self, image: Image.Image, *args) -> Image.Image: |
||||
|
return image.filter(self.mode) |
||||
|
|
||||
|
|
||||
|
class ShuffledAug: |
||||
|
|
||||
|
def __init__(self, aug_list): |
||||
|
self.aug_list = aug_list |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
# without replacement |
||||
|
shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list)) |
||||
|
for op in shuffled_aug_list: |
||||
|
x = op(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def convert2rgb(x): |
||||
|
return x.convert('RGB') |
||||
|
|
||||
|
|
||||
|
def get_dataset(train_data_dir): |
||||
|
input_size = 256 |
||||
|
ncrops = 2 |
||||
|
|
||||
|
backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True) |
||||
|
|
||||
|
train_paths = list(Path(train_data_dir).glob('**/*.jpg')) |
||||
|
|
||||
|
aug_moderate = [ |
||||
|
transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)), |
||||
|
transforms.RandomHorizontalFlip(), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']) |
||||
|
] |
||||
|
|
||||
|
aug_list = [ |
||||
|
transforms.ColorJitter(0.7, 0.7, 0.7, 0.2), |
||||
|
RandomPixelization(p=0.25), |
||||
|
ShufflePixels(factor=0.1, p=0.25), |
||||
|
OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25), |
||||
|
transforms.RandomGrayscale(p=0.25), |
||||
|
RandomBlur(p=0.25), |
||||
|
transforms.RandomPerspective(p=0.25), |
||||
|
transforms.RandomHorizontalFlip(p=0.5), |
||||
|
transforms.RandomVerticalFlip(p=0.1), |
||||
|
RandomOverlayText(p=0.25), |
||||
|
RandomEmojiOverlay(p=0.25), |
||||
|
OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE), |
||||
|
RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)], |
||||
|
p=0.25), |
||||
|
] |
||||
|
aug_hard = [ |
||||
|
RandomRotation(p=0.25), |
||||
|
RandomOverlayImageAndResizedCrop( |
||||
|
train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6, |
||||
|
input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0, |
||||
|
), |
||||
|
ShuffledAug(aug_list), |
||||
|
convert2rgb, |
||||
|
transforms.ToTensor(), |
||||
|
transforms.RandomErasing(value='random', p=0.25), |
||||
|
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']), |
||||
|
] |
||||
|
|
||||
|
train_dataset = ISCDataset( |
||||
|
train_paths, |
||||
|
NCropsTransform( |
||||
|
transforms.Compose(aug_moderate), |
||||
|
transforms.Compose(aug_hard), |
||||
|
ncrops, |
||||
|
), |
||||
|
) |
||||
|
return train_dataset |
Loading…
Reference in new issue