|
@ -16,6 +16,7 @@ import logging |
|
|
import numpy |
|
|
import numpy |
|
|
import os |
|
|
import os |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
import towhee |
|
|
import towhee |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
@ -62,20 +63,31 @@ class TimmImage(NNOperator): |
|
|
self.tfms = create_transform(**self.config) |
|
|
self.tfms = create_transform(**self.config) |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.skip_tfms = skip_preprocess |
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
|
|
def __call__(self, img: towhee._types.Image) -> numpy.ndarray: |
|
|
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
|
|
if not self.skip_tfms: |
|
|
|
|
|
img = self.tfms(img).unsqueeze(0) |
|
|
|
|
|
img = img.to(self.device) |
|
|
|
|
|
features = self.model.forward_features(img) |
|
|
|
|
|
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
|
|
|
if isinstance(data, list): |
|
|
|
|
|
imgs = [] |
|
|
|
|
|
for img in data: |
|
|
|
|
|
img = self.convert_img(img) |
|
|
|
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
|
|
|
imgs.append(img) |
|
|
|
|
|
inputs = torch.stack(imgs) |
|
|
|
|
|
else: |
|
|
|
|
|
img = self.convert_img(data) |
|
|
|
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
|
|
|
inputs = img.unsqueeze(0) |
|
|
|
|
|
inputs = inputs.to(self.device) |
|
|
|
|
|
features = self.model.forward_features(inputs) |
|
|
if features.dim() == 4: |
|
|
if features.dim() == 4: |
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
features = global_pool(features) |
|
|
features = global_pool(features) |
|
|
|
|
|
|
|
|
features = features.to('cpu') |
|
|
|
|
|
vec = features.flatten().detach().numpy() |
|
|
|
|
|
return vec |
|
|
|
|
|
|
|
|
vecs = features.to('cpu').flatten(1).squeeze(0).detach().numpy() |
|
|
|
|
|
return vecs |
|
|
|
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
|
|
def convert_img(self, img: towhee._types.Image): |
|
|
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
if path == 'default': |
|
|
if path == 'default': |
|
|