From 951ae888e72407621f72987df1c23f884b18f3f6 Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 29 Dec 2021 20:01:30 +0800 Subject: [PATCH] change the input format. --- README.md | 9 +++++---- resnet_image_embedding.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index be4df3b..d3330d6 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,15 @@ __init__(self, model_name: str, framework: str = 'pytorch') - supported types: `str`, default is 'pytorch' ```python -__call__(self, img_tensor: torch.Tensor) +__call__(self, image: 'towhee.types.Image') ``` **Args:** -- img_tensor: - - the input image tensor - - supported types: `torch.Tensor` + image: + - the input image + - supported types: `towhee.types.Image` + **Returns:** diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index d0c91c8..b62b9a4 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -22,6 +22,7 @@ from typing import NamedTuple import os from torchvision.transforms import InterpolationMode from towhee.operator import Operator +from towhee.utils.pil_utils import to_pil import warnings warnings.filterwarnings("ignore") @@ -44,9 +45,8 @@ class ResnetImageEmbedding(Operator): transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0) + def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + img = self.tfms(to_pil(image)).unsqueeze(0) embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding)