nnfp
copied
Jael Gu
2 years ago
4 changed files with 151 additions and 111 deletions
@ -1,50 +0,0 @@ |
|||||
from towhee import ops |
|
||||
|
|
||||
import warnings |
|
||||
|
|
||||
import torch |
|
||||
import numpy |
|
||||
import onnx |
|
||||
import onnxruntime |
|
||||
|
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
||||
|
|
||||
def to_numpy(tensor): |
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|
||||
|
|
||||
|
|
||||
# decode = ops.audio_decode.ffmpeg() |
|
||||
# audio = [x[0] for x in decode('path/to/audio.wav')] |
|
||||
audio = torch.rand(10, 256, 32).to(device) |
|
||||
|
|
||||
op = ops.audio_embedding.nnfp() |
|
||||
out0 = op.get_op().model(audio) |
|
||||
# print(out0) |
|
||||
|
|
||||
# Test Pytorch |
|
||||
op.get_op().save_model(format='pytorch') |
|
||||
op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt') |
|
||||
out1 = op.get_op().model(audio) |
|
||||
assert ((out0 == out1).all()) |
|
||||
|
|
||||
# Test Torchscript |
|
||||
op.get_op().save_model(format='torchscript') |
|
||||
op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt') |
|
||||
out2 = op.get_op().model(audio) |
|
||||
assert ((out0 == out2).all()) |
|
||||
|
|
||||
# Test ONNX |
|
||||
op.get_op().save_model(format='onnx') |
|
||||
op = ops.audio_embedding.nnfp() |
|
||||
onnx_model = onnx.load('./saved/onnx/nnfp.onnx') |
|
||||
onnx.checker.check_model(onnx_model) |
|
||||
|
|
||||
ort_session = onnxruntime.InferenceSession( |
|
||||
'./saved/onnx/nnfp.onnx', |
|
||||
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
||||
|
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} |
|
||||
ort_outs = ort_session.run(None, ort_inputs) |
|
||||
out3 = ort_outs[0] |
|
||||
# print(out3) |
|
||||
assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05)) |
|
@ -0,0 +1,99 @@ |
|||||
|
from towhee import ops |
||||
|
import torch |
||||
|
import numpy |
||||
|
import onnx |
||||
|
import onnxruntime |
||||
|
|
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
import logging |
||||
|
import platform |
||||
|
import psutil |
||||
|
|
||||
|
models = ['nnfp_default'] |
||||
|
|
||||
|
atol = 1e-3 |
||||
|
log_path = 'nnfp_onnx.log' |
||||
|
f = open('onnx.csv', 'w+') |
||||
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
||||
|
|
||||
|
logger = logging.getLogger('nnfp_onnx') |
||||
|
logger.setLevel(logging.DEBUG) |
||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
||||
|
fh = logging.FileHandler(log_path) |
||||
|
fh.setLevel(logging.DEBUG) |
||||
|
fh.setFormatter(formatter) |
||||
|
logger.addHandler(fh) |
||||
|
ch = logging.StreamHandler() |
||||
|
ch.setLevel(logging.ERROR) |
||||
|
ch.setFormatter(formatter) |
||||
|
logger.addHandler(ch) |
||||
|
|
||||
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
||||
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
||||
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
||||
|
|
||||
|
|
||||
|
status = None |
||||
|
for name in models: |
||||
|
logger.info(f'***{name}***') |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
||||
|
|
||||
|
try: |
||||
|
op = ops.audio_embedding.nnfp(model_name=name, device='cpu').get_op() |
||||
|
except Exception as e: |
||||
|
logger.error(f'Fail to load model {name}. Please check weights.') |
||||
|
|
||||
|
data = torch.rand((1,) + (op.params['n_mels'], op.params['u'])) |
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
status = [name] + ['fail'] * 5 |
||||
|
|
||||
|
try: |
||||
|
out1 = op.model(data).detach().numpy() |
||||
|
logger.info('OP LOADED.') |
||||
|
status[1] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO LOAD OP: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
op.save_model(format='onnx') |
||||
|
logger.info('ONNX SAVED.') |
||||
|
status[2] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
try: |
||||
|
onnx_model = onnx.load(onnx_path) |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
except Exception: |
||||
|
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
||||
|
onnx.checker.check_model(saved_onnx) |
||||
|
logger.info('ONNX CHECKED.') |
||||
|
status[3] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
||||
|
pass |
||||
|
try: |
||||
|
sess = onnxruntime.InferenceSession(onnx_path, |
||||
|
providers=onnxruntime.get_available_providers()) |
||||
|
out2 = sess.run(None, input_feed={'input': data.detach().numpy()}) |
||||
|
logger.info('ONNX WORKED.') |
||||
|
status[4] = 'success' |
||||
|
if numpy.allclose(out1, out2, atol=atol): |
||||
|
logger.info('Check accuracy: OK') |
||||
|
status[5] = 'success' |
||||
|
else: |
||||
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
||||
|
continue |
||||
|
|
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
|
||||
|
print('Finished.') |
Loading…
Reference in new issue