Browse Source
        
      
      Update save onnx in evaluation
      
        Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 1 changed files with 
21 additions and 
1 deletions
			 
			
		 
		
			
				- 
					
					
					 
					benchmark/run.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -122,7 +122,27 @@ elif args.format == 'onnx': | 
			
		
	
		
			
				
					|  |  |  |     collection_name = collection_name + '_onnx' | 
			
		
	
		
			
				
					|  |  |  |     saved_name = model_name.replace('/', '-') | 
			
		
	
		
			
				
					|  |  |  |     if not os.path.exists(onnx_path): | 
			
		
	
		
			
				
					|  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |             op.save_model(format='onnx', path=onnx_path[:-5]) | 
			
		
	
		
			
				
					|  |  |  |         except Exception: | 
			
		
	
		
			
				
					|  |  |  |             inputs = op.tokenizer('This is test.', return_tensors='pt') | 
			
		
	
		
			
				
					|  |  |  |             input_names = list(inputs.keys()) | 
			
		
	
		
			
				
					|  |  |  |             dynamic_axes = {} | 
			
		
	
		
			
				
					|  |  |  |             for i_n in input_names: | 
			
		
	
		
			
				
					|  |  |  |                 dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |             output_names = ['last_hidden_state'] | 
			
		
	
		
			
				
					|  |  |  |             for o_n in output_names: | 
			
		
	
		
			
				
					|  |  |  |                 dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |             torch.onnx.export( | 
			
		
	
		
			
				
					|  |  |  |                 op.model, | 
			
		
	
		
			
				
					|  |  |  |                 tuple(inputs.values()), | 
			
		
	
		
			
				
					|  |  |  |                 onnx_path, | 
			
		
	
		
			
				
					|  |  |  |                 input_names=input_names, | 
			
		
	
		
			
				
					|  |  |  |                 output_names=output_names, | 
			
		
	
		
			
				
					|  |  |  |                 dynamic_axes=dynamic_axes, | 
			
		
	
		
			
				
					|  |  |  |                 opset_version=14, | 
			
		
	
		
			
				
					|  |  |  |                 do_constant_folding=True, | 
			
		
	
		
			
				
					|  |  |  |                 ) | 
			
		
	
		
			
				
					|  |  |  |     sess = onnxruntime.InferenceSession(onnx_path, | 
			
		
	
		
			
				
					|  |  |  |                                         providers=onnxruntime.get_available_providers()) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | 
 |