logo
Browse Source

build DnS operator

main
ChengZi 3 years ago
parent
commit
9401a60fb6
  1. BIN
      DnS.png
  2. 183
      README.md
  3. 20
      __init__.py
  4. BIN
      demo_video.flv
  5. 112
      distill_and_select.py
  6. 18
      model/__init__.py
  7. 25
      model/constraints.py
  8. 46
      model/feature_extractor.py
  9. 214
      model/layers.py
  10. 33
      model/losses.py
  11. 106
      model/selector.py
  12. 121
      model/similarities.py
  13. 203
      model/students.py
  14. BIN
      output_imgs/cg_student.png
  15. BIN
      output_imgs/cg_student_specifical.png
  16. BIN
      output_imgs/feature_extractor.png
  17. BIN
      output_imgs/fg_att_student.png
  18. BIN
      output_imgs/fg_bin_student.png
  19. BIN
      output_imgs/selector_att.png
  20. BIN
      output_imgs/selector_bin.png
  21. 4
      requirements.txt

BIN
DnS.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 424 KiB

183
README.md

@ -1,2 +1,183 @@
# distill-and-select
# Video deduplication with Distill-and-Select
*author: Chen Zhang*
<br />
## Description
This operator is made for video deduplication task base on [DnS: Distill-and-Select for Efficient and Accurate Video Indexing and Retrieval](https://arxiv.org/abs/2106.13266).
Training with knowledge distillation method in large, unlabelled datasets, DnS learns: a) Student Networks at different retrieval performance and computational efficiency trade-offs and b) a Selection Network that at test time rapidly directs samples to the appropriate student to maintain both high retrieval performance and high computational efficiency.
![](DnS.png)
<br />
## Code Example
Load a video from path './demo_video.flv' using ffmpeg operator to decode it.
Then use distill_and_select operator to get the output using the specified model.
For fine-grained student model, get a 3d output with the temporal-dim information. For coarse-grained student model, get a 1d output representing the whole video. For selector model, get a scalar output.
*For feature_extractor model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='feature_extractor') \
.show()
```
![](output_imgs/feature_extractor.png)
*For fg_att_student model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='fg_att_student') \
.show()
```
![](output_imgs/fg_att_student.png)
*For fg_bin_student model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='fg_bin_student') \
.show()
```
![](output_imgs/fg_bin_student.png)
*For cg_student model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='cg_student') \
.show()
```
![](output_imgs/cg_student.png)
*For selector_att model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='selector_att') \
.show()
```
![](output_imgs/selector_att.png)
*For selector_bin model*:
```python
import towhee
towhee.dc(['./demo_video.flv']) \
.video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op(func=lambda x: [y for y in x]) \
.distill_and_select(model_name='selector_bin') \
.show()
```
![](output_imgs/selector_bin.png)
*Write a same pipeline with explicit inputs/outputs name specifications, take cg_student model for example:*
```python
import towhee
towhee.dc['path'](['./demo_video.flv']) \
.video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
.runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
.distill_and_select['frames', 'vec'](model_name='cg_student') \
.show()
```
![](output_imgs/cg_student_specifical.png)
<br />
## Factory Constructor
Create the operator via the following factory method
***distill_and_select(model_name, \*\*kwargs)***
**Parameters:**
​ ***model_name:*** *str*
​ Can be one of them:
`feature_extractor`: Feature Extractor only,
`fg_att_student`: Fine Grained Student with attention,
`fg_bin_student`: Fine Grained Student with binarization,
`cg_student`: Coarse Grained Student,
`selector_att`: Selector Network with attention,
`selector_bin`: Selector Network with binarization.
​ ***model_weight_path:*** *str*
​ Default is None, download use the original pretrained weights.
​ ***feature_extractor:*** *Union[str, nn.Module]*
​ `None`, 'default' or a pytorch nn.Module instance.
`None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input.
'default' means using the original pretrained feature extracting weights and this operator can process video data as input.
Or you can pass in a nn.Module instance as a specified feature extractor.
Default is `default`.
​ ***device:*** *str*
​ Model device, cpu or cuda.
<br />
## Interface
Get the output from your specified model.
**Parameters:**
​ ***data:*** *List[towhee.types.VideoFrame]* or *Any*
​ The input type is List[VideoFrame] when using default feature_extractor, else the type for your customer feature_extractor.
**Returns:** *numpy.ndarray*
​ Output by specified model.

20
__init__.py

@ -0,0 +1,20 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .distill_and_select import DistillAndSelect
def distill_and_select(model_name: str, **kwargs):
return DistillAndSelect(model_name, **kwargs)

BIN
demo_video.flv

Binary file not shown.

112
distill_and_select.py

@ -0,0 +1,112 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Union, Any
from torchvision import transforms
from towhee.operator.base import NNOperator
from towhee import register
from PIL import Image as PILImage
from model.feature_extractor import FeatureExtractor
from model.students import FineGrainedStudent, CoarseGrainedStudent
from model.selector import SelectorNetwork
from torch import nn
@register(output_schema=['vec'])
class DistillAndSelect(NNOperator):
"""
DistillAndSelect
"""
def __init__(self, model_name: str, model_weight_path: str = None,
feature_extractor: Union[str, nn.Module] = 'default', device: str = None):
"""
Args:
model_name (`str`):
Can be one of them:
`feature_extractor`: Feature Extractor only,
`fg_att_student`: Fine Grained Student with attention,
`fg_bin_student`: Fine Grained Student with binarization,
`cg_student`: Coarse Grained Student,
`selector_att`: Selector Network with attention,
`selector_bin`: Selector Network with binarization.
model_weight_path (`str`):
Default is None, download use the original pretrained weights.
feature_extractor (`Union[str, nn.Module]`):
`None`, 'default' or a pytorch nn.Module instance.
`None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input.
'default' means using the original pretrained feature extracting weights and this operator can process video data as input.
Or you can pass in a nn.Module instance as a specific feature extractor.
Default is `default`.
device (`str`):
Model device, cpu or cuda.
"""
super().__init__()
assert model_name in ['feature_extractor', 'fg_att_student', 'fg_bin_student', 'cg_student', 'selector_att',
'selector_bin'], 'unsupported model.'
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model_name = model_name
self.feature_extractor = None
if feature_extractor == 'default':
self.feature_extractor = FeatureExtractor(dims=512).to(device).eval()
elif isinstance(feature_extractor, nn.Module):
self.feature_extractor = feature_extractor
self.model = None
pretraind = True if model_weight_path is None else None
if self.model_name == 'fg_att_student':
self.model = FineGrainedStudent(pretrained=pretraind, attention=True)
elif self.model_name == 'fg_bin_student':
self.model = FineGrainedStudent(pretrained=pretraind, binarization=True)
elif self.model_name == 'cg_student':
self.model = CoarseGrainedStudent(pretrained=pretraind)
elif self.model_name == 'selector_att':
self.model = SelectorNetwork(pretrained=pretraind, attention=True)
elif self.model_name == 'selector_bin':
self.model = SelectorNetwork(pretrained=pretraind, binarization=True)
if model_weight_path is not None:
self.model.load_state_dict(torch.load(model_weight_path))
if self.model is not None:
self.model.to(device).eval()
self.tfms = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(256),
transforms.ToTensor(),
])
def __call__(self, data: Any): # List[VideoFrame] when self.feature_extractor is not None
if self.feature_extractor is not None:
pil_img_list = []
for img in data:
pil_img = PILImage.fromarray(img, img.mode)
tfmed_img = self.tfms(pil_img).permute(1, 2, 0).unsqueeze(0)
pil_img_list.append(tfmed_img)
data = torch.concat(pil_img_list, dim=0) * 255
data = self.feature_extractor(data.to(self.device)).to(self.device)
if self.model_name == 'feature_extractor':
return data.cpu().detach().squeeze().numpy()
index_feature = self.model.index_video(data)
return index_feature.cpu().detach().squeeze().numpy()

18
model/__init__.py

@ -0,0 +1,18 @@
from .layers import *
from .losses import *
from .similarities import *
def check_dims(features, mask=None, axis=0):
if features.ndim == 4:
return features, mask
elif features.ndim == 3:
features = features.unsqueeze(axis)
if mask is not None:
mask = mask.unsqueeze(axis)
return features, mask
else:
raise Exception('Wrong shape of input video tensor. The shape of the tensor must be either '
'[N, T, R, D] or [T, R, D], where N is the batch size, T the number of frames, '
'R the number of regions and D number of dimensions. '
'Input video tensor has shape {}'.format(features.shape))

25
model/constraints.py

@ -0,0 +1,25 @@
import torch
import torch.nn.functional as F
class L2Constrain(object):
def __init__(self, axis=-1, eps=1e-6):
self.axis = axis
self.eps = eps
def __call__(self, module):
if hasattr(module, 'weight'):
w = module.weight.data
module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps)
class NonNegConstrain(object):
def __init__(self, eps=1e-3):
self.eps = eps
def __call__(self, module):
if hasattr(module, 'weight'):
w = module.weight.data
module.weight.data = torch.clamp(w, min=self.eps)

46
model/feature_extractor.py

@ -0,0 +1,46 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from .layers import *
class FeatureExtractor(nn.Module):
def __init__(self, network='resnet50', whiteninig=False, dims=3840):
super(FeatureExtractor, self).__init__()
self.normalizer = VideoNormalizer()
self.cnn = models.resnet50(pretrained=True)
self.rpool = RMAC()
self.layers = {'layer1': 28, 'layer2': 14, 'layer3': 6, 'layer4': 3}
if whiteninig or dims != 3840:
self.pca = PCA(dims)
def extract_region_vectors(self, x):
tensors = []
for nm, module in self.cnn._modules.items():
if nm not in {'avgpool', 'fc', 'classifier'}:
x = module(x).contiguous()
if nm in self.layers:
# region_vectors = self.rpool(x)
s = self.layers[nm]
region_vectors = F.max_pool2d(x, [s, s], int(np.ceil(s / 2)))
region_vectors = F.normalize(region_vectors, p=2, dim=1)
tensors.append(region_vectors)
for i in range(len(tensors)):
tensors[i] = F.normalize(F.adaptive_max_pool2d(tensors[i], tensors[-1].shape[2:]), p=2, dim=1)
x = torch.cat(tensors, 1)
x = x.view(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
x = F.normalize(x, p=2, dim=-1)
return x
def forward(self, x):
x = self.normalizer(x)
x = self.extract_region_vectors(x)
if hasattr(self, 'pca'):
x = self.pca(x)
return x

214
model/layers.py

@ -0,0 +1,214 @@
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .constraints import L2Constrain
class VideoNormalizer(nn.Module):
def __init__(self):
super(VideoNormalizer, self).__init__()
self.scale = nn.Parameter(torch.Tensor([255.]), requires_grad=False)
self.mean = nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]), requires_grad=False)
self.std = nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]), requires_grad=False)
def forward(self, video):
video = video.float()
video = ((video / self.scale) - self.mean) / self.std
return video.permute(0, 3, 1, 2)
class RMAC(nn.Module):
def __init__(self, L=[3]):
super(RMAC,self).__init__()
self.L = L
def forward(self, x):
return self.region_pooling(x, L=self.L)
def region_pooling(self, x, L=[3]):
ovr = 0.4 # desired overlap of neighboring regions
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
W = x.shape[3]
H = x.shape[2]
w = min(W, H)
w2 = math.floor(w / 2.0 - 1)
b = (max(H, W) - w) / (steps - 1)
(tmp, idx) = torch.min(torch.abs(((w ** 2 - w * b) / w ** 2) - ovr), 0) # steps(idx) regions for long dimension
# region overplus per dimension
Wd = 0
Hd = 0
if H < W:
Wd = idx.item() + 1
elif H > W:
Hd = idx.item() + 1
vecs = []
for l in L:
wl = math.floor(2 * w / (l + 1))
wl2 = math.floor(wl / 2 - 1)
if l + Wd == 1:
b = 0
else:
b = (W - wl) / (l + Wd - 1)
cenW = torch.floor(wl2 + torch.tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates
if l + Hd == 1:
b = 0
else:
b = (H - wl) / (l + Hd - 1)
cenH = torch.floor(wl2 + torch.tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates
for i in cenH.long().tolist():
v = []
for j in cenW.long().tolist():
if wl == 0:
continue
R = x[:, :, i: i+wl, j: j+wl]
v.append(F.adaptive_max_pool2d(R, (1, 1)))
vecs.append(torch.cat(v, dim=3))
return torch.cat(vecs, dim=2)
class PCA(nn.Module):
def __init__(self, n_components=None):
super(PCA, self).__init__()
pretrained_url = 'http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth'
white = torch.hub.load_state_dict_from_url(pretrained_url)
idx = torch.argsort(white['d'], descending=True)[: n_components]
d = white['d'][idx]
V = white['V'][:, idx]
D = torch.diag(1. / torch.sqrt(d + 1e-7))
self.mean = nn.Parameter(white['mean'], requires_grad=False)
self.DVt = nn.Parameter(torch.mm(D, V.T).T, requires_grad=False)
def forward(self, logits):
logits -= self.mean.expand_as(logits)
logits = torch.matmul(logits, self.DVt)
logits = F.normalize(logits, p=2, dim=-1)
return logits
class Attention(nn.Module):
def __init__(self, dims, norm=False):
super(Attention, self).__init__()
self.norm = norm
if self.norm:
self.constrain = L2Constrain()
else:
self.transform = nn.Linear(dims, dims)
self.context_vector = nn.Linear(dims, 1, bias=False)
self.reset_parameters()
def forward(self, x):
if self.norm:
weights = self.context_vector(x)
weights = torch.add(torch.div(weights, 2.), .5)
else:
x_tr = torch.tanh(self.transform(x))
weights = self.context_vector(x_tr)
weights = torch.sigmoid(weights)
x = x * weights
return x, weights
def reset_parameters(self):
if self.norm:
nn.init.normal_(self.context_vector.weight)
self.constrain(self.context_vector)
else:
nn.init.xavier_uniform_(self.context_vector.weight)
nn.init.xavier_uniform_(self.transform.weight)
nn.init.zeros_(self.transform.bias)
def apply_contraint(self):
if self.norm:
self.constrain(self.context_vector)
class BinarizationLayer(nn.Module):
def __init__(self, dims, bits=None, sigma=1e-6, ITQ_init=True):
super(BinarizationLayer, self).__init__()
self.sigma = sigma
if ITQ_init:
pretrained_url = 'https://mever.iti.gr/distill-and-select/models/itq_resnet50W_dns100k_1M.pth'
self.W = nn.Parameter(torch.hub.load_state_dict_from_url(pretrained_url)['proj'])
else:
if bits is None:
bits = dims
self.W = nn.Parameter(torch.rand(dims, bits))
def forward(self, x):
x = F.normalize(x, p=2, dim=-1)
x = torch.matmul(x, self.W)
if self.training:
x = torch.erf(x / np.sqrt(2 * self.sigma))
else:
x = torch.sign(x)
return x
def __repr__(self,):
return '{}(dims={}, bits={}, sigma={})'.format(
self.__class__.__name__, self.W.shape[0], self.W.shape[1], self.sigma)
class NetVLAD(nn.Module):
"""Acknowledgement to @lyakaap and @Nanne for providing their implementations"""
def __init__(self, dims, num_clusters, outdims=None):
super(NetVLAD, self).__init__()
self.num_clusters = num_clusters
self.dims = dims
self.centroids = nn.Parameter(torch.randn(num_clusters, dims) / math.sqrt(self.dims))
self.conv = nn.Conv2d(dims, num_clusters, kernel_size=1, bias=False)
if outdims is not None:
self.outdims = outdims
self.reduction_layer = nn.Linear(self.num_clusters * self.dims, self.outdims, bias=False)
else:
self.outdims = self.num_clusters * self.dims
self.norm = nn.LayerNorm(self.outdims)
self.reset_parameters()
def reset_parameters(self):
self.conv.weight = nn.Parameter(self.centroids.detach().clone().unsqueeze(-1).unsqueeze(-1))
if hasattr(self, 'reduction_layer'):
nn.init.normal_(self.reduction_layer.weight, std=1 / math.sqrt(self.num_clusters * self.dims))
def forward(self, x, mask=None):
N, C, T, R = x.shape
# soft-assignment
soft_assign = self.conv(x).view(N, self.num_clusters, -1)
soft_assign = F.softmax(soft_assign, dim=1).view(N, self.num_clusters, T, R)
x_flatten = x.view(N, C, -1)
vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device)
for cluster in range(self.num_clusters): # slower than non-looped, but lower memory usage
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[cluster:cluster + 1, :].\
expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
residual = residual.view(N, C, T, R)
residual *= soft_assign[:, cluster:cluster + 1, :]
if mask is not None:
residual = residual.masked_fill((1 - mask.unsqueeze(1).unsqueeze(-1)).bool(), 0.0)
vlad[:, cluster:cluster+1, :] = residual.sum([-2, -1]).unsqueeze(1)
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
vlad = vlad.view(x.size(0), -1) # flatten
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
if hasattr(self, 'reduction_layer'):
vlad = self.reduction_layer(vlad)
return self.norm(vlad)

33
model/losses.py

@ -0,0 +1,33 @@
import torch
import torch.nn as nn
class TripletLoss(nn.Module):
def __init__(self, gamma=1.0, similarity=True):
super(TripletLoss, self).__init__()
self.gamma = gamma
self.similarity = similarity
def forward(self, sim_pos, sim_neg):
if self.similarity:
loss = torch.clamp(sim_neg - sim_pos + self.gamma, min=0.)
else:
loss = torch.clamp(sim_pos - sim_neg + self.gamma, min=0.)
return loss.mean()
class SimilarityRegularizationLoss(nn.Module):
def __init__(self, min_val=-1., max_val=1.):
super(SimilarityRegularizationLoss, self).__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, sim):
loss = torch.sum(torch.abs(torch.clamp(sim - self.min_val, max=0.)))
loss += torch.sum(torch.abs(torch.clamp(sim - self.max_val, min=0.)))
return loss
def __repr__(self,):
return '{}(min_val={}, max_val={})'.format(self.__class__.__name__, self.min_val, self.max_val)

106
model/selector.py

@ -0,0 +1,106 @@
import torch
import torch.nn as nn
from . import *
model_urls = {
'dns_selector_cg-fg_att': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_att.pth',
'dns_selector_cg-fg_bin': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_bin.pth',
}
class MetadataModel(nn.Module):
def __init__(self,
input_size,
hidden_size=100,
num_layers=1
):
super(MetadataModel, self).__init__()
model = [
nn.Linear(input_size, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout()
]
for _ in range(num_layers):
model.extend([nn.Linear(hidden_size, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout()])
model.extend([nn.Linear(hidden_size, 1),
nn.Sigmoid()])
self.model = nn.Sequential(*model)
self.reset_parameters()
def reset_parameters(self):
for m in self.model.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
def forward(self, x):
return self.model(x)
class SelectorNetwork(nn.Module):
def __init__(self,
dims=512,
hidden_size=100,
num_layers=1,
attention=False,
binarization=False,
pretrained=False,
**kwargs
):
super(SelectorNetwork, self).__init__()
self.attention = Attention(dims, norm=False)
self.visil_head = VideoComperator()
self.mlp = MetadataModel(3, hidden_size, num_layers)
if pretrained:
if not (attention or binarization):
raise Exception('No pretrained model provided for the selected settings. '
'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
elif attention:
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_selector_cg-fg_att'])['model'])
elif binarization:
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_selector_cg-fg_bin'])['model'])
def get_network_name(self,):
return 'selector_network'
def index_video(self, x, mask=None):
x, mask = check_dims(x, mask)
sim = self.frame_to_frame_similarity(x)
sim_mask = None
if mask is not None:
sim_mask = torch.einsum("bik,bjk->bij", mask.unsqueeze(-1), mask.unsqueeze(-1))
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
sim, sim_mask = self.visil_head(sim, sim_mask)
if sim_mask is not None:
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
sim = torch.sum(sim, [1, 2]) / torch.sum(sim_mask, [1, 2])
else:
sim = torch.mean(sim, [1, 2])
return sim.unsqueeze(-1)
def frame_to_frame_similarity(self, x):
x, a = self.attention(x)
sim = torch.einsum("biok,bjpk->biopj", x, x)
return torch.mean(sim, [2, 3])
def forward(self, x):
return self.mlp(x)

121
model/similarities.py

@ -0,0 +1,121 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class TensorDot(nn.Module):
def __init__(self, pattern='iak,jbk->iabj', metric='cosine'):
super(TensorDot, self).__init__()
self.pattern = pattern
self.metric = metric
def forward(self, query, target):
if self.metric == 'cosine':
sim = torch.einsum(self.pattern, [query, target])
elif self.metric == 'euclidean':
sim = 1 - 2 * torch.einsum(self.pattern, [query, target])
elif self.metric == 'hamming':
sim = torch.einsum(self.pattern, query, target) / target.shape[-1]
return sim
def __repr__(self,):
return '{}(pattern={})'.format(self.__class__.__name__, self.pattern)
class ChamferSimilarity(nn.Module):
def __init__(self, symmetric=False, axes=[1, 0]):
super(ChamferSimilarity, self).__init__()
self.axes = axes
if symmetric:
self.sim_fun = lambda x, m: self.symmetric_chamfer_similarity(x, mask=m, axes=axes)
else:
self.sim_fun = lambda x, m: self.chamfer_similarity(x, mask=m, max_axis=axes[0], mean_axis=axes[1])
def chamfer_similarity(self, s, mask=None, max_axis=1, mean_axis=0):
if mask is not None:
s = s.masked_fill((1 - mask).bool(), -np.inf)
s = torch.max(s, max_axis, keepdim=True)[0]
mask = torch.max(mask, max_axis, keepdim=True)[0]
s = s.masked_fill((1 - mask).bool(), 0.0)
s = torch.sum(s, mean_axis, keepdim=True)
s /= torch.sum(mask, mean_axis, keepdim=True)
else:
s = torch.max(s, max_axis, keepdim=True)[0]
s = torch.mean(s, mean_axis, keepdim=True)
return s.squeeze(max(max_axis, mean_axis)).squeeze(min(max_axis, mean_axis))
def symmetric_chamfer_similarity(self, s, mask=None, axes=[0, 1]):
return (self.chamfer_similarity(s, mask=mask, max_axis=axes[0], mean_axis=axes[1]) +
self.chamfer_similarity(s, mask=mask, max_axis=axes[1], mean_axis=axes[0])) / 2
def forward(self, s, mask=None):
return self.sim_fun(s, mask)
def __repr__(self,):
return '{}(max_axis={}, mean_axis={})'.format(self.__class__.__name__, self.axes[0], self.axes[1])
class VideoComperator(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(VideoComperator, self).__init__()
self.rpad1 = nn.ZeroPad2d(1)
self.conv1 = nn.Conv2d(in_channels, 32, 3)
self.pool1 = nn.MaxPool2d((2, 2), 2)
self.rpad2 = nn.ZeroPad2d(1)
self.conv2 = nn.Conv2d(32, 64, 3)
self.pool2 = nn.MaxPool2d((2, 2), 2)
self.rpad3 = nn.ZeroPad2d(1)
self.conv3 = nn.Conv2d(64, 128, 3)
self.fconv = nn.Conv2d(128, out_channels, 1)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, sim_matrix, mask=None):
if sim_matrix.ndim == 3:
sim_matrix = sim_matrix.unsqueeze(1)
elif sim_matrix.ndim != 4:
raise Exception('Input tensor to VideoComperator have to be 3- or 4-dimensional')
if mask is not None:
assert mask.shape[-2:] == sim_matrix.shape[-2:], 'Mask tensor must be of the same shape as similarity ' \
'matrix in the last two dimensions. Mask shape is {} ' \
'while similarity matrix is {}'.format(mask.shape[-2:],
sim_matrix.shape[-2:])
mask = mask.unsqueeze(1)
sim = self.rpad1(sim_matrix)
sim = self.conv1(sim)
if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
sim = F.relu(sim)
sim = self.pool1(sim)
if mask is not None: mask = self.pool1(mask)
sim = self.rpad2(sim)
sim = self.conv2(sim)
if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
sim = F.relu(sim)
sim = self.pool2(sim)
if mask is not None: mask = self.pool2(mask)
sim = self.rpad3(sim)
sim = self.conv3(sim)
if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
sim = F.relu(sim)
sim = self.fconv(sim)
return sim.squeeze(1), mask.squeeze(1) if mask is not None else None

203
model/students.py

@ -0,0 +1,203 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import *
from einops import rearrange
model_urls = {
'dns_cg_student': 'https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth',
'dns_fg_att_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_att_student.pth',
'dns_fg_bin_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_bin_student.pth',
}
class CoarseGrainedStudent(nn.Module):
def __init__(self,
dims=512,
attention=True,
transformer=True,
transformer_heads=8,
transformer_feedforward_dims=2048,
transformer_layers=1,
netvlad=True,
netvlad_clusters=64,
netvlad_outdims=1024,
pretrained=False,
**kwargs
):
super(CoarseGrainedStudent, self).__init__()
self.student_type = 'cg'
if attention:
self.attention = Attention(dims, norm=False)
if transformer:
encoder_layer = nn.TransformerEncoderLayer(dims,
transformer_heads,
transformer_feedforward_dims)
self.transformer = nn.TransformerEncoder(encoder_layer,
transformer_layers,
nn.LayerNorm(dims))
self.apply(self._init_weights)
if netvlad:
self.netvlad = NetVLAD(dims, netvlad_clusters, outdims=netvlad_outdims)
if pretrained:
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_cg_student'])['model'])
def get_network_name(self,):
return '{}_student'.format(self.student_type)
def calculate_video_similarity(self, query, target):
return torch.mm(query, torch.transpose(target, 0, 1))
def index_video(self, x, mask=None):
x, mask = check_dims(x, mask)
if hasattr(self, 'attention'):
x, a = self.attention(x)
x = torch.sum(x, 2)
x = F.normalize(x, p=2, dim=-1)
if hasattr(self, 'transformer'):
x = x.permute(1, 0, 2)
x = self.transformer(x, src_key_padding_mask=
(1 - mask).bool() if mask is not None else None)
x = x.permute(1, 0, 2)
if hasattr(self, 'netvlad'):
x = x.unsqueeze(2).permute(0, 3, 1, 2)
x = self.netvlad(x, mask=mask)
else:
if mask is not None:
x = x.masked_fill((1 - mask.unsqueeze(-1)).bool(), 0.0)
x = torch.sum(x, 1) / torch.sum(mask, 1, keepdim=True)
else:
x = torch.mean(x, 1)
return F.normalize(x, p=2, dim=-1)
def forward(self, anchors, positives, negatives,
anchors_masks=None, positive_masks=None, negative_masks=None):
pos_pairs = torch.sum(anchors * positives, 1, keepdim=True)
neg_pairs = torch.sum(anchors * negatives, 1, keepdim=True)
return pos_pairs, neg_pairs, None
def _init_weights(self, m):
if isinstance(m, nn.Linear):
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class FineGrainedStudent(nn.Module):
def __init__(self,
dims=512,
attention=False,
binarization=False,
pretrained=False,
**kwargs
):
super(FineGrainedStudent, self).__init__()
self.student_type = 'fg'
if attention and binarization:
raise Exception('Can\'t use \'attention=True\' and \'binarization=True\' at the same time. '
'Select one of the two options.')
elif binarization:
self.fg_type = 'bin'
self.binarization = BinarizationLayer(dims)
elif attention:
self.fg_type = 'att'
self.attention = Attention(dims, norm=False)
else:
self.fg_type = 'none'
self.f2f_sim = ChamferSimilarity(axes=[3, 2])
self.visil_head = VideoComperator()
self.htanh = nn.Hardtanh()
self.v2v_sim = ChamferSimilarity(axes=[2, 1])
self.sim_criterion = SimilarityRegularizationLoss()
if pretrained:
if not (attention or binarization):
raise Exception('No pretrained model provided for the selected settings. '
'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
self.load_state_dict(
torch.hub.load_state_dict_from_url(
model_urls['dns_fg_{}_student'.format(self.fg_type)])['model'])
def get_network_name(self,):
return '{}_{}_student'.format(self.student_type, self.fg_type)
def frame_to_frame_similarity(self, query, target, query_mask=None, target_mask=None, batched=False):
d = target.shape[-1]
sim_mask = None
if batched:
sim = torch.einsum('biok,bjpk->biopj', query, target)
sim = self.f2f_sim(sim)
if query_mask is not None and target_mask is not None:
sim_mask = torch.einsum('bik,bjk->bij', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
else:
sim = torch.einsum('aiok,bjpk->aiopjb', query, target)
sim = self.f2f_sim(sim)
sim = rearrange(sim, 'a i j b -> (a b) i j')
if query_mask is not None and target_mask is not None:
sim_mask = torch.einsum('aik,bjk->aijb', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
sim_mask = rearrange(sim_mask, 'a i j b -> (a b) i j')
if self.fg_type == 'bin':
sim /= d
if sim_mask is not None:
sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
return sim, sim_mask
def calculate_video_similarity(self, query, target, query_mask=None, target_mask=None):
query, query_mask = check_dims(query, query_mask)
target, target_mask = check_dims(target, target_mask)
sim, sim_mask = self.similarity_matrix(query, target, query_mask, target_mask)
sim = self.v2v_sim(sim, sim_mask)
return sim.view(query.shape[0], target.shape[0])
def similarity_matrix(self, query, target, query_mask=None, target_mask=None):
query, query_mask = check_dims(query, query_mask)
target, target_mask = check_dims(target, target_mask)
sim, sim_mask = self.frame_to_frame_similarity(query, target, query_mask, target_mask)
sim, sim_mask = self.visil_head(sim, sim_mask)
return self.htanh(sim), sim_mask
def index_video(self, x, mask=None):
if self.fg_type == 'bin':
x = self.binarization(x)
elif self.fg_type == 'att':
x, a = self.attention(x)
if mask is not None:
x = x.masked_fill((1 - mask).bool().unsqueeze(-1).unsqueeze(-1), 0.0)
return x
def forward(self, anchors, positives, negatives,
anchors_masks, positive_masks, negative_masks):
pos_sim, pos_mask = self.frame_to_frame_similarity(
anchors, positives, anchors_masks, positive_masks, batched=True)
neg_sim, neg_mask = self.frame_to_frame_similarity(
anchors, negatives, anchors_masks, negative_masks, batched=True)
sim, sim_mask = torch.cat([pos_sim, neg_sim], 0), torch.cat([pos_mask, neg_mask], 0)
sim, sim_mask = self.visil_head(sim, sim_mask)
loss = self.sim_criterion(sim)
sim = self.htanh(sim)
sim = self.v2v_sim(sim, sim_mask)
pos_pair, neg_pair = torch.chunk(sim.unsqueeze(-1), 2, dim=0)
return pos_pair, neg_pair, loss

BIN
output_imgs/cg_student.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

BIN
output_imgs/cg_student_specifical.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
output_imgs/feature_extractor.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

BIN
output_imgs/fg_att_student.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

BIN
output_imgs/fg_bin_student.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

BIN
output_imgs/selector_att.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

BIN
output_imgs/selector_bin.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

4
requirements.txt

@ -0,0 +1,4 @@
torch
einops
towhee
Pillow
Loading…
Cancel
Save