From 254ea6162e344572a3ef89851a073996a56a5b02 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 24 Jun 2022 15:07:20 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- README.md | 5 +++-- timm_image.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c40195d..d19c969 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,8 @@ It uses the pre-trained model specified by model name to generate an image embed The decoded image data in numpy.ndarray. It allows both single input and a list for batch input. -**Returns:** *numpy.ndarray* +**Returns:** *Union[List[numpy.ndarray], numpy.ndarray]* If only 1 image input, then output is an image embedding in shape of (feature_dim,). -If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim). +If a list of images as input, then output is a same-length list of numpy.ndarray, +each of which represents an image embedding in shape of (feature_dim,). diff --git a/timm_image.py b/timm_image.py index d587b7a..aafd30b 100644 --- a/timm_image.py +++ b/timm_image.py @@ -65,20 +65,25 @@ class TimmImage(NNOperator): def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): if not isinstance(data, list): - data = [data] - imgs = [] - for img in data: + imgs = [data] + else: + imgs = data + img_list = [] + for img in imgs: img = self.convert_img(img) img = img if self.skip_tfms else self.tfms(img) - imgs.append(img) - inputs = torch.stack(imgs) + img_list.append(img) + inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model.forward_features(inputs) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) - - vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy() + features = features.to('cpu').flatten(1) + if isinstance(data, list): + vecs = list(features.detach().numpy()) + else: + vecs = features.squeeze(0).detach().numpy() return vecs @arg(1, to_image_color('RGB'))