towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
106 lines
3.4 KiB
106 lines
3.4 KiB
import torch
|
|
import torch.nn as nn
|
|
|
|
from . import *
|
|
|
|
|
|
model_urls = {
|
|
'dns_selector_cg-fg_att': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_att.pth',
|
|
'dns_selector_cg-fg_bin': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_bin.pth',
|
|
}
|
|
|
|
|
|
class MetadataModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
input_size,
|
|
hidden_size=100,
|
|
num_layers=1
|
|
):
|
|
super(MetadataModel, self).__init__()
|
|
|
|
model = [
|
|
nn.Linear(input_size, hidden_size, bias=False),
|
|
nn.BatchNorm1d(hidden_size),
|
|
nn.ReLU(),
|
|
nn.Dropout()
|
|
]
|
|
|
|
for _ in range(num_layers):
|
|
model.extend([nn.Linear(hidden_size, hidden_size, bias=False),
|
|
nn.BatchNorm1d(hidden_size),
|
|
nn.ReLU(),
|
|
nn.Dropout()])
|
|
|
|
model.extend([nn.Linear(hidden_size, 1),
|
|
nn.Sigmoid()])
|
|
self.model = nn.Sequential(*model)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
for m in self.model.modules():
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
class SelectorNetwork(nn.Module):
|
|
|
|
def __init__(self,
|
|
dims=512,
|
|
hidden_size=100,
|
|
num_layers=1,
|
|
attention=False,
|
|
binarization=False,
|
|
pretrained=False,
|
|
**kwargs
|
|
):
|
|
super(SelectorNetwork, self).__init__()
|
|
self.attention = Attention(dims, norm=False)
|
|
self.visil_head = VideoComperator()
|
|
self.mlp = MetadataModel(3, hidden_size, num_layers)
|
|
|
|
if pretrained:
|
|
if not (attention or binarization):
|
|
raise Exception('No pretrained model provided for the selected settings. '
|
|
'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
|
|
elif attention:
|
|
self.load_state_dict(
|
|
torch.hub.load_state_dict_from_url(
|
|
model_urls['dns_selector_cg-fg_att'])['model'])
|
|
elif binarization:
|
|
self.load_state_dict(
|
|
torch.hub.load_state_dict_from_url(
|
|
model_urls['dns_selector_cg-fg_bin'])['model'])
|
|
|
|
def get_network_name(self,):
|
|
return 'selector_network'
|
|
|
|
def index_video(self, x, mask=None):
|
|
x, mask = check_dims(x, mask)
|
|
sim = self.frame_to_frame_similarity(x)
|
|
|
|
sim_mask = None
|
|
if mask is not None:
|
|
sim_mask = torch.einsum("bik,bjk->bij", mask.unsqueeze(-1), mask.unsqueeze(-1))
|
|
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
|
|
|
|
sim, sim_mask = self.visil_head(sim, sim_mask)
|
|
|
|
if sim_mask is not None:
|
|
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
|
|
sim = torch.sum(sim, [1, 2]) / torch.sum(sim_mask, [1, 2])
|
|
else:
|
|
sim = torch.mean(sim, [1, 2])
|
|
|
|
return sim.unsqueeze(-1)
|
|
|
|
def frame_to_frame_similarity(self, x):
|
|
x, a = self.attention(x)
|
|
sim = torch.einsum("biok,bjpk->biopj", x, x)
|
|
return torch.mean(sim, [2, 3])
|
|
|
|
def forward(self, x):
|
|
return self.mlp(x)
|