Browse Source
        
      
      Update onnx export config
      
        Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 1 changed files with 
16 additions and 
7 deletions
			 
			
		 
		
			
				- 
					
					
					 
					auto_transformers.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -142,6 +142,7 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |     @property | 
			
		
	
		
			
				
					|  |  |  |     def onnx_config(self): | 
			
		
	
		
			
				
					|  |  |  |         from transformers.onnx.features import FeaturesManager | 
			
		
	
		
			
				
					|  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |             model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( | 
			
		
	
		
			
				
					|  |  |  |                 self._model, feature='default') | 
			
		
	
		
			
				
					|  |  |  |             old_config = model_onnx_config(self.model_config) | 
			
		
	
	
		
			
				
					|  |  | @ -149,6 +150,14 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                 'inputs': dict(old_config.inputs), | 
			
		
	
		
			
				
					|  |  |  |                 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} | 
			
		
	
		
			
				
					|  |  |  |             } | 
			
		
	
		
			
				
					|  |  |  |         except Exception: | 
			
		
	
		
			
				
					|  |  |  |             input_dict = {} | 
			
		
	
		
			
				
					|  |  |  |             for k in self.tokenizer.model_input_names: | 
			
		
	
		
			
				
					|  |  |  |                 input_dict[k] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |             onnx_config = { | 
			
		
	
		
			
				
					|  |  |  |                 'inputs': input_dict, | 
			
		
	
		
			
				
					|  |  |  |                 'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}} | 
			
		
	
		
			
				
					|  |  |  |             } | 
			
		
	
		
			
				
					|  |  |  |         return onnx_config | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def post_proc(self, token_embeddings, inputs): | 
			
		
	
	
		
			
				
					|  |  | 
 |