logo
Browse Source

Allow list as input

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d6516ba189
  1. 8
      README.md
  2. 32
      timm_image.py

8
README.md

@ -79,12 +79,12 @@ It uses the pre-trained model specified by model name to generate an image embed
**Parameters:**
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
The decoded image data in numpy.ndarray.
***data:*** *Union[List[towhee._types.Image], towhee._types.Image]*
The decoded image data in numpy.ndarray. It allows both single input and a list for batch input.
**Returns:** *numpy.ndarray*
The image embedding extracted by model.
If only 1 image input, then output is an image embedding in shape of (feature_dim,).
If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim).

32
timm_image.py

@ -16,6 +16,7 @@ import logging
import numpy
import os
from pathlib import Path
from typing import List, Union
import towhee
from towhee.operator.base import NNOperator, OperatorFlag
@ -62,20 +63,31 @@ class TimmImage(NNOperator):
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
@arg(1, to_image_color('RGB'))
def __call__(self, img: towhee._types.Image) -> numpy.ndarray:
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
if not self.skip_tfms:
img = self.tfms(img).unsqueeze(0)
img = img.to(self.device)
features = self.model.forward_features(img)
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if isinstance(data, list):
imgs = []
for img in data:
img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img)
imgs.append(img)
inputs = torch.stack(imgs)
else:
img = self.convert_img(data)
img = img if self.skip_tfms else self.tfms(img)
inputs = img.unsqueeze(0)
inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
features = features.to('cpu')
vec = features.flatten().detach().numpy()
return vec
vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy()
return vecs
@arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
return img
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':

Loading…
Cancel
Save