logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
7126190325
  1. 17
      timm_image.py

17
timm_image.py

@ -64,17 +64,14 @@ class TimmImage(NNOperator):
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if isinstance(data, list):
imgs = []
for img in data:
img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img)
imgs.append(img)
inputs = torch.stack(imgs)
else:
img = self.convert_img(data)
if not isinstance(data, list):
data = [data]
imgs = []
for img in data:
img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img) img = img if self.skip_tfms else self.tfms(img)
inputs = img.unsqueeze(0)
imgs.append(img)
inputs = torch.stack(imgs)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
features = self.model.forward_features(inputs) features = self.model.forward_features(inputs)
if features.dim() == 4: if features.dim() == 4:

Loading…
Cancel
Save