|
|
@ -45,6 +45,7 @@ log = logging.getLogger('timm_op') |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, device, num_classes): |
|
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
|
self.config = resolve_data_config({}, model=self.model) |
|
|
|
self.model.eval() |
|
|
|
self.model.to(device) |
|
|
|
|
|
|
@ -77,13 +78,12 @@ class TimmImage(NNOperator): |
|
|
|
self.device = device |
|
|
|
self.model_name = model_name |
|
|
|
if self.model_name: |
|
|
|
self.accelerate_model = Model( |
|
|
|
self.model = Model( |
|
|
|
model_name=model_name, |
|
|
|
device=self.device, |
|
|
|
num_classes=num_classes |
|
|
|
) |
|
|
|
self.model = self.accelerate_model.model |
|
|
|
self.config = resolve_data_config({}, model=self.model) |
|
|
|
self.config = self.model.config |
|
|
|
self.tfms = create_transform(**self.config) |
|
|
|
self.skip_tfms = skip_preprocess |
|
|
|
else: |
|
|
@ -102,7 +102,7 @@ class TimmImage(NNOperator): |
|
|
|
img_list.append(img) |
|
|
|
inputs = torch.stack(img_list) |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
features = self.accelerate_model(inputs) |
|
|
|
features = self.model(inputs) |
|
|
|
if features.dim() == 4: |
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) |
|
|
|
features = global_pool(features) |
|
|
@ -113,6 +113,10 @@ class TimmImage(NNOperator): |
|
|
|
vecs = features.squeeze(0).detach().numpy() |
|
|
|
return vecs |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def convert_img(self, img: towhee._types.Image): |
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
@ -133,21 +137,21 @@ class TimmImage(NNOperator): |
|
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
|
dummy_input = torch.rand((1,) + self.config['input_size']) |
|
|
|
if format == 'pytorch': |
|
|
|
torch.save(self.model, path) |
|
|
|
torch.save(self._model, path) |
|
|
|
elif format == 'torchscript': |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
jit_model = torch.jit.script(self._model) |
|
|
|
except Exception: |
|
|
|
jit_model = torch.jit.trace(self.model, dummy_input, strict=False) |
|
|
|
jit_model = torch.jit.trace(self._model, dummy_input, strict=False) |
|
|
|
torch.jit.save(jit_model, path) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
self.model.forward = self.model.forward_features |
|
|
|
self._model.forward = self._model.forward_features |
|
|
|
try: |
|
|
|
torch.onnx.export(self.model, |
|
|
|
torch.onnx.export(self._model, |
|
|
|
dummy_input, |
|
|
|
path, |
|
|
|
input_names=['input_0'], |
|
|
|