logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

203 lines
8.0 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
from . import *
from einops import rearrange
model_urls = {
'dns_cg_student': 'https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth',
'dns_fg_att_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_att_student.pth',
'dns_fg_bin_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_bin_student.pth',
}
class CoarseGrainedStudent(nn.Module):
def __init__(self,
dims=512,
attention=True,
transformer=True,
transformer_heads=8,
transformer_feedforward_dims=2048,
transformer_layers=1,
netvlad=True,
netvlad_clusters=64,
netvlad_outdims=1024,
pretrained=False,
**kwargs
):
super(CoarseGrainedStudent, self).__init__()
self.student_type = 'cg'
if attention:
self.attention = Attention(dims, norm=False)
if transformer:
encoder_layer = nn.TransformerEncoderLayer(dims,
transformer_heads,
transformer_feedforward_dims)
self.transformer = nn.TransformerEncoder(encoder_layer,
transformer_layers,
nn.LayerNorm(dims))
self.apply(self._init_weights)
if netvlad:
self.netvlad = NetVLAD(dims, netvlad_clusters, outdims=netvlad_outdims)
if pretrained:
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_cg_student'])['model'])
def get_network_name(self,):
return '{}_student'.format(self.student_type)
def calculate_video_similarity(self, query, target):
return torch.mm(query, torch.transpose(target, 0, 1))
def index_video(self, x, mask=None):
x, mask = check_dims(x, mask)
if hasattr(self, 'attention'):
x, a = self.attention(x)
x = torch.sum(x, 2)
x = F.normalize(x, p=2, dim=-1)
if hasattr(self, 'transformer'):
x = x.permute(1, 0, 2)
x = self.transformer(x, src_key_padding_mask=
(1 - mask).bool() if mask is not None else None)
x = x.permute(1, 0, 2)
if hasattr(self, 'netvlad'):
x = x.unsqueeze(2).permute(0, 3, 1, 2)
x = self.netvlad(x, mask=mask)
else:
if mask is not None:
x = x.masked_fill((1 - mask.unsqueeze(-1)).bool(), 0.0)
x = torch.sum(x, 1) / torch.sum(mask, 1, keepdim=True)
else:
x = torch.mean(x, 1)
return F.normalize(x, p=2, dim=-1)
def forward(self, anchors, positives, negatives,
anchors_masks=None, positive_masks=None, negative_masks=None):
pos_pairs = torch.sum(anchors * positives, 1, keepdim=True)
neg_pairs = torch.sum(anchors * negatives, 1, keepdim=True)
return pos_pairs, neg_pairs, None
def _init_weights(self, m):
if isinstance(m, nn.Linear):
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class FineGrainedStudent(nn.Module):
def __init__(self,
dims=512,
attention=False,
binarization=False,
pretrained=False,
**kwargs
):
super(FineGrainedStudent, self).__init__()
self.student_type = 'fg'
if attention and binarization:
raise Exception('Can\'t use \'attention=True\' and \'binarization=True\' at the same time. '
'Select one of the two options.')
elif binarization:
self.fg_type = 'bin'
self.binarization = BinarizationLayer(dims)
elif attention:
self.fg_type = 'att'
self.attention = Attention(dims, norm=False)
else:
self.fg_type = 'none'
self.f2f_sim = ChamferSimilarity(axes=[3, 2])
self.visil_head = VideoComperator()
self.htanh = nn.Hardtanh()
self.v2v_sim = ChamferSimilarity(axes=[2, 1])
self.sim_criterion = SimilarityRegularizationLoss()
if pretrained:
if not (attention or binarization):
raise Exception('No pretrained model provided for the selected settings. '
'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_fg_{}_student'.format(self.fg_type)])['model'])
def get_network_name(self,):
return '{}_{}_student'.format(self.student_type, self.fg_type)
def frame_to_frame_similarity(self, query, target, query_mask=None, target_mask=None, batched=False):
d = target.shape[-1]
sim_mask = None
if batched:
sim = torch.einsum('biok,bjpk->biopj', query, target)
sim = self.f2f_sim(sim)
if query_mask is not None and target_mask is not None:
sim_mask = torch.einsum('bik,bjk->bij', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
else:
sim = torch.einsum('aiok,bjpk->aiopjb', query, target)
sim = self.f2f_sim(sim)
sim = rearrange(sim, 'a i j b -> (a b) i j')
if query_mask is not None and target_mask is not None:
sim_mask = torch.einsum('aik,bjk->aijb', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
sim_mask = rearrange(sim_mask, 'a i j b -> (a b) i j')
if self.fg_type == 'bin':
sim /= d
if sim_mask is not None:
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
return sim, sim_mask
def calculate_video_similarity(self, query, target, query_mask=None, target_mask=None):
query, query_mask = check_dims(query, query_mask)
target, target_mask = check_dims(target, target_mask)
sim, sim_mask = self.similarity_matrix(query, target, query_mask, target_mask)
sim = self.v2v_sim(sim, sim_mask)
return sim.view(query.shape[0], target.shape[0])
def similarity_matrix(self, query, target, query_mask=None, target_mask=None):
query, query_mask = check_dims(query, query_mask)
target, target_mask = check_dims(target, target_mask)
sim, sim_mask = self.frame_to_frame_similarity(query, target, query_mask, target_mask)
sim, sim_mask = self.visil_head(sim, sim_mask)
return self.htanh(sim), sim_mask
def index_video(self, x, mask=None):
if self.fg_type == 'bin':
x = self.binarization(x)
elif self.fg_type == 'att':
x, a = self.attention(x)
if mask is not None:
x = x.masked_fill((1 - mask).bool().unsqueeze(-1).unsqueeze(-1), 0.0)
return x
def forward(self, anchors, positives, negatives,
anchors_masks, positive_masks, negative_masks):
pos_sim, pos_mask = self.frame_to_frame_similarity(
anchors, positives, anchors_masks, positive_masks, batched=True)
neg_sim, neg_mask = self.frame_to_frame_similarity(
anchors, negatives, anchors_masks, negative_masks, batched=True)
sim, sim_mask = torch.cat([pos_sim, neg_sim], 0), torch.cat([pos_mask, neg_mask], 0)
sim, sim_mask = self.visil_head(sim, sim_mask)
loss = self.sim_criterion(sim)
sim = self.htanh(sim)
sim = self.v2v_sim(sim, sim_mask)
pos_pair, neg_pair = torch.chunk(sim.unsqueeze(-1), 2, dim=0)
return pos_pair, neg_pair, loss