logo
Browse Source

Modify op to support Triton & test onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
7ef2229755
  1. 0
      benchmark/performance.md
  2. 113
      nn_fingerprint.py
  3. 50
      test.py
  4. 99
      test_onnx.py

0
performance.md → benchmark/performance.md

113
nn_fingerprint.py

@ -32,24 +32,11 @@ from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec
from .configs import default_params, hop25_params, distill_params from .configs import default_params, hop25_params, distill_params
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
log = logging.getLogger()
log = logging.getLogger('nnfp_op')
@register(output_schema=['vecs'])
class NNFingerprint(NNOperator):
"""
Audio embedding operator using Neural Network Fingerprint
"""
def __init__(self,
model_name: str = 'nnfp_default',
model_path: str = None,
framework: str = 'pytorch',
device: str = None
):
super().__init__(framework=framework)
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Model:
def __init__(self, model_name, device='cpu', model_path=None):
self.device = device self.device = device
if model_name == 'nnfp_default': if model_name == 'nnfp_default':
self.params = default_params self.params = default_params
@ -64,51 +51,61 @@ class NNFingerprint(NNOperator):
if model_path is None: if model_path is None:
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
if model_path.endswith('.onnx'):
log.warning('Using onnx.')
import onnxruntime
self.model = onnxruntime.InferenceSession(
model_path,
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
)
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
state_dict = torch.load(model_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else: else:
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
state_dict = torch.load(model_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else:
dim = self.params['dim']
h = self.params['h']
u = self.params['u']
f_bin = self.params['n_mels']
n_seg = int(self.params['segment_size'] * self.params['sample_rate'])
t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length']
log.info('Creating model with parameters...')
self.model = NNFp(
dim=dim, h=h, u=u,
in_f=f_bin, in_t=t,
fuller=self.params['fuller'],
activation=self.params['activation']
).to(self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
dim = self.params['dim']
h = self.params['h']
u = self.params['u']
f_bin = self.params['n_mels']
n_seg = int(self.params['segment_size'] * self.params['sample_rate'])
t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length']
log.info('Creating model with parameters...')
self.model = NNFp(
dim=dim, h=h, u=u,
in_f=f_bin, in_t=t,
fuller=self.params['fuller'],
activation=self.params['activation']
).to(self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: 'Tensor'):
return self.model(data)
@register(output_schema=['vecs'])
class NNFingerprint(NNOperator):
"""
Audio embedding operator using Neural Network Fingerprint
"""
def __init__(self,
model_name: str = 'nnfp_default',
model_path: str = None,
framework: str = 'pytorch',
device: str = None
):
super().__init__(framework=framework)
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.params = self.model.params
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data) audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device: if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device) audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape) # print(audio_tensors.shape)
if isinstance(self.model, onnxruntime.InferenceSession):
audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \
else audio_tensors.cpu().numpy()
ort_inputs = {self.model.get_inputs()[0].name: audio_numpy}
outs = self.model.run(None, ort_inputs)[0]
else:
features = self.model(audio_tensors)
outs = features.detach().cpu().numpy()
features = self.model(audio_tensors)
outs = features.detach().cpu().numpy()
return outs return outs
def preprocess(self, frames: Union[str, List[AudioFrame]]): def preprocess(self, frames: Union[str, List[AudioFrame]]):
@ -183,7 +180,7 @@ class NNFingerprint(NNOperator):
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format) path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
name = 'nnfp'
name = self.model_name.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
dummy_input = torch.rand( dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u']) (1,) + (self.params['n_mels'], self.params['u'])
@ -208,7 +205,7 @@ class NNFingerprint(NNOperator):
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx' path = path + '.onnx'
try: try:
torch.onnx.export(self.model,
torch.onnx.export(self.model.model,
dummy_input, dummy_input,
path, path,
input_names=['input'], input_names=['input'],
@ -225,9 +222,3 @@ class NNFingerprint(NNOperator):
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')
def input_schema(self):
return [(AudioFrame, (1024,))]
def output_schema(self):
return [(numpy.ndarray, (-1, self.params['dim']))]

50
test.py

@ -1,50 +0,0 @@
from towhee import ops
import warnings
import torch
import numpy
import onnx
import onnxruntime
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# decode = ops.audio_decode.ffmpeg()
# audio = [x[0] for x in decode('path/to/audio.wav')]
audio = torch.rand(10, 256, 32).to(device)
op = ops.audio_embedding.nnfp()
out0 = op.get_op().model(audio)
# print(out0)
# Test Pytorch
op.get_op().save_model(format='pytorch')
op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt')
out1 = op.get_op().model(audio)
assert ((out0 == out1).all())
# Test Torchscript
op.get_op().save_model(format='torchscript')
op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt')
out2 = op.get_op().model(audio)
assert ((out0 == out2).all())
# Test ONNX
op.get_op().save_model(format='onnx')
op = ops.audio_embedding.nnfp()
onnx_model = onnx.load('./saved/onnx/nnfp.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(
'./saved/onnx/nnfp.onnx',
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)}
ort_outs = ort_session.run(None, ort_inputs)
out3 = ort_outs[0]
# print(out3)
assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05))

99
test_onnx.py

@ -0,0 +1,99 @@
from towhee import ops
import torch
import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import logging
import platform
import psutil
models = ['nnfp_default']
atol = 1e-3
log_path = 'nnfp_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('nnfp_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = ops.audio_embedding.nnfp(model_name=name, device='cpu').get_op()
except Exception as e:
logger.error(f'Fail to load model {name}. Please check weights.')
data = torch.rand((1,) + (op.params['n_mels'], op.params['u']))
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try:
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
pass
try:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input': data.detach().numpy()})
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
if status:
f.write(','.join(status) + '\n')
print('Finished.')
Loading…
Cancel
Save