logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 4 years ago
parent
commit
71d25f60e1
  1. 6
      efficientnet_image_embedding.py

6
efficientnet_image_embedding.py

@ -18,11 +18,12 @@ import torch
from torchvision import transforms
import sys
from pathlib import Path
import numpy
from towhee.operator import Operator
class EfficientnetEmbeddingOperator(Operator):
class EfficientnetImageEmbeddingOperator(Operator):
"""
Embedding extractor using efficientnet.
Args:
@ -41,8 +42,9 @@ class EfficientnetEmbeddingOperator(Operator):
self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]):
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)])
img = self.tfms(Image.open(img_path)).unsqueeze(0)
features = self.model._model.extract_features(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(features.flatten().detach().numpy())

Loading…
Cancel
Save