logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

152 lines
5.1 KiB

## This training script is with reference to https://github.com/stdio2016/pfann
from typing import Any
import numpy as np
import torch
from torch import nn
from towhee.trainer.callback import Callback
from towhee.trainer.trainer import Trainer
from towhee.trainer.training_config import TrainingConfig
from .datautil.dataset_v2 import SegmentedDataLoader
from .datautil.simpleutils import read_config
from .datautil.specaug import SpecAugment
from torch.cuda.amp import autocast, GradScaler
# other requirements: torchvision, torchmetrics==0.7.0
def similarity_loss(y, tau):
a = torch.matmul(y, y.T)
a /= tau
Ls = []
for i in range(y.shape[0]):
nn_self = torch.cat([a[i, :i], a[i, i + 1:]])
softmax = torch.nn.functional.log_softmax(nn_self, dim=0)
Ls.append(softmax[i if i % 2 == 0 else i - 1])
Ls = torch.stack(Ls)
loss = torch.sum(Ls) / -y.shape[0]
return loss
class SetEpochCallback(Callback):
def __init__(self, dataloader):
super().__init__()
self.dataloader = dataloader
def on_epoch_begin(self, epochs, logs):
self.dataloader.set_epoch(epochs)
class NNFPTrainer(Trainer):
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
model_card, params):
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
model_card)
self.specaug = SpecAugment(params)
self.scaler = GradScaler()
self.tau = params.get('tau', 0.05)
self.losses = []
self.set_epoch_callback = SetEpochCallback(self.train_dataloader)
self.callbacks.add_callback(self.set_epoch_callback)
print('evaluate before fine-tune...')
self.evaluate(self.model, dict())
def train_step(self, model: nn.Module, inputs: Any) -> dict:
x = inputs
self.optimizer.zero_grad()
x = torch.flatten(x, 0, 1)
x = self.specaug.augment(x)
with autocast():
y = model(x.to(self.configs.device))
loss = similarity_loss(y, self.tau)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.lr_scheduler.step()
lossnum = float(loss.item())
self.losses.append(lossnum)
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0}
return step_logs
@torch.no_grad()
def evaluate(self, model: nn.Module, logs: dict) -> dict:
validate_N = 0
y_embed = []
minibatch = 40
if torch.cuda.get_device_properties(0).total_memory > 11e9:
minibatch = 640
for x in self.eval_dataloader:
x = torch.flatten(x, 0, 1)
for xx in torch.split(x, minibatch):
y = model(xx.to(self.configs.device)).cpu()
y_embed.append(y)
y_embed = torch.cat(y_embed)
y_embed_org = y_embed[0::2]
y_embed_aug = y_embed[1::2].to(self.configs.device)
# compute validation score on GPU
self_score = []
for embeds in torch.split(y_embed_org, 320):
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
self_score.append(A.diagonal(-validate_N).cpu())
validate_N += embeds.shape[0]
self_score = torch.cat(self_score).to(self.configs.device)
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device)
for embeds in torch.split(y_embed_org, 320):
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
ranks += (A.T >= self_score).sum(dim=0)
acc = int((ranks == 1).sum())
print('\nvalidate score: %f' % (acc / validate_N,))
# logs['epoch_metric'] = acc / validate_N
return logs
def train_nnfp(model, config_json_path):
model.train()
params = read_config(config_json_path)
train_dataloader = SegmentedDataLoader('train', params, num_workers=4)
print('training data contains %d samples' % len(train_dataloader.dataset))
train_dataloader.shuffle = True
train_dataloader.eval_time_shift = False
train_dataloader.augmented = True
train_dataloader.set_epoch(0)
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4)
val_dataloader.shuffle = False
val_dataloader.eval_time_shift = True
val_dataloader.set_epoch(-1)
training_config = TrainingConfig(
batch_size=params['batch_size'],
epoch_num=params['epoch'],
output_dir='fine_tune_output',
lr_scheduler_type='cosine',
eval_strategy='epoch',
lr=params['lr']
)
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr)
nnfp_trainer = NNFPTrainer(
model=model,
training_config=training_config,
train_dataset=train_dataloader.dataset,
eval_dataset=None,
train_dataloader=train_dataloader,
eval_dataloader=val_dataloader,
model_card=None,
params = params
)
nnfp_trainer.set_optimizer(optimizer)
nnfp_trainer.train()