|  | @ -97,15 +97,22 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |             The local checkpoint path. |  |  |             The local checkpoint path. | 
		
	
		
			
				|  |  |         tokenizer (`object`): |  |  |         tokenizer (`object`): | 
		
	
		
			
				|  |  |             The tokenizer to tokenize input text as model inputs. |  |  |             The tokenizer to tokenize input text as model inputs. | 
		
	
		
			
				|  |  |  |  |  |         pool (`str`): | 
		
	
		
			
				|  |  |  |  |  |             The type of post-process pooling after token embeddings, defaults to "mean". Options: "mean", "cls" | 
		
	
		
			
				|  |  |     """ |  |  |     """ | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def __init__(self, |  |  |     def __init__(self, | 
		
	
		
			
				|  |  |                  model_name: str = None, |  |  |                  model_name: str = None, | 
		
	
		
			
				|  |  |                  checkpoint_path: str = None, |  |  |                  checkpoint_path: str = None, | 
		
	
		
			
				|  |  |                  tokenizer: object = None, |  |  |                  tokenizer: object = None, | 
		
	
		
			
				|  |  |  |  |  |                  pool: str = 'mean', | 
		
	
		
			
				|  |  |                  device: str = None, |  |  |                  device: str = None, | 
		
	
		
			
				|  |  |                  ): |  |  |                  ): | 
		
	
		
			
				|  |  |         super().__init__() |  |  |         super().__init__() | 
		
	
		
			
				|  |  |  |  |  |         if pool not in ['mean', 'cls']: | 
		
	
		
			
				|  |  |  |  |  |             log.warning('Invalid pool %s, using mean pooling instead.', pool) | 
		
	
		
			
				|  |  |  |  |  |             pool = 'mean' | 
		
	
		
			
				|  |  |  |  |  |         self.pool = pool | 
		
	
		
			
				|  |  |         if device: |  |  |         if device: | 
		
	
		
			
				|  |  |             self.device = device |  |  |             self.device = device | 
		
	
		
			
				|  |  |         else: |  |  |         else: | 
		
	
	
		
			
				|  | @ -145,7 +152,10 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |         except Exception as e: |  |  |         except Exception as e: | 
		
	
		
			
				|  |  |             log.error(f'Invalid input for the model: {self.model_name}') |  |  |             log.error(f'Invalid input for the model: {self.model_name}') | 
		
	
		
			
				|  |  |             raise e |  |  |             raise e | 
		
	
		
			
				|  |  |         outs = self.post_proc(outs, inputs) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         if self.pool == 'mean': | 
		
	
		
			
				|  |  |  |  |  |             outs = self.mean_pool(outs, inputs) | 
		
	
		
			
				|  |  |  |  |  |         elif self.pool == 'cls': | 
		
	
		
			
				|  |  |  |  |  |             outs = self.cls_pool(outs) | 
		
	
		
			
				|  |  |         features = outs.detach().numpy() |  |  |         features = outs.detach().numpy() | 
		
	
		
			
				|  |  |         if isinstance(data, str): |  |  |         if isinstance(data, str): | 
		
	
		
			
				|  |  |             features = features.squeeze(0) |  |  |             features = features.squeeze(0) | 
		
	
	
		
			
				|  | @ -184,7 +194,7 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |             } |  |  |             } | 
		
	
		
			
				|  |  |         return onnx_config |  |  |         return onnx_config | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def post_proc(self, token_embeddings, inputs): |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     def mean_pool(self, token_embeddings, inputs): | 
		
	
		
			
				|  |  |         token_embeddings = token_embeddings |  |  |         token_embeddings = token_embeddings | 
		
	
		
			
				|  |  |         attention_mask = inputs['attention_mask'] |  |  |         attention_mask = inputs['attention_mask'] | 
		
	
		
			
				|  |  |         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |  |  |         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | 
		
	
	
		
			
				|  | @ -192,6 +202,15 @@ class AutoTransformers(NNOperator): | 
		
	
		
			
				|  |  |             token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |  |  |             token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | 
		
	
		
			
				|  |  |         return sentence_embs |  |  |         return sentence_embs | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     def cls_pool(self, token_embeddings): | 
		
	
		
			
				|  |  |  |  |  |         _shape = token_embeddings.shape | 
		
	
		
			
				|  |  |  |  |  |         if len(_shape) == 3: | 
		
	
		
			
				|  |  |  |  |  |             return token_embeddings[:, 0, :] | 
		
	
		
			
				|  |  |  |  |  |         elif len(_shape) == 2: | 
		
	
		
			
				|  |  |  |  |  |             return token_embeddings[0] | 
		
	
		
			
				|  |  |  |  |  |         else: | 
		
	
		
			
				|  |  |  |  |  |             raise RuntimeError(f'Invalid shape of token embeddings: {_shape}') | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |     def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |  |  |     def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): | 
		
	
		
			
				|  |  |         if output_file == 'default': |  |  |         if output_file == 'default': | 
		
	
		
			
				|  |  |             output_file = str(Path(__file__).parent) |  |  |             output_file = str(Path(__file__).parent) | 
		
	
	
		
			
				|  | 
 |