nnfp
              
                 
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            50 lines
          
        
        
          
            1.4 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            50 lines
          
        
        
          
            1.4 KiB
          
        
        
      | from towhee import ops | |
| 
 | |
| import warnings | |
| 
 | |
| import torch | |
| import numpy | |
| import onnx | |
| import onnxruntime | |
| 
 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| 
 | |
| def to_numpy(tensor): | |
|     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| 
 | |
| 
 | |
| # decode = ops.audio_decode.ffmpeg() | |
| # audio = [x[0] for x in decode('path/to/audio.wav')] | |
| audio = torch.rand(10, 256, 32).to(device) | |
| 
 | |
| op = ops.audio_embedding.nnfp() | |
| out0 = op.get_op().model(audio) | |
| # print(out0) | |
| 
 | |
| # Test Pytorch | |
| op.get_op().save_model(format='pytorch') | |
| op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt') | |
| out1 = op.get_op().model(audio) | |
| assert ((out0 == out1).all()) | |
| 
 | |
| # Test Torchscript | |
| op.get_op().save_model(format='torchscript') | |
| op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt') | |
| out2 = op.get_op().model(audio) | |
| assert ((out0 == out2).all()) | |
| 
 | |
| # Test ONNX | |
| op.get_op().save_model(format='onnx') | |
| op = ops.audio_embedding.nnfp() | |
| onnx_model = onnx.load('./saved/onnx/nnfp.onnx') | |
| onnx.checker.check_model(onnx_model) | |
| 
 | |
| ort_session = onnxruntime.InferenceSession( | |
| './saved/onnx/nnfp.onnx', | |
| providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| 
 | |
| ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| out3 = ort_outs[0] | |
| # print(out3) | |
| assert (numpy.allclose(to_numpy(out0), out3, rtol=1e-03, atol=1e-05))
 | 
