from towhee import ops import warnings import torch import numpy import onnx import onnxruntime def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # decode = ops.audio_decode.ffmpeg() # audio = [x[0] for x in decode('path/to/audio.wav')] audio = torch.rand(10, 256, 32) op = ops.audio_embedding.nnfp() out0 = op.get_op().model(audio) # print(out0) # Test Pytorch op.get_op().save_model(format='pytorch') op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt') out1 = op.get_op().model(audio) assert ((out0 == out1).all()) # Test Torchscript op.get_op().save_model(format='torchscript') op = ops.audio_embedding.nnfp(checkpoint_path='./saved/torchscript/nnfp.pt') out2 = op.get_op().model(audio) assert ((out0 == out2).all()) # Test ONNX op.get_op().save_model(format='onnx') op = ops.audio_embedding.nnfp() onnx_model = onnx.load('./saved/onnx/nnfp.onnx') onnx.checker.check_model(onnx_model) ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} ort_outs = ort_session.run(None, ort_inputs) out3 = ort_outs[0] # print(out3) assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05))