sbert
copied
6 changed files with 459 additions and 1 deletions
@ -1,2 +1,115 @@ |
|||
# sbert |
|||
# Sentence Embedding with Sentence Transformers |
|||
|
|||
*author: [Jael Gu](https://github.com/jaelgu)* |
|||
|
|||
<br /> |
|||
|
|||
## Description |
|||
|
|||
This operator takes a sentence or a list of sentences in string as input. |
|||
It generates an embedding vector in numpy.ndarray for each sentence, which captures the input sentence's core semantic elements. |
|||
This operator is implemented with pre-trained models from [Sentence Transformers](https://www.sbert.net/). |
|||
|
|||
<br /> |
|||
|
|||
## Code Example |
|||
|
|||
Use the pre-trained model "all-MiniLM-L12-v2" |
|||
to generate a text embedding for the sentence "This is a sentence.". |
|||
|
|||
*Write a same pipeline with explicit inputs/outputs name specifications:* |
|||
|
|||
- **option 1 (towhee>=0.9.0):** |
|||
```python |
|||
from towhee.dc2 import pipe, ops, DataCollection |
|||
|
|||
p = ( |
|||
pipe.input('sentence') |
|||
.map('sentence', 'vec', ops.sentence_embedding.sbert(model_name='all-MiniLM-L12-v2')) |
|||
.output('sentence', 'vec') |
|||
) |
|||
|
|||
DataCollection(p('This is a sentence.')).show() |
|||
``` |
|||
|
|||
<img src="./result.png" width="800px"/> |
|||
|
|||
- **option 2:** |
|||
|
|||
```python |
|||
import towhee |
|||
|
|||
( |
|||
towhee.dc['sentence'](['This is a sentence.']) |
|||
.sentence_embedding.sbert['sentence', 'vec'](model_name='all-MiniLM-L12-v2') |
|||
.show() |
|||
) |
|||
``` |
|||
|
|||
<br /> |
|||
|
|||
## Factory Constructor |
|||
|
|||
Create the operator via the following factory method: |
|||
|
|||
***text_embedding.sbert(model_name='all-MiniLM-L12-v2')*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***model_name***: *str* |
|||
|
|||
The model name in string. Supported model names: |
|||
|
|||
Refer to [SBert Doc](https://www.sbert.net/docs/pretrained_models.html). |
|||
Please note that only models listed `supported_model_names` are tested. |
|||
You can refer to [Towhee Pipeline]() for model performance. |
|||
|
|||
***device***: *str* |
|||
|
|||
The device to run model, defaults to None. |
|||
If None, it will use 'cuda' automatically when cuda is available. |
|||
|
|||
<br /> |
|||
|
|||
## Interface |
|||
|
|||
The operator takes a sentence or a list of sentences in string as input. |
|||
It loads tokenizer and pre-trained model using model name, |
|||
and then returns text embedding in numpy.ndarray. |
|||
|
|||
***__call__(txt)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***txt***: *Union[List[str], str]* |
|||
|
|||
A sentence or a list of sentences in string. |
|||
|
|||
|
|||
**Returns**: |
|||
|
|||
*Union[List[numpy.ndarray], numpy.ndarray]* |
|||
|
|||
If input is a sentence in string, then it returns an embedding vector of shape (dim,) in numpy.ndarray. |
|||
If input is a list of sentences, then it returns a list of embedding vectors, each of which a numpy.ndarray in shape of (dim,). |
|||
|
|||
<br/> |
|||
|
|||
***supported_model_names(format=None)*** |
|||
|
|||
Get a list of all supported model names or supported model names for specified model format. |
|||
|
|||
**Parameters:** |
|||
|
|||
***format***: *str* |
|||
|
|||
The model format such as 'pytorch', defaults to None. |
|||
If None, it will return a full list of supported model names. |
|||
|
|||
```python |
|||
from towhee import ops |
|||
|
|||
op = ops.sentence_embedding.sentence_transformers().get_op() |
|||
full_list = op.supported_model_names() |
|||
onnx_list = op.supported_model_names(format='onnx') |
|||
``` |
|||
|
@ -0,0 +1,19 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .s_bert import STransformers |
|||
|
|||
|
|||
def sbert(*args, **kwargs): |
|||
return STransformers(*args, **kwargs) |
@ -0,0 +1,2 @@ |
|||
sentence_transformers |
|||
torch |
After Width: | Height: | Size: 6.0 KiB |
@ -0,0 +1,221 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import logging |
|||
import numpy |
|||
from typing import Union, List |
|||
from pathlib import Path |
|||
|
|||
import torch |
|||
from sentence_transformers import SentenceTransformer |
|||
|
|||
from towhee.operator import NNOperator |
|||
# from towhee.dc2 import accelerate |
|||
|
|||
import os |
|||
import warnings |
|||
|
|||
warnings.filterwarnings('ignore') |
|||
logging.getLogger('sentence_transformers').setLevel(logging.ERROR) |
|||
log = logging.getLogger('op_sbert') |
|||
|
|||
|
|||
class ConvertModel(torch.nn.Module): |
|||
def __init__(self, model): |
|||
super().__init__() |
|||
self.net = model |
|||
try: |
|||
self.input_names = self.net.tokenizer.model_input_names |
|||
except AttributeError: |
|||
self.input_names = list(self.net.tokenize(['test']).keys()) |
|||
|
|||
def forward(self, *args, **kwargs): |
|||
if args: |
|||
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.' |
|||
assert len(args) == len(self.input_names) |
|||
for k, v in zip(self.input_names, args): |
|||
kwargs[k] = v |
|||
outs = self.net(kwargs) |
|||
return outs['sentence_embedding'] |
|||
|
|||
|
|||
# @accelerate |
|||
class Model: |
|||
def __init__(self, model): |
|||
self.model = model |
|||
|
|||
def __call__(self, **features): |
|||
outs = self.model(features) |
|||
return outs['sentence_embedding'] |
|||
|
|||
|
|||
class STransformers(NNOperator): |
|||
""" |
|||
Operator using pretrained Sentence Transformers |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = None, device: str = None): |
|||
self.model_name = model_name |
|||
if device: |
|||
self.device = device |
|||
else: |
|||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|||
if self.model_name: |
|||
self.model = Model(self._model) |
|||
else: |
|||
log.warning('The operator is initialized without specified model.') |
|||
pass |
|||
|
|||
def __call__(self, txt: Union[List[str], str]): |
|||
if isinstance(txt, str): |
|||
sentences = [txt] |
|||
else: |
|||
sentences = txt |
|||
inputs = self.tokenize(sentences) |
|||
embs = self.model(**inputs).cpu().detach().numpy() |
|||
if isinstance(txt, str): |
|||
embs = embs.squeeze(0) |
|||
else: |
|||
embs = list(embs) |
|||
return embs |
|||
|
|||
@property |
|||
def _model(self): |
|||
m = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) |
|||
m.eval() |
|||
return m |
|||
|
|||
@property |
|||
def supported_formats(self): |
|||
return ['onnx'] |
|||
|
|||
def tokenize(self, x): |
|||
try: |
|||
outs = self._model.tokenize(x) |
|||
except Exception: |
|||
from transformers import AutoTokenizer |
|||
try: |
|||
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
|||
except Exception as e: |
|||
log.error(e) |
|||
log.warning(f'Fail to load tokenizer with sentence-transformers/{self.model_name}.' |
|||
f'Trying to load tokenizer with self.model_name...') |
|||
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|||
outs = tokenizer( |
|||
x, |
|||
padding=True, truncation='longest_first', max_length=self.max_seq_length, |
|||
return_tensors='pt', |
|||
) |
|||
return outs |
|||
|
|||
@property |
|||
def max_seq_length(self): |
|||
import json |
|||
from torch.hub import _get_torch_home |
|||
torch_cache = _get_torch_home() |
|||
sbert_cache = os.path.join(torch_cache, 'sentence_transformers') |
|||
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') |
|||
if not os.path.exists(cfg_path): |
|||
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') |
|||
k = 'max_position_embeddings' |
|||
else: |
|||
k = 'max_seq_length' |
|||
with open(cfg_path) as f: |
|||
cfg = json.load(f) |
|||
if k in cfg: |
|||
max_seq_len = cfg[k] |
|||
else: |
|||
max_seq_len = None |
|||
return max_seq_len |
|||
|
|||
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|||
if path == 'default': |
|||
path = str(Path(__file__).parent) |
|||
path = os.path.join(path, 'saved', format) |
|||
os.makedirs(path, exist_ok=True) |
|||
name = self.model_name.replace('/', '-') |
|||
path = os.path.join(path, name) |
|||
if format in ['pytorch', 'torchscript']: |
|||
path = path + '.pt' |
|||
elif format == 'onnx': |
|||
path = path + '.onnx' |
|||
else: |
|||
raise AttributeError(f'Invalid format {format}.') |
|||
dummy_text = ['[CLS]'] |
|||
dummy_input = self.tokenize(dummy_text) |
|||
if format == 'pytorch': |
|||
torch.save(self._model, path) |
|||
elif format == 'torchscript': |
|||
try: |
|||
try: |
|||
jit_model = torch.jit.script(self._model) |
|||
except Exception: |
|||
jit_model = torch.jit.trace(self._model, dummy_input, strict=False) |
|||
torch.jit.save(jit_model, path) |
|||
except Exception as e: |
|||
log.error(f'Fail to save as torchscript: {e}.') |
|||
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|||
elif format == 'onnx': |
|||
new_model = ConvertModel(self._model) |
|||
input_names = list(dummy_input.keys()) |
|||
dynamic_axes = {} |
|||
for i_n, i_v in dummy_input.items(): |
|||
if len(i_v.shape) == 1: |
|||
dynamic_axes[i_n] = {0: 'batch_size'} |
|||
else: |
|||
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|||
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} |
|||
try: |
|||
torch.onnx.export(new_model, |
|||
tuple(dummy_input.values()), |
|||
path, |
|||
input_names=input_names, |
|||
output_names=['output_0'], |
|||
opset_version=13, |
|||
dynamic_axes=dynamic_axes, |
|||
do_constant_folding=True |
|||
) |
|||
except Exception as e: |
|||
log.error(f'Fail to save as onnx: {e}.') |
|||
raise RuntimeError(f'Fail to save as onnx: {e}.') |
|||
# todo: elif format == 'tensorrt': |
|||
else: |
|||
log.error(f'Unsupported format "{format}".') |
|||
return Path(path).resolve() |
|||
|
|||
@staticmethod |
|||
def supported_model_names(format: str = None): |
|||
import requests |
|||
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") |
|||
data = req.text |
|||
full_list = [] |
|||
for line in data.split('\r\n'): |
|||
line = line.replace(' ', '') |
|||
if line.startswith('"name":'): |
|||
name = line.split(':')[-1].replace('"', '').replace(',', '') |
|||
full_list.append(name) |
|||
full_list.sort() |
|||
if format is None: |
|||
model_list = full_list |
|||
elif format == 'pytorch': |
|||
to_remove = [] |
|||
assert set(to_remove).issubset(set(full_list)) |
|||
model_list = list(set(full_list) - set(to_remove)) |
|||
elif format == 'onnx': |
|||
to_remove = [] |
|||
assert set(to_remove).issubset(set(full_list)) |
|||
model_list = list(set(full_list) - set(to_remove)) |
|||
else: |
|||
log.error(f'Invalid or unsupported format "{format}".') |
|||
return model_list |
@ -0,0 +1,103 @@ |
|||
from towhee import ops |
|||
import numpy |
|||
import onnx |
|||
import onnxruntime |
|||
|
|||
import os |
|||
from pathlib import Path |
|||
import logging |
|||
import platform |
|||
import psutil |
|||
|
|||
op = ops.sentence_embedding.sbert().get_op() |
|||
# full_models = op.supported_model_names() |
|||
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
|||
# models = [x for x in full_models if x not in checked_models] |
|||
models = ['all-MiniLM-L12-v2'] |
|||
test_txt = 'hello, world.' |
|||
atol = 1e-3 |
|||
log_path = 'sbert.log' |
|||
f = open('onnx.csv', 'w+') |
|||
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
|||
|
|||
logger = logging.getLogger('sbert_onnx') |
|||
logger.setLevel(logging.DEBUG) |
|||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|||
fh = logging.FileHandler(log_path) |
|||
fh.setLevel(logging.DEBUG) |
|||
fh.setFormatter(formatter) |
|||
logger.addHandler(fh) |
|||
ch = logging.StreamHandler() |
|||
ch.setLevel(logging.ERROR) |
|||
ch.setFormatter(formatter) |
|||
logger.addHandler(ch) |
|||
|
|||
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|||
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
|||
logger.debug(f'cpu: {psutil.cpu_count()}') |
|||
|
|||
|
|||
status = None |
|||
for name in models: |
|||
logger.info(f'***{name}***') |
|||
saved_name = name.replace('/', '-') |
|||
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
status = [name] + ['fail'] * 5 |
|||
try: |
|||
op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op() |
|||
out1 = op(test_txt) |
|||
logger.info('OP LOADED.') |
|||
status[1] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO LOAD OP: {e}') |
|||
continue |
|||
try: |
|||
op.save_model('onnx') |
|||
logger.info('ONNX SAVED.') |
|||
status[2] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|||
continue |
|||
try: |
|||
try: |
|||
onnx_model = onnx.load(onnx_path) |
|||
onnx.checker.check_model(onnx_model) |
|||
except Exception: |
|||
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
|||
onnx.checker.check_model(saved_onnx) |
|||
logger.info('ONNX CHECKED.') |
|||
status[3] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|||
pass |
|||
try: |
|||
inputs = op._model.tokenize([test_txt]) |
|||
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) |
|||
onnx_inputs = {} |
|||
for n in sess.get_inputs(): |
|||
k = n.name |
|||
if k in inputs: |
|||
onnx_inputs[k] = inputs[k].cpu().detach().numpy() |
|||
out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) |
|||
logger.info('ONNX WORKED.') |
|||
status[4] = 'success' |
|||
if numpy.allclose(out1, out2, atol=atol): |
|||
logger.info('Check accuracy: OK') |
|||
status[5] = 'success' |
|||
else: |
|||
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO RUN ONNX: {e}') |
|||
continue |
|||
|
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
|
|||
print('Finished.') |
|||
|
|||
|
|||
|
Loading…
Reference in new issue