logo
Browse Source

Support onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
786e263a01
  1. 149
      s_bert.py
  2. 98
      test_onnx.py

149
s_bert.py

@ -28,7 +28,38 @@ import warnings
warnings.filterwarnings('ignore')
logging.getLogger('sentence_transformers').setLevel(logging.ERROR)
log = logging.getLogger('op_sbert')
log = logging.getLogger('op_s_transformers')
class ConvertModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.net = model
try:
self.input_names = self.net.tokenizer.model_input_names
except AttributeError:
self.input_names = list(self.net.tokenize(['test']).keys())
def forward(self, *args, **kwargs):
if args:
assert kwargs == {}, 'Only accept neither args or kwargs as inputs.'
assert len(args) == len(self.input_names)
for k, v in zip(self.input_names, args):
kwargs[k] = v
outs = self.net(kwargs)
return outs['sentence_embedding']
# @accelerate
class Model:
def __init__(self, model_name, device):
self.device = device
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device)
self.model.eval()
def __call__(self, **features):
outs = self.model(features)
return outs['sentence_embedding']
class STransformers(NNOperator):
@ -38,12 +69,11 @@ class STransformers(NNOperator):
def __init__(self, model_name: str = None, device: str = None):
self.model_name = model_name
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
if self.model_name:
self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device)
self.model = Model(model_name=self.model_name, device=self.device)
else:
log.warning('The operator is initialized without specified model.')
pass
@ -53,13 +83,113 @@ class STransformers(NNOperator):
sentences = [txt]
else:
sentences = txt
embs = self.model.encode(sentences) # return numpy.ndarray
inputs = self.tokenize(sentences)
embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str):
embs = embs.squeeze(0)
else:
embs = list(embs)
return embs
@property
def supported_formats(self):
return ['onnx']
def tokenize(self, x):
try:
outs = self._model.tokenize(x)
except Exception:
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
outs = tokenizer(
x,
padding=True, truncation='longest_first', max_length=self.max_seq_length,
return_tensors='pt',
)
return outs
@property
def max_seq_length(self):
import json
from torch.hub import _get_torch_home
torch_cache = _get_torch_home()
sbert_cache = os.path.join(torch_cache, 'sentence_transformers')
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json')
if not os.path.exists(cfg_path):
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json')
k = 'max_position_embeddings'
else:
k = 'max_seq_length'
with open(cfg_path) as f:
cfg = json.load(f)
if k in cfg:
max_seq_len = cfg[k]
else:
max_seq_len = None
return max_seq_len
@property
def _model(self):
return self.model.model
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_text = ['[CLS]']
dummy_input = self.tokenize(dummy_text)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
new_model = ConvertModel(self._model)
input_names = list(dummy_input.keys())
dynamic_axes = {}
for i_n, i_v in dummy_input.items():
if len(i_v.shape) == 1:
dynamic_axes[i_n] = {0: 'batch_size'}
else:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'}
try:
torch.onnx.export(new_model,
tuple(dummy_input.values()),
path,
input_names=input_names,
output_names=['output_0'],
opset_version=13,
dynamic_axes=dynamic_axes,
do_constant_folding=True
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()
@staticmethod
def supported_model_names(format: str = None):
import requests
@ -78,7 +208,10 @@ class STransformers(NNOperator):
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
else:
raise ValueError(f'Invalid or unsupported format "{format}".')
log.error(f'Invalid or unsupported format "{format}".')
return model_list

98
test_onnx.py

@ -0,0 +1,98 @@
from towhee import ops
import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import logging
import platform
import psutil
op = ops.sentence_embedding.sbert().get_op()
# full_models = op.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['all-MiniLM-L12-v2']
test_txt = 'hello, world.'
atol = 1e-3
log_path = 'sbert_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models:
logger.info(f'***{name}***')
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try:
op = ops.text_embedding.sentence_transformers(model_name=name, device='cpu').get_op()
out1 = op(test_txt)
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
onnx_path = str(op.save_model('onnx'))
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
pass
try:
inputs = op.tokenize([test_txt])
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())
onnx_inputs = {}
for n in sess.get_inputs():
k = n.name
if k in inputs:
onnx_inputs[k] = inputs[k].cpu().detach().numpy()
out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0)
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
status[5] = 'success'
else:
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
if status:
f.write(','.join(status) + '\n')
print('Finished.')
Loading…
Cancel
Save