| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -15,6 +15,7 @@ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import os | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import shutil | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from pathlib import Path | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from transformers import AutoTokenizer, AutoModel | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -47,6 +48,8 @@ class AutoTransformers(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model = AutoModel.from_pretrained(model_name).to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.configs = self.model.config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            model_list = self.supported_model_names() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if model_name not in model_list: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -87,57 +90,32 @@ class AutoTransformers(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            name = self.model_name.replace('/', '-') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = os.path.join(path, name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        inputs = self.tokenizer('[CLS]', return_tensors='pt')  # a dictionary | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        dummy_input = '[CLS]' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        inputs = self.tokenizer(dummy_input, return_tensors='pt')  # a dictionary | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if format == 'pytorch': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = path + '.pt' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.save(self.model, path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.save(self.model, path + '.pt') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        elif format == 'torchscript': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = path + '.pt' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            inputs = list(inputs.values()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    jit_model = torch.jit.script(self.model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                except Exception: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    jit_model = torch.jit.trace(self.model, inputs, strict=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.jit.save(jit_model, path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.jit.save(jit_model, path + '.pt') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                log.error(f'Fail to save as torchscript: {e}.') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        elif format == 'onnx': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = path + '.onnx' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            input_names = list(inputs.keys()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dynamic_axes = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for i_n in input_names: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                output_names = ['last_hidden_state'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for o_n in output_names: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.onnx.export(self.model, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  tuple(inputs.values()), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  path, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  input_names=input_names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  output_names=output_names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  dynamic_axes=dynamic_axes, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  opset_version=11, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  do_constant_folding=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  # enable_onnx_checker=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                print(e, '\nTrying with 2 outputs...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                output_names = ['last_hidden_state', 'pooler_output'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for o_n in output_names: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.onnx.export(self.model, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  tuple(inputs.values()), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  path, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  input_names=input_names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  output_names=output_names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  dynamic_axes=dynamic_axes, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  opset_version=11, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  do_constant_folding=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  # enable_onnx_checker=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from transformers.convert_graph_to_onnx import convert | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if os.path.isdir(path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                shutil.rmtree(path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = os.path.join(path, 'model.onnx') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            convert( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                model=self.model_name, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                output=Path(path), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                framework='pt', | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                opset=13 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # todo: elif format == 'tensorrt': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log.error(f'Unsupported format "{format}".') | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |