logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 4 years ago
parent
commit
71d25f60e1
  1. 6
      efficientnet_image_embedding.py

6
efficientnet_image_embedding.py

@ -18,11 +18,12 @@ import torch
from torchvision import transforms from torchvision import transforms
import sys import sys
from pathlib import Path from pathlib import Path
import numpy
from towhee.operator import Operator from towhee.operator import Operator
class EfficientnetEmbeddingOperator(Operator):
class EfficientnetImageEmbeddingOperator(Operator):
""" """
Embedding extractor using efficientnet. Embedding extractor using efficientnet.
Args: Args:
@ -41,8 +42,9 @@ class EfficientnetEmbeddingOperator(Operator):
self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]):
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)])
img = self.tfms(Image.open(img_path)).unsqueeze(0) img = self.tfms(Image.open(img_path)).unsqueeze(0)
features = self.model._model.extract_features(img) features = self.model._model.extract_features(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(features.flatten().detach().numpy()) return Outputs(features.flatten().detach().numpy())

Loading…
Cancel
Save