logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

473 lines
16 KiB

import dataclasses
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import os
try:
import pickle
import random
from pathlib import Path
from PIL import Image, ImageFilter
import timm
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from augly.image import (EncodingQuality, OneOf,
RandomBlur, RandomEmojiOverlay, RandomPixelization,
RandomRotation, ShufflePixels)
from augly.image.functional import overlay_emoji, overlay_image, overlay_text
from augly.image.transforms import BaseTransform
from augly.utils import pathmgr
from augly.utils.base_paths import MODULE_BASE_DIR
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR
from pytorch_metric_learning import losses
from pytorch_metric_learning.utils import distributed as pml_dist
except:
pass
def dataclass_from_dict(klass, d):
try:
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
except:
return d # Not a dataclass field
@dataclass
class TrainingArguments:
output_dir: Optional[str] = field(
default='./output', metadata={"help": "output checkpoint saving dir."}
)
distributed: Optional[bool] = field(
default=False, metadata={"help": "If true, use all gpu in your machine, else use only one gpu."}
)
gpu: Optional[int] = field(
default=0, metadata={"help": "When distributed is False, use this gpu No. in your machine."}
)
start_epoch: Optional[int] = field(
default=0, metadata={"help": "Start epoch number."}
)
epochs: Optional[int] = field(
default=6, metadata={"help": "End epoch number."}
)
batch_size: Optional[int] = field(
default=128, metadata={"help": "Total batch size in all gpu."}
)
init_lr: Optional[float] = field(
default=0.1, metadata={"help": "init learning rate in SGD."}
)
train_data_dir: Optional[str] = field(default=None,
metadata={"help": "The dir containing all training data image files."})
def train_isc(model, training_args):
from towhee.trainer.training_config import get_dataclasses_help
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
if training_args.distributed is True:
ngpus_per_node = torch.cuda.device_count()
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, model, training_args))
else:
main_worker(training_args.gpu, 1, model, training_args)
def main_worker(gpu, ngpus_per_node, model, training_args):
rank = gpu
world_size = ngpus_per_node
distributed = training_args.distributed
if distributed:
dist.init_process_group(backend='nccl', init_method='tcp://localhost:10001',
world_size=world_size, rank=rank)
torch.distributed.barrier(device_ids=[rank])
# infer learning rate before changing batch size
init_lr = training_args.init_lr
if distributed:
# Apply SyncBN
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = training_args.batch_size
batch_size = int(batch_size / ngpus_per_node)
workers = 8
if distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
memory_size = 20000
pos_margin = 0.0
neg_margin = 1.0
loss_fn = losses.ContrastiveLoss(pos_margin=pos_margin, neg_margin=neg_margin)
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=256, memory_size=memory_size)
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank])
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or "gain" in name:
no_decay.append(param)
else:
decay.append(param)
optim_params = [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': 1e-6}
]
momentum = 0.9
optimizer = torch.optim.SGD(optim_params, init_lr, momentum=momentum)
scaler = torch.cuda.amp.GradScaler()
cudnn.benchmark = True
train_dataset = get_dataset(training_args.train_data_dir)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, sampler=train_sampler, drop_last=True)
start_epoch = training_args.start_epoch
epochs = training_args.epochs
for epoch in range(start_epoch, epochs):
if distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank)
if not distributed or (distributed and rank == 0):
if not os.path.exists(training_args.output_dir):
os.mkdir(training_args.output_dir)
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'training_args': training_args,
}, f'{training_args.output_dir}/checkpoint_epoch{epoch:04d}.pth.tar')
def train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank):
losses = AverageMeter('Loss', ':.4f')
progress = tqdm(train_loader, desc=f'epoch {epoch}', leave=False, total=len(train_loader))
model.train()
for labels, images in progress:
optimizer.zero_grad()
labels = labels.cuda(rank, non_blocking=True)
images = torch.cat([
image for image in images
], dim=0).cuda(rank, non_blocking=True)
labels = torch.tile(labels, dims=(2,))
with torch.cuda.amp.autocast():
embeddings = model(images)
loss = loss_fn(embeddings, labels)
losses.update(loss.item(), images.size(0))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
progress.set_postfix(loss=losses.avg)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ISCDataset(torch.utils.data.Dataset):
def __init__(
self,
paths,
transforms,
):
self.paths = paths
self.transforms = transforms
def __len__(self):
return len(self.paths)
def __getitem__(self, i):
image = Image.open(self.paths[i])
image = self.transforms(image)
return i, image
class NCropsTransform:
"""Take n random crops of one image as the query and key."""
def __init__(self, aug_moderate, aug_hard, ncrops=2):
self.aug_moderate = aug_moderate
self.aug_hard = aug_hard
self.ncrops = ncrops
def __call__(self, x):
return [self.aug_moderate(x)] + [self.aug_hard(x) for _ in range(self.ncrops - 1)]
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[.1, 2.]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
class RandomOverlayText(BaseTransform):
def __init__(
self,
opacity: float = 1.0,
p: float = 1.0,
):
super().__init__(p)
self.opacity = opacity
with open(Path(FONTS_DIR) / FONT_LIST_PATH) as f:
font_list = [s.strip() for s in f.readlines()]
blacklist = [
'TypeMyMusic',
'PainttheSky-Regular',
]
self.font_list = [
f for f in font_list
if all(_ not in f for _ in blacklist)
]
self.font_lens = []
for ff in self.font_list:
font_file = Path(MODULE_BASE_DIR) / ff.replace('.ttf', '.pkl')
with open(font_file, 'rb') as f:
self.font_lens.append(len(pickle.load(f)))
def apply_transform(
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
) -> Image.Image:
i = random.randrange(0, len(self.font_list))
kwargs = dict(
font_file=Path(MODULE_BASE_DIR) / self.font_list[i],
font_size=random.uniform(0.1, 0.3),
color=[random.randrange(0, 256) for _ in range(3)],
x_pos=random.uniform(0.0, 0.5),
metadata=metadata,
opacity=self.opacity,
)
try:
for j in range(random.randrange(1, 3)):
if j == 0:
y_pos = random.uniform(0.0, 0.5)
else:
y_pos += kwargs['font_size']
image = overlay_text(
image,
text=[random.randrange(0, self.font_lens[i]) for _ in range(random.randrange(5, 10))],
y_pos=y_pos,
**kwargs,
)
return image
except OSError:
return image
class RandomOverlayImageAndResizedCrop(BaseTransform):
def __init__(
self,
img_paths: List[Path],
opacity_lower: float = 0.5,
size_lower: float = 0.4,
size_upper: float = 0.6,
input_size: int = 224,
moderate_scale_lower: float = 0.7,
hard_scale_lower: float = 0.15,
overlay_p: float = 0.05,
p: float = 1.0,
):
super().__init__(p)
self.img_paths = img_paths
self.opacity_lower = opacity_lower
self.size_lower = size_lower
self.size_upper = size_upper
self.input_size = input_size
self.moderate_scale_lower = moderate_scale_lower
self.hard_scale_lower = hard_scale_lower
self.overlay_p = overlay_p
def apply_transform(
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
) -> Image.Image:
if random.uniform(0.0, 1.0) < self.overlay_p:
if random.uniform(0.0, 1.0) > 0.5:
background = Image.open(random.choice(self.img_paths))
overlay = image
else:
background = image
overlay = Image.open(random.choice(self.img_paths))
overlay_size = random.uniform(self.size_lower, self.size_upper)
image = overlay_image(
background,
overlay=overlay,
opacity=random.uniform(self.opacity_lower, 1.0),
overlay_size=overlay_size,
x_pos=random.uniform(0.0, 1.0 - overlay_size),
y_pos=random.uniform(0.0, 1.0 - overlay_size),
metadata=metadata,
)
return transforms.RandomResizedCrop(self.input_size, scale=(self.moderate_scale_lower, 1.))(image)
else:
return transforms.RandomResizedCrop(self.input_size, scale=(self.hard_scale_lower, 1.))(image)
class RandomEmojiOverlay(BaseTransform):
def __init__(
self,
emoji_directory: str = SMILEY_EMOJI_DIR,
opacity: float = 1.0,
p: float = 1.0,
):
super().__init__(p)
self.emoji_directory = emoji_directory
self.emoji_paths = pathmgr.ls(emoji_directory)
self.opacity = opacity
def apply_transform(
self, image: Image.Image, metadata: Optional[List[Dict[str, Any]]] = None, bboxes=None, bbox_format=None
) -> Image.Image:
emoji_path = random.choice(self.emoji_paths)
return overlay_emoji(
image,
emoji_path=os.path.join(self.emoji_directory, emoji_path),
opacity=self.opacity,
emoji_size=random.uniform(0.1, 0.3),
x_pos=random.uniform(0.0, 1.0),
y_pos=random.uniform(0.0, 1.0),
metadata=metadata,
)
class RandomEdgeEnhance(BaseTransform):
def __init__(
self,
mode=ImageFilter.EDGE_ENHANCE,
p: float = 1.0,
):
super().__init__(p)
self.mode = mode
def apply_transform(self, image: Image.Image, *args) -> Image.Image:
return image.filter(self.mode)
class ShuffledAug:
def __init__(self, aug_list):
self.aug_list = aug_list
def __call__(self, x):
# without replacement
shuffled_aug_list = random.sample(self.aug_list, len(self.aug_list))
for op in shuffled_aug_list:
x = op(x)
return x
def convert2rgb(x):
return x.convert('RGB')
def get_dataset(train_data_dir):
input_size = 256
ncrops = 2
backbone = timm.create_model('tf_efficientnetv2_m_in21ft1k', features_only=True, pretrained=True)
train_paths = list(Path(train_data_dir).glob('**/*.jpg'))
aug_moderate = [
transforms.RandomResizedCrop(input_size, scale=(0.7, 1.)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std'])
]
aug_list = [
transforms.ColorJitter(0.7, 0.7, 0.7, 0.2),
RandomPixelization(p=0.25),
ShufflePixels(factor=0.1, p=0.25),
OneOf([EncodingQuality(quality=q) for q in [10, 20, 30, 50]], p=0.25),
transforms.RandomGrayscale(p=0.25),
RandomBlur(p=0.25),
transforms.RandomPerspective(p=0.25),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.1),
RandomOverlayText(p=0.25),
RandomEmojiOverlay(p=0.25),
OneOf([RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE),
RandomEdgeEnhance(mode=ImageFilter.EDGE_ENHANCE_MORE)],
p=0.25),
]
aug_hard = [
RandomRotation(p=0.25),
RandomOverlayImageAndResizedCrop(
train_paths, opacity_lower=0.6, size_lower=0.4, size_upper=0.6,
input_size=input_size, moderate_scale_lower=0.7, hard_scale_lower=0.15, overlay_p=0.05, p=1.0,
),
ShuffledAug(aug_list),
convert2rgb,
transforms.ToTensor(),
transforms.RandomErasing(value='random', p=0.25),
transforms.Normalize(mean=backbone.default_cfg['mean'], std=backbone.default_cfg['std']),
]
train_dataset = ISCDataset(
train_paths,
NCropsTransform(
transforms.Compose(aug_moderate),
transforms.Compose(aug_hard),
ncrops,
),
)
return train_dataset