diff --git a/DnS.png b/DnS.png new file mode 100644 index 0000000..e05a23e Binary files /dev/null and b/DnS.png differ diff --git a/README.md b/README.md index 6fed073..b7275fb 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,183 @@ -# distill-and-select +# Video deduplication with Distill-and-Select + +*author: Chen Zhang* + + +
+ + + + +## Description + +This operator is made for video deduplication task base on [DnS: Distill-and-Select for Efficient and Accurate Video Indexing and Retrieval](https://arxiv.org/abs/2106.13266). +Training with knowledge distillation method in large, unlabelled datasets, DnS learns: a) Student Networks at different retrieval performance and computational efficiency trade-offs and b) a Selection Network that at test time rapidly directs samples to the appropriate student to maintain both high retrieval performance and high computational efficiency. + +![](DnS.png) + +
+ + +## Code Example + +Load a video from path './demo_video.flv' using ffmpeg operator to decode it. + +Then use distill_and_select operator to get the output using the specified model. + +For fine-grained student model, get a 3d output with the temporal-dim information. For coarse-grained student model, get a 1d output representing the whole video. For selector model, get a scalar output. + + *For feature_extractor model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='feature_extractor') \ + .show() +``` +![](output_imgs/feature_extractor.png) + + + *For fg_att_student model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='fg_att_student') \ + .show() +``` +![](output_imgs/fg_att_student.png) + + + *For fg_bin_student model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='fg_bin_student') \ + .show() +``` +![](output_imgs/fg_bin_student.png) + + + *For cg_student model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='cg_student') \ + .show() +``` +![](output_imgs/cg_student.png) + + + *For selector_att model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='selector_att') \ + .show() +``` +![](output_imgs/selector_att.png) + + + *For selector_bin model*: + +```python +import towhee +towhee.dc(['./demo_video.flv']) \ + .video_decode.ffmpeg(start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op(func=lambda x: [y for y in x]) \ + .distill_and_select(model_name='selector_bin') \ + .show() +``` +![](output_imgs/selector_bin.png) + + + + +*Write a same pipeline with explicit inputs/outputs name specifications, take cg_student model for example:* + +```python +import towhee +towhee.dc['path'](['./demo_video.flv']) \ + .video_decode.ffmpeg['path', 'frames'](start_time=0.0, end_time=1000.0, sample_type='time_step_sample', args={'time_step': 1}) \ + .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \ + .distill_and_select['frames', 'vec'](model_name='cg_student') \ + .show() +``` +![](output_imgs/cg_student_specifical.png) + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***distill_and_select(model_name, \*\*kwargs)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ Can be one of them: +`feature_extractor`: Feature Extractor only, +`fg_att_student`: Fine Grained Student with attention, +`fg_bin_student`: Fine Grained Student with binarization, +`cg_student`: Coarse Grained Student, +`selector_att`: Selector Network with attention, +`selector_bin`: Selector Network with binarization. + + +​ ***model_weight_path:*** *str* + +​ Default is None, download use the original pretrained weights. + +​ ***feature_extractor:*** *Union[str, nn.Module]* + +​ `None`, 'default' or a pytorch nn.Module instance. +`None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input. +'default' means using the original pretrained feature extracting weights and this operator can process video data as input. +Or you can pass in a nn.Module instance as a specified feature extractor. +Default is `default`. + +​ ***device:*** *str* +​ Model device, cpu or cuda. + +
+ + + +## Interface + +Get the output from your specified model. + +**Parameters:** + +​ ***data:*** *List[towhee.types.VideoFrame]* or *Any* + +​ The input type is List[VideoFrame] when using default feature_extractor, else the type for your customer feature_extractor. + + + + +**Returns:** *numpy.ndarray* + +​ Output by specified model. + + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..519986b --- /dev/null +++ b/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distill_and_select import DistillAndSelect + + +def distill_and_select(model_name: str, **kwargs): + return DistillAndSelect(model_name, **kwargs) + diff --git a/demo_video.flv b/demo_video.flv new file mode 100644 index 0000000..f577d0e Binary files /dev/null and b/demo_video.flv differ diff --git a/distill_and_select.py b/distill_and_select.py new file mode 100644 index 0000000..07cc379 --- /dev/null +++ b/distill_and_select.py @@ -0,0 +1,112 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Union, Any +from torchvision import transforms +from towhee.operator.base import NNOperator +from towhee import register +from PIL import Image as PILImage +from model.feature_extractor import FeatureExtractor +from model.students import FineGrainedStudent, CoarseGrainedStudent +from model.selector import SelectorNetwork +from torch import nn + + +@register(output_schema=['vec']) +class DistillAndSelect(NNOperator): + """ + DistillAndSelect + """ + + def __init__(self, model_name: str, model_weight_path: str = None, + feature_extractor: Union[str, nn.Module] = 'default', device: str = None): + """ + + Args: + model_name (`str`): + Can be one of them: + `feature_extractor`: Feature Extractor only, + `fg_att_student`: Fine Grained Student with attention, + `fg_bin_student`: Fine Grained Student with binarization, + `cg_student`: Coarse Grained Student, + `selector_att`: Selector Network with attention, + `selector_bin`: Selector Network with binarization. + model_weight_path (`str`): + Default is None, download use the original pretrained weights. + feature_extractor (`Union[str, nn.Module]`): + `None`, 'default' or a pytorch nn.Module instance. + `None` means this operator don't support feature extracting from the video data and this operator process embedding feature as input. + 'default' means using the original pretrained feature extracting weights and this operator can process video data as input. + Or you can pass in a nn.Module instance as a specific feature extractor. + Default is `default`. + device (`str`): + Model device, cpu or cuda. + """ + super().__init__() + assert model_name in ['feature_extractor', 'fg_att_student', 'fg_bin_student', 'cg_student', 'selector_att', + 'selector_bin'], 'unsupported model.' + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + self.model_name = model_name + + self.feature_extractor = None + if feature_extractor == 'default': + self.feature_extractor = FeatureExtractor(dims=512).to(device).eval() + elif isinstance(feature_extractor, nn.Module): + self.feature_extractor = feature_extractor + + self.model = None + pretraind = True if model_weight_path is None else None + if self.model_name == 'fg_att_student': + self.model = FineGrainedStudent(pretrained=pretraind, attention=True) + elif self.model_name == 'fg_bin_student': + self.model = FineGrainedStudent(pretrained=pretraind, binarization=True) + + elif self.model_name == 'cg_student': + self.model = CoarseGrainedStudent(pretrained=pretraind) + + elif self.model_name == 'selector_att': + self.model = SelectorNetwork(pretrained=pretraind, attention=True) + elif self.model_name == 'selector_bin': + self.model = SelectorNetwork(pretrained=pretraind, binarization=True) + + if model_weight_path is not None: + self.model.load_state_dict(torch.load(model_weight_path)) + + if self.model is not None: + self.model.to(device).eval() + + self.tfms = transforms.Compose([ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(256), + transforms.ToTensor(), + ]) + + def __call__(self, data: Any): # List[VideoFrame] when self.feature_extractor is not None + if self.feature_extractor is not None: + pil_img_list = [] + for img in data: + pil_img = PILImage.fromarray(img, img.mode) + tfmed_img = self.tfms(pil_img).permute(1, 2, 0).unsqueeze(0) + pil_img_list.append(tfmed_img) + data = torch.concat(pil_img_list, dim=0) * 255 + data = self.feature_extractor(data.to(self.device)).to(self.device) + if self.model_name == 'feature_extractor': + return data.cpu().detach().squeeze().numpy() + index_feature = self.model.index_video(data) + return index_feature.cpu().detach().squeeze().numpy() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..ae697e1 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,18 @@ +from .layers import * +from .losses import * +from .similarities import * + + +def check_dims(features, mask=None, axis=0): + if features.ndim == 4: + return features, mask + elif features.ndim == 3: + features = features.unsqueeze(axis) + if mask is not None: + mask = mask.unsqueeze(axis) + return features, mask + else: + raise Exception('Wrong shape of input video tensor. The shape of the tensor must be either ' + '[N, T, R, D] or [T, R, D], where N is the batch size, T the number of frames, ' + 'R the number of regions and D number of dimensions. ' + 'Input video tensor has shape {}'.format(features.shape)) diff --git a/model/constraints.py b/model/constraints.py new file mode 100644 index 0000000..6e97fa9 --- /dev/null +++ b/model/constraints.py @@ -0,0 +1,25 @@ +import torch +import torch.nn.functional as F + + +class L2Constrain(object): + + def __init__(self, axis=-1, eps=1e-6): + self.axis = axis + self.eps = eps + + def __call__(self, module): + if hasattr(module, 'weight'): + w = module.weight.data + module.weight.data = F.normalize(w, p=2, dim=self.axis, eps=self.eps) + + +class NonNegConstrain(object): + + def __init__(self, eps=1e-3): + self.eps = eps + + def __call__(self, module): + if hasattr(module, 'weight'): + w = module.weight.data + module.weight.data = torch.clamp(w, min=self.eps) diff --git a/model/feature_extractor.py b/model/feature_extractor.py new file mode 100644 index 0000000..8eb900f --- /dev/null +++ b/model/feature_extractor.py @@ -0,0 +1,46 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +from .layers import * + + +class FeatureExtractor(nn.Module): + + def __init__(self, network='resnet50', whiteninig=False, dims=3840): + super(FeatureExtractor, self).__init__() + self.normalizer = VideoNormalizer() + + self.cnn = models.resnet50(pretrained=True) + + self.rpool = RMAC() + self.layers = {'layer1': 28, 'layer2': 14, 'layer3': 6, 'layer4': 3} + if whiteninig or dims != 3840: + self.pca = PCA(dims) + + def extract_region_vectors(self, x): + tensors = [] + for nm, module in self.cnn._modules.items(): + if nm not in {'avgpool', 'fc', 'classifier'}: + x = module(x).contiguous() + if nm in self.layers: + # region_vectors = self.rpool(x) + s = self.layers[nm] + region_vectors = F.max_pool2d(x, [s, s], int(np.ceil(s / 2))) + region_vectors = F.normalize(region_vectors, p=2, dim=1) + tensors.append(region_vectors) + for i in range(len(tensors)): + tensors[i] = F.normalize(F.adaptive_max_pool2d(tensors[i], tensors[-1].shape[2:]), p=2, dim=1) + x = torch.cat(tensors, 1) + x = x.view(x.shape[0], x.shape[1], -1).permute(0, 2, 1) + x = F.normalize(x, p=2, dim=-1) + return x + + def forward(self, x): + x = self.normalizer(x) + x = self.extract_region_vectors(x) + if hasattr(self, 'pca'): + x = self.pca(x) + return x \ No newline at end of file diff --git a/model/layers.py b/model/layers.py new file mode 100644 index 0000000..e55ac73 --- /dev/null +++ b/model/layers.py @@ -0,0 +1,214 @@ +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from .constraints import L2Constrain + + +class VideoNormalizer(nn.Module): + + def __init__(self): + super(VideoNormalizer, self).__init__() + self.scale = nn.Parameter(torch.Tensor([255.]), requires_grad=False) + self.mean = nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]), requires_grad=False) + self.std = nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]), requires_grad=False) + + def forward(self, video): + video = video.float() + video = ((video / self.scale) - self.mean) / self.std + return video.permute(0, 3, 1, 2) + + +class RMAC(nn.Module): + + def __init__(self, L=[3]): + super(RMAC,self).__init__() + self.L = L + + def forward(self, x): + return self.region_pooling(x, L=self.L) + + def region_pooling(self, x, L=[3]): + ovr = 0.4 # desired overlap of neighboring regions + steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension + + W = x.shape[3] + H = x.shape[2] + + w = min(W, H) + w2 = math.floor(w / 2.0 - 1) + + b = (max(H, W) - w) / (steps - 1) + (tmp, idx) = torch.min(torch.abs(((w ** 2 - w * b) / w ** 2) - ovr), 0) # steps(idx) regions for long dimension + + # region overplus per dimension + Wd = 0 + Hd = 0 + if H < W: + Wd = idx.item() + 1 + elif H > W: + Hd = idx.item() + 1 + + vecs = [] + for l in L: + wl = math.floor(2 * w / (l + 1)) + wl2 = math.floor(wl / 2 - 1) + + if l + Wd == 1: + b = 0 + else: + b = (W - wl) / (l + Wd - 1) + cenW = torch.floor(wl2 + torch.tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates + if l + Hd == 1: + b = 0 + else: + b = (H - wl) / (l + Hd - 1) + cenH = torch.floor(wl2 + torch.tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates + + for i in cenH.long().tolist(): + v = [] + for j in cenW.long().tolist(): + if wl == 0: + continue + R = x[:, :, i: i+wl, j: j+wl] + v.append(F.adaptive_max_pool2d(R, (1, 1))) + vecs.append(torch.cat(v, dim=3)) + return torch.cat(vecs, dim=2) + + +class PCA(nn.Module): + + def __init__(self, n_components=None): + super(PCA, self).__init__() + pretrained_url = 'http://ndd.iti.gr/visil/pca_resnet50_vcdb_1M.pth' + white = torch.hub.load_state_dict_from_url(pretrained_url) + idx = torch.argsort(white['d'], descending=True)[: n_components] + d = white['d'][idx] + V = white['V'][:, idx] + D = torch.diag(1. / torch.sqrt(d + 1e-7)) + self.mean = nn.Parameter(white['mean'], requires_grad=False) + self.DVt = nn.Parameter(torch.mm(D, V.T).T, requires_grad=False) + + def forward(self, logits): + logits -= self.mean.expand_as(logits) + logits = torch.matmul(logits, self.DVt) + logits = F.normalize(logits, p=2, dim=-1) + return logits + + +class Attention(nn.Module): + + def __init__(self, dims, norm=False): + super(Attention, self).__init__() + self.norm = norm + if self.norm: + self.constrain = L2Constrain() + else: + self.transform = nn.Linear(dims, dims) + self.context_vector = nn.Linear(dims, 1, bias=False) + self.reset_parameters() + + def forward(self, x): + if self.norm: + weights = self.context_vector(x) + weights = torch.add(torch.div(weights, 2.), .5) + else: + x_tr = torch.tanh(self.transform(x)) + weights = self.context_vector(x_tr) + weights = torch.sigmoid(weights) + x = x * weights + return x, weights + + def reset_parameters(self): + if self.norm: + nn.init.normal_(self.context_vector.weight) + self.constrain(self.context_vector) + else: + nn.init.xavier_uniform_(self.context_vector.weight) + nn.init.xavier_uniform_(self.transform.weight) + nn.init.zeros_(self.transform.bias) + + def apply_contraint(self): + if self.norm: + self.constrain(self.context_vector) + + +class BinarizationLayer(nn.Module): + + def __init__(self, dims, bits=None, sigma=1e-6, ITQ_init=True): + super(BinarizationLayer, self).__init__() + self.sigma = sigma + if ITQ_init: + pretrained_url = 'https://mever.iti.gr/distill-and-select/models/itq_resnet50W_dns100k_1M.pth' + self.W = nn.Parameter(torch.hub.load_state_dict_from_url(pretrained_url)['proj']) + else: + if bits is None: + bits = dims + self.W = nn.Parameter(torch.rand(dims, bits)) + + def forward(self, x): + x = F.normalize(x, p=2, dim=-1) + x = torch.matmul(x, self.W) + if self.training: + x = torch.erf(x / np.sqrt(2 * self.sigma)) + else: + x = torch.sign(x) + return x + + def __repr__(self,): + return '{}(dims={}, bits={}, sigma={})'.format( + self.__class__.__name__, self.W.shape[0], self.W.shape[1], self.sigma) + + +class NetVLAD(nn.Module): + """Acknowledgement to @lyakaap and @Nanne for providing their implementations""" + + def __init__(self, dims, num_clusters, outdims=None): + super(NetVLAD, self).__init__() + self.num_clusters = num_clusters + self.dims = dims + + self.centroids = nn.Parameter(torch.randn(num_clusters, dims) / math.sqrt(self.dims)) + self.conv = nn.Conv2d(dims, num_clusters, kernel_size=1, bias=False) + + if outdims is not None: + self.outdims = outdims + self.reduction_layer = nn.Linear(self.num_clusters * self.dims, self.outdims, bias=False) + else: + self.outdims = self.num_clusters * self.dims + self.norm = nn.LayerNorm(self.outdims) + self.reset_parameters() + + def reset_parameters(self): + self.conv.weight = nn.Parameter(self.centroids.detach().clone().unsqueeze(-1).unsqueeze(-1)) + if hasattr(self, 'reduction_layer'): + nn.init.normal_(self.reduction_layer.weight, std=1 / math.sqrt(self.num_clusters * self.dims)) + + def forward(self, x, mask=None): + N, C, T, R = x.shape + + # soft-assignment + soft_assign = self.conv(x).view(N, self.num_clusters, -1) + soft_assign = F.softmax(soft_assign, dim=1).view(N, self.num_clusters, T, R) + + x_flatten = x.view(N, C, -1) + + vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device) + for cluster in range(self.num_clusters): # slower than non-looped, but lower memory usage + residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - self.centroids[cluster:cluster + 1, :].\ + expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual = residual.view(N, C, T, R) + residual *= soft_assign[:, cluster:cluster + 1, :] + if mask is not None: + residual = residual.masked_fill((1 - mask.unsqueeze(1).unsqueeze(-1)).bool(), 0.0) + vlad[:, cluster:cluster+1, :] = residual.sum([-2, -1]).unsqueeze(1) + + vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization + vlad = vlad.view(x.size(0), -1) # flatten + vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize + + if hasattr(self, 'reduction_layer'): + vlad = self.reduction_layer(vlad) + return self.norm(vlad) diff --git a/model/losses.py b/model/losses.py new file mode 100644 index 0000000..dc9e1e6 --- /dev/null +++ b/model/losses.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + + +class TripletLoss(nn.Module): + + def __init__(self, gamma=1.0, similarity=True): + super(TripletLoss, self).__init__() + self.gamma = gamma + self.similarity = similarity + + def forward(self, sim_pos, sim_neg): + if self.similarity: + loss = torch.clamp(sim_neg - sim_pos + self.gamma, min=0.) + else: + loss = torch.clamp(sim_pos - sim_neg + self.gamma, min=0.) + return loss.mean() + + +class SimilarityRegularizationLoss(nn.Module): + + def __init__(self, min_val=-1., max_val=1.): + super(SimilarityRegularizationLoss, self).__init__() + self.min_val = min_val + self.max_val = max_val + + def forward(self, sim): + loss = torch.sum(torch.abs(torch.clamp(sim - self.min_val, max=0.))) + loss += torch.sum(torch.abs(torch.clamp(sim - self.max_val, min=0.))) + return loss + + def __repr__(self,): + return '{}(min_val={}, max_val={})'.format(self.__class__.__name__, self.min_val, self.max_val) diff --git a/model/selector.py b/model/selector.py new file mode 100644 index 0000000..2b29e5c --- /dev/null +++ b/model/selector.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +from . import * + + +model_urls = { + 'dns_selector_cg-fg_att': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_att.pth', + 'dns_selector_cg-fg_bin': 'https://mever.iti.gr/distill-and-select/models/dns_selector_cg-fg_bin.pth', +} + + +class MetadataModel(nn.Module): + + def __init__(self, + input_size, + hidden_size=100, + num_layers=1 + ): + super(MetadataModel, self).__init__() + + model = [ + nn.Linear(input_size, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Dropout() + ] + + for _ in range(num_layers): + model.extend([nn.Linear(hidden_size, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Dropout()]) + + model.extend([nn.Linear(hidden_size, 1), + nn.Sigmoid()]) + self.model = nn.Sequential(*model) + self.reset_parameters() + + def reset_parameters(self): + for m in self.model.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + + def forward(self, x): + return self.model(x) + + +class SelectorNetwork(nn.Module): + + def __init__(self, + dims=512, + hidden_size=100, + num_layers=1, + attention=False, + binarization=False, + pretrained=False, + **kwargs + ): + super(SelectorNetwork, self).__init__() + self.attention = Attention(dims, norm=False) + self.visil_head = VideoComperator() + self.mlp = MetadataModel(3, hidden_size, num_layers) + + if pretrained: + if not (attention or binarization): + raise Exception('No pretrained model provided for the selected settings. ' + 'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.') + elif attention: + self.load_state_dict( + torch.hub.load_state_dict_from_url( + model_urls['dns_selector_cg-fg_att'])['model']) + elif binarization: + self.load_state_dict( + torch.hub.load_state_dict_from_url( + model_urls['dns_selector_cg-fg_bin'])['model']) + + def get_network_name(self,): + return 'selector_network' + + def index_video(self, x, mask=None): + x, mask = check_dims(x, mask) + sim = self.frame_to_frame_similarity(x) + + sim_mask = None + if mask is not None: + sim_mask = torch.einsum("bik,bjk->bij", mask.unsqueeze(-1), mask.unsqueeze(-1)) + sim = sim.masked_fill((1 - sim_mask).bool(), 0.0) + + sim, sim_mask = self.visil_head(sim, sim_mask) + + if sim_mask is not None: + sim = sim.masked_fill((1 - sim_mask).bool(), 0.0) + sim = torch.sum(sim, [1, 2]) / torch.sum(sim_mask, [1, 2]) + else: + sim = torch.mean(sim, [1, 2]) + + return sim.unsqueeze(-1) + + def frame_to_frame_similarity(self, x): + x, a = self.attention(x) + sim = torch.einsum("biok,bjpk->biopj", x, x) + return torch.mean(sim, [2, 3]) + + def forward(self, x): + return self.mlp(x) diff --git a/model/similarities.py b/model/similarities.py new file mode 100644 index 0000000..b138fe0 --- /dev/null +++ b/model/similarities.py @@ -0,0 +1,121 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +class TensorDot(nn.Module): + + def __init__(self, pattern='iak,jbk->iabj', metric='cosine'): + super(TensorDot, self).__init__() + self.pattern = pattern + self.metric = metric + + def forward(self, query, target): + if self.metric == 'cosine': + sim = torch.einsum(self.pattern, [query, target]) + elif self.metric == 'euclidean': + sim = 1 - 2 * torch.einsum(self.pattern, [query, target]) + elif self.metric == 'hamming': + sim = torch.einsum(self.pattern, query, target) / target.shape[-1] + return sim + + def __repr__(self,): + return '{}(pattern={})'.format(self.__class__.__name__, self.pattern) + + +class ChamferSimilarity(nn.Module): + + def __init__(self, symmetric=False, axes=[1, 0]): + super(ChamferSimilarity, self).__init__() + self.axes = axes + if symmetric: + self.sim_fun = lambda x, m: self.symmetric_chamfer_similarity(x, mask=m, axes=axes) + else: + self.sim_fun = lambda x, m: self.chamfer_similarity(x, mask=m, max_axis=axes[0], mean_axis=axes[1]) + + def chamfer_similarity(self, s, mask=None, max_axis=1, mean_axis=0): + if mask is not None: + s = s.masked_fill((1 - mask).bool(), -np.inf) + s = torch.max(s, max_axis, keepdim=True)[0] + mask = torch.max(mask, max_axis, keepdim=True)[0] + s = s.masked_fill((1 - mask).bool(), 0.0) + s = torch.sum(s, mean_axis, keepdim=True) + s /= torch.sum(mask, mean_axis, keepdim=True) + else: + s = torch.max(s, max_axis, keepdim=True)[0] + s = torch.mean(s, mean_axis, keepdim=True) + return s.squeeze(max(max_axis, mean_axis)).squeeze(min(max_axis, mean_axis)) + + def symmetric_chamfer_similarity(self, s, mask=None, axes=[0, 1]): + return (self.chamfer_similarity(s, mask=mask, max_axis=axes[0], mean_axis=axes[1]) + + self.chamfer_similarity(s, mask=mask, max_axis=axes[1], mean_axis=axes[0])) / 2 + + def forward(self, s, mask=None): + return self.sim_fun(s, mask) + + def __repr__(self,): + return '{}(max_axis={}, mean_axis={})'.format(self.__class__.__name__, self.axes[0], self.axes[1]) + + +class VideoComperator(nn.Module): + + def __init__(self, in_channels=1, out_channels=1): + super(VideoComperator, self).__init__() + + self.rpad1 = nn.ZeroPad2d(1) + self.conv1 = nn.Conv2d(in_channels, 32, 3) + self.pool1 = nn.MaxPool2d((2, 2), 2) + + self.rpad2 = nn.ZeroPad2d(1) + self.conv2 = nn.Conv2d(32, 64, 3) + self.pool2 = nn.MaxPool2d((2, 2), 2) + + self.rpad3 = nn.ZeroPad2d(1) + self.conv3 = nn.Conv2d(64, 128, 3) + + self.fconv = nn.Conv2d(128, out_channels, 1) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, sim_matrix, mask=None): + if sim_matrix.ndim == 3: + sim_matrix = sim_matrix.unsqueeze(1) + elif sim_matrix.ndim != 4: + raise Exception('Input tensor to VideoComperator have to be 3- or 4-dimensional') + + if mask is not None: + assert mask.shape[-2:] == sim_matrix.shape[-2:], 'Mask tensor must be of the same shape as similarity ' \ + 'matrix in the last two dimensions. Mask shape is {} ' \ + 'while similarity matrix is {}'.format(mask.shape[-2:], + sim_matrix.shape[-2:]) + mask = mask.unsqueeze(1) + + sim = self.rpad1(sim_matrix) + sim = self.conv1(sim) + if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0) + sim = F.relu(sim) + sim = self.pool1(sim) + if mask is not None: mask = self.pool1(mask) + + sim = self.rpad2(sim) + sim = self.conv2(sim) + if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0) + sim = F.relu(sim) + sim = self.pool2(sim) + if mask is not None: mask = self.pool2(mask) + + sim = self.rpad3(sim) + sim = self.conv3(sim) + if mask is not None: sim = sim.masked_fill((1 - mask).bool(), 0.0) + sim = F.relu(sim) + + sim = self.fconv(sim) + return sim.squeeze(1), mask.squeeze(1) if mask is not None else None diff --git a/model/students.py b/model/students.py new file mode 100644 index 0000000..98f5688 --- /dev/null +++ b/model/students.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import * +from einops import rearrange + + +model_urls = { + 'dns_cg_student': 'https://mever.iti.gr/distill-and-select/models/dns_cg_student.pth', + 'dns_fg_att_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_att_student.pth', + 'dns_fg_bin_student': 'https://mever.iti.gr/distill-and-select/models/dns_fg_bin_student.pth', +} + + +class CoarseGrainedStudent(nn.Module): + + def __init__(self, + dims=512, + attention=True, + transformer=True, + transformer_heads=8, + transformer_feedforward_dims=2048, + transformer_layers=1, + netvlad=True, + netvlad_clusters=64, + netvlad_outdims=1024, + pretrained=False, + **kwargs + ): + super(CoarseGrainedStudent, self).__init__() + self.student_type = 'cg' + + if attention: + self.attention = Attention(dims, norm=False) + if transformer: + encoder_layer = nn.TransformerEncoderLayer(dims, + transformer_heads, + transformer_feedforward_dims) + self.transformer = nn.TransformerEncoder(encoder_layer, + transformer_layers, + nn.LayerNorm(dims)) + self.apply(self._init_weights) + + if netvlad: + self.netvlad = NetVLAD(dims, netvlad_clusters, outdims=netvlad_outdims) + + if pretrained: + self.load_state_dict( + torch.hub.load_state_dict_from_url( + model_urls['dns_cg_student'])['model']) + + def get_network_name(self,): + return '{}_student'.format(self.student_type) + + def calculate_video_similarity(self, query, target): + return torch.mm(query, torch.transpose(target, 0, 1)) + + def index_video(self, x, mask=None): + x, mask = check_dims(x, mask) + + if hasattr(self, 'attention'): + x, a = self.attention(x) + x = torch.sum(x, 2) + x = F.normalize(x, p=2, dim=-1) + + if hasattr(self, 'transformer'): + x = x.permute(1, 0, 2) + x = self.transformer(x, src_key_padding_mask= + (1 - mask).bool() if mask is not None else None) + x = x.permute(1, 0, 2) + + if hasattr(self, 'netvlad'): + x = x.unsqueeze(2).permute(0, 3, 1, 2) + x = self.netvlad(x, mask=mask) + else: + if mask is not None: + x = x.masked_fill((1 - mask.unsqueeze(-1)).bool(), 0.0) + x = torch.sum(x, 1) / torch.sum(mask, 1, keepdim=True) + else: + x = torch.mean(x, 1) + return F.normalize(x, p=2, dim=-1) + + def forward(self, anchors, positives, negatives, + anchors_masks=None, positive_masks=None, negative_masks=None): + pos_pairs = torch.sum(anchors * positives, 1, keepdim=True) + neg_pairs = torch.sum(anchors * negatives, 1, keepdim=True) + return pos_pairs, neg_pairs, None + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +class FineGrainedStudent(nn.Module): + + def __init__(self, + dims=512, + attention=False, + binarization=False, + pretrained=False, + **kwargs + ): + super(FineGrainedStudent, self).__init__() + self.student_type = 'fg' + if attention and binarization: + raise Exception('Can\'t use \'attention=True\' and \'binarization=True\' at the same time. ' + 'Select one of the two options.') + elif binarization: + self.fg_type = 'bin' + self.binarization = BinarizationLayer(dims) + elif attention: + self.fg_type = 'att' + self.attention = Attention(dims, norm=False) + else: + self.fg_type = 'none' + + self.f2f_sim = ChamferSimilarity(axes=[3, 2]) + + self.visil_head = VideoComperator() + self.htanh = nn.Hardtanh() + + self.v2v_sim = ChamferSimilarity(axes=[2, 1]) + + self.sim_criterion = SimilarityRegularizationLoss() + + if pretrained: + if not (attention or binarization): + raise Exception('No pretrained model provided for the selected settings. ' + 'Use either \'attention=True\' or \'binarization=True\' to load a pretrained model.') + self.load_state_dict( + torch.hub.load_state_dict_from_url( + model_urls['dns_fg_{}_student'.format(self.fg_type)])['model']) + + def get_network_name(self,): + return '{}_{}_student'.format(self.student_type, self.fg_type) + + def frame_to_frame_similarity(self, query, target, query_mask=None, target_mask=None, batched=False): + d = target.shape[-1] + sim_mask = None + if batched: + sim = torch.einsum('biok,bjpk->biopj', query, target) + sim = self.f2f_sim(sim) + if query_mask is not None and target_mask is not None: + sim_mask = torch.einsum('bik,bjk->bij', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1)) + else: + sim = torch.einsum('aiok,bjpk->aiopjb', query, target) + sim = self.f2f_sim(sim) + sim = rearrange(sim, 'a i j b -> (a b) i j') + if query_mask is not None and target_mask is not None: + sim_mask = torch.einsum('aik,bjk->aijb', query_mask.unsqueeze(-1), target_mask.unsqueeze(-1)) + sim_mask = rearrange(sim_mask, 'a i j b -> (a b) i j') + if self.fg_type == 'bin': + sim /= d + if sim_mask is not None: + sim = sim.masked_fill((1 - sim_mask).bool(), 0.0) + return sim, sim_mask + + def calculate_video_similarity(self, query, target, query_mask=None, target_mask=None): + query, query_mask = check_dims(query, query_mask) + target, target_mask = check_dims(target, target_mask) + + sim, sim_mask = self.similarity_matrix(query, target, query_mask, target_mask) + sim = self.v2v_sim(sim, sim_mask) + + return sim.view(query.shape[0], target.shape[0]) + + def similarity_matrix(self, query, target, query_mask=None, target_mask=None): + query, query_mask = check_dims(query, query_mask) + target, target_mask = check_dims(target, target_mask) + + sim, sim_mask = self.frame_to_frame_similarity(query, target, query_mask, target_mask) + sim, sim_mask = self.visil_head(sim, sim_mask) + return self.htanh(sim), sim_mask + + def index_video(self, x, mask=None): + if self.fg_type == 'bin': + x = self.binarization(x) + elif self.fg_type == 'att': + x, a = self.attention(x) + if mask is not None: + x = x.masked_fill((1 - mask).bool().unsqueeze(-1).unsqueeze(-1), 0.0) + return x + + def forward(self, anchors, positives, negatives, + anchors_masks, positive_masks, negative_masks): + pos_sim, pos_mask = self.frame_to_frame_similarity( + anchors, positives, anchors_masks, positive_masks, batched=True) + neg_sim, neg_mask = self.frame_to_frame_similarity( + anchors, negatives, anchors_masks, negative_masks, batched=True) + sim, sim_mask = torch.cat([pos_sim, neg_sim], 0), torch.cat([pos_mask, neg_mask], 0) + + sim, sim_mask = self.visil_head(sim, sim_mask) + loss = self.sim_criterion(sim) + sim = self.htanh(sim) + sim = self.v2v_sim(sim, sim_mask) + + pos_pair, neg_pair = torch.chunk(sim.unsqueeze(-1), 2, dim=0) + return pos_pair, neg_pair, loss diff --git a/output_imgs/cg_student.png b/output_imgs/cg_student.png new file mode 100644 index 0000000..b2bbacf Binary files /dev/null and b/output_imgs/cg_student.png differ diff --git a/output_imgs/cg_student_specifical.png b/output_imgs/cg_student_specifical.png new file mode 100644 index 0000000..e1b355c Binary files /dev/null and b/output_imgs/cg_student_specifical.png differ diff --git a/output_imgs/feature_extractor.png b/output_imgs/feature_extractor.png new file mode 100644 index 0000000..2bf1241 Binary files /dev/null and b/output_imgs/feature_extractor.png differ diff --git a/output_imgs/fg_att_student.png b/output_imgs/fg_att_student.png new file mode 100644 index 0000000..451ff4c Binary files /dev/null and b/output_imgs/fg_att_student.png differ diff --git a/output_imgs/fg_bin_student.png b/output_imgs/fg_bin_student.png new file mode 100644 index 0000000..22536c1 Binary files /dev/null and b/output_imgs/fg_bin_student.png differ diff --git a/output_imgs/selector_att.png b/output_imgs/selector_att.png new file mode 100644 index 0000000..7e17830 Binary files /dev/null and b/output_imgs/selector_att.png differ diff --git a/output_imgs/selector_bin.png b/output_imgs/selector_bin.png new file mode 100644 index 0000000..9c9de9c Binary files /dev/null and b/output_imgs/selector_bin.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2fc395a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch +einops +towhee +Pillow \ No newline at end of file