logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
254ea6162e
  1. 5
      README.md
  2. 19
      timm_image.py

5
README.md

@ -84,7 +84,8 @@ It uses the pre-trained model specified by model name to generate an image embed
The decoded image data in numpy.ndarray. It allows both single input and a list for batch input. The decoded image data in numpy.ndarray. It allows both single input and a list for batch input.
**Returns:** *numpy.ndarray*
**Returns:** *Union[List[numpy.ndarray], numpy.ndarray]*
If only 1 image input, then output is an image embedding in shape of (feature_dim,). If only 1 image input, then output is an image embedding in shape of (feature_dim,).
If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim).
If a list of images as input, then output is a same-length list of numpy.ndarray,
each of which represents an image embedding in shape of (feature_dim,).

19
timm_image.py

@ -65,20 +65,25 @@ class TimmImage(NNOperator):
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if not isinstance(data, list): if not isinstance(data, list):
data = [data]
imgs = []
for img in data:
imgs = [data]
else:
imgs = data
img_list = []
for img in imgs:
img = self.convert_img(img) img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img) img = img if self.skip_tfms else self.tfms(img)
imgs.append(img)
inputs = torch.stack(imgs)
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
features = self.model.forward_features(inputs) features = self.model.forward_features(inputs)
if features.dim() == 4: if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1) global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features) features = global_pool(features)
vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy()
features = features.to('cpu').flatten(1)
if isinstance(data, list):
vecs = list(features.detach().numpy())
else:
vecs = features.squeeze(0).detach().numpy()
return vecs return vecs
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))

Loading…
Cancel
Save