diff --git a/test_onnx.py b/test_onnx.py index 2ded7d9..c06718a 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -11,8 +11,8 @@ import logging import platform import psutil -models = TimmImage.supported_model_names()[:2] -# models = ['resnet50'] +# models = TimmImage.supported_model_names()[:2] +models = ['resnet50'] atol = 1e-3 log_path = 'timm_onnx.log' @@ -57,7 +57,7 @@ for name in models: status = [name] + ['fail'] * 5 try: - out1 = op.model.forward_features(data).detach().numpy() + out1 = op.accelerate_model(data).detach().numpy() logger.info('OP LOADED.') status[1] = 'success' except Exception as e: diff --git a/timm_image.py b/timm_image.py index 98bf302..89ceb41 100644 --- a/timm_image.py +++ b/timm_image.py @@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from towhee.types import Image +from towhee.dc2 import accelerate import torch from torch import nn @@ -40,6 +41,17 @@ warnings.filterwarnings('ignore') log = logging.getLogger('timm_op') +@accelerate +class Model: + def __init__(self, model_name, device, num_classes): + self.model = create_model(model_name, pretrained=True, num_classes=num_classes) + self.model.eval() + self.model.to(device) + + def __call__(self, x: torch.Tensor): + return self.model.forward_features(x) + + @register(output_schema=['vec']) class TimmImage(NNOperator): """ @@ -65,10 +77,12 @@ class TimmImage(NNOperator): self.device = device self.model_name = model_name if self.model_name: - self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) - self.model.eval() - self.model.to(self.device) - + self.accelerate_model = Model( + model_name=model_name, + device=self.device, + num_classes=num_classes + ) + self.model = self.accelerate_model.model self.config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess @@ -88,7 +102,7 @@ class TimmImage(NNOperator): img_list.append(img) inputs = torch.stack(img_list) inputs = inputs.to(self.device) - features = self.model.forward_features(inputs) + features = self.accelerate_model(inputs) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) features = global_pool(features) @@ -111,12 +125,16 @@ class TimmImage(NNOperator): os.makedirs(path, exist_ok=True) name = self.model_name.replace('/', '-') path = os.path.join(path, name) + if format in ['pytorch', 'torchscript']: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise AttributeError(f'Invalid format {format}.') dummy_input = torch.rand((1,) + self.config['input_size']) if format == 'pytorch': - path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': - path = path + '.pt' try: try: jit_model = torch.jit.script(self.model) @@ -127,7 +145,6 @@ class TimmImage(NNOperator): log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - path = path + '.onnx' self.model.forward = self.model.forward_features try: torch.onnx.export(self.model, @@ -227,11 +244,9 @@ class TimmImage(NNOperator): log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') return model_list - def input_schema(self): - return [(Image, (-1, -1, 3))] - - def output_schema(self): - image = Image(numpy.random.randn(480, 480, 3), "RGB") - ret = self(image) - data_type = type(ret.reshape(-1)[0]) - return [(data_type, ret.shape)] + @property + def supported_formats(self): + if self.model_name in self.supported_model_names(format='onnx'): + return ['onnx'] + else: + return []