import numpy import torch from typing import NamedTuple from towhee.operator.base import NNOperator from towhee.utils.pil_utils import to_pil from towhee.types.image import Image as towheeImage from torch import nn from timm.data.transforms_factory import create_transform from timm.data import resolve_data_config from timm.models.factory import create_model import warnings warnings.filterwarnings('ignore') class TimmImage(NNOperator): """ Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. Args: model_name (`str`): Which model to use for the embeddings. """ def __init__(self, model_name: str, num_classes: int = 1000) -> None: super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model.to(self.device) self.model.eval() config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**config) def __call__(self, image: 'towheeImage') -> NamedTuple('Outputs', [('vec', numpy.ndarray)]): img = self.tfms(to_pil(image)).unsqueeze(0) img = img.to(self.device) features = self.model.forward_features(img) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.to('cpu') feature_vector = features.flatten().detach().numpy() Outputs = NamedTuple('Outputs', [('vec', numpy.ndarray)]) return Outputs(feature_vector)