towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
214 lines
7.6 KiB
214 lines
7.6 KiB
import math
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .constraints import L2Constrain
|
|
|
|
|
|
class VideoNormalizer(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(VideoNormalizer, self).__init__()
|
|
self.scale = nn.Parameter(torch.Tensor([255.]), requires_grad=False)
|
|
self.mean = nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]), requires_grad=False)
|
|
self.std = nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]), requires_grad=False)
|
|
|
|
def forward(self, video):
|
|
video = video.float()
|
|
video = ((video / self.scale) - self.mean) / self.std
|
|
return video.permute(0, 3, 1, 2)
|
|
|
|
|
|
class RMAC(nn.Module):
|
|
|
|
def __init__(self, L=[3]):
|
|
super(RMAC,self).__init__()
|
|
self.L = L
|
|
|
|
def forward(self, x):
|
|
return self.region_pooling(x, L=self.L)
|
|
|
|
def region_pooling(self, x, L=[3]):
|
|
ovr = 0.4 # desired overlap of neighboring regions
|
|
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
|
|
|
|
W = x.shape[3]
|
|
H = x.shape[2]
|
|
|
|
w = min(W, H)
|
|
w2 = math.floor(w / 2.0 - 1)
|
|
|
|
b = (max(H, W) - w) / (steps - 1)
|
|
(tmp, idx) = torch.min(torch.abs(((w ** 2 - w * b) / w ** 2) - ovr), 0) # steps(idx) regions for long dimension
|
|
|
|
# region overplus per dimension
|
|
Wd = 0
|
|
Hd = 0
|
|
if H < W:
|
|
Wd = idx.item() + 1
|
|
elif H > W:
|
|
Hd = idx.item() + 1
|
|
|
|
vecs = []
|
|
for l in L:
|
|
wl = math.floor(2 * w / (l + 1))
|
|
wl2 = math.floor(wl / 2 - 1)
|
|
|
|
if l + Wd == 1:
|
|
b = 0
|
|
else:
|
|
b = (W - wl) / (l + Wd - 1)
|
|
cenW = torch.floor(wl2 + torch.tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates
|
|
if l + Hd == 1:
|
|
b = 0
|
|
else:
|
|
b = (H - wl) / (l + Hd - 1)
|
|
cenH = torch.floor(wl2 + torch.tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates
|
|
|
|
for i in cenH.long().tolist():
|
|
v = []
|
|
for j in cenW.long().tolist():
|
|
if wl == 0:
|
|
continue
|
|
R = x[:, :, i: i+wl, j: j+wl]
|
|
v.append(F.adaptive_max_pool2d(R, (1, 1)))
|
|
vecs.append(torch.cat(v, dim=3))
|
|
return torch.cat(vecs, dim=2)
|
|
|
|
|
|
class PCA(nn.Module):
|
|
|
|
def __init__(self, n_components=None):
|
|
super(PCA, self).__init__()
|
|
pretrained_url = 'http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth'
|
|
white = torch.hub.load_state_dict_from_url(pretrained_url)
|
|
idx = torch.argsort(white['d'], descending=True)[: n_components]
|
|
d = white['d'][idx]
|
|
V = white['V'][:, idx]
|
|
D = torch.diag(1. / torch.sqrt(d + 1e-7))
|
|
self.mean = nn.Parameter(white['mean'], requires_grad=False)
|
|
self.DVt = nn.Parameter(torch.mm(D, V.T).T, requires_grad=False)
|
|
|
|
def forward(self, logits):
|
|
logits -= self.mean.expand_as(logits)
|
|
logits = torch.matmul(logits, self.DVt)
|
|
logits = F.normalize(logits, p=2, dim=-1)
|
|
return logits
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, dims, norm=False):
|
|
super(Attention, self).__init__()
|
|
self.norm = norm
|
|
if self.norm:
|
|
self.constrain = L2Constrain()
|
|
else:
|
|
self.transform = nn.Linear(dims, dims)
|
|
self.context_vector = nn.Linear(dims, 1, bias=False)
|
|
self.reset_parameters()
|
|
|
|
def forward(self, x):
|
|
if self.norm:
|
|
weights = self.context_vector(x)
|
|
weights = torch.add(torch.div(weights, 2.), .5)
|
|
else:
|
|
x_tr = torch.tanh(self.transform(x))
|
|
weights = self.context_vector(x_tr)
|
|
weights = torch.sigmoid(weights)
|
|
x = x * weights
|
|
return x, weights
|
|
|
|
def reset_parameters(self):
|
|
if self.norm:
|
|
nn.init.normal_(self.context_vector.weight)
|
|
self.constrain(self.context_vector)
|
|
else:
|
|
nn.init.xavier_uniform_(self.context_vector.weight)
|
|
nn.init.xavier_uniform_(self.transform.weight)
|
|
nn.init.zeros_(self.transform.bias)
|
|
|
|
def apply_contraint(self):
|
|
if self.norm:
|
|
self.constrain(self.context_vector)
|
|
|
|
|
|
class BinarizationLayer(nn.Module):
|
|
|
|
def __init__(self, dims, bits=None, sigma=1e-6, ITQ_init=True):
|
|
super(BinarizationLayer, self).__init__()
|
|
self.sigma = sigma
|
|
if ITQ_init:
|
|
pretrained_url = 'https://mever.iti.gr/distill-and-select/models/itq_resnet50W_dns100k_1M.pth'
|
|
self.W = nn.Parameter(torch.hub.load_state_dict_from_url(pretrained_url)['proj'])
|
|
else:
|
|
if bits is None:
|
|
bits = dims
|
|
self.W = nn.Parameter(torch.rand(dims, bits))
|
|
|
|
def forward(self, x):
|
|
x = F.normalize(x, p=2, dim=-1)
|
|
x = torch.matmul(x, self.W)
|
|
if self.training:
|
|
x = torch.erf(x / np.sqrt(2 * self.sigma))
|
|
else:
|
|
x = torch.sign(x)
|
|
return x
|
|
|
|
def __repr__(self,):
|
|
return '{}(dims={}, bits={}, sigma={})'.format(
|
|
self.__class__.__name__, self.W.shape[0], self.W.shape[1], self.sigma)
|
|
|
|
|
|
class NetVLAD(nn.Module):
|
|
"""Acknowledgement to @lyakaap and @Nanne for providing their implementations"""
|
|
|
|
def __init__(self, dims, num_clusters, outdims=None):
|
|
super(NetVLAD, self).__init__()
|
|
self.num_clusters = num_clusters
|
|
self.dims = dims
|
|
|
|
self.centroids = nn.Parameter(torch.randn(num_clusters, dims) / math.sqrt(self.dims))
|
|
self.conv = nn.Conv2d(dims, num_clusters, kernel_size=1, bias=False)
|
|
|
|
if outdims is not None:
|
|
self.outdims = outdims
|
|
self.reduction_layer = nn.Linear(self.num_clusters * self.dims, self.outdims, bias=False)
|
|
else:
|
|
self.outdims = self.num_clusters * self.dims
|
|
self.norm = nn.LayerNorm(self.outdims)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
self.conv.weight = nn.Parameter(self.centroids.detach().clone().unsqueeze(-1).unsqueeze(-1))
|
|
if hasattr(self, 'reduction_layer'):
|
|
nn.init.normal_(self.reduction_layer.weight, std=1 / math.sqrt(self.num_clusters * self.dims))
|
|
|
|
def forward(self, x, mask=None):
|
|
N, C, T, R = x.shape
|
|
|
|
# soft-assignment
|
|
soft_assign = self.conv(x).view(N, self.num_clusters, -1)
|
|
soft_assign = F.softmax(soft_assign, dim=1).view(N, self.num_clusters, T, R)
|
|
|
|
x_flatten = x.view(N, C, -1)
|
|
|
|
vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device)
|
|
for cluster in range(self.num_clusters): # slower than non-looped, but lower memory usage
|
|
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[cluster:cluster + 1, :].\
|
|
expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
|
|
residual = residual.view(N, C, T, R)
|
|
residual *= soft_assign[:, cluster:cluster + 1, :]
|
|
if mask is not None:
|
|
residual = residual.masked_fill((1 - mask.unsqueeze(1).unsqueeze(-1)).bool(), 0.0)
|
|
vlad[:, cluster:cluster+1, :] = residual.sum([-2, -1]).unsqueeze(1)
|
|
|
|
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
|
|
vlad = vlad.view(x.size(0), -1) # flatten
|
|
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
|
|
|
|
if hasattr(self, 'reduction_layer'):
|
|
vlad = self.reduction_layer(vlad)
|
|
return self.norm(vlad)
|