swag
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
182 lines
5.9 KiB
182 lines
5.9 KiB
3 years ago
|
# Copyright 2021 Zilliz. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import logging
|
||
|
import numpy
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
|
||
|
import towhee
|
||
|
from towhee.operator.base import NNOperator, OperatorFlag
|
||
|
from towhee.types.arg import arg, to_image_color
|
||
|
from towhee import register
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torchvision import transforms
|
||
|
|
||
|
from PIL import Image as PILImage
|
||
|
|
||
|
import warnings
|
||
|
|
||
|
warnings.filterwarnings('ignore')
|
||
|
log = logging.getLogger()
|
||
|
|
||
|
|
||
|
@register(output_schema=['vec'])
|
||
|
class Swag(NNOperator):
|
||
|
"""
|
||
|
Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection.
|
||
|
Args:
|
||
|
model_name (`str`):
|
||
|
Which model to use for the embeddings.
|
||
|
skip_preprocess (`bool = False`):
|
||
|
Whether skip image transforms.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_name: str, skip_preprocess: bool = False) -> None:
|
||
|
super().__init__()
|
||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
|
||
|
self.skip_tfms = skip_preprocess
|
||
|
self.tfms = self.get_transforms(model_name)
|
||
|
|
||
|
self.model_name = model_name
|
||
|
self.model = torch.hub.load("facebookresearch/swag", model=model_name)
|
||
|
self.model.to(self.device)
|
||
|
self.model.eval()
|
||
|
self.extract_features = FeatureExtractor(self.model)
|
||
|
|
||
|
@arg(1, to_image_color('RGB'))
|
||
|
def __call__(self, img: towhee._types.Image) -> numpy.ndarray:
|
||
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
|
||
|
if not self.skip_tfms:
|
||
|
img = self.tfms(img).unsqueeze(0)
|
||
|
img = img.to(self.device)
|
||
|
features, _ = self.extract_features(img)
|
||
|
if features.dim() == 4:
|
||
|
global_pool = nn.AdaptiveAvgPool2d(1)
|
||
|
features = global_pool(features)
|
||
|
|
||
|
features = features.to('cpu')
|
||
|
vec = features.flatten().detach().numpy()
|
||
|
return vec
|
||
|
|
||
|
def save_model(self, format: str = 'pytorch', path: str = 'default'):
|
||
|
if path == 'default':
|
||
|
path = str(Path(__file__).parent)
|
||
|
name = self.model_name.replace('/', '-')
|
||
|
path = os.path.join(path, name)
|
||
|
inputs = torch.ones(1, 3, 224, 224)
|
||
|
if format == 'pytorch':
|
||
|
torch.save(self.model, path)
|
||
|
elif format == 'torchscript':
|
||
|
path = path + '.pt'
|
||
|
try:
|
||
|
try:
|
||
|
jit_model = torch.jit.script(self.model)
|
||
|
except Exception:
|
||
|
jit_model = torch.jit.trace(self.model, inputs, strict=False)
|
||
|
torch.jit.save(jit_model, path)
|
||
|
except Exception as e:
|
||
|
log.error(f'Fail to save as torchscript: {e}.')
|
||
|
raise RuntimeError(f'Fail to save as torchscript: {e}.')
|
||
|
elif format == 'onxx':
|
||
|
pass # todo
|
||
|
else:
|
||
|
log.error(f'Save model: unsupported format "{format}".')
|
||
|
|
||
|
@staticmethod
|
||
|
def supported_model_names(format: str = None):
|
||
|
full_list = [
|
||
|
'vit_h14_in1k',
|
||
|
'vit_l16_in1k',
|
||
|
'vit_b16_in1k',
|
||
|
'regnety_16gf_in1k',
|
||
|
'regnety_32gf_in1k',
|
||
|
'regnety_128gf_in1k',
|
||
|
]
|
||
|
full_list.sort()
|
||
|
if format is None:
|
||
|
model_list = full_list
|
||
|
elif format == 'pytorch':
|
||
|
to_remove = []
|
||
|
assert set(to_remove).issubset(set(full_list))
|
||
|
model_list = list(set(full_list) - set(to_remove))
|
||
|
else: # todo: format in {'torchscript', 'onnx', 'tensorrt'}
|
||
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
|
||
|
return model_list
|
||
|
|
||
|
@staticmethod
|
||
|
def get_transforms(model_name):
|
||
|
model_resolution = {
|
||
|
'vit_h14_in1k': 518,
|
||
|
'vit_l16_in1k': 512,
|
||
|
'vit_b16_in1k': 384,
|
||
|
'regnety_16gf_in1k': 384,
|
||
|
'regnety_32gf_in1k': 384,
|
||
|
'regnety_128gf_in1k': 384
|
||
|
}
|
||
|
if model_name not in model_resolution.keys():
|
||
|
log.warning('No transforms specified for model "%s", using resolution 384.', model_name)
|
||
|
resolution = 384
|
||
|
else:
|
||
|
resolution = model_resolution[model_name]
|
||
|
transform = transforms.Compose([
|
||
|
transforms.Resize(
|
||
|
resolution,
|
||
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
||
|
),
|
||
|
transforms.CenterCrop(resolution),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(
|
||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||
|
),
|
||
|
])
|
||
|
return transform
|
||
|
|
||
|
|
||
|
class FeatureExtractor(nn.Module):
|
||
|
def __init__(self, model: nn.Module):
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.features = None
|
||
|
|
||
|
for name, child in self.model.named_children():
|
||
|
if name == 'trunk_output':
|
||
|
self.handler = child.register_forward_hook(self.save_outputs_hook())
|
||
|
|
||
|
def save_outputs_hook(self):
|
||
|
def fn(_, __, output):
|
||
|
self.features = output
|
||
|
return fn
|
||
|
|
||
|
def forward(self, x):
|
||
|
outs = self.model(x)
|
||
|
self.handler.remove()
|
||
|
return self.features, outs
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
from towhee import ops
|
||
|
|
||
|
path = '/Users/mengjiagu/Desktop/models/data/image/animals10/bird.jpg'
|
||
|
|
||
|
decoder = ops.image_decode.cv2()
|
||
|
img = decoder(path)
|
||
|
op = Swag('vit_b16_in1k')
|
||
|
# op = Swag('regnety_16gf_in1k')
|
||
|
out = op(img)
|
||
|
print(out.shape)
|