From 71d25f60e1467b026b74ef35af5a3d063cdcb9da Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 15 Dec 2021 20:59:13 +0800 Subject: [PATCH] Update Signed-off-by: shiyu22 --- efficientnet_image_embedding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index d45bd86..ecd47dd 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -18,11 +18,12 @@ import torch from torchvision import transforms import sys from pathlib import Path +import numpy from towhee.operator import Operator -class EfficientnetEmbeddingOperator(Operator): +class EfficientnetImageEmbeddingOperator(Operator): """ Embedding extractor using efficientnet. Args: @@ -41,8 +42,9 @@ class EfficientnetEmbeddingOperator(Operator): self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) img = self.tfms(Image.open(img_path)).unsqueeze(0) features = self.model._model.extract_features(img) + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(features.flatten().detach().numpy())