|  | @ -16,6 +16,7 @@ import logging | 
		
	
		
			
				|  |  | import numpy |  |  | import numpy | 
		
	
		
			
				|  |  | from typing import Union, List |  |  | from typing import Union, List | 
		
	
		
			
				|  |  | from pathlib import Path |  |  | from pathlib import Path | 
		
	
		
			
				|  |  |  |  |  | from functools import partial | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import torch |  |  | import torch | 
		
	
		
			
				|  |  | from sentence_transformers import SentenceTransformer |  |  | from sentence_transformers import SentenceTransformer | 
		
	
	
		
			
				|  | @ -62,8 +63,11 @@ class Model: | 
		
	
		
			
				|  |  |         self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) |  |  |         self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) | 
		
	
		
			
				|  |  |         self.model.eval() |  |  |         self.model.eval() | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def __call__(self, **features): |  |  |  | 
		
	
		
			
				|  |  |         outs = self.model(features) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     def __call__(self, *_, **kwargs): | 
		
	
		
			
				|  |  |  |  |  |         new_kwargs = {} | 
		
	
		
			
				|  |  |  |  |  |         for k, v in kwargs.items(): | 
		
	
		
			
				|  |  |  |  |  |             new_kwargs[k] = v.to(self.device) | 
		
	
		
			
				|  |  |  |  |  |         outs = self.model(new_kwargs) | 
		
	
		
			
				|  |  |         return outs['sentence_embedding'] |  |  |         return outs['sentence_embedding'] | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | @ -81,16 +85,14 @@ class STransformers(NNOperator): | 
		
	
		
			
				|  |  |             self.model = Model(model_name=self.model_name, device=self.device) |  |  |             self.model = Model(model_name=self.model_name, device=self.device) | 
		
	
		
			
				|  |  |         else: |  |  |         else: | 
		
	
		
			
				|  |  |             log.warning('The operator is initialized without specified model.') |  |  |             log.warning('The operator is initialized without specified model.') | 
		
	
		
			
				|  |  |             pass |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         self._tokenize = self.get_tokenizer() | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def __call__(self, txt: Union[List[str], str]): |  |  |     def __call__(self, txt: Union[List[str], str]): | 
		
	
		
			
				|  |  |         if isinstance(txt, str): |  |  |         if isinstance(txt, str): | 
		
	
		
			
				|  |  |             sentences = [txt] |  |  |             sentences = [txt] | 
		
	
		
			
				|  |  |         else: |  |  |         else: | 
		
	
		
			
				|  |  |             sentences = txt |  |  |             sentences = txt | 
		
	
		
			
				|  |  |         inputs = self.tokenize(sentences) |  |  |  | 
		
	
		
			
				|  |  | #         for k, v in inputs.items(): |  |  |  | 
		
	
		
			
				|  |  | #             inputs[k] = v.to(self.device) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         inputs = self._tokenize(sentences) | 
		
	
		
			
				|  |  |         embs = self.model(**inputs).cpu().detach().numpy() |  |  |         embs = self.model(**inputs).cpu().detach().numpy() | 
		
	
		
			
				|  |  |         if isinstance(txt, str): |  |  |         if isinstance(txt, str): | 
		
	
		
			
				|  |  |             embs = embs.squeeze(0) |  |  |             embs = embs.squeeze(0) | 
		
	
	
		
			
				|  | @ -102,41 +104,41 @@ class STransformers(NNOperator): | 
		
	
		
			
				|  |  |     def supported_formats(self): |  |  |     def supported_formats(self): | 
		
	
		
			
				|  |  |         return ['onnx'] |  |  |         return ['onnx'] | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     def tokenize(self, x): |  |  |  | 
		
	
		
			
				|  |  |         try: |  |  |  | 
		
	
		
			
				|  |  |             outs = self._model.tokenize(x) |  |  |  | 
		
	
		
			
				|  |  |         except Exception: |  |  |  | 
		
	
		
			
				|  |  |             from transformers import AutoTokenizer |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     def get_tokenizer(self): | 
		
	
		
			
				|  |  |  |  |  |         if hasattr(self._model, "tokenize"): | 
		
	
		
			
				|  |  |  |  |  |             return self._model.tokenize | 
		
	
		
			
				|  |  |  |  |  |         else: | 
		
	
		
			
				|  |  |  |  |  |             from transformers import AutoTokenizer, AutoConfig | 
		
	
		
			
				|  |  |             try: |  |  |             try: | 
		
	
		
			
				|  |  |                 tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |  |  |                 tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) | 
		
	
		
			
				|  |  |  |  |  |                 conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) | 
		
	
		
			
				|  |  |             except Exception: |  |  |             except Exception: | 
		
	
		
			
				|  |  |                 tokenizer = AutoTokenizer.from_pretrained(self.model_name) |  |  |                 tokenizer = AutoTokenizer.from_pretrained(self.model_name) | 
		
	
		
			
				|  |  |             outs = tokenizer( |  |  |  | 
		
	
		
			
				|  |  |                 x, |  |  |  | 
		
	
		
			
				|  |  |                 padding=True, truncation='longest_first', max_length=self.max_seq_length, |  |  |  | 
		
	
		
			
				|  |  |                 return_tensors='pt', |  |  |  | 
		
	
		
			
				|  |  |             ) |  |  |  | 
		
	
		
			
				|  |  |         return outs |  |  |  | 
		
	
		
			
				|  |  | 
 |  |  |  | 
		
	
		
			
				|  |  |     @property |  |  |  | 
		
	
		
			
				|  |  |     def max_seq_length(self): |  |  |  | 
		
	
		
			
				|  |  |         import json |  |  |  | 
		
	
		
			
				|  |  |         from torch.hub import _get_torch_home |  |  |  | 
		
	
		
			
				|  |  |         torch_cache = _get_torch_home() |  |  |  | 
		
	
		
			
				|  |  |         sbert_cache = os.path.join(torch_cache, 'sentence_transformers') |  |  |  | 
		
	
		
			
				|  |  |         cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') |  |  |  | 
		
	
		
			
				|  |  |         if not os.path.exists(cfg_path): |  |  |  | 
		
	
		
			
				|  |  |             cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') |  |  |  | 
		
	
		
			
				|  |  |             k = 'max_position_embeddings' |  |  |  | 
		
	
		
			
				|  |  |         else: |  |  |  | 
		
	
		
			
				|  |  |             k = 'max_seq_length' |  |  |  | 
		
	
		
			
				|  |  |         with open(cfg_path) as f: |  |  |  | 
		
	
		
			
				|  |  |             cfg = json.load(f) |  |  |  | 
		
	
		
			
				|  |  |             if k in cfg: |  |  |  | 
		
	
		
			
				|  |  |                 max_seq_len = cfg[k] |  |  |  | 
		
	
		
			
				|  |  |             else: |  |  |  | 
		
	
		
			
				|  |  |                 max_seq_len = None |  |  |  | 
		
	
		
			
				|  |  |         return max_seq_len |  |  |  | 
		
	
		
			
				|  |  |  |  |  |                 conf = AutoConfig.from_pretrained(self.model_name) | 
		
	
		
			
				|  |  |  |  |  |             return partial(tokenizer, | 
		
	
		
			
				|  |  |  |  |  |                            padding=True, | 
		
	
		
			
				|  |  |  |  |  |                            truncation='longest_first', | 
		
	
		
			
				|  |  |  |  |  |                            max_length=conf.max_position_embeddings, | 
		
	
		
			
				|  |  |  |  |  |                            return_tensors='pt') | 
		
	
		
			
				|  |  |  |  |  |     # @property | 
		
	
		
			
				|  |  |  |  |  |     # def max_seq_length(self): | 
		
	
		
			
				|  |  |  |  |  |     #     import json | 
		
	
		
			
				|  |  |  |  |  |     #     from torch.hub import _get_torch_home | 
		
	
		
			
				|  |  |  |  |  |     #     torch_cache = _get_torch_home() | 
		
	
		
			
				|  |  |  |  |  |     #     sbert_cache = os.path.join(torch_cache, 'sentence_transformers') | 
		
	
		
			
				|  |  |  |  |  |     #     cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') | 
		
	
		
			
				|  |  |  |  |  |     #     if not os.path.exists(cfg_path): | 
		
	
		
			
				|  |  |  |  |  |     #         cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') | 
		
	
		
			
				|  |  |  |  |  |     #         k = 'max_position_embeddings' | 
		
	
		
			
				|  |  |  |  |  |     #     else: | 
		
	
		
			
				|  |  |  |  |  |     #         k = 'max_seq_length' | 
		
	
		
			
				|  |  |  |  |  |     #     with open(cfg_path) as f: | 
		
	
		
			
				|  |  |  |  |  |     #         cfg = json.load(f) | 
		
	
		
			
				|  |  |  |  |  |     #         if k in cfg: | 
		
	
		
			
				|  |  |  |  |  |     #             max_seq_len = cfg[k] | 
		
	
		
			
				|  |  |  |  |  |     #         else: | 
		
	
		
			
				|  |  |  |  |  |     #             max_seq_len = None | 
		
	
		
			
				|  |  |  |  |  |     #     return max_seq_len | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     @property |  |  |     @property | 
		
	
		
			
				|  |  |     def _model(self): |  |  |     def _model(self): | 
		
	
	
		
			
				|  | @ -156,7 +158,7 @@ class STransformers(NNOperator): | 
		
	
		
			
				|  |  |             else: |  |  |             else: | 
		
	
		
			
				|  |  |                 raise AttributeError(f'Invalid format {format}.') |  |  |                 raise AttributeError(f'Invalid format {format}.') | 
		
	
		
			
				|  |  |         dummy_text = ['[CLS]'] |  |  |         dummy_text = ['[CLS]'] | 
		
	
		
			
				|  |  |         dummy_input = self.tokenize(dummy_text) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         dummy_input = self._tokenize(dummy_text) | 
		
	
		
			
				|  |  |         if format == 'pytorch': |  |  |         if format == 'pytorch': | 
		
	
		
			
				|  |  |             torch.save(self._model, path) |  |  |             torch.save(self._model, path) | 
		
	
		
			
				|  |  |         elif format == 'torchscript': |  |  |         elif format == 'torchscript': | 
		
	
	
		
			
				|  | @ -180,7 +182,7 @@ class STransformers(NNOperator): | 
		
	
		
			
				|  |  |                     dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |  |  |                     dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} | 
		
	
		
			
				|  |  |             dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} |  |  |             dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} | 
		
	
		
			
				|  |  |             try: |  |  |             try: | 
		
	
		
			
				|  |  |                 torch.onnx.export(new_model, |  |  |  | 
		
	
		
			
				|  |  |  |  |  |                 torch.onnx.export(new_model.to('cpu'), | 
		
	
		
			
				|  |  |                                   tuple(dummy_input.values()), |  |  |                                   tuple(dummy_input.values()), | 
		
	
		
			
				|  |  |                                   path, |  |  |                                   path, | 
		
	
		
			
				|  |  |                                   input_names=input_names, |  |  |                                   input_names=input_names, | 
		
	
	
		
			
				|  | @ -200,6 +202,7 @@ class STransformers(NNOperator): | 
		
	
		
			
				|  |  |     @staticmethod |  |  |     @staticmethod | 
		
	
		
			
				|  |  |     def supported_model_names(format: str = None): |  |  |     def supported_model_names(format: str = None): | 
		
	
		
			
				|  |  |         full_list = [ |  |  |         full_list = [ | 
		
	
		
			
				|  |  |  |  |  |             'clip-ViT-B-32-multilingual-v1', | 
		
	
		
			
				|  |  |             'sentence-t5-xxl', |  |  |             'sentence-t5-xxl', | 
		
	
		
			
				|  |  |             'sentence-t5-xl', |  |  |             'sentence-t5-xl', | 
		
	
		
			
				|  |  |             'sentence-t5-large', |  |  |             'sentence-t5-large', | 
		
	
	
		
			
				|  | 
 |