|
|
|
from towhee import ops
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy
|
|
|
|
import onnx
|
|
|
|
import onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
def to_numpy(tensor):
|
|
|
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# decode = ops.audio_decode.ffmpeg()
|
|
|
|
# audio = [x[0] for x in decode('path/to/audio.wav')]
|
|
|
|
audio = torch.rand(10, 256, 32)
|
|
|
|
|
|
|
|
op = ops.audio_embedding.nnfp()
|
|
|
|
out0 = op.get_op().model(audio)
|
|
|
|
# print(out0)
|
|
|
|
|
|
|
|
# Test Pytorch
|
|
|
|
op.get_op().save_model(format='pytorch')
|
|
|
|
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt')
|
|
|
|
out1 = op.get_op().model(audio)
|
|
|
|
assert ((out0 == out1).all())
|
|
|
|
|
|
|
|
# Test Torchscript
|
|
|
|
op.get_op().save_model(format='torchscript')
|
|
|
|
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/torchscript/nnfp.pt')
|
|
|
|
out2 = op.get_op().model(audio)
|
|
|
|
assert ((out0 == out2).all())
|
|
|
|
|
|
|
|
# Test ONNX
|
|
|
|
op.get_op().save_model(format='onnx')
|
|
|
|
op = ops.audio_embedding.nnfp()
|
|
|
|
onnx_model = onnx.load('./saved/onnx/nnfp.onnx')
|
|
|
|
onnx.checker.check_model(onnx_model)
|
|
|
|
|
|
|
|
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx')
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)}
|
|
|
|
ort_outs = ort_session.run(None, ort_inputs)
|
|
|
|
out3 = ort_outs[0]
|
|
|
|
# print(out3)
|
|
|
|
assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05))
|