diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index d45bd86..ecd47dd 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -18,11 +18,12 @@ import torch from torchvision import transforms import sys from pathlib import Path +import numpy from towhee.operator import Operator -class EfficientnetEmbeddingOperator(Operator): +class EfficientnetImageEmbeddingOperator(Operator): """ Embedding extractor using efficientnet. Args: @@ -41,8 +42,9 @@ class EfficientnetEmbeddingOperator(Operator): self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) img = self.tfms(Image.open(img_path)).unsqueeze(0) features = self.model._model.extract_features(img) + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(features.flatten().detach().numpy())