|  |  | @ -37,6 +37,7 @@ warnings.filterwarnings('ignore') | 
			
		
	
		
			
				
					|  |  |  | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | 
			
		
	
		
			
				
					|  |  |  | t_logging.set_verbosity_error() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def create_model(model_name, checkpoint_path, device): | 
			
		
	
		
			
				
					|  |  |  |     model = AutoModel.from_pretrained(model_name).to(device) | 
			
		
	
		
			
				
					|  |  |  |     if hasattr(model, 'pooler') and model.pooler: | 
			
		
	
	
		
			
				
					|  |  | @ -50,6 +51,7 @@ def create_model(model_name, checkpoint_path, device): | 
			
		
	
		
			
				
					|  |  |  |     model.eval() | 
			
		
	
		
			
				
					|  |  |  |     return model | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | # @accelerate | 
			
		
	
		
			
				
					|  |  |  | class Model: | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name, checkpoint_path, device): | 
			
		
	
	
		
			
				
					|  |  | @ -131,7 +133,7 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     @property | 
			
		
	
		
			
				
					|  |  |  |     def _model(self): | 
			
		
	
		
			
				
					|  |  |  |         return self.model.model | 
			
		
	
		
			
				
					|  |  |  |         return self.model.model.to('cpu') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): | 
			
		
	
		
			
				
					|  |  |  |         if output_file == 'default': | 
			
		
	
	
		
			
				
					|  |  | @ -165,7 +167,7 @@ class AutoTransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |         elif model_type == 'onnx': | 
			
		
	
		
			
				
					|  |  |  |             from transformers.onnx.features import FeaturesManager | 
			
		
	
		
			
				
					|  |  |  |             from transformers.onnx import export | 
			
		
	
		
			
				
					|  |  |  |             self._model = self._model.to('cpu') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |             model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( | 
			
		
	
		
			
				
					|  |  |  |                 self._model, feature='default') | 
			
		
	
		
			
				
					|  |  |  |             onnx_config = model_onnx_config(self._model.config) | 
			
		
	
	
		
			
				
					|  |  | 
 |