nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
153 lines
5.1 KiB
153 lines
5.1 KiB
2 years ago
|
## This training script is with reference to https://github.com/stdio2016/pfann
|
||
|
from typing import Any
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from towhee.trainer.callback import Callback
|
||
|
from towhee.trainer.trainer import Trainer
|
||
|
from towhee.trainer.training_config import TrainingConfig
|
||
|
|
||
|
from .datautil.dataset_v2 import SegmentedDataLoader
|
||
|
from .datautil.simpleutils import read_config
|
||
|
from .datautil.specaug import SpecAugment
|
||
|
from torch.cuda.amp import autocast, GradScaler
|
||
|
|
||
|
|
||
|
# other requirements: torchvision, torchmetrics==0.7.0
|
||
|
|
||
|
|
||
|
def similarity_loss(y, tau):
|
||
|
a = torch.matmul(y, y.T)
|
||
|
a /= tau
|
||
|
Ls = []
|
||
|
for i in range(y.shape[0]):
|
||
|
nn_self = torch.cat([a[i, :i], a[i, i + 1:]])
|
||
|
softmax = torch.nn.functional.log_softmax(nn_self, dim=0)
|
||
|
Ls.append(softmax[i if i % 2 == 0 else i - 1])
|
||
|
Ls = torch.stack(Ls)
|
||
|
|
||
|
loss = torch.sum(Ls) / -y.shape[0]
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class SetEpochCallback(Callback):
|
||
|
def __init__(self, dataloader):
|
||
|
super().__init__()
|
||
|
self.dataloader = dataloader
|
||
|
|
||
|
def on_epoch_begin(self, epochs, logs):
|
||
|
self.dataloader.set_epoch(epochs)
|
||
|
|
||
|
|
||
|
class NNFPTrainer(Trainer):
|
||
|
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
|
||
|
model_card, params):
|
||
|
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
|
||
|
model_card)
|
||
|
self.specaug = SpecAugment(params)
|
||
|
self.scaler = GradScaler()
|
||
|
self.tau = params.get('tau', 0.05)
|
||
|
self.losses = []
|
||
|
self.set_epoch_callback = SetEpochCallback(self.train_dataloader)
|
||
|
self.callbacks.add_callback(self.set_epoch_callback)
|
||
|
print('evaluate before fine-tune...')
|
||
|
self.evaluate(self.model, dict())
|
||
|
|
||
|
def train_step(self, model: nn.Module, inputs: Any) -> dict:
|
||
|
x = inputs
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
x = torch.flatten(x, 0, 1)
|
||
|
x = self.specaug.augment(x)
|
||
|
|
||
|
with autocast():
|
||
|
y = model(x.to(self.configs.device))
|
||
|
loss = similarity_loss(y, self.tau)
|
||
|
self.scaler.scale(loss).backward()
|
||
|
|
||
|
self.scaler.step(self.optimizer)
|
||
|
self.scaler.update()
|
||
|
|
||
|
self.lr_scheduler.step()
|
||
|
|
||
|
lossnum = float(loss.item())
|
||
|
self.losses.append(lossnum)
|
||
|
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0}
|
||
|
return step_logs
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def evaluate(self, model: nn.Module, logs: dict) -> dict:
|
||
|
validate_N = 0
|
||
|
y_embed = []
|
||
|
minibatch = 40
|
||
|
if torch.cuda.get_device_properties(0).total_memory > 11e9:
|
||
|
minibatch = 640
|
||
|
for x in self.eval_dataloader:
|
||
|
x = torch.flatten(x, 0, 1)
|
||
|
for xx in torch.split(x, minibatch):
|
||
|
y = model(xx.to(self.configs.device)).cpu()
|
||
|
y_embed.append(y)
|
||
|
y_embed = torch.cat(y_embed)
|
||
|
y_embed_org = y_embed[0::2]
|
||
|
y_embed_aug = y_embed[1::2].to(self.configs.device)
|
||
|
|
||
|
# compute validation score on GPU
|
||
|
self_score = []
|
||
|
for embeds in torch.split(y_embed_org, 320):
|
||
|
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
|
||
|
self_score.append(A.diagonal(-validate_N).cpu())
|
||
|
validate_N += embeds.shape[0]
|
||
|
self_score = torch.cat(self_score).to(self.configs.device)
|
||
|
|
||
|
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device)
|
||
|
|
||
|
for embeds in torch.split(y_embed_org, 320):
|
||
|
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
|
||
|
ranks += (A.T >= self_score).sum(dim=0)
|
||
|
acc = int((ranks == 1).sum())
|
||
|
print('\nvalidate score: %f' % (acc / validate_N,))
|
||
|
# logs['epoch_metric'] = acc / validate_N
|
||
|
return logs
|
||
|
|
||
|
|
||
|
def train_nnfp(model, config_json_path):
|
||
|
model.train()
|
||
|
params = read_config(config_json_path)
|
||
|
train_dataloader = SegmentedDataLoader('train', params, num_workers=4)
|
||
|
print('training data contains %d samples' % len(train_dataloader.dataset))
|
||
|
train_dataloader.shuffle = True
|
||
|
train_dataloader.eval_time_shift = False
|
||
|
train_dataloader.augmented = True
|
||
|
train_dataloader.set_epoch(0)
|
||
|
|
||
|
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4)
|
||
|
val_dataloader.shuffle = False
|
||
|
val_dataloader.eval_time_shift = True
|
||
|
val_dataloader.set_epoch(-1)
|
||
|
|
||
|
training_config = TrainingConfig(
|
||
|
batch_size=params['batch_size'],
|
||
|
epoch_num=params['epoch'],
|
||
|
output_dir='fine_tune_output',
|
||
|
lr_scheduler_type='cosine',
|
||
|
eval_strategy='epoch',
|
||
|
lr=params['lr']
|
||
|
)
|
||
|
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr)
|
||
|
|
||
|
nnfp_trainer = NNFPTrainer(
|
||
|
model=model,
|
||
|
training_config=training_config,
|
||
|
train_dataset=train_dataloader.dataset,
|
||
|
eval_dataset=None,
|
||
|
train_dataloader=train_dataloader,
|
||
|
eval_dataloader=val_dataloader,
|
||
|
model_card=None,
|
||
|
params = params
|
||
|
)
|
||
|
nnfp_trainer.set_optimizer(optimizer)
|
||
|
|
||
|
nnfp_trainer.train()
|