diff --git a/README.md b/README.md
index e2cfae6..0c5d656 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,149 @@
-# transformers
+# Sentence Embedding with Transformers
+*author: [Jael Gu](https://github.com/jaelgu)*
+
+
+
+## Description
+
+A sentence embedding operator generates one embedding vector in ndarray for each input text.
+The embedding represents the semantic information of the whole input text as one vector.
+This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers).
+
+
+
+## Code Example
+
+Use the pre-trained model 'sentence-transformers/paraphrase-albert-small-v2'
+to generate an embedding for the sentence "Hello, world.".
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+- **option 1 (towhee>=0.9.0):**
+```python
+from towhee.dc2 import pipe, ops, DataCollection
+
+p = (
+ pipe.input('text')
+ .map('text', 'vec',
+ ops.sentence_embedding.transformers(model_name='sentence-transformers/paraphrase-albert-small-v2'))
+ .output('text', 'vec')
+)
+
+DataCollection(p('Hello, world.')).show()
+```
+
+
+
+- **option 2:**
+
+```python
+import towhee
+
+(
+ towhee.dc['text'](['Hello, world.'])
+ .sentence_embedding.transformers['text', 'vec'](
+ model_name='sentence-transformers/paraphrase-albert-small-v2')
+ .show()
+)
+```
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method:
+
+***sentence_embedding.transformers(model_name=None)***
+
+**Parameters:**
+
+***model_name***: *str*
+
+The model name in string, defaults to None.
+If None, the operator will be initialized without specified model.
+
+Supported model names: refer to `supported_model_names` below.
+
+***checkpoint_path***: *str*
+
+The path to local checkpoint, defaults to None.
+If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers.
+
+
+
+***tokenizer***: *object*
+
+The method to tokenize input text, defaults to None.
+If None, the operator will use default tokenizer by `model_name` from Huggingface transformers.
+
+
+
+## Interface
+
+The operator takes a piece of text in string as input.
+It loads tokenizer and pre-trained model using model name,
+and then return a text emabedding in numpy.ndarray.
+
+***\_\_call\_\_(txt)***
+
+**Parameters:**
+
+***data***: *Union[str, list]*
+
+ The text in string or a list of texts.
+
+**Returns**:
+
+*numpy.ndarray or list*
+
+ The text embedding (or token embeddings) extracted by model.
+If `data` is string, the operator returns an embedding in numpy.ndarray with shape of (dim,).
+If `data` is a list, the operator returns a list of embedding(s) with length of input list.
+
+
+
+***save_model(format='pytorch', path='default')***
+
+Save model to local with specified format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The format to export model as, such as 'pytorch', 'torchscript', 'onnx',
+defaults to 'pytorch'.
+
+***path***: *str*
+
+ The path where exported model is saved to.
+By default, it will save model to `saved` directory under the operator cache.
+
+```python
+from towhee import ops
+
+op = ops.sentence_embedding.transformers(model_name='sentence-transformers/paraphrase-albert-small-v2').get_op()
+op.save_model('onnx', 'test.onnx')
+```
+PosixPath('/Home/.towhee/operators/sentence-embedding/transformers/main/test.onnx')
+
+
+
+***supported_model_names(format=None)***
+
+Get a list of all supported model names or supported model names for specified model format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The model format such as 'pytorch', 'torchscript', 'onnx'.
+
+```python
+from towhee import ops
+
+
+op = ops.sentence_embedding.transformers().get_op()
+full_list = op.supported_model_names()
+onnx_list = op.supported_model_names(format='onnx')
+```
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..2cc07d3
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auto_transformers import AutoTransformers
+
+
+def transformers(*args, **kwargs):
+ return AutoTransformers(*args, **kwargs)
diff --git a/auto_transformers.py b/auto_transformers.py
new file mode 100644
index 0000000..0a601a2
--- /dev/null
+++ b/auto_transformers.py
@@ -0,0 +1,256 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy
+import os
+import torch
+import shutil
+from pathlib import Path
+from typing import Union
+from collections import OrderedDict
+
+from transformers import AutoModel
+
+from towhee.operator import NNOperator
+from towhee import register
+# from towhee.dc2 import accelerate
+
+import warnings
+import logging
+from transformers import logging as t_logging
+
+log = logging.getLogger('run_op')
+warnings.filterwarnings('ignore')
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+t_logging.set_verbosity_error()
+
+
+# @accelerate
+class Model:
+ def __init__(self, model):
+ self.model = model
+
+ def __call__(self, *args, **kwargs):
+ outs = self.model(*args, **kwargs, return_dict=True)
+ return outs['last_hidden_state']
+
+
+@register(output_schema=['vec'])
+class AutoTransformers(NNOperator):
+ """
+ NLP embedding operator that uses the pretrained transformers model gathered by huggingface.
+ Args:
+ model_name (`str`):
+ The model name to load a pretrained model from transformers.
+ checkpoint_path (`str`):
+ The local checkpoint path.
+ tokenizer (`object`):
+ The tokenizer to tokenize input text as model inputs.
+ """
+
+ def __init__(self,
+ model_name: str = None,
+ checkpoint_path: str = None,
+ tokenizer: object = None,
+ device: str = None,
+ norm: bool = False
+ ):
+ super().__init__()
+ self._device = device
+ self.model_name = model_name
+ self.user_tokenizer = tokenizer
+ self.norm = norm
+ self.checkpoint_path = checkpoint_path
+
+ if self.model_name:
+ model_list = self.supported_model_names()
+ # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
+ self.model = Model(self._model)
+ else:
+ log.warning('The operator is initialized without specified model.')
+ pass
+
+ def __call__(self, data: Union[str, list]) -> numpy.ndarray:
+ if isinstance(data, str):
+ txt = [data]
+ else:
+ txt = data
+ try:
+ inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device)
+ except Exception as e:
+ log.error(f'Fail to tokenize inputs: {e}')
+ raise e
+ try:
+ outs = self.model(**inputs)
+ except Exception as e:
+ log.error(f'Invalid input for the model: {self.model_name}')
+ raise e
+ outs = self.post_proc(outs, inputs)
+ if self.norm:
+ outs = torch.nn.functional.normalize(outs, )
+ features = outs.cpu().detach().numpy()
+ if isinstance(data, str):
+ features = features.squeeze(0)
+ else:
+ features = list(features)
+ return features
+
+ @property
+ def _model(self):
+ model = AutoModel.from_pretrained(self.model_name).to(self.device)
+ if hasattr(model, 'pooler') and model.pooler:
+ model.pooler = None
+ if self.checkpoint_path:
+ try:
+ state_dict = torch.load(self.checkpoint_path, map_location=self.device)
+ model.load_state_dict(state_dict)
+ except Exception:
+ log.error(f'Fail to load weights from {self.checkpoint_path}')
+ model.eval()
+ return model
+
+ @property
+ def device(self):
+ if self._device is None:
+ if self._device_id < 0:
+ self._device = torch.device('cpu')
+ else:
+ self._device = torch.device(self._device_id)
+ return self._device
+
+ @property
+ def model_config(self):
+ from transformers import AutoConfig
+ configs = AutoConfig.from_pretrained(self.model_name)
+ return configs
+
+ @property
+ def onnx_config(self):
+ from transformers.onnx.features import FeaturesManager
+ model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
+ self._model, feature='default')
+ old_config = model_onnx_config(self.model_config)
+ onnx_config = {
+ 'inputs': dict(old_config.inputs),
+ 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
+ }
+ return onnx_config
+
+ @property
+ def tokenizer(self):
+ from transformers import AutoTokenizer
+ try:
+ if self.user_tokenizer:
+ t = tokenizer
+ else:
+ t = AutoTokenizer.from_pretrained(self.model_name)
+ if not t.pad_token:
+ t.pad_token = '[PAD]'
+ except Exception as e:
+ log.error(f'Fail to load tokenizer.')
+ raise e
+ return t
+
+ def post_proc(self, token_embeddings, inputs):
+ token_embeddings = token_embeddings.to(self.device)
+ attention_mask = inputs['attention_mask'].to(self.device)
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ sentence_embs = torch.sum(
+ token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+ return sentence_embs
+
+ def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
+ if output_file == 'default':
+ output_file = str(Path(__file__).parent)
+ output_file = os.path.join(output_file, 'saved', model_type)
+ os.makedirs(output_file, exist_ok=True)
+ name = self.model_name.replace('/', '-')
+ output_file = os.path.join(output_file, name)
+ if model_type in ['pytorch', 'torchscript']:
+ output_file = output_file + '.pt'
+ elif model_type == 'onnx':
+ output_file = output_file + '.onnx'
+ else:
+ raise AttributeError('Unsupported model_type.')
+
+ dummy_input = 'test sentence'
+ inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary
+ if model_type == 'pytorch':
+ torch.save(self._model, output_file)
+ elif model_type == 'torchscript':
+ inputs = list(inputs.values())
+ try:
+ try:
+ jit_model = torch.jit.script(self._model)
+ except Exception:
+ jit_model = torch.jit.trace(self._model, inputs, strict=False)
+ torch.jit.save(jit_model, output_file)
+ except Exception as e:
+ log.error(f'Fail to save as torchscript: {e}.')
+ raise RuntimeError(f'Fail to save as torchscript: {e}.')
+ elif model_type == 'onnx':
+ dynamic_axes = {}
+ for k, v in self.onnx_config['inputs'].items():
+ dynamic_axes[k] = v
+ for k, v in self.onnx_config['outputs'].items():
+ dynamic_axes[k] = v
+ torch.onnx.export(
+ self._model,
+ tuple(inputs.values()),
+ output_file,
+ input_names=list(self.onnx_config['inputs'].keys()),
+ output_names=list(self.onnx_config['outputs'].keys()),
+ dynamic_axes=dynamic_axes,
+ opset_version=torch.onnx.constant_folding_opset_versions[-1],
+ do_constant_folding=True,
+ )
+ # todo: elif format == 'tensorrt':
+ else:
+ log.error(f'Unsupported format "{format}".')
+ return Path(output_file).resolve()
+
+ @property
+ def supported_formats(self):
+ onnxes = self.supported_model_names(format='onnx')
+ if self.model_name in onnxes:
+ return ['onnx']
+ else:
+ return ['pytorch']
+
+ @staticmethod
+ def supported_model_names(format: str = None):
+ full_list = [
+
+ ]
+ full_list.sort()
+ if format is None:
+ model_list = full_list
+ elif format == 'pytorch':
+ to_remove = []
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ elif format == 'torchscript':
+ to_remove = [
+ ]
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ elif format == 'onnx':
+ to_remove = [
+ ]
+ assert set(to_remove).issubset(set(full_list))
+ model_list = list(set(full_list) - set(to_remove))
+ # todo: elif format == 'tensorrt':
+ else:
+ log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
+ return model_list
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5e12bcc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+numpy
+transformers
+sentencepiece
+protobuf
+
+towhee
+torch
diff --git a/result.png b/result.png
new file mode 100644
index 0000000..7924280
Binary files /dev/null and b/result.png differ
diff --git a/test_onnx.py b/test_onnx.py
new file mode 100644
index 0000000..87dc20e
--- /dev/null
+++ b/test_onnx.py
@@ -0,0 +1,105 @@
+from towhee import ops
+import torch
+import numpy
+import onnx
+import onnxruntime
+
+import os
+from pathlib import Path
+import logging
+import platform
+import psutil
+
+import warnings
+from transformers import logging as t_logging
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+warnings.filterwarnings("ignore")
+t_logging.set_verbosity_error()
+
+# full_models = AutoTransformers.supported_model_names()
+# checked_models = AutoTransformers.supported_model_names(format='onnx')
+# models = [x for x in full_models if x not in checked_models]
+models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2']
+test_txt = 'hello, world.'
+atol = 1e-3
+log_path = 'transformers_onnx.log'
+f = open('onnx.csv', 'w+')
+f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
+
+logger = logging.getLogger('transformers_onnx')
+logger.setLevel(logging.DEBUG)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+fh = logging.FileHandler(log_path)
+fh.setLevel(logging.DEBUG)
+fh.setFormatter(formatter)
+logger.addHandler(fh)
+ch = logging.StreamHandler()
+ch.setLevel(logging.ERROR)
+ch.setFormatter(formatter)
+logger.addHandler(ch)
+
+logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
+logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
+ f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
+ f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
+logger.debug(f'cpu: {psutil.cpu_count()}')
+
+
+status = None
+for name in models:
+ logger.info(f'***{name}***')
+ saved_name = name.replace('/', '-')
+ onnx_path = f'saved/onnx/{saved_name}.onnx'
+ if status:
+ f.write(','.join(status) + '\n')
+ status = [name] + ['fail'] * 5
+ try:
+ op = ops.sentence_embedding.transformers(model_name=name).get_op()
+ out1 = op(test_txt)
+ logger.info('OP LOADED.')
+ status[1] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO LOAD OP: {e}')
+ continue
+ try:
+ op.save_model(model_type='onnx')
+ logger.info('ONNX SAVED.')
+ status[2] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO SAVE ONNX: {e}')
+ continue
+ try:
+ try:
+ onnx_model = onnx.load(onnx_path)
+ onnx.checker.check_model(onnx_model)
+ except Exception:
+ saved_onnx = onnx.load(onnx_path, load_external_data=False)
+ onnx.checker.check_model(saved_onnx)
+ logger.info('ONNX CHECKED.')
+ status[3] = 'success'
+ except Exception as e:
+ logger.error(f'FAIL TO CHECK ONNX: {e}')
+ continue
+ try:
+ sess = onnxruntime.InferenceSession(onnx_path,
+ providers=onnxruntime.get_available_providers())
+ inputs = op.tokenizer(test_txt, return_tensors='np')
+ out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
+ new_inputs = op.tokenizer(test_txt, return_tensors='pt')
+ out2 = op.post_proc(torch.from_numpy(out2), new_inputs)
+ logger.info('ONNX WORKED.')
+ status[4] = 'success'
+ if numpy.allclose(out1, out2, atol=atol):
+ logger.info('Check accuracy: OK')
+ status[5] = 'success'
+ else:
+ logger.info(f'Check accuracy: atol is larger than {atol}.')
+ except Exception as e:
+ logger.error(f'FAIL TO RUN ONNX: {e}')
+ continue
+
+if status:
+ f.write(','.join(status) + '\n')
+
+print('Finished.')