diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7dba25f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy +timm>=0.5.4 diff --git a/timm_image.py b/timm_image.py new file mode 100644 index 0000000..fef714f --- /dev/null +++ b/timm_image.py @@ -0,0 +1,45 @@ +import numpy +import torch +from typing import NamedTuple + +from towhee.operator.base import NNOperator +from towhee.utils.pil_utils import to_pil +from towhee.types.image import Image as towheeImage + +from torch import nn + +from timm.data.transforms_factory import create_transform +from timm.data import resolve_data_config +from timm.models.factory import create_model + +import warnings +warnings.filterwarnings('ignore') + +class TimmImage(NNOperator): + """ + Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. + Args: + model_name (`str`): + Which model to use for the embeddings. + """ + def __init__(self, model_name: str, num_classes: int = 1000) -> None: + super().__init__() + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = create_model(model_name, pretrained=True, num_classes=num_classes) + self.model.to(self.device) + self.model.eval() + config = resolve_data_config({}, model=self.model) + self.tfms = create_transform(**config) + + def __call__(self, image: 'towheeImage') -> NamedTuple('Outputs', [('vec', numpy.ndarray)]): + img = self.tfms(to_pil(image)).unsqueeze(0) + img = img.to(self.device) + features = self.model.forward_features(img) + if features.dim() == 4: + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + + features = features.to('cpu') + feature_vector = features.flatten().detach().numpy() + Outputs = NamedTuple('Outputs', [('vec', numpy.ndarray)]) + return Outputs(feature_vector)