logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

214 lines
7.6 KiB

import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .constraints import L2Constrain
class VideoNormalizer(nn.Module):
def __init__(self):
super(VideoNormalizer, self).__init__()
self.scale = nn.Parameter(torch.Tensor([255.]), requires_grad=False)
self.mean = nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]), requires_grad=False)
self.std = nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]), requires_grad=False)
def forward(self, video):
video = video.float()
video = ((video / self.scale) - self.mean) / self.std
return video.permute(0, 3, 1, 2)
class RMAC(nn.Module):
def __init__(self, L=[3]):
super(RMAC,self).__init__()
self.L = L
def forward(self, x):
return self.region_pooling(x, L=self.L)
def region_pooling(self, x, L=[3]):
ovr = 0.4 # desired overlap of neighboring regions
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
W = x.shape[3]
H = x.shape[2]
w = min(W, H)
w2 = math.floor(w / 2.0 - 1)
b = (max(H, W) - w) / (steps - 1)
(tmp, idx) = torch.min(torch.abs(((w ** 2 - w * b) / w ** 2) - ovr), 0) # steps(idx) regions for long dimension
# region overplus per dimension
Wd = 0
Hd = 0
if H < W:
Wd = idx.item() + 1
elif H > W:
Hd = idx.item() + 1
vecs = []
for l in L:
wl = math.floor(2 * w / (l + 1))
wl2 = math.floor(wl / 2 - 1)
if l + Wd == 1:
b = 0
else:
b = (W - wl) / (l + Wd - 1)
cenW = torch.floor(wl2 + torch.tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates
if l + Hd == 1:
b = 0
else:
b = (H - wl) / (l + Hd - 1)
cenH = torch.floor(wl2 + torch.tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates
for i in cenH.long().tolist():
v = []
for j in cenW.long().tolist():
if wl == 0:
continue
R = x[:, :, i: i+wl, j: j+wl]
v.append(F.adaptive_max_pool2d(R, (1, 1)))
vecs.append(torch.cat(v, dim=3))
return torch.cat(vecs, dim=2)
class PCA(nn.Module):
def __init__(self, n_components=None):
super(PCA, self).__init__()
pretrained_url = 'http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth'
white = torch.hub.load_state_dict_from_url(pretrained_url)
idx = torch.argsort(white['d'], descending=True)[: n_components]
d = white['d'][idx]
V = white['V'][:, idx]
D = torch.diag(1. / torch.sqrt(d + 1e-7))
self.mean = nn.Parameter(white['mean'], requires_grad=False)
self.DVt = nn.Parameter(torch.mm(D, V.T).T, requires_grad=False)
def forward(self, logits):
logits -= self.mean.expand_as(logits)
logits = torch.matmul(logits, self.DVt)
logits = F.normalize(logits, p=2, dim=-1)
return logits
class Attention(nn.Module):
def __init__(self, dims, norm=False):
super(Attention, self).__init__()
self.norm = norm
if self.norm:
self.constrain = L2Constrain()
else:
self.transform = nn.Linear(dims, dims)
self.context_vector = nn.Linear(dims, 1, bias=False)
self.reset_parameters()
def forward(self, x):
if self.norm:
weights = self.context_vector(x)
weights = torch.add(torch.div(weights, 2.), .5)
else:
x_tr = torch.tanh(self.transform(x))
weights = self.context_vector(x_tr)
weights = torch.sigmoid(weights)
x = x * weights
return x, weights
def reset_parameters(self):
if self.norm:
nn.init.normal_(self.context_vector.weight)
self.constrain(self.context_vector)
else:
nn.init.xavier_uniform_(self.context_vector.weight)
nn.init.xavier_uniform_(self.transform.weight)
nn.init.zeros_(self.transform.bias)
def apply_contraint(self):
if self.norm:
self.constrain(self.context_vector)
class BinarizationLayer(nn.Module):
def __init__(self, dims, bits=None, sigma=1e-6, ITQ_init=True):
super(BinarizationLayer, self).__init__()
self.sigma = sigma
if ITQ_init:
pretrained_url = 'https://mever.iti.gr/distill-and-select/models/itq_resnet50W_dns100k_1M.pth'
self.W = nn.Parameter(torch.hub.load_state_dict_from_url(pretrained_url)['proj'])
else:
if bits is None:
bits = dims
self.W = nn.Parameter(torch.rand(dims, bits))
def forward(self, x):
x = F.normalize(x, p=2, dim=-1)
x = torch.matmul(x, self.W)
if self.training:
x = torch.erf(x / np.sqrt(2 * self.sigma))
else:
x = torch.sign(x)
return x
def __repr__(self,):
return '{}(dims={}, bits={}, sigma={})'.format(
self.__class__.__name__, self.W.shape[0], self.W.shape[1], self.sigma)
class NetVLAD(nn.Module):
"""Acknowledgement to @lyakaap and @Nanne for providing their implementations"""
def __init__(self, dims, num_clusters, outdims=None):
super(NetVLAD, self).__init__()
self.num_clusters = num_clusters
self.dims = dims
self.centroids = nn.Parameter(torch.randn(num_clusters, dims) / math.sqrt(self.dims))
self.conv = nn.Conv2d(dims, num_clusters, kernel_size=1, bias=False)
if outdims is not None:
self.outdims = outdims
self.reduction_layer = nn.Linear(self.num_clusters * self.dims, self.outdims, bias=False)
else:
self.outdims = self.num_clusters * self.dims
self.norm = nn.LayerNorm(self.outdims)
self.reset_parameters()
def reset_parameters(self):
self.conv.weight = nn.Parameter(self.centroids.detach().clone().unsqueeze(-1).unsqueeze(-1))
if hasattr(self, 'reduction_layer'):
nn.init.normal_(self.reduction_layer.weight, std=1 / math.sqrt(self.num_clusters * self.dims))
def forward(self, x, mask=None):
N, C, T, R = x.shape
# soft-assignment
soft_assign = self.conv(x).view(N, self.num_clusters, -1)
soft_assign = F.softmax(soft_assign, dim=1).view(N, self.num_clusters, T, R)
x_flatten = x.view(N, C, -1)
vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device)
for cluster in range(self.num_clusters): # slower than non-looped, but lower memory usage
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[cluster:cluster + 1, :].\
expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
residual = residual.view(N, C, T, R)
residual *= soft_assign[:, cluster:cluster + 1, :]
if mask is not None:
residual = residual.masked_fill((1 - mask.unsqueeze(1).unsqueeze(-1)).bool(), 0.0)
vlad[:, cluster:cluster+1, :] = residual.sum([-2, -1]).unsqueeze(1)
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
vlad = vlad.view(x.size(0), -1) # flatten
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
if hasattr(self, 'reduction_layer'):
vlad = self.reduction_layer(vlad)
return self.norm(vlad)