dolg
copied
6 changed files with 283 additions and 1 deletions
@ -1,2 +1,84 @@ |
|||||
# dolg |
|
||||
|
# Image Embdding with DOLG |
||||
|
|
||||
|
*author: David Wang* |
||||
|
|
||||
|
|
||||
|
|
||||
|
## Desription |
||||
|
|
||||
|
This operator extracts features for image with [DOLG](https://arxiv.org/abs/2108.02927) which has special design for image retrieval task. It integrates local and global information inside images into compact image representations. This operator is an adaptation from [dongkyuk/DOLG-pytorch](https://github.com/dongkyuk/DOLG-pytorch). |
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
Load an image from path './towhee.jpg' to generate an image embedding. |
||||
|
|
||||
|
*Write the pipeline in simplified style*: |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
towhee.glob('./towhee.jpg') \ |
||||
|
.image_decode.cv2() \ |
||||
|
.image_embedding.dolg(img_size=512, input_dim=3, hidden_dim=1024, output_dim=2048) \ |
||||
|
.show() |
||||
|
|
||||
|
``` |
||||
|
<img src="https://towhee.io/image-embedding/dolg/raw/branch/main/result1.png" alt="result1" style="height:20px;"/> |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
towhee.glob['path']('./towhee.jpg') \ |
||||
|
.image_decode.cv2['path', 'img']() \ |
||||
|
.image_embedding.dolg['img', 'vec'](img_size=512, input_dim=3, hidden_dim=1024, output_dim=2048) \ |
||||
|
.select('img', 'vec') \ |
||||
|
.show() |
||||
|
``` |
||||
|
<img src="https://towhee.io/image-embedding/dolg/raw/branch/main/result2.png" alt="result2" style="height:60px;"/> |
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method |
||||
|
|
||||
|
***image_embedding.dolg(img_size=512, input_dim=3, hidden_dim=1024, output_dim=2048)*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***img_size***: *int* |
||||
|
|
||||
|
Scaled input image size to extract embedding. The higher resolution would generate the more discriminateive feature but cost more time to calculate. |
||||
|
|
||||
|
***input_dim***: *int* |
||||
|
|
||||
|
The input dimension of DOLG module (equals pretrained cnn output dimension). |
||||
|
|
||||
|
***hidden_dim***: *int* |
||||
|
|
||||
|
The hidden dimension size, local feature branch output dimension. |
||||
|
|
||||
|
***output_dim***: *int* |
||||
|
|
||||
|
The output dimsion size, same as embedding size. |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
An image embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input. |
||||
|
It uses the pre-trained model specified by model name to generate an image embedding in ndarray. |
||||
|
|
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***img***: *towhee.types.Image (a sub-class of numpy.ndarray)* |
||||
|
|
||||
|
The decoded image data in towhee.types.Image (numpy.ndarray). |
||||
|
|
||||
|
|
||||
|
|
||||
|
**Returns**: *numpy.ndarray* |
||||
|
|
||||
|
The image embedding extracted by model. |
||||
|
|
||||
|
|
||||
|
|
||||
|
@ -0,0 +1,19 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .dolg import Swintransformer |
||||
|
|
||||
|
def dolg(img_size=512, input_dim=3, hidden_dim=1024, output_dim=2048): |
||||
|
return Dolg(img_size, input_dim, hidden_dim, output_dim) |
||||
|
|
@ -0,0 +1,52 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
import logging |
||||
|
import numpy |
||||
|
import towhee |
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
from torchvision import transforms |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
|
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee import register |
||||
|
|
||||
|
@register(output_schema=['vec']) |
||||
|
class Dolg(NNOperator): |
||||
|
""" |
||||
|
DOLG Embedding Operator |
||||
|
""" |
||||
|
def __init__(self, img_size, input_dim, hidden_dim, output_dim): |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
from dolg_impl import DolgNet |
||||
|
self.model = DolgNet(img_size, input_dim, hidden_dim, output_dim) |
||||
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
||||
|
std=[0.229, 0.224, 0.225]) |
||||
|
self.tfms = transforms.Compose([transforms.Resize([img_size, img_size]), |
||||
|
transforms.Scale([img_size, img_size]), |
||||
|
transforms.ToTensor(), |
||||
|
normalize]) |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def __call__(self, img: numpy.ndarray): |
||||
|
img = self.tfms(to_pil(img)).unsqueeze(0) |
||||
|
self.model.eval() |
||||
|
features = self.model(img) |
||||
|
feature_vector = features.flatten().detach().numpy() |
||||
|
return feature_vector |
||||
|
|
@ -0,0 +1,129 @@ |
|||||
|
import timm |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
import torch.optim as optim |
||||
|
|
||||
|
class GeM(nn.Module): |
||||
|
def __init__(self, p=3, eps=1e-6, requires_grad=False): |
||||
|
super(GeM, self).__init__() |
||||
|
self.p = nn.Parameter(torch.ones(1)*p, requires_grad=requires_grad) |
||||
|
self.eps = eps |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.gem(x, p=self.p, eps=self.eps) |
||||
|
|
||||
|
def gem(self, x, p=3, eps=1e-6): |
||||
|
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) |
||||
|
|
||||
|
def __repr__(self): |
||||
|
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' |
||||
|
|
||||
|
class MultiAtrous(nn.Module): |
||||
|
def __init__(self, in_channel, out_channel, size, dilation_rates=[3, 6, 9]): |
||||
|
super().__init__() |
||||
|
self.dilated_convs = [ |
||||
|
nn.Conv2d(in_channel, int(out_channel/4), |
||||
|
kernel_size=3, dilation=rate, padding=rate) |
||||
|
for rate in dilation_rates |
||||
|
] |
||||
|
self.gap_branch = nn.Sequential( |
||||
|
nn.AdaptiveAvgPool2d(1), |
||||
|
nn.Conv2d(in_channel, int(out_channel/4), kernel_size=1), |
||||
|
nn.ReLU(), |
||||
|
nn.Upsample(size=(size, size), mode='bilinear') |
||||
|
) |
||||
|
self.dilated_convs.append(self.gap_branch) |
||||
|
self.dilated_convs = nn.ModuleList(self.dilated_convs) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
local_feat = [] |
||||
|
for dilated_conv in self.dilated_convs: |
||||
|
local_feat.append(dilated_conv(x)) |
||||
|
local_feat = torch.cat(local_feat, dim=1) |
||||
|
return local_feat |
||||
|
|
||||
|
|
||||
|
class DolgLocalBranch(nn.Module): |
||||
|
def __init__(self, img_size, in_channel, out_channel, hidden_channel=2048): |
||||
|
super().__init__() |
||||
|
self.multi_atrous = MultiAtrous(in_channel, hidden_channel, size=int(img_size/8)) |
||||
|
self.conv1x1_1 = nn.Conv2d(hidden_channel, out_channel, kernel_size=1) |
||||
|
self.conv1x1_2 = nn.Conv2d( |
||||
|
out_channel, out_channel, kernel_size=1, bias=False) |
||||
|
self.conv1x1_3 = nn.Conv2d(out_channel, out_channel, kernel_size=1) |
||||
|
|
||||
|
self.relu = nn.ReLU() |
||||
|
self.bn = nn.BatchNorm2d(out_channel) |
||||
|
self.softplus = nn.Softplus() |
||||
|
|
||||
|
def forward(self, x): |
||||
|
local_feat = self.multi_atrous(x) |
||||
|
|
||||
|
local_feat = self.conv1x1_1(local_feat) |
||||
|
local_feat = self.relu(local_feat) |
||||
|
local_feat = self.conv1x1_2(local_feat) |
||||
|
local_feat = self.bn(local_feat) |
||||
|
|
||||
|
attention_map = self.relu(local_feat) |
||||
|
attention_map = self.conv1x1_3(attention_map) |
||||
|
attention_map = self.softplus(attention_map) |
||||
|
|
||||
|
local_feat = F.normalize(local_feat, p=2, dim=1) |
||||
|
local_feat = local_feat * attention_map |
||||
|
|
||||
|
return local_feat |
||||
|
|
||||
|
class OrthogonalFusion(nn.Module): |
||||
|
def __init__(self): |
||||
|
super().__init__() |
||||
|
|
||||
|
def forward(self, local_feat, global_feat): |
||||
|
global_feat_norm = torch.norm(global_feat, p=2, dim=1) |
||||
|
projection = torch.bmm(global_feat.unsqueeze(1), torch.flatten( |
||||
|
local_feat, start_dim=2)) |
||||
|
projection = torch.bmm(global_feat.unsqueeze( |
||||
|
2), projection).view(local_feat.size()) |
||||
|
projection = projection / \ |
||||
|
(global_feat_norm * global_feat_norm).view(-1, 1, 1, 1) |
||||
|
orthogonal_comp = local_feat - projection |
||||
|
global_feat = global_feat.unsqueeze(-1).unsqueeze(-1) |
||||
|
return torch.cat([global_feat.expand(orthogonal_comp.size()), orthogonal_comp], dim=1) |
||||
|
|
||||
|
class DolgNet(nn.Module): |
||||
|
def __init__(self, img_size, input_dim, hidden_dim, output_dim): |
||||
|
super().__init__() |
||||
|
self.cnn = timm.create_model( |
||||
|
'tv_resnet101', |
||||
|
pretrained=True, |
||||
|
features_only=True, |
||||
|
in_chans=input_dim, |
||||
|
out_indices=(2, 3) |
||||
|
) |
||||
|
self.orthogonal_fusion = OrthogonalFusion() |
||||
|
self.local_branch = DolgLocalBranch(img_size, 512, hidden_dim) |
||||
|
self.gap = nn.AdaptiveAvgPool2d(1) |
||||
|
self.gem_pool = GeM() |
||||
|
self.fc_1 = nn.Linear(1024, hidden_dim) |
||||
|
self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim) |
||||
|
# |
||||
|
# self.criterion = ArcFace( |
||||
|
# in_features=output_dim, |
||||
|
# out_features=num_of_classes, |
||||
|
# scale_factor=30, |
||||
|
# margin=0.15, |
||||
|
# criterion=nn.CrossEntropyLoss() |
||||
|
# ) |
||||
|
# |
||||
|
def forward(self, x): |
||||
|
output = self.cnn(x) |
||||
|
|
||||
|
local_feat = self.local_branch(output[0]) # ,hidden_channel,16,16 |
||||
|
global_feat = self.fc_1(self.gem_pool(output[1]).squeeze(3).squeeze(2)) # ,1024 |
||||
|
|
||||
|
feat = self.orthogonal_fusion(local_feat, global_feat) |
||||
|
feat = self.gap(feat).squeeze() |
||||
|
feat = self.fc_2(feat) |
||||
|
|
||||
|
return feat |
||||
|
|
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 160 KiB |
Loading…
Reference in new issue