diff --git a/README.md b/README.md index 432c853..d4ae5e5 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,115 @@ -# sbert +# Sentence Embedding with Sentence Transformers +*author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +This operator takes a sentence or a list of sentences in string as input. +It generates an embedding vector in numpy.ndarray for each sentence, which captures the input sentence's core semantic elements. +This operator is implemented with pre-trained models from [Sentence Transformers](https://www.sbert.net/). + +
+ +## Code Example + +Use the pre-trained model "all-MiniLM-L12-v2" +to generate a text embedding for the sentence "This is a sentence.". + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +- **option 1 (towhee>=0.9.0):** +```python +from towhee.dc2 import pipe, ops, DataCollection + +p = ( + pipe.input('sentence') + .map('sentence', 'vec', ops.sentence_embedding.sbert(model_name='all-MiniLM-L12-v2')) + .output('sentence', 'vec') +) + +DataCollection(p('This is a sentence.')).show() +``` + + + +- **option 2:** + +```python +import towhee + +( + towhee.dc['sentence'](['This is a sentence.']) + .sentence_embedding.sbert['sentence', 'vec'](model_name='all-MiniLM-L12-v2') + .show() +) +``` + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***text_embedding.sbert(model_name='all-MiniLM-L12-v2')*** + +**Parameters:** + +***model_name***: *str* + +The model name in string. Supported model names: + +Refer to [SBert Doc](https://www.sbert.net/docs/pretrained_models.html). +Please note that only models listed `supported_model_names` are tested. +You can refer to [Towhee Pipeline]() for model performance. + +***device***: *str* + +The device to run model, defaults to None. +If None, it will use 'cuda' automatically when cuda is available. + +
+ +## Interface + +The operator takes a sentence or a list of sentences in string as input. +It loads tokenizer and pre-trained model using model name, +and then returns text embedding in numpy.ndarray. + +***__call__(txt)*** + +**Parameters:** + +***txt***: *Union[List[str], str]* + +​ A sentence or a list of sentences in string. + + +**Returns**: + +*Union[List[numpy.ndarray], numpy.ndarray]* + +​ If input is a sentence in string, then it returns an embedding vector of shape (dim,) in numpy.ndarray. +If input is a list of sentences, then it returns a list of embedding vectors, each of which a numpy.ndarray in shape of (dim,). + +
+ +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', defaults to None. +If None, it will return a full list of supported model names. + +```python +from towhee import ops + +op = ops.sentence_embedding.sentence_transformers().get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..d7afb3c --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .s_bert import STransformers + + +def sbert(*args, **kwargs): + return STransformers(*args, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..63e3e64 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +sentence_transformers +torch \ No newline at end of file diff --git a/result.png b/result.png new file mode 100644 index 0000000..7be79f1 Binary files /dev/null and b/result.png differ diff --git a/s_bert.py b/s_bert.py new file mode 100644 index 0000000..4a42326 --- /dev/null +++ b/s_bert.py @@ -0,0 +1,221 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy +from typing import Union, List +from pathlib import Path + +import torch +from sentence_transformers import SentenceTransformer + +from towhee.operator import NNOperator +# from towhee.dc2 import accelerate + +import os +import warnings + +warnings.filterwarnings('ignore') +logging.getLogger('sentence_transformers').setLevel(logging.ERROR) +log = logging.getLogger('op_sbert') + + +class ConvertModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.net = model + try: + self.input_names = self.net.tokenizer.model_input_names + except AttributeError: + self.input_names = list(self.net.tokenize(['test']).keys()) + + def forward(self, *args, **kwargs): + if args: + assert kwargs == {}, 'Only accept neither args or kwargs as inputs.' + assert len(args) == len(self.input_names) + for k, v in zip(self.input_names, args): + kwargs[k] = v + outs = self.net(kwargs) + return outs['sentence_embedding'] + + +# @accelerate +class Model: + def __init__(self, model): + self.model = model + + def __call__(self, **features): + outs = self.model(features) + return outs['sentence_embedding'] + + +class STransformers(NNOperator): + """ + Operator using pretrained Sentence Transformers + """ + + def __init__(self, model_name: str = None, device: str = None): + self.model_name = model_name + if device: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if self.model_name: + self.model = Model(self._model) + else: + log.warning('The operator is initialized without specified model.') + pass + + def __call__(self, txt: Union[List[str], str]): + if isinstance(txt, str): + sentences = [txt] + else: + sentences = txt + inputs = self.tokenize(sentences) + embs = self.model(**inputs).cpu().detach().numpy() + if isinstance(txt, str): + embs = embs.squeeze(0) + else: + embs = list(embs) + return embs + + @property + def _model(self): + m = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) + m.eval() + return m + + @property + def supported_formats(self): + return ['onnx'] + + def tokenize(self, x): + try: + outs = self._model.tokenize(x) + except Exception: + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) + except Exception as e: + log.error(e) + log.warning(f'Fail to load tokenizer with sentence-transformers/{self.model_name}.' + f'Trying to load tokenizer with self.model_name...') + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + outs = tokenizer( + x, + padding=True, truncation='longest_first', max_length=self.max_seq_length, + return_tensors='pt', + ) + return outs + + @property + def max_seq_length(self): + import json + from torch.hub import _get_torch_home + torch_cache = _get_torch_home() + sbert_cache = os.path.join(torch_cache, 'sentence_transformers') + cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') + if not os.path.exists(cfg_path): + cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') + k = 'max_position_embeddings' + else: + k = 'max_seq_length' + with open(cfg_path) as f: + cfg = json.load(f) + if k in cfg: + max_seq_len = cfg[k] + else: + max_seq_len = None + return max_seq_len + + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': + path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self.model_name.replace('/', '-') + path = os.path.join(path, name) + if format in ['pytorch', 'torchscript']: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise AttributeError(f'Invalid format {format}.') + dummy_text = ['[CLS]'] + dummy_input = self.tokenize(dummy_text) + if format == 'pytorch': + torch.save(self._model, path) + elif format == 'torchscript': + try: + try: + jit_model = torch.jit.script(self._model) + except Exception: + jit_model = torch.jit.trace(self._model, dummy_input, strict=False) + torch.jit.save(jit_model, path) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onnx': + new_model = ConvertModel(self._model) + input_names = list(dummy_input.keys()) + dynamic_axes = {} + for i_n, i_v in dummy_input.items(): + if len(i_v.shape) == 1: + dynamic_axes[i_n] = {0: 'batch_size'} + else: + dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} + dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} + try: + torch.onnx.export(new_model, + tuple(dummy_input.values()), + path, + input_names=input_names, + output_names=['output_0'], + opset_version=13, + dynamic_axes=dynamic_axes, + do_constant_folding=True + ) + except Exception as e: + log.error(f'Fail to save as onnx: {e}.') + raise RuntimeError(f'Fail to save as onnx: {e}.') + # todo: elif format == 'tensorrt': + else: + log.error(f'Unsupported format "{format}".') + return Path(path).resolve() + + @staticmethod + def supported_model_names(format: str = None): + import requests + req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") + data = req.text + full_list = [] + for line in data.split('\r\n'): + line = line.replace(' ', '') + if line.startswith('"name":'): + name = line.split(':')[-1].replace('"', '').replace(',', '') + full_list.append(name) + full_list.sort() + if format is None: + model_list = full_list + elif format == 'pytorch': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + elif format == 'onnx': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + else: + log.error(f'Invalid or unsupported format "{format}".') + return model_list diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..245aec8 --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,103 @@ +from towhee import ops +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +op = ops.sentence_embedding.sbert().get_op() +# full_models = op.supported_model_names() +# checked_models = AutoTransformers.supported_model_names(format='onnx') +# models = [x for x in full_models if x not in checked_models] +models = ['all-MiniLM-L12-v2'] +test_txt = 'hello, world.' +atol = 1e-3 +log_path = 'sbert.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('sbert_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + try: + op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op() + out1 = op(test_txt) + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + op.save_model('onnx') + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + pass + try: + inputs = op._model.tokenize([test_txt]) + sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) + onnx_inputs = {} + for n in sess.get_inputs(): + k = n.name + if k in inputs: + onnx_inputs[k] = inputs[k].cpu().detach().numpy() + out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.') + + +