From c947c6ea84b23f3f8dbeb1114312873f27858017 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 16 Dec 2022 14:06:18 +0800 Subject: [PATCH] Debug to support TritonServer Signed-off-by: Jael Gu --- test_onnx.py | 2 +- timm_image.py | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/test_onnx.py b/test_onnx.py index c06718a..ab7f718 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -57,7 +57,7 @@ for name in models: status = [name] + ['fail'] * 5 try: - out1 = op.accelerate_model(data).detach().numpy() + out1 = op.model(data).detach().numpy() logger.info('OP LOADED.') status[1] = 'success' except Exception as e: diff --git a/timm_image.py b/timm_image.py index 89ceb41..c9ad998 100644 --- a/timm_image.py +++ b/timm_image.py @@ -45,6 +45,7 @@ log = logging.getLogger('timm_op') class Model: def __init__(self, model_name, device, num_classes): self.model = create_model(model_name, pretrained=True, num_classes=num_classes) + self.config = resolve_data_config({}, model=self.model) self.model.eval() self.model.to(device) @@ -77,13 +78,12 @@ class TimmImage(NNOperator): self.device = device self.model_name = model_name if self.model_name: - self.accelerate_model = Model( + self.model = Model( model_name=model_name, device=self.device, num_classes=num_classes ) - self.model = self.accelerate_model.model - self.config = resolve_data_config({}, model=self.model) + self.config = self.model.config self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess else: @@ -102,7 +102,7 @@ class TimmImage(NNOperator): img_list.append(img) inputs = torch.stack(img_list) inputs = inputs.to(self.device) - features = self.accelerate_model(inputs) + features = self.model(inputs) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) features = global_pool(features) @@ -113,6 +113,10 @@ class TimmImage(NNOperator): vecs = features.squeeze(0).detach().numpy() return vecs + @property + def _model(self): + return self.model.model + @arg(1, to_image_color('RGB')) def convert_img(self, img: towhee._types.Image): img = PILImage.fromarray(img.astype('uint8'), 'RGB') @@ -133,21 +137,21 @@ class TimmImage(NNOperator): raise AttributeError(f'Invalid format {format}.') dummy_input = torch.rand((1,) + self.config['input_size']) if format == 'pytorch': - torch.save(self.model, path) + torch.save(self._model, path) elif format == 'torchscript': try: try: - jit_model = torch.jit.script(self.model) + jit_model = torch.jit.script(self._model) except Exception: - jit_model = torch.jit.trace(self.model, dummy_input, strict=False) + jit_model = torch.jit.trace(self._model, dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - self.model.forward = self.model.forward_features + self._model.forward = self._model.forward_features try: - torch.onnx.export(self.model, + torch.onnx.export(self._model, dummy_input, path, input_names=['input_0'],