diff --git a/DnS.png b/DnS.png
new file mode 100644
index 0000000..e05a23e
Binary files /dev/null and b/DnS.png differ
diff --git a/README.md b/README.md
index 6fed073..b7275fb 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,183 @@
-# distill-and-select
+# Video deduplication with Distill-and-Select
+
+*author: Chen Zhang*
+
+
+
+
+
+
+
+## Description
+
+This operator is made for video deduplication task base on [DnS: Distill-and-Select for Efficient and Accurate Video Indexing and Retrieval](https://arxiv.org/abs/2106.13266).
+Training with knowledge distillation method in large, unlabelled datasets, DnS learns: a) Student Networks at different retrieval performance and computational efficiency trade-offs and b) a Selection Network that at test time rapidly directs samples to the appropriate student to maintain both high retrieval performance and high computational efficiency.
+
+
+
+
+
+
+## Code Example
+
+Load a video from path './demo_video.flv' using ffmpeg operator to decode it.
+
+Then use distill_and_select operator to get the output using the specified model.
+
+For fine-grained student model, get a 3d output with the temporal-dim information. For coarse-grained student model, get a 1d output representing the whole video. For selector model, get a scalar output.
+
+ *For feature_extractor model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='feature_extractor') \
+ .show()
+```
+
+
+
+ *For fg_att_student model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='fg_att_student') \
+ .show()
+```
+
+
+
+ *For fg_bin_student model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='fg_bin_student') \
+ .show()
+```
+
+
+
+ *For cg_student model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='cg_student') \
+ .show()
+```
+
+
+
+ *For selector_att model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='selector_att') \
+ .show()
+```
+
+
+
+ *For selector_bin model*:
+
+```python
+import towhee
+towhee.dc(['./demo_video.flv']) \
+ .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op(func=lambda x: [y for y in x]) \
+ .distill_and_select(model_name='selector_bin') \
+ .show()
+```
+
+
+
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications, take cg_student model for example:*
+
+```python
+import towhee
+towhee.dc['path'](['./demo_video.flv']) \
+ .video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \
+ .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
+ .distill_and_select['frames', 'vec'](model_name='cg_student') \
+ .show()
+```
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***distill_and_select(model_name, \*\*kwargs)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ Can be one of them:
+`feature_extractor`: Feature Extractor only,
+`fg_att_student`: Fine Grained Student with attention,
+`fg_bin_student`: Fine Grained Student with binarization,
+`cg_student`: Coarse Grained Student,
+`selector_att`: Selector Network with attention,
+`selector_bin`: Selector Network with binarization.
+
+
+ ***model_weight_path:*** *str*
+
+ Default is None, download use the original pretrained weights.
+
+ ***feature_extractor:*** *Union[str, nn.Module]*
+
+ `None`, 'default' or a pytorch nn.Module instance.
+`None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input.
+'default' means using the original pretrained feature extracting weights and this operator can process video data as input.
+Or you can pass in a nn.Module instance as a specified feature extractor.
+Default is `default`.
+
+ ***device:*** *str*
+ Model device, cpu or cuda.
+
+
+
+
+
+## Interface
+
+Get the output from your specified model.
+
+**Parameters:**
+
+ ***data:*** *List[towhee.types.VideoFrame]* or *Any*
+
+ The input type is List[VideoFrame] when using default feature_extractor, else the type for your customer feature_extractor.
+
+
+
+
+**Returns:** *numpy.ndarray*
+
+ Output by specified model.
+
+
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..519986b
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .distill_and_select import DistillAndSelect
+
+
+def distill_and_select(model_name: str, **kwargs):
+ return DistillAndSelect(model_name, **kwargs)
+
diff --git a/demo_video.flv b/demo_video.flv
new file mode 100644
index 0000000..f577d0e
Binary files /dev/null and b/demo_video.flv differ
diff --git a/distill_and_select.py b/distill_and_select.py
new file mode 100644
index 0000000..07cc379
--- /dev/null
+++ b/distill_and_select.py
@@ -0,0 +1,112 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from typing import Union, Any
+from torchvision import transforms
+from towhee.operator.base import NNOperator
+from towhee import register
+from PIL import Image as PILImage
+from model.feature_extractor import FeatureExtractor
+from model.students import FineGrainedStudent, CoarseGrainedStudent
+from model.selector import SelectorNetwork
+from torch import nn
+
+
+@register(output_schema=['vec'])
+class DistillAndSelect(NNOperator):
+ """
+ DistillAndSelect
+ """
+
+ def __init__(self, model_name: str, model_weight_path: str = None,
+ feature_extractor: Union[str, nn.Module] = 'default', device: str = None):
+ """
+
+ Args:
+ model_name (`str`):
+ Can be one of them:
+ `feature_extractor`: Feature Extractor only,
+ `fg_att_student`: Fine Grained Student with attention,
+ `fg_bin_student`: Fine Grained Student with binarization,
+ `cg_student`: Coarse Grained Student,
+ `selector_att`: Selector Network with attention,
+ `selector_bin`: Selector Network with binarization.
+ model_weight_path (`str`):
+ Default is None, download use the original pretrained weights.
+ feature_extractor (`Union[str, nn.Module]`):
+ `None`, 'default' or a pytorch nn.Module instance.
+ `None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input.
+ 'default' means using the original pretrained feature extracting weights and this operator can process video data as input.
+ Or you can pass in a nn.Module instance as a specific feature extractor.
+ Default is `default`.
+ device (`str`):
+ Model device, cpu or cuda.
+ """
+ super().__init__()
+ assert model_name in ['feature_extractor', 'fg_att_student', 'fg_bin_student', 'cg_student', 'selector_att',
+ 'selector_bin'], 'unsupported model.'
+ if device is None:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ self.device = device
+
+ self.model_name = model_name
+
+ self.feature_extractor = None
+ if feature_extractor == 'default':
+ self.feature_extractor = FeatureExtractor(dims=512).to(device).eval()
+ elif isinstance(feature_extractor, nn.Module):
+ self.feature_extractor = feature_extractor
+
+ self.model = None
+ pretraind = True if model_weight_path is None else None
+ if self.model_name == 'fg_att_student':
+ self.model = FineGrainedStudent(pretrained=pretraind, attention=True)
+ elif self.model_name == 'fg_bin_student':
+ self.model = FineGrainedStudent(pretrained=pretraind, binarization=True)
+
+ elif self.model_name == 'cg_student':
+ self.model = CoarseGrainedStudent(pretrained=pretraind)
+
+ elif self.model_name == 'selector_att':
+ self.model = SelectorNetwork(pretrained=pretraind, attention=True)
+ elif self.model_name == 'selector_bin':
+ self.model = SelectorNetwork(pretrained=pretraind, binarization=True)
+
+ if model_weight_path is not None:
+ self.model.load_state_dict(torch.load(model_weight_path))
+
+ if self.model is not None:
+ self.model.to(device).eval()
+
+ self.tfms = transforms.Compose([
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(256),
+ transforms.ToTensor(),
+ ])
+
+ def __call__(self, data: Any): # List[VideoFrame] when self.feature_extractor is not None
+ if self.feature_extractor is not None:
+ pil_img_list = []
+ for img in data:
+ pil_img = PILImage.fromarray(img, img.mode)
+ tfmed_img = self.tfms(pil_img).permute(1, 2, 0).unsqueeze(0)
+ pil_img_list.append(tfmed_img)
+ data = torch.concat(pil_img_list, dim=0) * 255
+ data = self.feature_extractor(data.to(self.device)).to(self.device)
+ if self.model_name == 'feature_extractor':
+ return data.cpu().detach().squeeze().numpy()
+ index_feature = self.model.index_video(data)
+ return index_feature.cpu().detach().squeeze().numpy()
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..ae697e1
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,18 @@
+from .layers import *
+from .losses import *
+from .similarities import *
+
+
+def check_dims(features, mask=None, axis=0):
+ if features.ndim == 4:
+ return features, mask
+ elif features.ndim == 3:
+ features = features.unsqueeze(axis)
+ if mask is not None:
+ mask = mask.unsqueeze(axis)
+ return features, mask
+ else:
+ raise Exception('Wrong shape of input video tensor. The shape of the tensor must be either '
+ '[N, T, R, D] or [T, R, D], where N is the batch size, T the number of frames, '
+ 'R the number of regions and D number of dimensions. '
+ 'Input video tensor has shape {}'.format(features.shape))
diff --git a/model/constraints.py b/model/constraints.py
new file mode 100644
index 0000000..6e97fa9
--- /dev/null
+++ b/model/constraints.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn.functional as F
+
+
+class L2Constrain(object):
+
+ def __init__(self, axis=-1, eps=1e-6):
+ self.axis = axis
+ self.eps = eps
+
+ def __call__(self, module):
+ if hasattr(module, 'weight'):
+ w = module.weight.data
+ module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps)
+
+
+class NonNegConstrain(object):
+
+ def __init__(self, eps=1e-3):
+ self.eps = eps
+
+ def __call__(self, module):
+ if hasattr(module, 'weight'):
+ w = module.weight.data
+ module.weight.data = torch.clamp(w, min=self.eps)
diff --git a/model/feature_extractor.py b/model/feature_extractor.py
new file mode 100644
index 0000000..8eb900f
--- /dev/null
+++ b/model/feature_extractor.py
@@ -0,0 +1,46 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+from .layers import *
+
+
+class FeatureExtractor(nn.Module):
+
+ def __init__(self, network='resnet50', whiteninig=False, dims=3840):
+ super(FeatureExtractor, self).__init__()
+ self.normalizer = VideoNormalizer()
+
+ self.cnn = models.resnet50(pretrained=True)
+
+ self.rpool = RMAC()
+ self.layers = {'layer1': 28, 'layer2': 14, 'layer3': 6, 'layer4': 3}
+ if whiteninig or dims != 3840:
+ self.pca = PCA(dims)
+
+ def extract_region_vectors(self, x):
+ tensors = []
+ for nm, module in self.cnn._modules.items():
+ if nm not in {'avgpool', 'fc', 'classifier'}:
+ x = module(x).contiguous()
+ if nm in self.layers:
+ # region_vectors = self.rpool(x)
+ s = self.layers[nm]
+ region_vectors = F.max_pool2d(x, [s, s], int(np.ceil(s / 2)))
+ region_vectors = F.normalize(region_vectors, p=2, dim=1)
+ tensors.append(region_vectors)
+ for i in range(len(tensors)):
+ tensors[i] = F.normalize(F.adaptive_max_pool2d(tensors[i], tensors[-1].shape[2:]), p=2, dim=1)
+ x = torch.cat(tensors, 1)
+ x = x.view(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
+ x = F.normalize(x, p=2, dim=-1)
+ return x
+
+ def forward(self, x):
+ x = self.normalizer(x)
+ x = self.extract_region_vectors(x)
+ if hasattr(self, 'pca'):
+ x = self.pca(x)
+ return x
\ No newline at end of file
diff --git a/model/layers.py b/model/layers.py
new file mode 100644
index 0000000..e55ac73
--- /dev/null
+++ b/model/layers.py
@@ -0,0 +1,214 @@
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .constraints import L2Constrain
+
+
+class VideoNormalizer(nn.Module):
+
+ def __init__(self):
+ super(VideoNormalizer, self).__init__()
+ self.scale = nn.Parameter(torch.Tensor([255.]), requires_grad=False)
+ self.mean = nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]), requires_grad=False)
+ self.std = nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]), requires_grad=False)
+
+ def forward(self, video):
+ video = video.float()
+ video = ((video / self.scale) - self.mean) / self.std
+ return video.permute(0, 3, 1, 2)
+
+
+class RMAC(nn.Module):
+
+ def __init__(self, L=[3]):
+ super(RMAC,self).__init__()
+ self.L = L
+
+ def forward(self, x):
+ return self.region_pooling(x, L=self.L)
+
+ def region_pooling(self, x, L=[3]):
+ ovr = 0.4 # desired overlap of neighboring regions
+ steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
+
+ W = x.shape[3]
+ H = x.shape[2]
+
+ w = min(W, H)
+ w2 = math.floor(w / 2.0 - 1)
+
+ b = (max(H, W) - w) / (steps - 1)
+ (tmp, idx) = torch.min(torch.abs(((w ** 2 - w * b) / w ** 2) - ovr), 0) # steps(idx) regions for long dimension
+
+ # region overplus per dimension
+ Wd = 0
+ Hd = 0
+ if H < W:
+ Wd = idx.item() + 1
+ elif H > W:
+ Hd = idx.item() + 1
+
+ vecs = []
+ for l in L:
+ wl = math.floor(2 * w / (l + 1))
+ wl2 = math.floor(wl / 2 - 1)
+
+ if l + Wd == 1:
+ b = 0
+ else:
+ b = (W - wl) / (l + Wd - 1)
+ cenW = torch.floor(wl2 + torch.tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates
+ if l + Hd == 1:
+ b = 0
+ else:
+ b = (H - wl) / (l + Hd - 1)
+ cenH = torch.floor(wl2 + torch.tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates
+
+ for i in cenH.long().tolist():
+ v = []
+ for j in cenW.long().tolist():
+ if wl == 0:
+ continue
+ R = x[:, :, i: i+wl, j: j+wl]
+ v.append(F.adaptive_max_pool2d(R, (1, 1)))
+ vecs.append(torch.cat(v, dim=3))
+ return torch.cat(vecs, dim=2)
+
+
+class PCA(nn.Module):
+
+ def __init__(self, n_components=None):
+ super(PCA, self).__init__()
+ pretrained_url = 'http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth'
+ white = torch.hub.load_state_dict_from_url(pretrained_url)
+ idx = torch.argsort(white['d'], descending=True)[: n_components]
+ d = white['d'][idx]
+ V = white['V'][:, idx]
+ D = torch.diag(1. / torch.sqrt(d + 1e-7))
+ self.mean = nn.Parameter(white['mean'], requires_grad=False)
+ self.DVt = nn.Parameter(torch.mm(D, V.T).T, requires_grad=False)
+
+ def forward(self, logits):
+ logits -= self.mean.expand_as(logits)
+ logits = torch.matmul(logits, self.DVt)
+ logits = F.normalize(logits, p=2, dim=-1)
+ return logits
+
+
+class Attention(nn.Module):
+
+ def __init__(self, dims, norm=False):
+ super(Attention, self).__init__()
+ self.norm = norm
+ if self.norm:
+ self.constrain = L2Constrain()
+ else:
+ self.transform = nn.Linear(dims, dims)
+ self.context_vector = nn.Linear(dims, 1, bias=False)
+ self.reset_parameters()
+
+ def forward(self, x):
+ if self.norm:
+ weights = self.context_vector(x)
+ weights = torch.add(torch.div(weights, 2.), .5)
+ else:
+ x_tr = torch.tanh(self.transform(x))
+ weights = self.context_vector(x_tr)
+ weights = torch.sigmoid(weights)
+ x = x * weights
+ return x, weights
+
+ def reset_parameters(self):
+ if self.norm:
+ nn.init.normal_(self.context_vector.weight)
+ self.constrain(self.context_vector)
+ else:
+ nn.init.xavier_uniform_(self.context_vector.weight)
+ nn.init.xavier_uniform_(self.transform.weight)
+ nn.init.zeros_(self.transform.bias)
+
+ def apply_contraint(self):
+ if self.norm:
+ self.constrain(self.context_vector)
+
+
+class BinarizationLayer(nn.Module):
+
+ def __init__(self, dims, bits=None, sigma=1e-6, ITQ_init=True):
+ super(BinarizationLayer, self).__init__()
+ self.sigma = sigma
+ if ITQ_init:
+ pretrained_url = 'https://mever.iti.gr/distill-and-select/models/itq_resnet50W_dns100k_1M.pth'
+ self.W = nn.Parameter(torch.hub.load_state_dict_from_url(pretrained_url)['proj'])
+ else:
+ if bits is None:
+ bits = dims
+ self.W = nn.Parameter(torch.rand(dims, bits))
+
+ def forward(self, x):
+ x = F.normalize(x, p=2, dim=-1)
+ x = torch.matmul(x, self.W)
+ if self.training:
+ x = torch.erf(x / np.sqrt(2 * self.sigma))
+ else:
+ x = torch.sign(x)
+ return x
+
+ def __repr__(self,):
+ return '{}(dims={}, bits={}, sigma={})'.format(
+ self.__class__.__name__, self.W.shape[0], self.W.shape[1], self.sigma)
+
+
+class NetVLAD(nn.Module):
+ """Acknowledgement to @lyakaap and @Nanne for providing their implementations"""
+
+ def __init__(self, dims, num_clusters, outdims=None):
+ super(NetVLAD, self).__init__()
+ self.num_clusters = num_clusters
+ self.dims = dims
+
+ self.centroids = nn.Parameter(torch.randn(num_clusters, dims) / math.sqrt(self.dims))
+ self.conv = nn.Conv2d(dims, num_clusters, kernel_size=1, bias=False)
+
+ if outdims is not None:
+ self.outdims = outdims
+ self.reduction_layer = nn.Linear(self.num_clusters * self.dims, self.outdims, bias=False)
+ else:
+ self.outdims = self.num_clusters * self.dims
+ self.norm = nn.LayerNorm(self.outdims)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ self.conv.weight = nn.Parameter(self.centroids.detach().clone().unsqueeze(-1).unsqueeze(-1))
+ if hasattr(self, 'reduction_layer'):
+ nn.init.normal_(self.reduction_layer.weight, std=1 / math.sqrt(self.num_clusters * self.dims))
+
+ def forward(self, x, mask=None):
+ N, C, T, R = x.shape
+
+ # soft-assignment
+ soft_assign = self.conv(x).view(N, self.num_clusters, -1)
+ soft_assign = F.softmax(soft_assign, dim=1).view(N, self.num_clusters, T, R)
+
+ x_flatten = x.view(N, C, -1)
+
+ vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device)
+ for cluster in range(self.num_clusters): # slower than non-looped, but lower memory usage
+ residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[cluster:cluster + 1, :].\
+ expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
+ residual = residual.view(N, C, T, R)
+ residual *= soft_assign[:, cluster:cluster + 1, :]
+ if mask is not None:
+ residual = residual.masked_fill((1 - mask.unsqueeze(1).unsqueeze(-1)).bool(), 0.0)
+ vlad[:, cluster:cluster+1, :] = residual.sum([-2, -1]).unsqueeze(1)
+
+ vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
+ vlad = vlad.view(x.size(0), -1) # flatten
+ vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
+
+ if hasattr(self, 'reduction_layer'):
+ vlad = self.reduction_layer(vlad)
+ return self.norm(vlad)
diff --git a/model/losses.py b/model/losses.py
new file mode 100644
index 0000000..dc9e1e6
--- /dev/null
+++ b/model/losses.py
@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+
+class TripletLoss(nn.Module):
+
+ def __init__(self, gamma=1.0, similarity=True):
+ super(TripletLoss, self).__init__()
+ self.gamma = gamma
+ self.similarity = similarity
+
+ def forward(self, sim_pos, sim_neg):
+ if self.similarity:
+ loss = torch.clamp(sim_neg - sim_pos + self.gamma, min=0.)
+ else:
+ loss = torch.clamp(sim_pos - sim_neg + self.gamma, min=0.)
+ return loss.mean()
+
+
+class SimilarityRegularizationLoss(nn.Module):
+
+ def __init__(self, min_val=-1., max_val=1.):
+ super(SimilarityRegularizationLoss, self).__init__()
+ self.min_val = min_val
+ self.max_val = max_val
+
+ def forward(self, sim):
+ loss = torch.sum(torch.abs(torch.clamp(sim - self.min_val, max=0.)))
+ loss += torch.sum(torch.abs(torch.clamp(sim - self.max_val, min=0.)))
+ return loss
+
+ def __repr__(self,):
+ return '{}(min_val={}, max_val={})'.format(self.__class__.__name__, self.min_val, self.max_val)
diff --git a/model/selector.py b/model/selector.py
new file mode 100644
index 0000000..2b29e5c
--- /dev/null
+++ b/model/selector.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+
+from . import *
+
+
+model_urls = {
+ 'dns_selector_cg-fg_att': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_att.pth',
+ 'dns_selector_cg-fg_bin': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_bin.pth',
+}
+
+
+class MetadataModel(nn.Module):
+
+ def __init__(self,
+ input_size,
+ hidden_size=100,
+ num_layers=1
+ ):
+ super(MetadataModel, self).__init__()
+
+ model = [
+ nn.Linear(input_size, hidden_size, bias=False),
+ nn.BatchNorm1d(hidden_size),
+ nn.ReLU(),
+ nn.Dropout()
+ ]
+
+ for _ in range(num_layers):
+ model.extend([nn.Linear(hidden_size, hidden_size, bias=False),
+ nn.BatchNorm1d(hidden_size),
+ nn.ReLU(),
+ nn.Dropout()])
+
+ model.extend([nn.Linear(hidden_size, 1),
+ nn.Sigmoid()])
+ self.model = nn.Sequential(*model)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.model.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SelectorNetwork(nn.Module):
+
+ def __init__(self,
+ dims=512,
+ hidden_size=100,
+ num_layers=1,
+ attention=False,
+ binarization=False,
+ pretrained=False,
+ **kwargs
+ ):
+ super(SelectorNetwork, self).__init__()
+ self.attention = Attention(dims, norm=False)
+ self.visil_head = VideoComperator()
+ self.mlp = MetadataModel(3, hidden_size, num_layers)
+
+ if pretrained:
+ if not (attention or binarization):
+ raise Exception('No pretrained model provided for the selected settings. '
+ 'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
+ elif attention:
+ self.load_state_dict(
+ torch.hub.load_state_dict_from_url(
+ model_urls['dns_selector_cg-fg_att'])['model'])
+ elif binarization:
+ self.load_state_dict(
+ torch.hub.load_state_dict_from_url(
+ model_urls['dns_selector_cg-fg_bin'])['model'])
+
+ def get_network_name(self,):
+ return 'selector_network'
+
+ def index_video(self, x, mask=None):
+ x, mask = check_dims(x, mask)
+ sim = self.frame_to_frame_similarity(x)
+
+ sim_mask = None
+ if mask is not None:
+ sim_mask = torch.einsum("bik,bjk->bij", mask.unsqueeze(-1), mask.unsqueeze(-1))
+ sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
+
+ sim, sim_mask = self.visil_head(sim, sim_mask)
+
+ if sim_mask is not None:
+ sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
+ sim = torch.sum(sim, [1, 2]) / torch.sum(sim_mask, [1, 2])
+ else:
+ sim = torch.mean(sim, [1, 2])
+
+ return sim.unsqueeze(-1)
+
+ def frame_to_frame_similarity(self, x):
+ x, a = self.attention(x)
+ sim = torch.einsum("biok,bjpk->biopj", x, x)
+ return torch.mean(sim, [2, 3])
+
+ def forward(self, x):
+ return self.mlp(x)
diff --git a/model/similarities.py b/model/similarities.py
new file mode 100644
index 0000000..b138fe0
--- /dev/null
+++ b/model/similarities.py
@@ -0,0 +1,121 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TensorDot(nn.Module):
+
+ def __init__(self, pattern='iak,jbk->iabj', metric='cosine'):
+ super(TensorDot, self).__init__()
+ self.pattern = pattern
+ self.metric = metric
+
+ def forward(self, query, target):
+ if self.metric == 'cosine':
+ sim = torch.einsum(self.pattern, [query, target])
+ elif self.metric == 'euclidean':
+ sim = 1 - 2 * torch.einsum(self.pattern, [query, target])
+ elif self.metric == 'hamming':
+ sim = torch.einsum(self.pattern, query, target) / target.shape[-1]
+ return sim
+
+ def __repr__(self,):
+ return '{}(pattern={})'.format(self.__class__.__name__, self.pattern)
+
+
+class ChamferSimilarity(nn.Module):
+
+ def __init__(self, symmetric=False, axes=[1, 0]):
+ super(ChamferSimilarity, self).__init__()
+ self.axes = axes
+ if symmetric:
+ self.sim_fun = lambda x, m: self.symmetric_chamfer_similarity(x, mask=m, axes=axes)
+ else:
+ self.sim_fun = lambda x, m: self.chamfer_similarity(x, mask=m, max_axis=axes[0], mean_axis=axes[1])
+
+ def chamfer_similarity(self, s, mask=None, max_axis=1, mean_axis=0):
+ if mask is not None:
+ s = s.masked_fill((1 - mask).bool(), -np.inf)
+ s = torch.max(s, max_axis, keepdim=True)[0]
+ mask = torch.max(mask, max_axis, keepdim=True)[0]
+ s = s.masked_fill((1 - mask).bool(), 0.0)
+ s = torch.sum(s, mean_axis, keepdim=True)
+ s /= torch.sum(mask, mean_axis, keepdim=True)
+ else:
+ s = torch.max(s, max_axis, keepdim=True)[0]
+ s = torch.mean(s, mean_axis, keepdim=True)
+ return s.squeeze(max(max_axis, mean_axis)).squeeze(min(max_axis, mean_axis))
+
+ def symmetric_chamfer_similarity(self, s, mask=None, axes=[0, 1]):
+ return (self.chamfer_similarity(s, mask=mask, max_axis=axes[0], mean_axis=axes[1]) +
+ self.chamfer_similarity(s, mask=mask, max_axis=axes[1], mean_axis=axes[0])) / 2
+
+ def forward(self, s, mask=None):
+ return self.sim_fun(s, mask)
+
+ def __repr__(self,):
+ return '{}(max_axis={}, mean_axis={})'.format(self.__class__.__name__, self.axes[0], self.axes[1])
+
+
+class VideoComperator(nn.Module):
+
+ def __init__(self, in_channels=1, out_channels=1):
+ super(VideoComperator, self).__init__()
+
+ self.rpad1 = nn.ZeroPad2d(1)
+ self.conv1 = nn.Conv2d(in_channels, 32, 3)
+ self.pool1 = nn.MaxPool2d((2, 2), 2)
+
+ self.rpad2 = nn.ZeroPad2d(1)
+ self.conv2 = nn.Conv2d(32, 64, 3)
+ self.pool2 = nn.MaxPool2d((2, 2), 2)
+
+ self.rpad3 = nn.ZeroPad2d(1)
+ self.conv3 = nn.Conv2d(64, 128, 3)
+
+ self.fconv = nn.Conv2d(128, out_channels, 1)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, sim_matrix, mask=None):
+ if sim_matrix.ndim == 3:
+ sim_matrix = sim_matrix.unsqueeze(1)
+ elif sim_matrix.ndim != 4:
+ raise Exception('Input tensor to VideoComperator have to be 3- or 4-dimensional')
+
+ if mask is not None:
+ assert mask.shape[-2:] == sim_matrix.shape[-2:], 'Mask tensor must be of the same shape as similarity ' \
+ 'matrix in the last two dimensions. Mask shape is {} ' \
+ 'while similarity matrix is {}'.format(mask.shape[-2:],
+ sim_matrix.shape[-2:])
+ mask = mask.unsqueeze(1)
+
+ sim = self.rpad1(sim_matrix)
+ sim = self.conv1(sim)
+ if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
+ sim = F.relu(sim)
+ sim = self.pool1(sim)
+ if mask is not None: mask = self.pool1(mask)
+
+ sim = self.rpad2(sim)
+ sim = self.conv2(sim)
+ if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
+ sim = F.relu(sim)
+ sim = self.pool2(sim)
+ if mask is not None: mask = self.pool2(mask)
+
+ sim = self.rpad3(sim)
+ sim = self.conv3(sim)
+ if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0)
+ sim = F.relu(sim)
+
+ sim = self.fconv(sim)
+ return sim.squeeze(1), mask.squeeze(1) if mask is not None else None
diff --git a/model/students.py b/model/students.py
new file mode 100644
index 0000000..98f5688
--- /dev/null
+++ b/model/students.py
@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import *
+from einops import rearrange
+
+
+model_urls = {
+ 'dns_cg_student': 'https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth',
+ 'dns_fg_att_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_att_student.pth',
+ 'dns_fg_bin_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_bin_student.pth',
+}
+
+
+class CoarseGrainedStudent(nn.Module):
+
+ def __init__(self,
+ dims=512,
+ attention=True,
+ transformer=True,
+ transformer_heads=8,
+ transformer_feedforward_dims=2048,
+ transformer_layers=1,
+ netvlad=True,
+ netvlad_clusters=64,
+ netvlad_outdims=1024,
+ pretrained=False,
+ **kwargs
+ ):
+ super(CoarseGrainedStudent, self).__init__()
+ self.student_type = 'cg'
+
+ if attention:
+ self.attention = Attention(dims, norm=False)
+ if transformer:
+ encoder_layer = nn.TransformerEncoderLayer(dims,
+ transformer_heads,
+ transformer_feedforward_dims)
+ self.transformer = nn.TransformerEncoder(encoder_layer,
+ transformer_layers,
+ nn.LayerNorm(dims))
+ self.apply(self._init_weights)
+
+ if netvlad:
+ self.netvlad = NetVLAD(dims, netvlad_clusters, outdims=netvlad_outdims)
+
+ if pretrained:
+ self.load_state_dict(
+ torch.hub.load_state_dict_from_url(
+ model_urls['dns_cg_student'])['model'])
+
+ def get_network_name(self,):
+ return '{}_student'.format(self.student_type)
+
+ def calculate_video_similarity(self, query, target):
+ return torch.mm(query, torch.transpose(target, 0, 1))
+
+ def index_video(self, x, mask=None):
+ x, mask = check_dims(x, mask)
+
+ if hasattr(self, 'attention'):
+ x, a = self.attention(x)
+ x = torch.sum(x, 2)
+ x = F.normalize(x, p=2, dim=-1)
+
+ if hasattr(self, 'transformer'):
+ x = x.permute(1, 0, 2)
+ x = self.transformer(x, src_key_padding_mask=
+ (1 - mask).bool() if mask is not None else None)
+ x = x.permute(1, 0, 2)
+
+ if hasattr(self, 'netvlad'):
+ x = x.unsqueeze(2).permute(0, 3, 1, 2)
+ x = self.netvlad(x, mask=mask)
+ else:
+ if mask is not None:
+ x = x.masked_fill((1 - mask.unsqueeze(-1)).bool(), 0.0)
+ x = torch.sum(x, 1) / torch.sum(mask, 1, keepdim=True)
+ else:
+ x = torch.mean(x, 1)
+ return F.normalize(x, p=2, dim=-1)
+
+ def forward(self, anchors, positives, negatives,
+ anchors_masks=None, positive_masks=None, negative_masks=None):
+ pos_pairs = torch.sum(anchors * positives, 1, keepdim=True)
+ neg_pairs = torch.sum(anchors * negatives, 1, keepdim=True)
+ return pos_pairs, neg_pairs, None
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+
+class FineGrainedStudent(nn.Module):
+
+ def __init__(self,
+ dims=512,
+ attention=False,
+ binarization=False,
+ pretrained=False,
+ **kwargs
+ ):
+ super(FineGrainedStudent, self).__init__()
+ self.student_type = 'fg'
+ if attention and binarization:
+ raise Exception('Can\'t use \'attention=True\' and \'binarization=True\' at the same time. '
+ 'Select one of the two options.')
+ elif binarization:
+ self.fg_type = 'bin'
+ self.binarization = BinarizationLayer(dims)
+ elif attention:
+ self.fg_type = 'att'
+ self.attention = Attention(dims, norm=False)
+ else:
+ self.fg_type = 'none'
+
+ self.f2f_sim = ChamferSimilarity(axes=[3, 2])
+
+ self.visil_head = VideoComperator()
+ self.htanh = nn.Hardtanh()
+
+ self.v2v_sim = ChamferSimilarity(axes=[2, 1])
+
+ self.sim_criterion = SimilarityRegularizationLoss()
+
+ if pretrained:
+ if not (attention or binarization):
+ raise Exception('No pretrained model provided for the selected settings. '
+ 'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.')
+ self.load_state_dict(
+ torch.hub.load_state_dict_from_url(
+ model_urls['dns_fg_{}_student'.format(self.fg_type)])['model'])
+
+ def get_network_name(self,):
+ return '{}_{}_student'.format(self.student_type, self.fg_type)
+
+ def frame_to_frame_similarity(self, query, target, query_mask=None, target_mask=None, batched=False):
+ d = target.shape[-1]
+ sim_mask = None
+ if batched:
+ sim = torch.einsum('biok,bjpk->biopj', query, target)
+ sim = self.f2f_sim(sim)
+ if query_mask is not None and target_mask is not None:
+ sim_mask = torch.einsum('bik,bjk->bij', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
+ else:
+ sim = torch.einsum('aiok,bjpk->aiopjb', query, target)
+ sim = self.f2f_sim(sim)
+ sim = rearrange(sim, 'a i j b -> (a b) i j')
+ if query_mask is not None and target_mask is not None:
+ sim_mask = torch.einsum('aik,bjk->aijb', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1))
+ sim_mask = rearrange(sim_mask, 'a i j b -> (a b) i j')
+ if self.fg_type == 'bin':
+ sim /= d
+ if sim_mask is not None:
+ sim = sim.masked_fill((1 - sim_mask).bool(), 0.0)
+ return sim, sim_mask
+
+ def calculate_video_similarity(self, query, target, query_mask=None, target_mask=None):
+ query, query_mask = check_dims(query, query_mask)
+ target, target_mask = check_dims(target, target_mask)
+
+ sim, sim_mask = self.similarity_matrix(query, target, query_mask, target_mask)
+ sim = self.v2v_sim(sim, sim_mask)
+
+ return sim.view(query.shape[0], target.shape[0])
+
+ def similarity_matrix(self, query, target, query_mask=None, target_mask=None):
+ query, query_mask = check_dims(query, query_mask)
+ target, target_mask = check_dims(target, target_mask)
+
+ sim, sim_mask = self.frame_to_frame_similarity(query, target, query_mask, target_mask)
+ sim, sim_mask = self.visil_head(sim, sim_mask)
+ return self.htanh(sim), sim_mask
+
+ def index_video(self, x, mask=None):
+ if self.fg_type == 'bin':
+ x = self.binarization(x)
+ elif self.fg_type == 'att':
+ x, a = self.attention(x)
+ if mask is not None:
+ x = x.masked_fill((1 - mask).bool().unsqueeze(-1).unsqueeze(-1), 0.0)
+ return x
+
+ def forward(self, anchors, positives, negatives,
+ anchors_masks, positive_masks, negative_masks):
+ pos_sim, pos_mask = self.frame_to_frame_similarity(
+ anchors, positives, anchors_masks, positive_masks, batched=True)
+ neg_sim, neg_mask = self.frame_to_frame_similarity(
+ anchors, negatives, anchors_masks, negative_masks, batched=True)
+ sim, sim_mask = torch.cat([pos_sim, neg_sim], 0), torch.cat([pos_mask, neg_mask], 0)
+
+ sim, sim_mask = self.visil_head(sim, sim_mask)
+ loss = self.sim_criterion(sim)
+ sim = self.htanh(sim)
+ sim = self.v2v_sim(sim, sim_mask)
+
+ pos_pair, neg_pair = torch.chunk(sim.unsqueeze(-1), 2, dim=0)
+ return pos_pair, neg_pair, loss
diff --git a/output_imgs/cg_student.png b/output_imgs/cg_student.png
new file mode 100644
index 0000000..b2bbacf
Binary files /dev/null and b/output_imgs/cg_student.png differ
diff --git a/output_imgs/cg_student_specifical.png b/output_imgs/cg_student_specifical.png
new file mode 100644
index 0000000..e1b355c
Binary files /dev/null and b/output_imgs/cg_student_specifical.png differ
diff --git a/output_imgs/feature_extractor.png b/output_imgs/feature_extractor.png
new file mode 100644
index 0000000..2bf1241
Binary files /dev/null and b/output_imgs/feature_extractor.png differ
diff --git a/output_imgs/fg_att_student.png b/output_imgs/fg_att_student.png
new file mode 100644
index 0000000..451ff4c
Binary files /dev/null and b/output_imgs/fg_att_student.png differ
diff --git a/output_imgs/fg_bin_student.png b/output_imgs/fg_bin_student.png
new file mode 100644
index 0000000..22536c1
Binary files /dev/null and b/output_imgs/fg_bin_student.png differ
diff --git a/output_imgs/selector_att.png b/output_imgs/selector_att.png
new file mode 100644
index 0000000..7e17830
Binary files /dev/null and b/output_imgs/selector_att.png differ
diff --git a/output_imgs/selector_bin.png b/output_imgs/selector_bin.png
new file mode 100644
index 0000000..9c9de9c
Binary files /dev/null and b/output_imgs/selector_bin.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2fc395a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch
+einops
+towhee
+Pillow
\ No newline at end of file