From c042d6c39de670982294e917dcdaa30c688cc694 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 13 Jan 2023 14:04:56 +0800 Subject: [PATCH] Remove ONNX & accelerate Signed-off-by: Jael Gu --- README.md | 1 - s_bert.py | 143 ++------------------------------------------------- test_onnx.py | 103 ------------------------------------- 3 files changed, 3 insertions(+), 244 deletions(-) delete mode 100644 test_onnx.py diff --git a/README.md b/README.md index d4ae5e5..74c9ebf 100644 --- a/README.md +++ b/README.md @@ -111,5 +111,4 @@ from towhee import ops op = ops.sentence_embedding.sentence_transformers().get_op() full_list = op.supported_model_names() -onnx_list = op.supported_model_names(format='onnx') ``` diff --git a/s_bert.py b/s_bert.py index 4a42326..ee31d9f 100644 --- a/s_bert.py +++ b/s_bert.py @@ -31,35 +31,6 @@ logging.getLogger('sentence_transformers').setLevel(logging.ERROR) log = logging.getLogger('op_sbert') -class ConvertModel(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.net = model - try: - self.input_names = self.net.tokenizer.model_input_names - except AttributeError: - self.input_names = list(self.net.tokenize(['test']).keys()) - - def forward(self, *args, **kwargs): - if args: - assert kwargs == {}, 'Only accept neither args or kwargs as inputs.' - assert len(args) == len(self.input_names) - for k, v in zip(self.input_names, args): - kwargs[k] = v - outs = self.net(kwargs) - return outs['sentence_embedding'] - - -# @accelerate -class Model: - def __init__(self, model): - self.model = model - - def __call__(self, **features): - outs = self.model(features) - return outs['sentence_embedding'] - - class STransformers(NNOperator): """ Operator using pretrained Sentence Transformers @@ -72,7 +43,7 @@ class STransformers(NNOperator): else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if self.model_name: - self.model = Model(self._model) + self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) else: log.warning('The operator is initialized without specified model.') pass @@ -82,118 +53,13 @@ class STransformers(NNOperator): sentences = [txt] else: sentences = txt - inputs = self.tokenize(sentences) - embs = self.model(**inputs).cpu().detach().numpy() + embs = self.model.encode(sentences) # return numpy.ndarray if isinstance(txt, str): embs = embs.squeeze(0) else: embs = list(embs) return embs - @property - def _model(self): - m = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) - m.eval() - return m - - @property - def supported_formats(self): - return ['onnx'] - - def tokenize(self, x): - try: - outs = self._model.tokenize(x) - except Exception: - from transformers import AutoTokenizer - try: - tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) - except Exception as e: - log.error(e) - log.warning(f'Fail to load tokenizer with sentence-transformers/{self.model_name}.' - f'Trying to load tokenizer with self.model_name...') - tokenizer = AutoTokenizer.from_pretrained(self.model_name) - outs = tokenizer( - x, - padding=True, truncation='longest_first', max_length=self.max_seq_length, - return_tensors='pt', - ) - return outs - - @property - def max_seq_length(self): - import json - from torch.hub import _get_torch_home - torch_cache = _get_torch_home() - sbert_cache = os.path.join(torch_cache, 'sentence_transformers') - cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') - if not os.path.exists(cfg_path): - cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') - k = 'max_position_embeddings' - else: - k = 'max_seq_length' - with open(cfg_path) as f: - cfg = json.load(f) - if k in cfg: - max_seq_len = cfg[k] - else: - max_seq_len = None - return max_seq_len - - def save_model(self, format: str = 'pytorch', path: str = 'default'): - if path == 'default': - path = str(Path(__file__).parent) - path = os.path.join(path, 'saved', format) - os.makedirs(path, exist_ok=True) - name = self.model_name.replace('/', '-') - path = os.path.join(path, name) - if format in ['pytorch', 'torchscript']: - path = path + '.pt' - elif format == 'onnx': - path = path + '.onnx' - else: - raise AttributeError(f'Invalid format {format}.') - dummy_text = ['[CLS]'] - dummy_input = self.tokenize(dummy_text) - if format == 'pytorch': - torch.save(self._model, path) - elif format == 'torchscript': - try: - try: - jit_model = torch.jit.script(self._model) - except Exception: - jit_model = torch.jit.trace(self._model, dummy_input, strict=False) - torch.jit.save(jit_model, path) - except Exception as e: - log.error(f'Fail to save as torchscript: {e}.') - raise RuntimeError(f'Fail to save as torchscript: {e}.') - elif format == 'onnx': - new_model = ConvertModel(self._model) - input_names = list(dummy_input.keys()) - dynamic_axes = {} - for i_n, i_v in dummy_input.items(): - if len(i_v.shape) == 1: - dynamic_axes[i_n] = {0: 'batch_size'} - else: - dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} - dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} - try: - torch.onnx.export(new_model, - tuple(dummy_input.values()), - path, - input_names=input_names, - output_names=['output_0'], - opset_version=13, - dynamic_axes=dynamic_axes, - do_constant_folding=True - ) - except Exception as e: - log.error(f'Fail to save as onnx: {e}.') - raise RuntimeError(f'Fail to save as onnx: {e}.') - # todo: elif format == 'tensorrt': - else: - log.error(f'Unsupported format "{format}".') - return Path(path).resolve() - @staticmethod def supported_model_names(format: str = None): import requests @@ -212,10 +78,7 @@ class STransformers(NNOperator): to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) - elif format == 'onnx': - to_remove = [] - assert set(to_remove).issubset(set(full_list)) - model_list = list(set(full_list) - set(to_remove)) else: + raise ValueError(f'Invalid or unsupported format "{format}".') log.error(f'Invalid or unsupported format "{format}".') return model_list diff --git a/test_onnx.py b/test_onnx.py deleted file mode 100644 index 245aec8..0000000 --- a/test_onnx.py +++ /dev/null @@ -1,103 +0,0 @@ -from towhee import ops -import numpy -import onnx -import onnxruntime - -import os -from pathlib import Path -import logging -import platform -import psutil - -op = ops.sentence_embedding.sbert().get_op() -# full_models = op.supported_model_names() -# checked_models = AutoTransformers.supported_model_names(format='onnx') -# models = [x for x in full_models if x not in checked_models] -models = ['all-MiniLM-L12-v2'] -test_txt = 'hello, world.' -atol = 1e-3 -log_path = 'sbert.log' -f = open('onnx.csv', 'w+') -f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') - -logger = logging.getLogger('sbert_onnx') -logger.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') -fh = logging.FileHandler(log_path) -fh.setLevel(logging.DEBUG) -fh.setFormatter(formatter) -logger.addHandler(fh) -ch = logging.StreamHandler() -ch.setLevel(logging.ERROR) -ch.setFormatter(formatter) -logger.addHandler(ch) - -logger.debug(f'machine: {platform.platform()}-{platform.processor()}') -logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' - f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' - f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') -logger.debug(f'cpu: {psutil.cpu_count()}') - - -status = None -for name in models: - logger.info(f'***{name}***') - saved_name = name.replace('/', '-') - onnx_path = f'saved/onnx/{saved_name}.onnx' - if status: - f.write(','.join(status) + '\n') - status = [name] + ['fail'] * 5 - try: - op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op() - out1 = op(test_txt) - logger.info('OP LOADED.') - status[1] = 'success' - except Exception as e: - logger.error(f'FAIL TO LOAD OP: {e}') - continue - try: - op.save_model('onnx') - logger.info('ONNX SAVED.') - status[2] = 'success' - except Exception as e: - logger.error(f'FAIL TO SAVE ONNX: {e}') - continue - try: - try: - onnx_model = onnx.load(onnx_path) - onnx.checker.check_model(onnx_model) - except Exception: - saved_onnx = onnx.load(onnx_path, load_external_data=False) - onnx.checker.check_model(saved_onnx) - logger.info('ONNX CHECKED.') - status[3] = 'success' - except Exception as e: - logger.error(f'FAIL TO CHECK ONNX: {e}') - pass - try: - inputs = op._model.tokenize([test_txt]) - sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) - onnx_inputs = {} - for n in sess.get_inputs(): - k = n.name - if k in inputs: - onnx_inputs[k] = inputs[k].cpu().detach().numpy() - out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) - logger.info('ONNX WORKED.') - status[4] = 'success' - if numpy.allclose(out1, out2, atol=atol): - logger.info('Check accuracy: OK') - status[5] = 'success' - else: - logger.info(f'Check accuracy: atol is larger than {atol}.') - except Exception as e: - logger.error(f'FAIL TO RUN ONNX: {e}') - continue - -if status: - f.write(','.join(status) + '\n') - -print('Finished.') - - -