|  |  | @ -16,6 +16,7 @@ import logging | 
			
		
	
		
			
				
					|  |  |  | import numpy | 
			
		
	
		
			
				
					|  |  |  | from typing import Union, List | 
			
		
	
		
			
				
					|  |  |  | from pathlib import Path | 
			
		
	
		
			
				
					|  |  |  | from functools import partial | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | import torch | 
			
		
	
		
			
				
					|  |  |  | from sentence_transformers import SentenceTransformer | 
			
		
	
	
		
			
				
					|  |  | @ -62,8 +63,11 @@ class Model: | 
			
		
	
		
			
				
					|  |  |  |         self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) | 
			
		
	
		
			
				
					|  |  |  |         self.model.eval() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, **features): | 
			
		
	
		
			
				
					|  |  |  |         outs = self.model(features) | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, *_, **kwargs): | 
			
		
	
		
			
				
					|  |  |  |         new_kwargs = {} | 
			
		
	
		
			
				
					|  |  |  |         for k, v in kwargs.items(): | 
			
		
	
		
			
				
					|  |  |  |             new_kwargs[k] = v.to(self.device) | 
			
		
	
		
			
				
					|  |  |  |         outs = self.model(new_kwargs) | 
			
		
	
		
			
				
					|  |  |  |         return outs['sentence_embedding'] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -81,16 +85,14 @@ class STransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |             self.model = Model(model_name=self.model_name, device=self.device) | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             log.warning('The operator is initialized without specified model.') | 
			
		
	
		
			
				
					|  |  |  |             pass | 
			
		
	
		
			
				
					|  |  |  |         self._tokenize = self.get_tokenizer() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, txt: Union[List[str], str]): | 
			
		
	
		
			
				
					|  |  |  |         if isinstance(txt, str): | 
			
		
	
		
			
				
					|  |  |  |             sentences = [txt] | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             sentences = txt | 
			
		
	
		
			
				
					|  |  |  |         inputs = self.tokenize(sentences) | 
			
		
	
		
			
				
					|  |  |  | #         for k, v in inputs.items(): | 
			
		
	
		
			
				
					|  |  |  | #             inputs[k] = v.to(self.device) | 
			
		
	
		
			
				
					|  |  |  |         inputs = self._tokenize(sentences) | 
			
		
	
		
			
				
					|  |  |  |         embs = self.model(**inputs).cpu().detach().numpy() | 
			
		
	
		
			
				
					|  |  |  |         if isinstance(txt, str): | 
			
		
	
		
			
				
					|  |  |  |             embs = embs.squeeze(0) | 
			
		
	
	
		
			
				
					|  |  | @ -102,41 +104,41 @@ class STransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |     def supported_formats(self): | 
			
		
	
		
			
				
					|  |  |  |         return ['onnx'] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def tokenize(self, x): | 
			
		
	
		
			
				
					|  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |             outs = self._model.tokenize(x) | 
			
		
	
		
			
				
					|  |  |  |         except Exception: | 
			
		
	
		
			
				
					|  |  |  |             from transformers import AutoTokenizer | 
			
		
	
		
			
				
					|  |  |  |     def get_tokenizer(self): | 
			
		
	
		
			
				
					|  |  |  |         if hasattr(self._model, "tokenize"): | 
			
		
	
		
			
				
					|  |  |  |             return self._model.tokenize | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             from transformers import AutoTokenizer, AutoConfig | 
			
		
	
		
			
				
					|  |  |  |             try: | 
			
		
	
		
			
				
					|  |  |  |                 tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) | 
			
		
	
		
			
				
					|  |  |  |                 conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) | 
			
		
	
		
			
				
					|  |  |  |             except Exception: | 
			
		
	
		
			
				
					|  |  |  |                 tokenizer = AutoTokenizer.from_pretrained(self.model_name) | 
			
		
	
		
			
				
					|  |  |  |             outs = tokenizer( | 
			
		
	
		
			
				
					|  |  |  |                 x, | 
			
		
	
		
			
				
					|  |  |  |                 padding=True, truncation='longest_first', max_length=self.max_seq_length, | 
			
		
	
		
			
				
					|  |  |  |                 return_tensors='pt', | 
			
		
	
		
			
				
					|  |  |  |             ) | 
			
		
	
		
			
				
					|  |  |  |         return outs | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     @property | 
			
		
	
		
			
				
					|  |  |  |     def max_seq_length(self): | 
			
		
	
		
			
				
					|  |  |  |         import json | 
			
		
	
		
			
				
					|  |  |  |         from torch.hub import _get_torch_home | 
			
		
	
		
			
				
					|  |  |  |         torch_cache = _get_torch_home() | 
			
		
	
		
			
				
					|  |  |  |         sbert_cache = os.path.join(torch_cache, 'sentence_transformers') | 
			
		
	
		
			
				
					|  |  |  |         cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') | 
			
		
	
		
			
				
					|  |  |  |         if not os.path.exists(cfg_path): | 
			
		
	
		
			
				
					|  |  |  |             cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') | 
			
		
	
		
			
				
					|  |  |  |             k = 'max_position_embeddings' | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             k = 'max_seq_length' | 
			
		
	
		
			
				
					|  |  |  |         with open(cfg_path) as f: | 
			
		
	
		
			
				
					|  |  |  |             cfg = json.load(f) | 
			
		
	
		
			
				
					|  |  |  |             if k in cfg: | 
			
		
	
		
			
				
					|  |  |  |                 max_seq_len = cfg[k] | 
			
		
	
		
			
				
					|  |  |  |             else: | 
			
		
	
		
			
				
					|  |  |  |                 max_seq_len = None | 
			
		
	
		
			
				
					|  |  |  |         return max_seq_len | 
			
		
	
		
			
				
					|  |  |  |                 conf = AutoConfig.from_pretrained(self.model_name) | 
			
		
	
		
			
				
					|  |  |  |             return partial(tokenizer, | 
			
		
	
		
			
				
					|  |  |  |                            padding=True, | 
			
		
	
		
			
				
					|  |  |  |                            truncation='longest_first', | 
			
		
	
		
			
				
					|  |  |  |                            max_length=conf.max_position_embeddings, | 
			
		
	
		
			
				
					|  |  |  |                            return_tensors='pt') | 
			
		
	
		
			
				
					|  |  |  |     # @property | 
			
		
	
		
			
				
					|  |  |  |     # def max_seq_length(self): | 
			
		
	
		
			
				
					|  |  |  |     #     import json | 
			
		
	
		
			
				
					|  |  |  |     #     from torch.hub import _get_torch_home | 
			
		
	
		
			
				
					|  |  |  |     #     torch_cache = _get_torch_home() | 
			
		
	
		
			
				
					|  |  |  |     #     sbert_cache = os.path.join(torch_cache, 'sentence_transformers') | 
			
		
	
		
			
				
					|  |  |  |     #     cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') | 
			
		
	
		
			
				
					|  |  |  |     #     if not os.path.exists(cfg_path): | 
			
		
	
		
			
				
					|  |  |  |     #         cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') | 
			
		
	
		
			
				
					|  |  |  |     #         k = 'max_position_embeddings' | 
			
		
	
		
			
				
					|  |  |  |     #     else: | 
			
		
	
		
			
				
					|  |  |  |     #         k = 'max_seq_length' | 
			
		
	
		
			
				
					|  |  |  |     #     with open(cfg_path) as f: | 
			
		
	
		
			
				
					|  |  |  |     #         cfg = json.load(f) | 
			
		
	
		
			
				
					|  |  |  |     #         if k in cfg: | 
			
		
	
		
			
				
					|  |  |  |     #             max_seq_len = cfg[k] | 
			
		
	
		
			
				
					|  |  |  |     #         else: | 
			
		
	
		
			
				
					|  |  |  |     #             max_seq_len = None | 
			
		
	
		
			
				
					|  |  |  |     #     return max_seq_len | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     @property | 
			
		
	
		
			
				
					|  |  |  |     def _model(self): | 
			
		
	
	
		
			
				
					|  |  | @ -156,7 +158,7 @@ class STransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |             else: | 
			
		
	
		
			
				
					|  |  |  |                 raise AttributeError(f'Invalid format {format}.') | 
			
		
	
		
			
				
					|  |  |  |         dummy_text = ['[CLS]'] | 
			
		
	
		
			
				
					|  |  |  |         dummy_input = self.tokenize(dummy_text) | 
			
		
	
		
			
				
					|  |  |  |         dummy_input = self._tokenize(dummy_text) | 
			
		
	
		
			
				
					|  |  |  |         if format == 'pytorch': | 
			
		
	
		
			
				
					|  |  |  |             torch.save(self._model, path) | 
			
		
	
		
			
				
					|  |  |  |         elif format == 'torchscript': | 
			
		
	
	
		
			
				
					|  |  | @ -180,7 +182,7 @@ class STransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                     dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} | 
			
		
	
		
			
				
					|  |  |  |             dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} | 
			
		
	
		
			
				
					|  |  |  |             try: | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(new_model, | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(new_model.to('cpu'), | 
			
		
	
		
			
				
					|  |  |  |                                   tuple(dummy_input.values()), | 
			
		
	
		
			
				
					|  |  |  |                                   path, | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=input_names, | 
			
		
	
	
		
			
				
					|  |  | @ -200,6 +202,7 @@ class STransformers(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |     @staticmethod | 
			
		
	
		
			
				
					|  |  |  |     def supported_model_names(format: str = None): | 
			
		
	
		
			
				
					|  |  |  |         full_list = [ | 
			
		
	
		
			
				
					|  |  |  |             'clip-ViT-B-32-multilingual-v1', | 
			
		
	
		
			
				
					|  |  |  |             'sentence-t5-xxl', | 
			
		
	
		
			
				
					|  |  |  |             'sentence-t5-xl', | 
			
		
	
		
			
				
					|  |  |  |             'sentence-t5-large', | 
			
		
	
	
		
			
				
					|  |  | 
 |