From ec4aee0a13a1a68016c5a29da8ecfa88144bcec2 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 28 Dec 2022 14:38:18 +0800 Subject: [PATCH] Support TritonServe Signed-off-by: Jael Gu --- README.md | 10 +++++- nn_fingerprint.py | 78 +++++++++++++++++++++++++++-------------------- test_onnx.py | 2 +- 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index adb3953..9ff543b 100644 --- a/README.md +++ b/README.md @@ -110,4 +110,12 @@ Accepted formats: 'pytorch', 'torchscript, 'onnx', 'tensorrt' (in progress) *path: str* Path to save model, defaults to 'default'. -The default path is under 'saved' in the same directory of operator cache. \ No newline at end of file +The default path is under 'saved' in the same directory of operator cache. + +```python +from towhee import ops + +op = ops.audio_embedding.nnfp(device='cpu').get_op() +op.save_model('onnx', 'test.onnx') +``` + PosixPath('/Home/.towhee/operators/audio-embedding/nnfp/main/test.onnx') \ No newline at end of file diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 0245b58..d2fe4a4 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -28,6 +28,7 @@ from towhee import register from towhee.types.audio_frame import AudioFrame from towhee.models.nnfp import NNFp from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec +# from towhee.dc2 import accelerate from .configs import default_params, hop25_params, distill_params @@ -35,23 +36,11 @@ warnings.filterwarnings('ignore') log = logging.getLogger('nnfp_op') +# @accelerate class Model: - def __init__(self, model_name, device='cpu', model_path=None): + def __init__(self, params, device='cpu', model_path=None): self.device = device - if model_name == 'nnfp_default': - self.params = default_params - elif model_name == 'nnfp_hop25': - self.params = hop25_params - elif model_name == 'nnfp_distill': - self.params == distill_params - else: - raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]') - log.info('Loading model...') - if model_path is None: - path = str(Path(__file__).parent) - model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') - try: state_dict = torch.jit.load(model_path, map_location=self.device) except Exception: @@ -59,18 +48,18 @@ class Model: if isinstance(state_dict, torch.nn.Module): self.model = state_dict else: - dim = self.params['dim'] - h = self.params['h'] - u = self.params['u'] - f_bin = self.params['n_mels'] - n_seg = int(self.params['segment_size'] * self.params['sample_rate']) - t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length'] + dim = params['dim'] + h = params['h'] + u = params['u'] + f_bin = params['n_mels'] + n_seg = int(params['segment_size'] * params['sample_rate']) + t = (n_seg + params['hop_length'] - 1) // params['hop_length'] log.info('Creating model with parameters...') self.model = NNFp( dim=dim, h=h, u=u, in_f=f_bin, in_t=t, - fuller=self.params['fuller'], - activation=self.params['activation'] + fuller=params['fuller'], + activation=params['activation'] ).to(self.device) self.model.load_state_dict(state_dict) self.model.eval() @@ -96,19 +85,34 @@ class NNFingerprint(NNOperator): device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name - self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path) - self.model = self.accelerate_model.model - self.params = self.accelerate_model.params + + if model_name == 'nnfp_default': + self.params = default_params + elif model_name == 'nnfp_hop25': + self.params = hop25_params + elif model_name == 'nnfp_distill': + self.params == distill_params + else: + raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]') + + if model_path is None: + path = str(Path(__file__).parent) + model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') + self.model = Model(params=self.params, device=self.device, model_path=model_path) def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: audio_tensors = self.preprocess(data) if audio_tensors.device != self.device: audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) - features = self.accelerate_model(audio_tensors) + features = self.model(audio_tensors) outs = features.detach().cpu().numpy() return outs + @property + def _model(self): + return self.model.model + def preprocess(self, frames: Union[str, List[AudioFrame]]): if isinstance(frames, str): audio, sr = torchaudio.load(frames) @@ -176,6 +180,10 @@ class NNFingerprint(NNOperator): log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) return wav.astype(dtype) + @property + def supported_formats(self): + return ['onnx'] + def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) @@ -183,30 +191,33 @@ class NNFingerprint(NNOperator): os.makedirs(path, exist_ok=True) name = self.model_name.replace('/', '-') path = os.path.join(path, name) + if format in ['torchscript', 'pytorch']: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise ValueError(f'Invalid format {format}.') dummy_input = torch.rand( (1,) + (self.params['n_mels'], self.params['u']) ).to(self.device) if format == 'pytorch': - path = path + '.pt' - torch.save(self.model, path) + torch.save(self._model, path) elif format == 'torchscript': - path = path + '.pt' try: try: - jit_model = torch.jit.script(self.model) + jit_model = torch.jit.script(self._model) except Exception: log.warning( 'Failed to directly export as torchscript.' 'Using dummy input in shape of %s now.', dummy_input.shape) - jit_model = torch.jit.trace(self.model, dummy_input, strict=False) + jit_model = torch.jit.trace(self._model, dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error('Fail to save as torchscript: %s.', e) raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - path = path + '.onnx' try: - torch.onnx.export(self.model, + torch.onnx.export(self._model, dummy_input, path, input_names=['input'], @@ -223,3 +234,4 @@ class NNFingerprint(NNOperator): # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') + return Path(path).resolve() diff --git a/test_onnx.py b/test_onnx.py index 208644f..c69ea0a 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -60,7 +60,7 @@ for name in models: logger.error(f'FAIL TO LOAD OP: {e}') continue try: - op.save_model(format='onnx') + op.save_model(format='onnx', path=onnx_path) logger.info('ONNX SAVED.') status[2] = 'success' except Exception as e: