diff --git a/README.md b/README.md
index 432c853..d4ae5e5 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,115 @@
-# sbert
+# Sentence Embedding with Sentence Transformers
+*author: [Jael Gu](https://github.com/jaelgu)*
+
+
+
+## Description
+
+This operator takes a sentence or a list of sentences in string as input.
+It generates an embedding vector in numpy.ndarray for each sentence, which captures the input sentence's core semantic elements.
+This operator is implemented with pre-trained models from [Sentence Transformers](https://www.sbert.net/).
+
+
+
+## Code Example
+
+Use the pre-trained model "all-MiniLM-L12-v2"
+to generate a text embedding for the sentence "This is a sentence.".
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+- **option 1 (towhee>=0.9.0):**
+```python
+from towhee.dc2 import pipe, ops, DataCollection
+
+p = (
+ pipe.input('sentence')
+ .map('sentence', 'vec', ops.sentence_embedding.sbert(model_name='all-MiniLM-L12-v2'))
+ .output('sentence', 'vec')
+)
+
+DataCollection(p('This is a sentence.')).show()
+```
+
+
+
+- **option 2:**
+
+```python
+import towhee
+
+(
+ towhee.dc['sentence'](['This is a sentence.'])
+ .sentence_embedding.sbert['sentence', 'vec'](model_name='all-MiniLM-L12-v2')
+ .show()
+)
+```
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method:
+
+***text_embedding.sbert(model_name='all-MiniLM-L12-v2')***
+
+**Parameters:**
+
+***model_name***: *str*
+
+The model name in string. Supported model names:
+
+Refer to [SBert Doc](https://www.sbert.net/docs/pretrained_models.html).
+Please note that only models listed `supported_model_names` are tested.
+You can refer to [Towhee Pipeline]() for model performance.
+
+***device***: *str*
+
+The device to run model, defaults to None.
+If None, it will use 'cuda' automatically when cuda is available.
+
+
+
+## Interface
+
+The operator takes a sentence or a list of sentences in string as input.
+It loads tokenizer and pre-trained model using model name,
+and then returns text embedding in numpy.ndarray.
+
+***__call__(txt)***
+
+**Parameters:**
+
+***txt***: *Union[List[str], str]*
+
+ A sentence or a list of sentences in string.
+
+
+**Returns**:
+
+*Union[List[numpy.ndarray], numpy.ndarray]*
+
+ If input is a sentence in string, then it returns an embedding vector of shape (dim,) in numpy.ndarray.
+If input is a list of sentences, then it returns a list of embedding vectors, each of which a numpy.ndarray in shape of (dim,).
+
+
+
+***supported_model_names(format=None)***
+
+Get a list of all supported model names or supported model names for specified model format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The model format such as 'pytorch', defaults to None.
+If None, it will return a full list of supported model names.
+
+```python
+from towhee import ops
+
+op = ops.sentence_embedding.sentence_transformers().get_op()
+full_list = op.supported_model_names()
+onnx_list = op.supported_model_names(format='onnx')
+```
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..d7afb3c
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .s_bert import STransformers
+
+
+def sbert(*args, **kwargs):
+ return STransformers(*args, **kwargs)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..63e3e64
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,2 @@
+sentence_transformers
+torch
\ No newline at end of file
diff --git a/result.png b/result.png
new file mode 100644
index 0000000..7be79f1
Binary files /dev/null and b/result.png differ
diff --git a/s_bert.py b/s_bert.py
new file mode 100644
index 0000000..4a42326
--- /dev/null
+++ b/s_bert.py
@@ -0,0 +1,221 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import numpy
+from typing import Union, List
+from pathlib import Path
+
+import torch
+from sentence_transformers import SentenceTransformer
+
+from towhee.operator import NNOperator
+# from towhee.dc2 import accelerate
+
+import os
+import warnings
+
+warnings.filterwarnings('ignore')
+logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
+log = logging.getLogger('op_sbert')
+
+
+class ConvertModel(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.net = model
+ try:
+ self.input_names = self.net.tokenizer.model_input_names
+ except AttributeError:
+ self.input_names = list(self.net.tokenize(['test']).keys())
+
+ def forward(self, *args, **kwargs):
+ if args:
+ assert kwargs == {}, 'Only accept neither args or kwargs as inputs.'
+ assert len(args) == len(self.input_names)
+ for k, v in zip(self.input_names, args):
+ kwargs[k] = v
+ outs = self.net(kwargs)
+ return outs['sentence_embedding']
+
+
+# @accelerate
+class Model:
+ def __init__(self, model):
+ self.model = model
+
+ def __call__(self, **features):
+ outs = self.model(features)
+ return outs['sentence_embedding']
+
+
+class STransformers(NNOperator):
+ """
+ Operator using pretrained Sentence Transformers
+ """
+
+ def __init__(self, model_name: str = None, device: str = None):
+ self.model_name = model_name
+ if device:
+ self.device = device
+ else:
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ if self.model_name:
+ self.model = Model(self._model)
+ else:
+ log.warning('The operator is initialized without specified model.')
+ pass
+
+ def __call__(self, txt: Union[List[str], str]):
+ if isinstance(txt, str):
+ sentences = [txt]
+ else:
+ sentences = txt
+ inputs = self.tokenize(sentences)
+ embs = self.model(**inputs).cpu().detach().numpy()
+ if isinstance(txt, str):
+ embs = embs.squeeze(0)
+ else:
+ embs = list(embs)
+ return embs
+
+ @property
+ def _model(self):
+ m = SentenceTransformer(model_name_or_path=self.model_name, device=self.device)
+ m.eval()
+ return m
+
+ @property
+ def supported_formats(self):
+ return ['onnx']
+
+ def tokenize(self, x):
+ try:
+ outs = self._model.tokenize(x)
+ except Exception:
+ from transformers import AutoTokenizer
+ try:
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
+ except Exception as e:
+ log.error(e)
+ log.warning(f'Fail to load tokenizer with sentence-transformers/{self.model_name}.'
+ f'Trying to load tokenizer with self.model_name...')
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+ outs = tokenizer(
+ x,
+ padding=True, truncation='longest_first', max_length=self.max_seq_length,
+ return_tensors='pt',
+ )
+ return outs
+
+ @property
+ def max_seq_length(self):
+ import json
+ from torch.hub import _get_torch_home
+ torch_cache = _get_torch_home()
+ sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
+ cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
+ if not os.path.exists(cfg_path):
+ cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
+ k = 'max_position_embeddings'
+ else:
+ k = 'max_seq_length'
+ with open(cfg_path) as f:
+ cfg = json.load(f)
+ if k in cfg:
+ max_seq_len = cfg[k]
+ else:
+ max_seq_len = None
+ return max_seq_len
+
+ def save_model(self, format: str = 'pytorch', path: str = 'default'):
+ if path == 'default':
+ path = str(Path(__file__).parent)
+ path = os.path.join(path, 'saved', format)
+ os.makedirs(path, exist_ok=True)
+ name = self.model_name.replace('/', '-')
+ path = os.path.join(path, name)
+ if format in ['pytorch', 'torchscript']:
+ path = path + '.pt'
+ elif format == 'onnx':
+ path = path + '.onnx'
+ else:
+ raise AttributeError(f'Invalid format {format}.')
+ dummy_text = ['[CLS]']
+ dummy_input = self.tokenize(dummy_text)
+ if format == 'pytorch':
+ torch.save(self._model, path)
+ elif format == 'torchscript':
+ try:
+ try:
+ jit_model = torch.jit.script(self._model)
+ except Exception:
+ jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
+ torch.jit.save(jit_model, path)
+ except Exception as e:
+ log.error(f'Fail to save as torchscript: {e}.')
+ raise RuntimeError(f'Fail to save as torchscript: {e}.')
+ elif format == 'onnx':
+ new_model = ConvertModel(self._model)
+ input_names = list(dummy_input.keys())
+ dynamic_axes = {}
+ for i_n, i_v in dummy_input.items():
+ if len(i_v.shape) == 1:
+ dynamic_axes[i_n] = {0: 'batch_size'}
+ else:
+ dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
+ dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
+ try:
+ torch.onnx.export(new_model,
+ tuple(dummy_input.values()),
+ path,
+ input_names=input_names,
+ output_names=['output_0'],
+ opset_version=13,
+ dynamic_axes=dynamic_axes,
+ do_constant_folding=True
+ )
+ except Exception as e:
+ log.error(f'Fail to save as onnx: {e}.')
+ raise RuntimeError(f'Fail to save as onnx: {e}.')
+ # todo: elif format == 'tensorrt':
+ else:
+ log.error(f'Unsupported format "{format}".')
+ return Path(path).resolve()
+
+ @staticmethod
+ def supported_model_names(format: str = None):
+ import requests
+ req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html")
+ data = req.text
+ full_list = []
+ for line in data.split('\r\n'):
+ line = line.replace(' ', '')
+ if line.startswith('"name":'):
+ name = line.split(':')[-1].replace('"', '').replace(',', '')
+ full_list.append(name)
+ full_list.sort()
+ if format is None:
+ model_list = full_list
+ elif format == 'pytorch':
+ to_remove = []
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ elif format == 'onnx':
+ to_remove = []
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ else:
+ log.error(f'Invalid or unsupported format "{format}".')
+ return model_list
diff --git a/test_onnx.py b/test_onnx.py
new file mode 100644
index 0000000..245aec8
--- /dev/null
+++ b/test_onnx.py
@@ -0,0 +1,103 @@
+from towhee import ops
+import numpy
+import onnx
+import onnxruntime
+
+import os
+from pathlib import Path
+import logging
+import platform
+import psutil
+
+op = ops.sentence_embedding.sbert().get_op()
+# full_models = op.supported_model_names()
+# checked_models = AutoTransformers.supported_model_names(format='onnx')
+# models = [x for x in full_models if x not in checked_models]
+models = ['all-MiniLM-L12-v2']
+test_txt = 'hello, world.'
+atol = 1e-3
+log_path = 'sbert.log'
+f = open('onnx.csv', 'w+')
+f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
+
+logger = logging.getLogger('sbert_onnx')
+logger.setLevel(logging.DEBUG)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+fh = logging.FileHandler(log_path)
+fh.setLevel(logging.DEBUG)
+fh.setFormatter(formatter)
+logger.addHandler(fh)
+ch = logging.StreamHandler()
+ch.setLevel(logging.ERROR)
+ch.setFormatter(formatter)
+logger.addHandler(ch)
+
+logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
+logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
+ f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
+ f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
+logger.debug(f'cpu: {psutil.cpu_count()}')
+
+
+status = None
+for name in models:
+ logger.info(f'***{name}***')
+ saved_name = name.replace('/', '-')
+ onnx_path = f'saved/onnx/{saved_name}.onnx'
+ if status:
+ f.write(','.join(status) + '\n')
+ status = [name] + ['fail'] * 5
+ try:
+ op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op()
+ out1 = op(test_txt)
+ logger.info('OP LOADED.')
+ status[1] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO LOAD OP: {e}')
+ continue
+ try:
+ op.save_model('onnx')
+ logger.info('ONNX SAVED.')
+ status[2] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO SAVE ONNX: {e}')
+ continue
+ try:
+ try:
+ onnx_model = onnx.load(onnx_path)
+ onnx.checker.check_model(onnx_model)
+ except Exception:
+ saved_onnx = onnx.load(onnx_path, load_external_data=False)
+ onnx.checker.check_model(saved_onnx)
+ logger.info('ONNX CHECKED.')
+ status[3] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO CHECK ONNX: {e}')
+ pass
+ try:
+ inputs = op._model.tokenize([test_txt])
+ sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())
+ onnx_inputs = {}
+ for n in sess.get_inputs():
+ k = n.name
+ if k in inputs:
+ onnx_inputs[k] = inputs[k].cpu().detach().numpy()
+ out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0)
+ logger.info('ONNX WORKED.')
+ status[4] = 'success'
+ if numpy.allclose(out1, out2, atol=atol):
+ logger.info('Check accuracy: OK')
+ status[5] = 'success'
+ else:
+ logger.info(f'Check accuracy: atol is larger than {atol}.')
+ except Exception as e:
+ logger.error(f'FAIL TO RUN ONNX: {e}')
+ continue
+
+if status:
+ f.write(','.join(status) + '\n')
+
+print('Finished.')
+
+
+