diff --git a/isc.py b/isc.py index 28d5704..7908977 100644 --- a/isc.py +++ b/isc.py @@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from towhee.models import isc +# from towhee.dc2 import accelerate import torch from torch import nn @@ -36,6 +37,20 @@ warnings.filterwarnings('ignore') log = logging.getLogger('isc_op') +# @accelerate +class Model: + def __init__(self, timm_backbone, checkpoint_path, device): + self.device = device + self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) + self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, + backbone=self.backbone, p=3.0, eval_p=1.0) + self.model.eval() + + def __call__(self, x): + x = x.to(self.device) + return self.model(x) + + @register(output_schema=['vec']) class Isc(NNOperator): """ @@ -62,16 +77,13 @@ class Isc(NNOperator): if checkpoint_path is None: checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') - backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) - self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, - backbone=backbone, p=3.0, eval_p=1.0) - self.model.eval() + self.model = Model(self.timm_backbone, checkpoint_path, self.device) self.tfms = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), - transforms.Normalize(mean=backbone.default_cfg['mean'], - std=backbone.default_cfg['std']) + transforms.Normalize(mean=self.backbone.default_cfg['mean'], + std=self.backbone.default_cfg['std']) ]) def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): @@ -96,6 +108,14 @@ class Isc(NNOperator): vecs = features.squeeze(0).detach().numpy() return vecs + @property + def _model(self): + return self.model.model + + @property + def backbone(self): + return self.model.backbone + def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) @@ -103,25 +123,28 @@ class Isc(NNOperator): os.makedirs(path, exist_ok=True) name = self.timm_backbone.replace('/', '-') path = os.path.join(path, name) + if format in ['pytorch', 'torchscript']: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise ValueError(f'Invalid format {format}.') dummy_input = torch.rand(1, 3, 224, 224) if format == 'pytorch': - path = path + '.pt' - torch.save(self.model, path) + torch.save(self._model, path) elif format == 'torchscript': - path = path + '.pt' try: try: - jit_model = torch.jit.script(self.model) + jit_model = torch.jit.script(self._model) except Exception: - jit_model = torch.jit.trace(self.model, dummy_input, strict=False) + jit_model = torch.jit.trace(self._model, dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - path = path + '.onnx' try: - torch.onnx.export(self.model, + torch.onnx.export(self._model, dummy_input, path, input_names=['input_0'], @@ -139,12 +162,17 @@ class Isc(NNOperator): # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') + return path @arg(1, to_image_color('RGB')) def convert_img(self, img: towhee._types.Image): img = PILImage.fromarray(img.astype('uint8'), 'RGB') return img + @property + def supported_formats(self): + return ['onnx'] + # if __name__ == '__main__': # from towhee import ops diff --git a/test_onnx.py b/test_onnx.py index e01ad82..5725469 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -1,7 +1,6 @@ -import onnx -from isc import Isc - +import towhee from towhee import ops + import torch import numpy import onnx @@ -46,17 +45,17 @@ for name in models: onnx_path = f'saved/onnx/{saved_name}.onnx' try: - op = Isc(timm_backbone=name, device='cpu') + op = ops.image_embedding.isc(timm_backbone=name, device='cpu').get_op() except Exception as e: logger.error(f'Fail to load model {name}. Please check weights.') - data = torch.rand(1, 3, 224, 224) + data = torch.ones(1, 3, 224, 224) if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 try: - out1 = op.model(data).detach().numpy() + out1 = op.model(data).cpu().detach().numpy() logger.info('OP LOADED.') status[1] = 'success' except Exception as e: @@ -84,7 +83,8 @@ for name in models: try: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) - out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) + data = data.cpu().detach().numpy() + out2 = sess.run(None, input_feed={'input_0': data}) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol):