logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

48 lines
1.3 KiB

from towhee import ops
import warnings
import torch
import numpy
import onnx
import onnxruntime
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# decode = ops.audio_decode.ffmpeg()
# audio = [x[0] for x in decode('path/to/audio.wav')]
audio = torch.rand(10, 256, 32).to(device)
op = ops.audio_embedding.nnfp()
out0 = op.get_op().model(audio)
# print(out0)
# Test Pytorch
op.get_op().save_model(format='pytorch')
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt')
out1 = op.get_op().model(audio)
assert ((out0 == out1).all())
# Test Torchscript
op.get_op().save_model(format='torchscript')
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/torchscript/nnfp.pt')
out2 = op.get_op().model(audio)
assert ((out0 == out2).all())
# Test ONNX
op.get_op().save_model(format='onnx')
op = ops.audio_embedding.nnfp()
onnx_model = onnx.load('./saved/onnx/nnfp.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx')
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)}
ort_outs = ort_session.run(None, ort_inputs)
out3 = ort_outs[0]
# print(out3)
assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05))