## This training script is with reference to https://github.com/stdio2016/pfann from typing import Any import numpy as np import torch from torch import nn from towhee.trainer.callback import Callback from towhee.trainer.trainer import Trainer from towhee.trainer.training_config import TrainingConfig from .datautil.dataset_v2 import SegmentedDataLoader from .datautil.simpleutils import read_config from .datautil.specaug import SpecAugment from torch.cuda.amp import autocast, GradScaler # other requirements: torchvision, torchmetrics==0.7.0 def similarity_loss(y, tau): a = torch.matmul(y, y.T) a /= tau Ls = [] for i in range(y.shape[0]): nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) softmax = torch.nn.functional.log_softmax(nn_self, dim=0) Ls.append(softmax[i if i % 2 == 0 else i - 1]) Ls = torch.stack(Ls) loss = torch.sum(Ls) / -y.shape[0] return loss class SetEpochCallback(Callback): def __init__(self, dataloader): super().__init__() self.dataloader = dataloader def on_epoch_begin(self, epochs, logs): self.dataloader.set_epoch(epochs) class NNFPTrainer(Trainer): def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, model_card, params): super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, model_card) self.specaug = SpecAugment(params) self.scaler = GradScaler() self.tau = params.get('tau', 0.05) self.losses = [] self.set_epoch_callback = SetEpochCallback(self.train_dataloader) self.callbacks.add_callback(self.set_epoch_callback) print('evaluate before fine-tune...') self.evaluate(self.model, dict()) def train_step(self, model: nn.Module, inputs: Any) -> dict: x = inputs self.optimizer.zero_grad() x = torch.flatten(x, 0, 1) x = self.specaug.augment(x) with autocast(): y = model(x.to(self.configs.device)) loss = similarity_loss(y, self.tau) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.lr_scheduler.step() lossnum = float(loss.item()) self.losses.append(lossnum) step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} return step_logs @torch.no_grad() def evaluate(self, model: nn.Module, logs: dict) -> dict: validate_N = 0 y_embed = [] minibatch = 40 if torch.cuda.get_device_properties(0).total_memory > 11e9: minibatch = 640 for x in self.eval_dataloader: x = torch.flatten(x, 0, 1) for xx in torch.split(x, minibatch): y = model(xx.to(self.configs.device)).cpu() y_embed.append(y) y_embed = torch.cat(y_embed) y_embed_org = y_embed[0::2] y_embed_aug = y_embed[1::2].to(self.configs.device) # compute validation score on GPU self_score = [] for embeds in torch.split(y_embed_org, 320): A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) self_score.append(A.diagonal(-validate_N).cpu()) validate_N += embeds.shape[0] self_score = torch.cat(self_score).to(self.configs.device) ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) for embeds in torch.split(y_embed_org, 320): A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) ranks += (A.T >= self_score).sum(dim=0) acc = int((ranks == 1).sum()) print('\nvalidate score: %f' % (acc / validate_N,)) # logs['epoch_metric'] = acc / validate_N return logs def train_nnfp(model, config_json_path): model.train() params = read_config(config_json_path) train_dataloader = SegmentedDataLoader('train', params, num_workers=4) print('training data contains %d samples' % len(train_dataloader.dataset)) train_dataloader.shuffle = True train_dataloader.eval_time_shift = False train_dataloader.augmented = True train_dataloader.set_epoch(0) val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) val_dataloader.shuffle = False val_dataloader.eval_time_shift = True val_dataloader.set_epoch(-1) training_config = TrainingConfig( batch_size=params['batch_size'], epoch_num=params['epoch'], output_dir='fine_tune_output', lr_scheduler_type='cosine', eval_strategy='epoch', lr=params['lr'] ) optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) nnfp_trainer = NNFPTrainer( model=model, training_config=training_config, train_dataset=train_dataloader.dataset, eval_dataset=None, train_dataloader=train_dataloader, eval_dataloader=val_dataloader, model_card=None, params = params ) nnfp_trainer.set_optimizer(optimizer) nnfp_trainer.train()