nnfp
copied
Jael Gu
2 years ago
4 changed files with 151 additions and 111 deletions
@ -1,50 +0,0 @@ |
|||
from towhee import ops |
|||
|
|||
import warnings |
|||
|
|||
import torch |
|||
import numpy |
|||
import onnx |
|||
import onnxruntime |
|||
|
|||
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|||
|
|||
def to_numpy(tensor): |
|||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|||
|
|||
|
|||
# decode = ops.audio_decode.ffmpeg() |
|||
# audio = [x[0] for x in decode('path/to/audio.wav')] |
|||
audio = torch.rand(10, 256, 32).to(device) |
|||
|
|||
op = ops.audio_embedding.nnfp() |
|||
out0 = op.get_op().model(audio) |
|||
# print(out0) |
|||
|
|||
# Test Pytorch |
|||
op.get_op().save_model(format='pytorch') |
|||
op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt') |
|||
out1 = op.get_op().model(audio) |
|||
assert ((out0 == out1).all()) |
|||
|
|||
# Test Torchscript |
|||
op.get_op().save_model(format='torchscript') |
|||
op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt') |
|||
out2 = op.get_op().model(audio) |
|||
assert ((out0 == out2).all()) |
|||
|
|||
# Test ONNX |
|||
op.get_op().save_model(format='onnx') |
|||
op = ops.audio_embedding.nnfp() |
|||
onnx_model = onnx.load('./saved/onnx/nnfp.onnx') |
|||
onnx.checker.check_model(onnx_model) |
|||
|
|||
ort_session = onnxruntime.InferenceSession( |
|||
'./saved/onnx/nnfp.onnx', |
|||
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) |
|||
|
|||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} |
|||
ort_outs = ort_session.run(None, ort_inputs) |
|||
out3 = ort_outs[0] |
|||
# print(out3) |
|||
assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05)) |
@ -0,0 +1,99 @@ |
|||
from towhee import ops |
|||
import torch |
|||
import numpy |
|||
import onnx |
|||
import onnxruntime |
|||
|
|||
import os |
|||
from pathlib import Path |
|||
import logging |
|||
import platform |
|||
import psutil |
|||
|
|||
models = ['nnfp_default'] |
|||
|
|||
atol = 1e-3 |
|||
log_path = 'nnfp_onnx.log' |
|||
f = open('onnx.csv', 'w+') |
|||
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
|||
|
|||
logger = logging.getLogger('nnfp_onnx') |
|||
logger.setLevel(logging.DEBUG) |
|||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|||
fh = logging.FileHandler(log_path) |
|||
fh.setLevel(logging.DEBUG) |
|||
fh.setFormatter(formatter) |
|||
logger.addHandler(fh) |
|||
ch = logging.StreamHandler() |
|||
ch.setLevel(logging.ERROR) |
|||
ch.setFormatter(formatter) |
|||
logger.addHandler(ch) |
|||
|
|||
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|||
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
|||
logger.debug(f'cpu: {psutil.cpu_count()}') |
|||
|
|||
|
|||
status = None |
|||
for name in models: |
|||
logger.info(f'***{name}***') |
|||
saved_name = name.replace('/', '-') |
|||
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|||
|
|||
try: |
|||
op = ops.audio_embedding.nnfp(model_name=name, device='cpu').get_op() |
|||
except Exception as e: |
|||
logger.error(f'Fail to load model {name}. Please check weights.') |
|||
|
|||
data = torch.rand((1,) + (op.params['n_mels'], op.params['u'])) |
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
status = [name] + ['fail'] * 5 |
|||
|
|||
try: |
|||
out1 = op.model(data).detach().numpy() |
|||
logger.info('OP LOADED.') |
|||
status[1] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO LOAD OP: {e}') |
|||
continue |
|||
try: |
|||
op.save_model(format='onnx') |
|||
logger.info('ONNX SAVED.') |
|||
status[2] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|||
continue |
|||
try: |
|||
try: |
|||
onnx_model = onnx.load(onnx_path) |
|||
onnx.checker.check_model(onnx_model) |
|||
except Exception: |
|||
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
|||
onnx.checker.check_model(saved_onnx) |
|||
logger.info('ONNX CHECKED.') |
|||
status[3] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|||
pass |
|||
try: |
|||
sess = onnxruntime.InferenceSession(onnx_path, |
|||
providers=onnxruntime.get_available_providers()) |
|||
out2 = sess.run(None, input_feed={'input': data.detach().numpy()}) |
|||
logger.info('ONNX WORKED.') |
|||
status[4] = 'success' |
|||
if numpy.allclose(out1, out2, atol=atol): |
|||
logger.info('Check accuracy: OK') |
|||
status[5] = 'success' |
|||
else: |
|||
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO RUN ONNX: {e}') |
|||
continue |
|||
|
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
|
|||
print('Finished.') |
Loading…
Reference in new issue