|  | @ -13,6 +13,9 @@ | 
		
	
		
			
				|  |  | # limitations under the License. |  |  | # limitations under the License. | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import numpy |  |  | import numpy | 
		
	
		
			
				|  |  |  |  |  | import os | 
		
	
		
			
				|  |  |  |  |  | import torch | 
		
	
		
			
				|  |  |  |  |  | from pathlib import Path | 
		
	
		
			
				|  |  | from transformers import AutoTokenizer, AutoModel |  |  | from transformers import AutoTokenizer, AutoModel | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | from towhee.operator import NNOperator |  |  | from towhee.operator import NNOperator | 
		
	
	
		
			
				|  | @ -40,6 +43,7 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |         self.model_name = model_name |  |  |         self.model_name = model_name | 
		
	
		
			
				|  |  |         try: |  |  |         try: | 
		
	
		
			
				|  |  |             self.model = AutoModel.from_pretrained(model_name) |  |  |             self.model = AutoModel.from_pretrained(model_name) | 
		
	
		
			
				|  |  |  |  |  |             self.model.eval() | 
		
	
		
			
				|  |  |         except Exception as e: |  |  |         except Exception as e: | 
		
	
		
			
				|  |  |             model_list = get_model_list() |  |  |             model_list = get_model_list() | 
		
	
		
			
				|  |  |             if model_name not in model_list: |  |  |             if model_name not in model_list: | 
		
	
	
		
			
				|  | @ -65,13 +69,29 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |             log.error(f'Invalid input for the model: {self.model_name}') |  |  |             log.error(f'Invalid input for the model: {self.model_name}') | 
		
	
		
			
				|  |  |             raise e |  |  |             raise e | 
		
	
		
			
				|  |  |         try: |  |  |         try: | 
		
	
		
			
				|  |  |             features = outs.last_hidden_state.squeeze(0) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |             features = outs['last_hidden_state'].squeeze(0) | 
		
	
		
			
				|  |  |         except Exception as e: |  |  |         except Exception as e: | 
		
	
		
			
				|  |  |             log.error(f'Fail to extract features by model: {self.model_name}') |  |  |             log.error(f'Fail to extract features by model: {self.model_name}') | 
		
	
		
			
				|  |  |             raise e |  |  |             raise e | 
		
	
		
			
				|  |  |         vec = features.detach().numpy() |  |  |         vec = features.detach().numpy() | 
		
	
		
			
				|  |  |         return vec |  |  |         return vec | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     def save_model(self, jit: bool = True, destination: str = 'default'): | 
		
	
		
			
				|  |  |  |  |  |         if destination == 'default': | 
		
	
		
			
				|  |  |  |  |  |             path = str(Path(__file__).parent) | 
		
	
		
			
				|  |  |  |  |  |             destination = os.path.join(path, self.model_name + '.pt') | 
		
	
		
			
				|  |  |  |  |  |         inputs = self.tokenizer('[CLS]', return_tensors='pt') | 
		
	
		
			
				|  |  |  |  |  |         inputs = list(inputs.values()) | 
		
	
		
			
				|  |  |  |  |  |         if jit: | 
		
	
		
			
				|  |  |  |  |  |             try: | 
		
	
		
			
				|  |  |  |  |  |                 traced_model = torch.jit.trace(self.model, inputs, strict=False) | 
		
	
		
			
				|  |  |  |  |  |                 torch.jit.save(traced_model, destination) | 
		
	
		
			
				|  |  |  |  |  |             except Exception as e: | 
		
	
		
			
				|  |  |  |  |  |                 log.error(f'Fail to save as torchscript: {e}.') | 
		
	
		
			
				|  |  |  |  |  |                 raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
		
	
		
			
				|  |  |  |  |  |         else: | 
		
	
		
			
				|  |  |  |  |  |             torch.save(self.model, destination) | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def get_model_list(): |  |  | def get_model_list(): | 
		
	
		
			
				|  |  |     full_list = [ |  |  |     full_list = [ | 
		
	
	
		
			
				|  | 
 |