towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
46 lines
1.6 KiB
46 lines
1.6 KiB
3 years ago
|
import torch
|
||
|
import numpy as np
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torchvision.models as models
|
||
|
|
||
|
from .layers import *
|
||
|
|
||
|
|
||
|
class FeatureExtractor(nn.Module):
|
||
|
|
||
|
def __init__(self, network='resnet50', whiteninig=False, dims=3840):
|
||
|
super(FeatureExtractor, self).__init__()
|
||
|
self.normalizer = VideoNormalizer()
|
||
|
|
||
|
self.cnn = models.resnet50(pretrained=True)
|
||
|
|
||
|
self.rpool = RMAC()
|
||
|
self.layers = {'layer1': 28, 'layer2': 14, 'layer3': 6, 'layer4': 3}
|
||
|
if whiteninig or dims != 3840:
|
||
|
self.pca = PCA(dims)
|
||
|
|
||
|
def extract_region_vectors(self, x):
|
||
|
tensors = []
|
||
|
for nm, module in self.cnn._modules.items():
|
||
|
if nm not in {'avgpool', 'fc', 'classifier'}:
|
||
|
x = module(x).contiguous()
|
||
|
if nm in self.layers:
|
||
|
# region_vectors = self.rpool(x)
|
||
|
s = self.layers[nm]
|
||
|
region_vectors = F.max_pool2d(x, [s, s], int(np.ceil(s / 2)))
|
||
|
region_vectors = F.normalize(region_vectors, p=2, dim=1)
|
||
|
tensors.append(region_vectors)
|
||
|
for i in range(len(tensors)):
|
||
|
tensors[i] = F.normalize(F.adaptive_max_pool2d(tensors[i], tensors[-1].shape[2:]), p=2, dim=1)
|
||
|
x = torch.cat(tensors, 1)
|
||
|
x = x.view(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
|
||
|
x = F.normalize(x, p=2, dim=-1)
|
||
|
return x
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.normalizer(x)
|
||
|
x = self.extract_region_vectors(x)
|
||
|
if hasattr(self, 'pca'):
|
||
|
x = self.pca(x)
|
||
|
return x
|