diff --git a/README.md b/README.md index eabb5ef..9ad9307 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,14 @@ __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', - supported types: `str`, default is None, using pretrained weights ```python -__call__(self, img_path: str) +__call__(self, image: 'towhee.types.Image') ``` **Args:** -- img_path: - - the input image path - - supported types: `str` +- image: + - the input image + - supported types: `towhee.types.Image` **Returns:** diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index 11e7661..ba0857f 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -17,15 +17,15 @@ from PIL import Image import torch from torchvision import transforms import sys +import towhee from pathlib import Path import numpy from towhee.operator import Operator +from towhee.utils.pil_utils import to_pil from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform import os -import warnings -warnings.filterwarnings("ignore") class EfficientnetImageEmbedding(Operator): """ @@ -51,8 +51,8 @@ class EfficientnetImageEmbedding(Operator): config = resolve_data_config({}, model=self.model._model) self.tfms = create_transform(**config) - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - img = self.tfms(Image.open(img_path)).unsqueeze(0) + img = self.tfms(to_pil(image)).unsqueeze(0) features = self.model(img) return Outputs(features.flatten().detach().numpy())