sbert
copied
6 changed files with 459 additions and 1 deletions
@ -1,2 +1,115 @@ |
|||||
# sbert |
|
||||
|
# Sentence Embedding with Sentence Transformers |
||||
|
|
||||
|
*author: [Jael Gu](https://github.com/jaelgu)* |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Description |
||||
|
|
||||
|
This operator takes a sentence or a list of sentences in string as input. |
||||
|
It generates an embedding vector in numpy.ndarray for each sentence, which captures the input sentence's core semantic elements. |
||||
|
This operator is implemented with pre-trained models from [Sentence Transformers](https://www.sbert.net/). |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
Use the pre-trained model "all-MiniLM-L12-v2" |
||||
|
to generate a text embedding for the sentence "This is a sentence.". |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
- **option 1 (towhee>=0.9.0):** |
||||
|
```python |
||||
|
from towhee.dc2 import pipe, ops, DataCollection |
||||
|
|
||||
|
p = ( |
||||
|
pipe.input('sentence') |
||||
|
.map('sentence', 'vec', ops.sentence_embedding.sbert(model_name='all-MiniLM-L12-v2')) |
||||
|
.output('sentence', 'vec') |
||||
|
) |
||||
|
|
||||
|
DataCollection(p('This is a sentence.')).show() |
||||
|
``` |
||||
|
|
||||
|
<img src="./result.png" width="800px"/> |
||||
|
|
||||
|
- **option 2:** |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
( |
||||
|
towhee.dc['sentence'](['This is a sentence.']) |
||||
|
.sentence_embedding.sbert['sentence', 'vec'](model_name='all-MiniLM-L12-v2') |
||||
|
.show() |
||||
|
) |
||||
|
``` |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method: |
||||
|
|
||||
|
***text_embedding.sbert(model_name='all-MiniLM-L12-v2')*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***model_name***: *str* |
||||
|
|
||||
|
The model name in string. Supported model names: |
||||
|
|
||||
|
Refer to [SBert Doc](https://www.sbert.net/docs/pretrained_models.html). |
||||
|
Please note that only models listed `supported_model_names` are tested. |
||||
|
You can refer to [Towhee Pipeline]() for model performance. |
||||
|
|
||||
|
***device***: *str* |
||||
|
|
||||
|
The device to run model, defaults to None. |
||||
|
If None, it will use 'cuda' automatically when cuda is available. |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
The operator takes a sentence or a list of sentences in string as input. |
||||
|
It loads tokenizer and pre-trained model using model name, |
||||
|
and then returns text embedding in numpy.ndarray. |
||||
|
|
||||
|
***__call__(txt)*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***txt***: *Union[List[str], str]* |
||||
|
|
||||
|
​ A sentence or a list of sentences in string. |
||||
|
|
||||
|
|
||||
|
**Returns**: |
||||
|
|
||||
|
*Union[List[numpy.ndarray], numpy.ndarray]* |
||||
|
|
||||
|
​ If input is a sentence in string, then it returns an embedding vector of shape (dim,) in numpy.ndarray. |
||||
|
If input is a list of sentences, then it returns a list of embedding vectors, each of which a numpy.ndarray in shape of (dim,). |
||||
|
|
||||
|
<br/> |
||||
|
|
||||
|
***supported_model_names(format=None)*** |
||||
|
|
||||
|
Get a list of all supported model names or supported model names for specified model format. |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***format***: *str* |
||||
|
|
||||
|
​ The model format such as 'pytorch', defaults to None. |
||||
|
If None, it will return a full list of supported model names. |
||||
|
|
||||
|
```python |
||||
|
from towhee import ops |
||||
|
|
||||
|
op = ops.sentence_embedding.sentence_transformers().get_op() |
||||
|
full_list = op.supported_model_names() |
||||
|
onnx_list = op.supported_model_names(format='onnx') |
||||
|
``` |
||||
|
@ -0,0 +1,19 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .s_bert import STransformers |
||||
|
|
||||
|
|
||||
|
def sbert(*args, **kwargs): |
||||
|
return STransformers(*args, **kwargs) |
@ -0,0 +1,2 @@ |
|||||
|
sentence_transformers |
||||
|
torch |
After Width: | Height: | Size: 6.0 KiB |
@ -0,0 +1,221 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import logging |
||||
|
import numpy |
||||
|
from typing import Union, List |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import torch |
||||
|
from sentence_transformers import SentenceTransformer |
||||
|
|
||||
|
from towhee.operator import NNOperator |
||||
|
# from towhee.dc2 import accelerate |
||||
|
|
||||
|
import os |
||||
|
import warnings |
||||
|
|
||||
|
warnings.filterwarnings('ignore') |
||||
|
logging.getLogger('sentence_transformers').setLevel(logging.ERROR) |
||||
|
log = logging.getLogger('op_sbert') |
||||
|
|
||||
|
|
||||
|
class ConvertModel(torch.nn.Module): |
||||
|
def __init__(self, model): |
||||
|
super().__init__() |
||||
|
self.net = model |
||||
|
try: |
||||
|
self.input_names = self.net.tokenizer.model_input_names |
||||
|
except AttributeError: |
||||
|
self.input_names = list(self.net.tokenize(['test']).keys()) |
||||
|
|
||||
|
def forward(self, *args, **kwargs): |
||||
|
if args: |
||||
|
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.' |
||||
|
assert len(args) == len(self.input_names) |
||||
|
for k, v in zip(self.input_names, args): |
||||
|
kwargs[k] = v |
||||
|
outs = self.net(kwargs) |
||||
|
return outs['sentence_embedding'] |
||||
|
|
||||
|
|
||||
|
# @accelerate |
||||
|
class Model: |
||||
|
def __init__(self, model): |
||||
|
self.model = model |
||||
|
|
||||
|
def __call__(self, **features): |
||||
|
outs = self.model(features) |
||||
|
return outs['sentence_embedding'] |
||||
|
|
||||
|
|
||||
|
class STransformers(NNOperator): |
||||
|
""" |
||||
|
Operator using pretrained Sentence Transformers |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str = None, device: str = None): |
||||
|
self.model_name = model_name |
||||
|
if device: |
||||
|
self.device = device |
||||
|
else: |
||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
||||
|
if self.model_name: |
||||
|
self.model = Model(self._model) |
||||
|
else: |
||||
|
log.warning('The operator is initialized without specified model.') |
||||
|
pass |
||||
|
|
||||
|
def __call__(self, txt: Union[List[str], str]): |
||||
|
if isinstance(txt, str): |
||||
|
sentences = [txt] |
||||
|
else: |
||||
|
sentences = txt |
||||
|
inputs = self.tokenize(sentences) |
||||
|
embs = self.model(**inputs).cpu().detach().numpy() |
||||
|
if isinstance(txt, str): |
||||
|
embs = embs.squeeze(0) |
||||
|
else: |
||||
|
embs = list(embs) |
||||
|
return embs |
||||
|
|
||||
|
@property |
||||
|
def _model(self): |
||||
|
m = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) |
||||
|
m.eval() |
||||
|
return m |
||||
|
|
||||
|
@property |
||||
|
def supported_formats(self): |
||||
|
return ['onnx'] |
||||
|
|
||||
|
def tokenize(self, x): |
||||
|
try: |
||||
|
outs = self._model.tokenize(x) |
||||
|
except Exception: |
||||
|
from transformers import AutoTokenizer |
||||
|
try: |
||||
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
||||
|
except Exception as e: |
||||
|
log.error(e) |
||||
|
log.warning(f'Fail to load tokenizer with sentence-transformers/{self.model_name}.' |
||||
|
f'Trying to load tokenizer with self.model_name...') |
||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
||||
|
outs = tokenizer( |
||||
|
x, |
||||
|
padding=True, truncation='longest_first', max_length=self.max_seq_length, |
||||
|
return_tensors='pt', |
||||
|
) |
||||
|
return outs |
||||
|
|
||||
|
@property |
||||
|
def max_seq_length(self): |
||||
|
import json |
||||
|
from torch.hub import _get_torch_home |
||||
|
torch_cache = _get_torch_home() |
||||
|
sbert_cache = os.path.join(torch_cache, 'sentence_transformers') |
||||
|
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') |
||||
|
if not os.path.exists(cfg_path): |
||||
|
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') |
||||
|
k = 'max_position_embeddings' |
||||
|
else: |
||||
|
k = 'max_seq_length' |
||||
|
with open(cfg_path) as f: |
||||
|
cfg = json.load(f) |
||||
|
if k in cfg: |
||||
|
max_seq_len = cfg[k] |
||||
|
else: |
||||
|
max_seq_len = None |
||||
|
return max_seq_len |
||||
|
|
||||
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
||||
|
if path == 'default': |
||||
|
path = str(Path(__file__).parent) |
||||
|
path = os.path.join(path, 'saved', format) |
||||
|
os.makedirs(path, exist_ok=True) |
||||
|
name = self.model_name.replace('/', '-') |
||||
|
path = os.path.join(path, name) |
||||
|
if format in ['pytorch', 'torchscript']: |
||||
|
path = path + '.pt' |
||||
|
elif format == 'onnx': |
||||
|
path = path + '.onnx' |
||||
|
else: |
||||
|
raise AttributeError(f'Invalid format {format}.') |
||||
|
dummy_text = ['[CLS]'] |
||||
|
dummy_input = self.tokenize(dummy_text) |
||||
|
if format == 'pytorch': |
||||
|
torch.save(self._model, path) |
||||
|
elif format == 'torchscript': |
||||
|
try: |
||||
|
try: |
||||
|
jit_model = torch.jit.script(self._model) |
||||
|
except Exception: |
||||
|
jit_model = torch.jit.trace(self._model, dummy_input, strict=False) |
||||
|
torch.jit.save(jit_model, path) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to save as torchscript: {e}.') |
||||
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
||||
|
elif format == 'onnx': |
||||
|
new_model = ConvertModel(self._model) |
||||
|
input_names = list(dummy_input.keys()) |
||||
|
dynamic_axes = {} |
||||
|
for i_n, i_v in dummy_input.items(): |
||||
|
if len(i_v.shape) == 1: |
||||
|
dynamic_axes[i_n] = {0: 'batch_size'} |
||||
|
else: |
||||
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
||||
|
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} |
||||
|
try: |
||||
|
torch.onnx.export(new_model, |
||||
|
tuple(dummy_input.values()), |
||||
|
path, |
||||
|
input_names=input_names, |
||||
|
output_names=['output_0'], |
||||
|
opset_version=13, |
||||
|
dynamic_axes=dynamic_axes, |
||||
|
do_constant_folding=True |
||||
|
) |
||||
|
except Exception as e: |
||||
|
log.error(f'Fail to save as onnx: {e}.') |
||||
|
raise RuntimeError(f'Fail to save as onnx: {e}.') |
||||
|
# todo: elif format == 'tensorrt': |
||||
|
else: |
||||
|
log.error(f'Unsupported format "{format}".') |
||||
|
return Path(path).resolve() |
||||
|
|
||||
|
@staticmethod |
||||
|
def supported_model_names(format: str = None): |
||||
|
import requests |
||||
|
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") |
||||
|
data = req.text |
||||
|
full_list = [] |
||||
|
for line in data.split('\r\n'): |
||||
|
line = line.replace(' ', '') |
||||
|
if line.startswith('"name":'): |
||||
|
name = line.split(':')[-1].replace('"', '').replace(',', '') |
||||
|
full_list.append(name) |
||||
|
full_list.sort() |
||||
|
if format is None: |
||||
|
model_list = full_list |
||||
|
elif format == 'pytorch': |
||||
|
to_remove = [] |
||||
|
assert set(to_remove).issubset(set(full_list)) |
||||
|
model_list = list(set(full_list) - set(to_remove)) |
||||
|
elif format == 'onnx': |
||||
|
to_remove = [] |
||||
|
assert set(to_remove).issubset(set(full_list)) |
||||
|
model_list = list(set(full_list) - set(to_remove)) |
||||
|
else: |
||||
|
log.error(f'Invalid or unsupported format "{format}".') |
||||
|
return model_list |
@ -0,0 +1,103 @@ |
|||||
|
from towhee import ops |
||||
|
import numpy |
||||
|
import onnx |
||||
|
import onnxruntime |
||||
|
|
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
import logging |
||||
|
import platform |
||||
|
import psutil |
||||
|
|
||||
|
op = ops.sentence_embedding.sbert().get_op() |
||||
|
# full_models = op.supported_model_names() |
||||
|
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
||||
|
# models = [x for x in full_models if x not in checked_models] |
||||
|
models = ['all-MiniLM-L12-v2'] |
||||
|
test_txt = 'hello, world.' |
||||
|
atol = 1e-3 |
||||
|
log_path = 'sbert.log' |
||||
|
f = open('onnx.csv', 'w+') |
||||
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
||||
|
|
||||
|
logger = logging.getLogger('sbert_onnx') |
||||
|
logger.setLevel(logging.DEBUG) |
||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
||||
|
fh = logging.FileHandler(log_path) |
||||
|
fh.setLevel(logging.DEBUG) |
||||
|
fh.setFormatter(formatter) |
||||
|
logger.addHandler(fh) |
||||
|
ch = logging.StreamHandler() |
||||
|
ch.setLevel(logging.ERROR) |
||||
|
ch.setFormatter(formatter) |
||||
|
logger.addHandler(ch) |
||||
|
|
||||
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
||||
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
||||
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
||||
|
|
||||
|
|
||||
|
status = None |
||||
|
for name in models: |
||||
|
logger.info(f'***{name}***') |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
status = [name] + ['fail'] * 5 |
||||
|
try: |
||||
|
op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op() |
||||
|
out1 = op(test_txt) |
||||
|
logger.info('OP LOADED.') |
||||
|
status[1] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO LOAD OP: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
op.save_model('onnx') |
||||
|
logger.info('ONNX SAVED.') |
||||
|
status[2] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
try: |
||||
|
onnx_model = onnx.load(onnx_path) |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
except Exception: |
||||
|
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
||||
|
onnx.checker.check_model(saved_onnx) |
||||
|
logger.info('ONNX CHECKED.') |
||||
|
status[3] = 'success' |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
||||
|
pass |
||||
|
try: |
||||
|
inputs = op._model.tokenize([test_txt]) |
||||
|
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) |
||||
|
onnx_inputs = {} |
||||
|
for n in sess.get_inputs(): |
||||
|
k = n.name |
||||
|
if k in inputs: |
||||
|
onnx_inputs[k] = inputs[k].cpu().detach().numpy() |
||||
|
out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) |
||||
|
logger.info('ONNX WORKED.') |
||||
|
status[4] = 'success' |
||||
|
if numpy.allclose(out1, out2, atol=atol): |
||||
|
logger.info('Check accuracy: OK') |
||||
|
status[5] = 'success' |
||||
|
else: |
||||
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
||||
|
continue |
||||
|
|
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
|
||||
|
print('Finished.') |
||||
|
|
||||
|
|
||||
|
|
Loading…
Reference in new issue