timm
copied
Jael Gu
3 years ago
2 changed files with 47 additions and 0 deletions
@ -0,0 +1,2 @@ |
|||||
|
numpy |
||||
|
timm>=0.5.4 |
@ -0,0 +1,45 @@ |
|||||
|
import numpy |
||||
|
import torch |
||||
|
from typing import NamedTuple |
||||
|
|
||||
|
from towhee.operator.base import NNOperator |
||||
|
from towhee.utils.pil_utils import to_pil |
||||
|
from towhee.types.image import Image as towheeImage |
||||
|
|
||||
|
from torch import nn |
||||
|
|
||||
|
from timm.data.transforms_factory import create_transform |
||||
|
from timm.data import resolve_data_config |
||||
|
from timm.models.factory import create_model |
||||
|
|
||||
|
import warnings |
||||
|
warnings.filterwarnings('ignore') |
||||
|
|
||||
|
class TimmImage(NNOperator): |
||||
|
""" |
||||
|
Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. |
||||
|
Args: |
||||
|
model_name (`str`): |
||||
|
Which model to use for the embeddings. |
||||
|
""" |
||||
|
def __init__(self, model_name: str, num_classes: int = 1000) -> None: |
||||
|
super().__init__() |
||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
||||
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
||||
|
self.model.to(self.device) |
||||
|
self.model.eval() |
||||
|
config = resolve_data_config({}, model=self.model) |
||||
|
self.tfms = create_transform(**config) |
||||
|
|
||||
|
def __call__(self, image: 'towheeImage') -> NamedTuple('Outputs', [('vec', numpy.ndarray)]): |
||||
|
img = self.tfms(to_pil(image)).unsqueeze(0) |
||||
|
img = img.to(self.device) |
||||
|
features = self.model.forward_features(img) |
||||
|
if features.dim() == 4: |
||||
|
global_pool = nn.AdaptiveAvgPool2d(1) |
||||
|
features = global_pool(features) |
||||
|
|
||||
|
features = features.to('cpu') |
||||
|
feature_vector = features.flatten().detach().numpy() |
||||
|
Outputs = NamedTuple('Outputs', [('vec', numpy.ndarray)]) |
||||
|
return Outputs(feature_vector) |
Loading…
Reference in new issue