nnfp
copied
Jael Gu
2 years ago
3 changed files with 88 additions and 5 deletions
@ -0,0 +1,35 @@ |
|||
from towhee import ops |
|||
import torch |
|||
import numpy |
|||
import onnx |
|||
import onnxruntime |
|||
|
|||
|
|||
def to_numpy(tensor): |
|||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|||
|
|||
|
|||
# decode = ops.audio_decode.ffmpeg() |
|||
# audio = [x[0] for x in decode('path/to/audio.wav')] |
|||
audio = torch.rand(10, 256, 32) |
|||
|
|||
op = ops.audio_embedding.nnfp() |
|||
out0 = op.get_op().model(audio) |
|||
# print(out0) |
|||
|
|||
op.get_op().save_model(format='pytorch') |
|||
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt') |
|||
out1 = op.get_op().model(audio) |
|||
assert((out0 == out1).all()) |
|||
|
|||
op.get_op().save_model(format='onnx') |
|||
op = ops.audio_embedding.nnfp() |
|||
onnx_model = onnx.load('./saved/onnx/nnfp.onnx') |
|||
onnx.checker.check_model(onnx_model) |
|||
|
|||
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') |
|||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} |
|||
ort_outs = ort_session.run(None, ort_inputs) |
|||
out2 = ort_outs[0] |
|||
# print(out2) |
|||
assert(numpy.allclose(to_numpy(out0), out2, rtol=1e-03, atol=1e-05)) |
Loading…
Reference in new issue