From d6516ba18957735676b8b6172dc5ed4309842d52 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 24 Jun 2022 11:34:39 +0800 Subject: [PATCH] Allow list as input Signed-off-by: Jael Gu --- README.md | 8 ++++---- timm_image.py | 32 ++++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 4313627..c40195d 100644 --- a/README.md +++ b/README.md @@ -79,12 +79,12 @@ It uses the pre-trained model specified by model name to generate an image embed **Parameters:** -***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* - -The decoded image data in numpy.ndarray. +***data:*** *Union[List[towhee._types.Image], towhee._types.Image]* +The decoded image data in numpy.ndarray. It allows both single input and a list for batch input. **Returns:** *numpy.ndarray* -The image embedding extracted by model. +If only 1 image input, then output is an image embedding in shape of (feature_dim,). +If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim). diff --git a/timm_image.py b/timm_image.py index c7f3843..8ef3583 100644 --- a/timm_image.py +++ b/timm_image.py @@ -16,6 +16,7 @@ import logging import numpy import os from pathlib import Path +from typing import List, Union import towhee from towhee.operator.base import NNOperator, OperatorFlag @@ -62,20 +63,31 @@ class TimmImage(NNOperator): self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess - @arg(1, to_image_color('RGB')) - def __call__(self, img: towhee._types.Image) -> numpy.ndarray: - img = PILImage.fromarray(img.astype('uint8'), 'RGB') - if not self.skip_tfms: - img = self.tfms(img).unsqueeze(0) - img = img.to(self.device) - features = self.model.forward_features(img) + def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): + if isinstance(data, list): + imgs = [] + for img in data: + img = self.convert_img(img) + img = img if self.skip_tfms else self.tfms(img) + imgs.append(img) + inputs = torch.stack(imgs) + else: + img = self.convert_img(data) + img = img if self.skip_tfms else self.tfms(img) + inputs = img.unsqueeze(0) + inputs = inputs.to(self.device) + features = self.model.forward_features(inputs) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) - features = features.to('cpu') - vec = features.flatten().detach().numpy() - return vec + vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy() + return vecs + + @arg(1, to_image_color('RGB')) + def convert_img(self, img: towhee._types.Image): + img = PILImage.fromarray(img.astype('uint8'), 'RGB') + return img def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default':