towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
203 lines
8.0 KiB
203 lines
8.0 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from . import *
|
|
from einops import rearrange
|
|
|
|
|
|
model_urls = {
|
|
'dns_cg_student': 'https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth',
|
|
'dns_fg_att_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_att_student.pth',
|
|
'dns_fg_bin_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_bin_student.pth',
|
|
}
|
|
|
|
|
|
class CoarseGrainedStudent(nn.Module):
|
|
|
|
def __init__(self,
|
|
dims=512,
|
|
attention=True,
|
|
transformer=True,
|
|
transformer_heads=8,
|
|
transformer_feedforward_dims=2048,
|
|
transformer_layers=1,
|
|
netvlad=True,
|
|
netvlad_clusters=64,
|
|
netvlad_outdims=1024,
|
|
pretrained=False,
|
|
**kwargs
|
|
):
|
|
super(CoarseGrainedStudent, self).__init__()
|
|
self.student_type = 'cg'
|
|
|
|
if attention:
|
|
self.attention = Attention(dims, norm=False)
|
|
if transformer:
|
|
encoder_layer = nn.TransformerEncoderLayer(dims,
|
|
transformer_heads,
|
|
transformer_feedforward_dims)
|
|
self.transformer = nn.TransformerEncoder(encoder_layer,
|
|
transformer_layers,
|
|
nn.LayerNorm(dims))
|
|
self.apply(self._init_weights)
|
|
|
|
if netvlad:
|
|
self.netvlad = NetVLAD(dims, netvlad_clusters, outdims=netvlad_outdims)
|
|
|
|
if pretrained:
|
|
self.load_state_dict(
|
|
torch.hub.load_state_dict_from_url(
|
|
model_urls['dns_cg_student'])['model'])
|
|
|
|
def get_network_name(self,):
|
|
return '{}_student'.format(self.student_type)
|
|
|
|
def calculate_video_similarity(self, query, target):
|
|
return torch.mm(query, torch.transpose(target, 0, 1))
|
|
|
|
def index_video(self, x, mask=None):
|
|
x, mask = check_dims(x, mask)
|
|
|
|
if hasattr(self, 'attention'):
|
|
x, a = self.attention(x)
|
|
x = torch.sum(x, 2)
|
|
x = F.normalize(x, p=2, dim=-1)
|
|
|
|
if hasattr(self, 'transformer'):
|
|
x = x.permute(1, 0, 2)
|
|
x = self.transformer(x, src_key_padding_mask=
|
|
(1 - mask).bool() if mask is not None else None)
|
|
x = x.permute(1, 0, 2)
|
|
|
|
if hasattr(self, 'netvlad'):
|
|
x = x.unsqueeze(2).permute(0, 3, 1, 2)
|
|
x = self.netvlad(x, mask=mask)
|
|
else:
|
|
if mask is not None:
|
|
x = x.masked_fill((1 - mask.unsqueeze(-1)).bool(), 0.0)
|
|
x = torch.sum(x, 1) / torch.sum(mask, 1, keepdim=True)
|
|
else:
|
|
x = torch.mean(x, 1)
|
|
return F.normalize(x, p=2, dim=-1)
|
|
|
|
def forward(self, anchors, positives, negatives,
|
|
anchors_masks=None, positive_masks=None, negative_masks=None):
|
|
pos_pairs = torch.sum(anchors * positives, 1, keepdim=True)
|
|
neg_pairs = torch.sum(anchors * negatives, 1, keepdim=True)
|
|
return pos_pairs, neg_pairs, None
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
class FineGrainedStudent(nn.Module):
|
|
|
|
def __init__(self,
|
|
dims=512,
|
|
attention=False,
|
|
binarization=False,
|
|
pretrained=False,
|
|
**kwargs
|
|
):
|
|
super(FineGrainedStudent, self).__init__()
|
|
self.student_type = 'fg'
|
|
if attention and binarization:
|
|
raise Exception('Can\'t use \'attention=True\' and \'binarization=True\' at the same time. '
|
|
'Select one of the two options.')
|
|
elif binarization:
|
|
self.fg_type = 'bin'
|
|
self.binarization = BinarizationLayer(dims)
|
|
elif attention:
|
|
self.fg_type = 'att'
|
|
self.attention = Attention(dims, norm=False)
|
|
else:
|
|
self.fg_type = 'none'
|
|
|
|
self.f2f_sim = ChamferSimilarity(axes=[3, 2])
|
|
|
|
self.visil_head = VideoComperator()
|
|
self.htanh = nn.Hardtanh()
|
|
|
|
self.v2v_sim = ChamferSimilarity(axes=[2, 1])
|
|
|
|
self.sim_criterion = SimilarityRegularizationLoss()
|
|
|
|
if pretrained:
|
|
if not (attention or binarization):
|
|
raise Exception('No pretrained model provided for the selected settings. '
|
|
'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
|
|
self.load_state_dict(
|
|
torch.hub.load_state_dict_from_url(
|
|
model_urls['dns_fg_{}_student'.format(self.fg_type)])['model'])
|
|
|
|
def get_network_name(self,):
|
|
return '{}_{}_student'.format(self.student_type, self.fg_type)
|
|
|
|
def frame_to_frame_similarity(self, query, target, query_mask=None, target_mask=None, batched=False):
|
|
d = target.shape[-1]
|
|
sim_mask = None
|
|
if batched:
|
|
sim = torch.einsum('biok,bjpk->biopj', query, target)
|
|
sim = self.f2f_sim(sim)
|
|
if query_mask is not None and target_mask is not None:
|
|
sim_mask = torch.einsum('bik,bjk->bij', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
|
|
else:
|
|
sim = torch.einsum('aiok,bjpk->aiopjb', query, target)
|
|
sim = self.f2f_sim(sim)
|
|
sim = rearrange(sim, 'a i j b -> (a b) i j')
|
|
if query_mask is not None and target_mask is not None:
|
|
sim_mask = torch.einsum('aik,bjk->aijb', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
|
|
sim_mask = rearrange(sim_mask, 'a i j b -> (a b) i j')
|
|
if self.fg_type == 'bin':
|
|
sim /= d
|
|
if sim_mask is not None:
|
|
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
|
|
return sim, sim_mask
|
|
|
|
def calculate_video_similarity(self, query, target, query_mask=None, target_mask=None):
|
|
query, query_mask = check_dims(query, query_mask)
|
|
target, target_mask = check_dims(target, target_mask)
|
|
|
|
sim, sim_mask = self.similarity_matrix(query, target, query_mask, target_mask)
|
|
sim = self.v2v_sim(sim, sim_mask)
|
|
|
|
return sim.view(query.shape[0], target.shape[0])
|
|
|
|
def similarity_matrix(self, query, target, query_mask=None, target_mask=None):
|
|
query, query_mask = check_dims(query, query_mask)
|
|
target, target_mask = check_dims(target, target_mask)
|
|
|
|
sim, sim_mask = self.frame_to_frame_similarity(query, target, query_mask, target_mask)
|
|
sim, sim_mask = self.visil_head(sim, sim_mask)
|
|
return self.htanh(sim), sim_mask
|
|
|
|
def index_video(self, x, mask=None):
|
|
if self.fg_type == 'bin':
|
|
x = self.binarization(x)
|
|
elif self.fg_type == 'att':
|
|
x, a = self.attention(x)
|
|
if mask is not None:
|
|
x = x.masked_fill((1 - mask).bool().unsqueeze(-1).unsqueeze(-1), 0.0)
|
|
return x
|
|
|
|
def forward(self, anchors, positives, negatives,
|
|
anchors_masks, positive_masks, negative_masks):
|
|
pos_sim, pos_mask = self.frame_to_frame_similarity(
|
|
anchors, positives, anchors_masks, positive_masks, batched=True)
|
|
neg_sim, neg_mask = self.frame_to_frame_similarity(
|
|
anchors, negatives, anchors_masks, negative_masks, batched=True)
|
|
sim, sim_mask = torch.cat([pos_sim, neg_sim], 0), torch.cat([pos_mask, neg_mask], 0)
|
|
|
|
sim, sim_mask = self.visil_head(sim, sim_mask)
|
|
loss = self.sim_criterion(sim)
|
|
sim = self.htanh(sim)
|
|
sim = self.v2v_sim(sim, sim_mask)
|
|
|
|
pos_pair, neg_pair = torch.chunk(sim.unsqueeze(-1), 2, dim=0)
|
|
return pos_pair, neg_pair, loss
|