|  |  | @ -22,6 +22,9 @@ from towhee.operator import NNOperator | 
			
		
	
		
			
				
					|  |  |  | from towhee import register | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | import warnings | 
			
		
	
		
			
				
					|  |  |  | import logging | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | log = logging.getLogger('run_op') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | warnings.filterwarnings('ignore') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -83,6 +86,7 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |             os.makedirs(path, exist_ok=True) | 
			
		
	
		
			
				
					|  |  |  |             name = self.model_name.replace('/', '-') | 
			
		
	
		
			
				
					|  |  |  |             path = os.path.join(path, name) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         inputs = self.tokenizer('[CLS]', return_tensors='pt')  # a dictionary | 
			
		
	
		
			
				
					|  |  |  |         if format == 'pytorch': | 
			
		
	
		
			
				
					|  |  |  |             path = path + '.pt' | 
			
		
	
	
		
			
				
					|  |  | @ -101,37 +105,39 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                 raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
			
		
	
		
			
				
					|  |  |  |         elif format == 'onnx': | 
			
		
	
		
			
				
					|  |  |  |             path = path + '.onnx' | 
			
		
	
		
			
				
					|  |  |  |             input_names = list(inputs.keys()) | 
			
		
	
		
			
				
					|  |  |  |             dynamic_axes = {} | 
			
		
	
		
			
				
					|  |  |  |             for i_n in input_names: | 
			
		
	
		
			
				
					|  |  |  |                 dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |             try: | 
			
		
	
		
			
				
					|  |  |  |                 output_names = ['last_hidden_state'] | 
			
		
	
		
			
				
					|  |  |  |                 for o_n in output_names: | 
			
		
	
		
			
				
					|  |  |  |                     dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(self.model, | 
			
		
	
		
			
				
					|  |  |  |                                   tuple(inputs.values()), | 
			
		
	
		
			
				
					|  |  |  |                                   path, | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=list(inputs.keys()), | 
			
		
	
		
			
				
					|  |  |  |                                   output_names=["last_hidden_state"], | 
			
		
	
		
			
				
					|  |  |  |                                   dynamic_axes={ | 
			
		
	
		
			
				
					|  |  |  |                                       "input_ids": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "token_type_ids": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "attention_mask": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "last_hidden_state": {0: "batch_size"}, | 
			
		
	
		
			
				
					|  |  |  |                                   }, | 
			
		
	
		
			
				
					|  |  |  |                                   opset_version=13, | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=input_names, | 
			
		
	
		
			
				
					|  |  |  |                                   output_names=output_names, | 
			
		
	
		
			
				
					|  |  |  |                                   dynamic_axes=dynamic_axes, | 
			
		
	
		
			
				
					|  |  |  |                                   opset_version=11, | 
			
		
	
		
			
				
					|  |  |  |                                   do_constant_folding=True, | 
			
		
	
		
			
				
					|  |  |  |                                   # enable_onnx_checker=True, | 
			
		
	
		
			
				
					|  |  |  |                                   ) | 
			
		
	
		
			
				
					|  |  |  |             except Exception as e: | 
			
		
	
		
			
				
					|  |  |  |                 print(e, '\nTrying with 2 outputs...') | 
			
		
	
		
			
				
					|  |  |  |                 output_names = ['last_hidden_state', 'pooler_output'] | 
			
		
	
		
			
				
					|  |  |  |                 for o_n in output_names: | 
			
		
	
		
			
				
					|  |  |  |                     dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(self.model, | 
			
		
	
		
			
				
					|  |  |  |                                   tuple(inputs.values()), | 
			
		
	
		
			
				
					|  |  |  |                                   path, | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=["input_ids", "token_type_ids", "attention_mask"],  # list(inputs.keys()) | 
			
		
	
		
			
				
					|  |  |  |                                   output_names=["last_hidden_state", "pooler_output"], | 
			
		
	
		
			
				
					|  |  |  |                                   opset_version=13, | 
			
		
	
		
			
				
					|  |  |  |                                   dynamic_axes={ | 
			
		
	
		
			
				
					|  |  |  |                                       "input_ids": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "token_type_ids": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "attention_mask": {0: "batch_size", 1: "input_length"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "last_hidden_state": {0: "batch_size"}, | 
			
		
	
		
			
				
					|  |  |  |                                       "pooler_outputs": {0: "batch_size"} | 
			
		
	
		
			
				
					|  |  |  |                                   }) | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=input_names, | 
			
		
	
		
			
				
					|  |  |  |                                   output_names=output_names, | 
			
		
	
		
			
				
					|  |  |  |                                   dynamic_axes=dynamic_axes, | 
			
		
	
		
			
				
					|  |  |  |                                   opset_version=11, | 
			
		
	
		
			
				
					|  |  |  |                                   do_constant_folding=True, | 
			
		
	
		
			
				
					|  |  |  |                                   # enable_onnx_checker=True, | 
			
		
	
		
			
				
					|  |  |  |                                   ) | 
			
		
	
		
			
				
					|  |  |  |         # todo: elif format == 'tensorrt': | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             log.error(f'Unsupported format "{format}".') | 
			
		
	
	
		
			
				
					|  |  | 
 |