logo
Browse Source

Allow list as input

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d6516ba189
  1. 8
      README.md
  2. 32
      timm_image.py

8
README.md

@ -79,12 +79,12 @@ It uses the pre-trained model specified by model name to generate an image embed
**Parameters:** **Parameters:**
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
The decoded image data in numpy.ndarray.
***data:*** *Union[List[towhee._types.Image], towhee._types.Image]*
The decoded image data in numpy.ndarray. It allows both single input and a list for batch input.
**Returns:** *numpy.ndarray* **Returns:** *numpy.ndarray*
The image embedding extracted by model.
If only 1 image input, then output is an image embedding in shape of (feature_dim,).
If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim).

32
timm_image.py

@ -16,6 +16,7 @@ import logging
import numpy import numpy
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Union
import towhee import towhee
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
@ -62,20 +63,31 @@ class TimmImage(NNOperator):
self.tfms = create_transform(**self.config) self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
@arg(1, to_image_color('RGB'))
def __call__(self, img: towhee._types.Image) -> numpy.ndarray:
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
if not self.skip_tfms:
img = self.tfms(img).unsqueeze(0)
img = img.to(self.device)
features = self.model.forward_features(img)
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if isinstance(data, list):
imgs = []
for img in data:
img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img)
imgs.append(img)
inputs = torch.stack(imgs)
else:
img = self.convert_img(data)
img = img if self.skip_tfms else self.tfms(img)
inputs = img.unsqueeze(0)
inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
if features.dim() == 4: if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1) global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features) features = global_pool(features)
features = features.to('cpu')
vec = features.flatten().detach().numpy()
return vec
vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy()
return vecs
@arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
return img
def save_model(self, format: str = 'pytorch', path: str = 'default'): def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default': if path == 'default':

Loading…
Cancel
Save