Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
7 additions and
10 deletions
-
timm_image.py
|
@ -64,17 +64,14 @@ class TimmImage(NNOperator): |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.skip_tfms = skip_preprocess |
|
|
|
|
|
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
if isinstance(data, list): |
|
|
|
|
|
imgs = [] |
|
|
|
|
|
for img in data: |
|
|
|
|
|
img = self.convert_img(img) |
|
|
|
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
|
|
|
imgs.append(img) |
|
|
|
|
|
inputs = torch.stack(imgs) |
|
|
|
|
|
else: |
|
|
|
|
|
img = self.convert_img(data) |
|
|
|
|
|
|
|
|
if not isinstance(data, list): |
|
|
|
|
|
data = [data] |
|
|
|
|
|
imgs = [] |
|
|
|
|
|
for img in data: |
|
|
|
|
|
img = self.convert_img(img) |
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
inputs = img.unsqueeze(0) |
|
|
|
|
|
|
|
|
imgs.append(img) |
|
|
|
|
|
inputs = torch.stack(imgs) |
|
|
inputs = inputs.to(self.device) |
|
|
inputs = inputs.to(self.device) |
|
|
features = self.model.forward_features(inputs) |
|
|
features = self.model.forward_features(inputs) |
|
|
if features.dim() == 4: |
|
|
if features.dim() == 4: |
|
|