From 7ef22297552fb7af52ffdba9a4a0fcaf1b69cc77 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 15 Dec 2022 11:27:23 +0800 Subject: [PATCH] Modify op to support Triton & test onnx Signed-off-by: Jael Gu --- performance.md => benchmark/performance.md | 0 nn_fingerprint.py | 113 ++++++++++----------- test.py | 50 --------- test_onnx.py | 99 ++++++++++++++++++ 4 files changed, 151 insertions(+), 111 deletions(-) rename performance.md => benchmark/performance.md (100%) delete mode 100644 test.py create mode 100644 test_onnx.py diff --git a/performance.md b/benchmark/performance.md similarity index 100% rename from performance.md rename to benchmark/performance.md diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 131ab36..61b9506 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -32,24 +32,11 @@ from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec from .configs import default_params, hop25_params, distill_params warnings.filterwarnings('ignore') -log = logging.getLogger() +log = logging.getLogger('nnfp_op') -@register(output_schema=['vecs']) -class NNFingerprint(NNOperator): - """ - Audio embedding operator using Neural Network Fingerprint - """ - - def __init__(self, - model_name: str = 'nnfp_default', - model_path: str = None, - framework: str = 'pytorch', - device: str = None - ): - super().__init__(framework=framework) - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' +class Model: + def __init__(self, model_name, device='cpu', model_path=None): self.device = device if model_name == 'nnfp_default': self.params = default_params @@ -64,51 +51,61 @@ class NNFingerprint(NNOperator): if model_path is None: path = str(Path(__file__).parent) model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') - if model_path.endswith('.onnx'): - log.warning('Using onnx.') - import onnxruntime - self.model = onnxruntime.InferenceSession( - model_path, - providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] - ) + + try: + state_dict = torch.jit.load(model_path, map_location=self.device) + except Exception: + state_dict = torch.load(model_path, map_location=self.device) + if isinstance(state_dict, torch.nn.Module): + self.model = state_dict else: - try: - state_dict = torch.jit.load(model_path, map_location=self.device) - except Exception: - state_dict = torch.load(model_path, map_location=self.device) - if isinstance(state_dict, torch.nn.Module): - self.model = state_dict - else: - dim = self.params['dim'] - h = self.params['h'] - u = self.params['u'] - f_bin = self.params['n_mels'] - n_seg = int(self.params['segment_size'] * self.params['sample_rate']) - t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length'] - log.info('Creating model with parameters...') - self.model = NNFp( - dim=dim, h=h, u=u, - in_f=f_bin, in_t=t, - fuller=self.params['fuller'], - activation=self.params['activation'] - ).to(self.device) - self.model.load_state_dict(state_dict) - self.model.eval() + dim = self.params['dim'] + h = self.params['h'] + u = self.params['u'] + f_bin = self.params['n_mels'] + n_seg = int(self.params['segment_size'] * self.params['sample_rate']) + t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length'] + log.info('Creating model with parameters...') + self.model = NNFp( + dim=dim, h=h, u=u, + in_f=f_bin, in_t=t, + fuller=self.params['fuller'], + activation=self.params['activation'] + ).to(self.device) + self.model.load_state_dict(state_dict) + self.model.eval() log.info('Model is loaded.') + def __call__(self, data: 'Tensor'): + return self.model(data) + + +@register(output_schema=['vecs']) +class NNFingerprint(NNOperator): + """ + Audio embedding operator using Neural Network Fingerprint + """ + def __init__(self, + model_name: str = 'nnfp_default', + model_path: str = None, + framework: str = 'pytorch', + device: str = None + ): + super().__init__(framework=framework) + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device + self.model_name = model_name + self.model = Model(model_name=model_name, device=self.device, model_path=model_path) + self.params = self.model.params + def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: audio_tensors = self.preprocess(data) if audio_tensors.device != self.device: audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) - if isinstance(self.model, onnxruntime.InferenceSession): - audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \ - else audio_tensors.cpu().numpy() - ort_inputs = {self.model.get_inputs()[0].name: audio_numpy} - outs = self.model.run(None, ort_inputs)[0] - else: - features = self.model(audio_tensors) - outs = features.detach().cpu().numpy() + features = self.model(audio_tensors) + outs = features.detach().cpu().numpy() return outs def preprocess(self, frames: Union[str, List[AudioFrame]]): @@ -183,7 +180,7 @@ class NNFingerprint(NNOperator): path = str(Path(__file__).parent) path = os.path.join(path, 'saved', format) os.makedirs(path, exist_ok=True) - name = 'nnfp' + name = self.model_name.replace('/', '-') path = os.path.join(path, name) dummy_input = torch.rand( (1,) + (self.params['n_mels'], self.params['u']) @@ -208,7 +205,7 @@ class NNFingerprint(NNOperator): elif format == 'onnx': path = path + '.onnx' try: - torch.onnx.export(self.model, + torch.onnx.export(self.model.model, dummy_input, path, input_names=['input'], @@ -225,9 +222,3 @@ class NNFingerprint(NNOperator): # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') - - def input_schema(self): - return [(AudioFrame, (1024,))] - - def output_schema(self): - return [(numpy.ndarray, (-1, self.params['dim']))] diff --git a/test.py b/test.py deleted file mode 100644 index 87cd27d..0000000 --- a/test.py +++ /dev/null @@ -1,50 +0,0 @@ -from towhee import ops - -import warnings - -import torch -import numpy -import onnx -import onnxruntime - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - - -# decode = ops.audio_decode.ffmpeg() -# audio = [x[0] for x in decode('path/to/audio.wav')] -audio = torch.rand(10, 256, 32).to(device) - -op = ops.audio_embedding.nnfp() -out0 = op.get_op().model(audio) -# print(out0) - -# Test Pytorch -op.get_op().save_model(format='pytorch') -op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt') -out1 = op.get_op().model(audio) -assert ((out0 == out1).all()) - -# Test Torchscript -op.get_op().save_model(format='torchscript') -op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt') -out2 = op.get_op().model(audio) -assert ((out0 == out2).all()) - -# Test ONNX -op.get_op().save_model(format='onnx') -op = ops.audio_embedding.nnfp() -onnx_model = onnx.load('./saved/onnx/nnfp.onnx') -onnx.checker.check_model(onnx_model) - -ort_session = onnxruntime.InferenceSession( -'./saved/onnx/nnfp.onnx', -providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) - -ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} -ort_outs = ort_session.run(None, ort_inputs) -out3 = ort_outs[0] -# print(out3) -assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05)) diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..208644f --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,99 @@ +from towhee import ops +import torch +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +models = ['nnfp_default'] + +atol = 1e-3 +log_path = 'nnfp_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('nnfp_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + + try: + op = ops.audio_embedding.nnfp(model_name=name, device='cpu').get_op() + except Exception as e: + logger.error(f'Fail to load model {name}. Please check weights.') + + data = torch.rand((1,) + (op.params['n_mels'], op.params['u'])) + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + + try: + out1 = op.model(data).detach().numpy() + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + op.save_model(format='onnx') + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + pass + try: + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + out2 = sess.run(None, input_feed={'input': data.detach().numpy()}) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.')