Browse Source
        
      
      Add providers to onnxruntime
      
        Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 2 changed files with 
8 additions and 
2 deletions
			 
			
		 
		
			
				- 
					
					
					 
					nn_fingerprint.py
				
- 
					
					
					 
					test.py
				
				
				
					
						
							
								
									
	
		
			
				|  | @ -60,7 +60,10 @@ class NNFingerprint(NNOperator): | 
		
	
		
			
				|  |  |             model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') |  |  |             model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') | 
		
	
		
			
				|  |  |         if model_path.endswith('.onnx'): |  |  |         if model_path.endswith('.onnx'): | 
		
	
		
			
				|  |  |             log.warning('Using onnx.') |  |  |             log.warning('Using onnx.') | 
		
	
		
			
				|  |  |             self.model = onnxruntime.InferenceSession(model_path) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |             self.model = onnxruntime.InferenceSession( | 
		
	
		
			
				|  |  |  |  |  |             model_path, | 
		
	
		
			
				|  |  |  |  |  |             providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] | 
		
	
		
			
				|  |  |  |  |  |             ) | 
		
	
		
			
				|  |  |         else: |  |  |         else: | 
		
	
		
			
				|  |  |             state_dict = torch.load(model_path, map_location=self.device) |  |  |             state_dict = torch.load(model_path, map_location=self.device) | 
		
	
		
			
				|  |  |             if isinstance(state_dict, torch.nn.Module): |  |  |             if isinstance(state_dict, torch.nn.Module): | 
		
	
	
		
			
				|  | 
 | 
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				|  | @ -39,7 +39,10 @@ op = ops.audio_embedding.nnfp() | 
		
	
		
			
				|  |  | onnx_model = onnx.load('./saved/onnx/nnfp.onnx') |  |  | onnx_model = onnx.load('./saved/onnx/nnfp.onnx') | 
		
	
		
			
				|  |  | onnx.checker.check_model(onnx_model) |  |  | onnx.checker.check_model(onnx_model) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') |  |  |  | 
		
	
		
			
				|  |  |  |  |  | ort_session = onnxruntime.InferenceSession( | 
		
	
		
			
				|  |  |  |  |  | './saved/onnx/nnfp.onnx', | 
		
	
		
			
				|  |  |  |  |  | providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} |  |  | ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} | 
		
	
		
			
				|  |  | ort_outs = ort_session.run(None, ort_inputs) |  |  | ort_outs = ort_session.run(None, ort_inputs) | 
		
	
		
			
				|  |  | out3 = ort_outs[0] |  |  | out3 = ort_outs[0] | 
		
	
	
		
			
				|  | 
 |