|
|
|
import dataclasses
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import os
|
|
|
|
|
|
|
|
try:
|
|
|
|
import pickle
|
|
|
|
import random
|
|
|
|
from pathlib import Path
|
|
|
|
from PIL import Image, ImageFilter
|
|
|
|
import timm
|
|
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.nn.parallel
|
|
|
|
import torch.optim
|
|
|
|
import torch.utils.data
|
|
|
|
import torch.utils.data.distributed
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
from augly.image import (EncodingQuality, OneOf,
|
|
|
|
RandomBlur, RandomEmojiOverlay, RandomPixelization,
|
|
|
|
RandomRotation, ShufflePixels)
|
|
|
|
from augly.image.functional import overlay_emoji, overlay_image, overlay_text
|
|
|
|
from augly.image.transforms import BaseTransform
|
|
|
|
from augly.utils import pathmgr
|
|
|
|
from augly.utils.base_paths import MODULE_BASE_DIR
|
|
|
|
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR
|
|
|
|
|
|
|
|
from pytorch_metric_learning import losses
|
|
|
|
from pytorch_metric_learning.utils import distributed as pml_dist
|
|
|
|
|
|
|
|
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def dataclass_from_dict(klass, d):
|
|
|
|
try:
|
|
|
|
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
|
|
|
|
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
|
|
|
|
except:
|
|
|
|
return d # Not a dataclass field
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class TrainingArguments:
|
|
|
|
output_dir: Optional[str] = field(
|
|
|
|
default='./output', metadata={"help": "output checkpoint saving dir."}
|
|
|
|
)
|
|
|
|
distributed: Optional[bool] = field(
|
|
|
|
default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."}
|
|
|
|
)
|
|
|
|
gpu: Optional[int] = field(
|
|
|
|
default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."}
|
|
|
|
)
|
|
|
|
start_epoch: Optional[int] = field(
|
|
|
|
default=0, metadata={"help": "Start epoch number."}
|
|
|
|
)
|
|
|
|
epochs: Optional[int] = field(
|
|
|
|
default=6, metadata={"help": "End epoch number."}
|
|
|
|
)
|
|
|
|
batch_size: Optional[int] = field(
|
|
|
|
default=128, metadata={"help": "Total batch size in all gpu."}
|
|
|
|
)
|
|
|
|
init_lr: Optional[float] = field(
|
|
|
|
default=0.1, metadata={"help": "init learning rate in SGD."}
|
|
|
|
)
|
|
|
|
train_data_dir: Optional[str] = field(default=None,
|
|
|
|
metadata={"help": "The dir containing all training data image files."})
|
|
|
|
|
|
|
|
|
|
|
|
def train_isc(model, training_args):
|
|
|
|
from towhee.trainer.training_config import get_dataclasses_help
|
|
|
|
print('**** TrainingArguments ****')
|
|
|
|
get_dataclasses_help(TrainingArguments)
|
|
|
|
training_args = dataclass_from_dict(TrainingArguments, training_args)
|
|
|
|
if training_args.distributed is True:
|
|
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
|
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args))
|
|
|
|
else:
|
|
|
|
main_worker(training_args.gpu, 1, model, training_args)
|
|
|
|
|
|
|
|
|
|
|
|
def main_worker(gpu, ngpus_per_node, model, training_args):
|
|
|
|
rank = gpu
|
|
|
|
world_size = ngpus_per_node
|
|
|
|
distributed = training_args.distributed
|
|
|
|
if distributed:
|
|
|
|
dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001',
|
|
|
|
world_size=world_size, rank=rank)
|
|
|
|
torch.distributed.barrier(device_ids=[rank])
|
|
|
|
|
|
|
|
# infer learning rate before changing batch size
|
|
|
|
init_lr = training_args.init_lr
|
|
|
|
|
|
|
|
if distributed:
|
|
|
|
# Apply SyncBN
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
torch.cuda.set_device(gpu)
|
|
|
|
model.cuda(gpu)
|
|
|
|
batch_size = training_args.batch_size
|
|
|
|
batch_size = int(batch_size / ngpus_per_node)
|
|
|
|
workers = 8
|
|
|
|
if distributed:
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
|
|
|
|
|
|
|
|
memory_size = 20000
|
|
|
|
pos_margin = 0.0
|
|
|
|
neg_margin = 1.0
|
|
|
|
loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin)
|
|
|
|
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size)
|
|
|
|
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank])
|
|
|
|
|
|
|
|
decay = []
|
|
|
|
no_decay = []
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
if not param.requires_grad:
|
|
|
|
continue # frozen weights
|
|
|
|
if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name:
|
|
|
|
no_decay.append(param)
|
|
|
|
else:
|
|
|
|
decay.append(param)
|
|
|
|
|
|
|
|
optim_params = [
|
|
|
|
{'params': no_decay, 'weight_decay': 0.},
|
|
|
|
{'params': decay, 'weight_decay': 1e-6}
|
|
|
|
]
|
|
|
|
momentum = 0.9
|
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum)
|
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
|
|
train_dataset = get_dataset(training_args.train_data_dir)
|
|
|
|
if distributed:
|
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
|
|
else:
|
|
|
|
train_sampler = None
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
|
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
|
|
|
|
num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
|
|
|
start_epoch = training_args.start_epoch
|
|
|
|
epochs = training_args.epochs
|
|
|
|
for epoch in range(start_epoch, epochs):
|
|
|
|
if distributed:
|
|
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank)
|
|
|
|
|
|
|
|
if not distributed or (distributed and rank == 0):
|
|
|
|
torch.save({
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'training_args': training_args,
|
|
|
|
}, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar')
|
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank):
|
|
|
|
losses = AverageMeter('Loss', ':.4f')
|
|
|
|
progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader))
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
for labels, images in progress:
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
labels = labels.cuda(rank, non_blocking=True)
|
|
|
|
images = torch.cat([
|
|
|
|
image for image in images
|
|
|
|
], dim=0).cuda(rank, non_blocking=True)
|
|
|
|
labels = torch.tile(labels, dims=(2,))
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast():
|
|
|
|
embeddings = model(images)
|
|
|
|
loss = loss_fn(embeddings, labels)
|
|
|
|
|
|
|
|
losses.update(loss.item(), images.size(0))
|
|
|
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
scaler.step(optimizer)
|
|
|
|
scaler.update()
|
|
|
|
|
|
|
|
progress.set_postfix(loss=losses.avg)
|
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter(object):
|
|
|
|
"""Computes and stores the average and current value"""
|
|
|
|
|
|
|
|
def __init__(self, name, fmt=':f'):
|
|
|
|
self.name = name
|
|
|
|
self.fmt = fmt
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.val = 0
|
|
|
|
self.avg = 0
|
|
|
|
self.sum = 0
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
def update(self, val, n=1):
|
|
|
|
self.val = val
|
|
|
|
self.sum += val * n
|
|
|
|
self.count += n
|
|
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
|
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
|
|
|
|
|
|
|
|
class ISCDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
paths,
|
|
|
|
transforms,
|
|
|
|
):
|
|
|
|
self.paths = paths
|
|
|
|
self.transforms = transforms
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.paths)
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
image = Image.open(self.paths[i])
|
|
|
|
image = self.transforms(image)
|
|
|
|
return i, image
|
|
|
|
|
|
|
|
|
|
|
|
class NCropsTransform:
|
|
|
|
"""Take n random crops of one image as the query and key."""
|
|
|
|
|
|
|
|
def __init__(self, aug_moderate, aug_hard, ncrops=2):
|
|
|
|
self.aug_moderate = aug_moderate
|
|
|
|
self.aug_hard = aug_hard
|
|
|
|
self.ncrops = ncrops
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)]
|
|
|
|
|
|
|
|
|
|
|
|
class GaussianBlur(object):
|
|
|
|
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
|
|
|
|
|
|
|
def __init__(self, sigma=[.1, 2.]):
|
|
|
|
self.sigma = sigma
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
|
|
|
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class RandomOverlayText(BaseTransform):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
opacity: float = 1.0,
|
|
|
|
p: float = 1.0,
|
|
|
|
):
|
|
|
|
super().__init__(p)
|
|
|
|
self.opacity = opacity
|
|
|
|
|
|
|
|
with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f:
|
|
|
|
font_list = [s.strip() for s in f.readlines()]
|
|
|
|
blacklist = [
|
|
|
|
'TypeMyMusic',
|
|
|
|
'PainttheSky-Regular',
|
|
|
|
]
|
|
|
|
self.font_list = [
|
|
|
|
f for f in font_list
|
|
|
|
if all(_ not in f for _ in blacklist)
|
|
|
|
]
|
|
|
|
|
|
|
|
self.font_lens = []
|
|
|
|
for ff in self.font_list:
|
|
|
|
font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl')
|
|
|
|
with open(font_file, 'rb') as f:
|
|
|
|
self.font_lens.append(len(pickle.load(f)))
|
|
|
|
|
|
|
|
def apply_transform(
|
|
|
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
|
|
|
|
) -> Image.Image:
|
|
|
|
i = random.randrange(0, len(self.font_list))
|
|
|
|
kwargs = dict(
|
|
|
|
font_file=Path(MODULE_BASE_DIR) / self.font_list[i],
|
|
|
|
font_size=random.uniform(0.1, 0.3),
|
|
|
|
color=[random.randrange(0, 256) for _ in range(3)],
|
|
|
|
x_pos=random.uniform(0.0, 0.5),
|
|
|
|
metadata=metadata,
|
|
|
|
opacity=self.opacity,
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
for j in range(random.randrange(1, 3)):
|
|
|
|
if j == 0:
|
|
|
|
y_pos = random.uniform(0.0, 0.5)
|
|
|
|
else:
|
|
|
|
y_pos += kwargs['font_size']
|
|
|
|
image = overlay_text(
|
|
|
|
image,
|
|
|
|
text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))],
|
|
|
|
y_pos=y_pos,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
return image
|
|
|
|
except OSError:
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
class RandomOverlayImageAndResizedCrop(BaseTransform):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
img_paths: List[Path],
|
|
|
|
opacity_lower: float = 0.5,
|
|
|
|
size_lower: float = 0.4,
|
|
|
|
size_upper: float = 0.6,
|
|
|
|
input_size: int = 224,
|
|
|
|
moderate_scale_lower: float = 0.7,
|
|
|
|
hard_scale_lower: float = 0.15,
|
|
|
|
overlay_p: float = 0.05,
|
|
|
|
p: float = 1.0,
|
|
|
|
):
|
|
|
|
super().__init__(p)
|
|
|
|
self.img_paths = img_paths
|
|
|
|
self.opacity_lower = opacity_lower
|
|
|
|
self.size_lower = size_lower
|
|
|
|
self.size_upper = size_upper
|
|
|
|
self.input_size = input_size
|
|
|
|
self.moderate_scale_lower = moderate_scale_lower
|
|
|
|
self.hard_scale_lower = hard_scale_lower
|
|
|
|
self.overlay_p = overlay_p
|
|
|
|
|
|
|
|
def apply_transform(
|
|
|
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
|
|
|
|
) -> Image.Image:
|
|
|
|
|
|
|
|
if random.uniform(0.0, 1.0) < self.overlay_p:
|
|
|
|
if random.uniform(0.0, 1.0) > 0.5:
|
|
|
|
background = Image.open(random.choice(self.img_paths))
|
|
|
|
overlay = image
|
|
|
|
else:
|
|
|
|
background = image
|
|
|
|
overlay = Image.open(random.choice(self.img_paths))
|
|
|
|
|
|
|
|
overlay_size = random.uniform(self.size_lower, self.size_upper)
|
|
|
|
image = overlay_image(
|
|
|
|
background,
|
|
|
|
overlay=overlay,
|
|
|
|
opacity=random.uniform(self.opacity_lower, 1.0),
|
|
|
|
overlay_size=overlay_size,
|
|
|
|
x_pos=random.uniform(0.0, 1.0 - overlay_size),
|
|
|
|
y_pos=random.uniform(0.0, 1.0 - overlay_size),
|
|
|
|
metadata=metadata,
|
|
|
|
)
|
|
|
|
return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image)
|
|
|
|
else:
|
|
|
|
return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image)
|
|
|
|
|
|
|
|
|
|
|
|
class RandomEmojiOverlay(BaseTransform):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
emoji_directory: str = SMILEY_EMOJI_DIR,
|
|
|
|
opacity: float = 1.0,
|
|
|
|
p: float = 1.0,
|
|
|
|
):
|
|
|
|
super().__init__(p)
|
|
|
|
self.emoji_directory = emoji_directory
|
|
|
|
self.emoji_paths = pathmgr.ls(emoji_directory)
|
|
|
|
self.opacity = opacity
|
|
|
|
|
|
|
|
def apply_transform(
|
|
|
|
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
|
|
|
|
) -> Image.Image:
|
|
|
|
emoji_path = random.choice(self.emoji_paths)
|
|
|
|
return overlay_emoji(
|
|
|
|
image,
|
|
|
|
emoji_path=os.path.join(self.emoji_directory, emoji_path),
|
|
|
|
opacity=self.opacity,
|
|
|
|
emoji_size=random.uniform(0.1, 0.3),
|
|
|
|
x_pos=random.uniform(0.0, 1.0),
|
|
|
|
y_pos=random.uniform(0.0, 1.0),
|
|
|
|
metadata=metadata,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class RandomEdgeEnhance(BaseTransform):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
mode=ImageFilter.EDGE_ENHANCE,
|
|
|
|
p: float = 1.0,
|
|
|
|
):
|
|
|
|
super().__init__(p)
|
|
|
|
self.mode = mode
|
|
|
|
|
|
|
|
def apply_transform(self, image: Image.Image, *args) -> Image.Image:
|
|
|
|
return image.filter(self.mode)
|
|
|
|
|
|
|
|
|
|
|
|
class ShuffledAug:
|
|
|
|
|
|
|
|
def __init__(self, aug_list):
|
|
|
|
self.aug_list = aug_list
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
# without replacement
|
|
|
|
shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list))
|
|
|
|
for op in shuffled_aug_list:
|
|
|
|
x = op(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def convert2rgb(x):
|
|
|
|
return x.convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataset(train_data_dir):
|
|
|
|
input_size = 256
|
|
|
|
ncrops = 2
|
|
|
|
|
|
|
|
backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True)
|
|
|
|
|
|
|
|
train_paths = list(Path(train_data_dir).glob('**/*.jpg'))
|
|
|
|
|
|
|
|
aug_moderate = [
|
|
|
|
transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)),
|
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std'])
|
|
|
|
]
|
|
|
|
|
|
|
|
aug_list = [
|
|
|
|
transforms.ColorJitter(0.7, 0.7, 0.7, 0.2),
|
|
|
|
RandomPixelization(p=0.25),
|
|
|
|
ShufflePixels(factor=0.1, p=0.25),
|
|
|
|
OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25),
|
|
|
|
transforms.RandomGrayscale(p=0.25),
|
|
|
|
RandomBlur(p=0.25),
|
|
|
|
transforms.RandomPerspective(p=0.25),
|
|
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
|
|
transforms.RandomVerticalFlip(p=0.1),
|
|
|
|
RandomOverlayText(p=0.25),
|
|
|
|
RandomEmojiOverlay(p=0.25),
|
|
|
|
OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE),
|
|
|
|
RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)],
|
|
|
|
p=0.25),
|
|
|
|
]
|
|
|
|
aug_hard = [
|
|
|
|
RandomRotation(p=0.25),
|
|
|
|
RandomOverlayImageAndResizedCrop(
|
|
|
|
train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6,
|
|
|
|
input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0,
|
|
|
|
),
|
|
|
|
ShuffledAug(aug_list),
|
|
|
|
convert2rgb,
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.RandomErasing(value='random', p=0.25),
|
|
|
|
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']),
|
|
|
|
]
|
|
|
|
|
|
|
|
train_dataset = ISCDataset(
|
|
|
|
train_paths,
|
|
|
|
NCropsTransform(
|
|
|
|
transforms.Compose(aug_moderate),
|
|
|
|
transforms.Compose(aug_hard),
|
|
|
|
ncrops,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
return train_dataset
|