Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
15 additions and
9 deletions
-
README.md
-
timm_image.py
|
|
@ -84,7 +84,8 @@ It uses the pre-trained model specified by model name to generate an image embed |
|
|
|
The decoded image data in numpy.ndarray. It allows both single input and a list for batch input. |
|
|
|
|
|
|
|
|
|
|
|
**Returns:** *numpy.ndarray* |
|
|
|
**Returns:** *Union[List[numpy.ndarray], numpy.ndarray]* |
|
|
|
|
|
|
|
If only 1 image input, then output is an image embedding in shape of (feature_dim,). |
|
|
|
If a list of images as input, then output is a numpy.ndarray in shape of (batch_num, feature_dim). |
|
|
|
If a list of images as input, then output is a same-length list of numpy.ndarray, |
|
|
|
each of which represents an image embedding in shape of (feature_dim,). |
|
|
|
|
|
@ -65,20 +65,25 @@ class TimmImage(NNOperator): |
|
|
|
|
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
|
if not isinstance(data, list): |
|
|
|
data = [data] |
|
|
|
imgs = [] |
|
|
|
for img in data: |
|
|
|
imgs = [data] |
|
|
|
else: |
|
|
|
imgs = data |
|
|
|
img_list = [] |
|
|
|
for img in imgs: |
|
|
|
img = self.convert_img(img) |
|
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
|
imgs.append(img) |
|
|
|
inputs = torch.stack(imgs) |
|
|
|
img_list.append(img) |
|
|
|
inputs = torch.stack(img_list) |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
features = self.model.forward_features(inputs) |
|
|
|
if features.dim() == 4: |
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
features = global_pool(features) |
|
|
|
|
|
|
|
vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy() |
|
|
|
features = features.to('cpu').flatten(1) |
|
|
|
if isinstance(data, list): |
|
|
|
vecs = list(features.detach().numpy()) |
|
|
|
else: |
|
|
|
vecs = features.squeeze(0).detach().numpy() |
|
|
|
return vecs |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|