nnfp
              
                 
                
            
          copied
				 3 changed files with 88 additions and 5 deletions
			
			
		| @ -0,0 +1,35 @@ | |||||
|  | from towhee import ops | ||||
|  | import torch | ||||
|  | import numpy | ||||
|  | import onnx | ||||
|  | import onnxruntime | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def to_numpy(tensor): | ||||
|  |     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | # decode = ops.audio_decode.ffmpeg() | ||||
|  | # audio = [x[0] for x in decode('path/to/audio.wav')] | ||||
|  | audio = torch.rand(10, 256, 32) | ||||
|  | 
 | ||||
|  | op = ops.audio_embedding.nnfp() | ||||
|  | out0 = op.get_op().model(audio) | ||||
|  | # print(out0) | ||||
|  | 
 | ||||
|  | op.get_op().save_model(format='pytorch') | ||||
|  | op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt') | ||||
|  | out1 = op.get_op().model(audio) | ||||
|  | assert((out0 == out1).all()) | ||||
|  | 
 | ||||
|  | op.get_op().save_model(format='onnx') | ||||
|  | op = ops.audio_embedding.nnfp() | ||||
|  | onnx_model = onnx.load('./saved/onnx/nnfp.onnx') | ||||
|  | onnx.checker.check_model(onnx_model) | ||||
|  | 
 | ||||
|  | ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') | ||||
|  | ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} | ||||
|  | ort_outs = ort_session.run(None, ort_inputs) | ||||
|  | out2 = ort_outs[0] | ||||
|  | # print(out2) | ||||
|  | assert(numpy.allclose(to_numpy(out0), out2, rtol=1e-03, atol=1e-05)) | ||||
					Loading…
					
					
				
		Reference in new issue
	
	