diff --git a/s_bert.py b/s_bert.py index ee31d9f..5f221a5 100644 --- a/s_bert.py +++ b/s_bert.py @@ -28,7 +28,38 @@ import warnings warnings.filterwarnings('ignore') logging.getLogger('sentence_transformers').setLevel(logging.ERROR) -log = logging.getLogger('op_sbert') +log = logging.getLogger('op_s_transformers') + + +class ConvertModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.net = model + try: + self.input_names = self.net.tokenizer.model_input_names + except AttributeError: + self.input_names = list(self.net.tokenize(['test']).keys()) + + def forward(self, *args, **kwargs): + if args: + assert kwargs == {}, 'Only accept neither args or kwargs as inputs.' + assert len(args) == len(self.input_names) + for k, v in zip(self.input_names, args): + kwargs[k] = v + outs = self.net(kwargs) + return outs['sentence_embedding'] + + +# @accelerate +class Model: + def __init__(self, model_name, device): + self.device = device + self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) + self.model.eval() + + def __call__(self, **features): + outs = self.model(features) + return outs['sentence_embedding'] class STransformers(NNOperator): @@ -38,12 +69,11 @@ class STransformers(NNOperator): def __init__(self, model_name: str = None, device: str = None): self.model_name = model_name - if device: - self.device = device - else: - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device if self.model_name: - self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device) + self.model = Model(model_name=self.model_name, device=self.device) else: log.warning('The operator is initialized without specified model.') pass @@ -53,13 +83,113 @@ class STransformers(NNOperator): sentences = [txt] else: sentences = txt - embs = self.model.encode(sentences) # return numpy.ndarray + inputs = self.tokenize(sentences) + embs = self.model(**inputs).cpu().detach().numpy() if isinstance(txt, str): embs = embs.squeeze(0) else: embs = list(embs) return embs + @property + def supported_formats(self): + return ['onnx'] + + def tokenize(self, x): + try: + outs = self._model.tokenize(x) + except Exception: + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) + except Exception: + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + outs = tokenizer( + x, + padding=True, truncation='longest_first', max_length=self.max_seq_length, + return_tensors='pt', + ) + return outs + + @property + def max_seq_length(self): + import json + from torch.hub import _get_torch_home + torch_cache = _get_torch_home() + sbert_cache = os.path.join(torch_cache, 'sentence_transformers') + cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') + if not os.path.exists(cfg_path): + cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') + k = 'max_position_embeddings' + else: + k = 'max_seq_length' + with open(cfg_path) as f: + cfg = json.load(f) + if k in cfg: + max_seq_len = cfg[k] + else: + max_seq_len = None + return max_seq_len + + @property + def _model(self): + return self.model.model + + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': + path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self.model_name.replace('/', '-') + path = os.path.join(path, name) + if format in ['pytorch', 'torchscript']: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise AttributeError(f'Invalid format {format}.') + dummy_text = ['[CLS]'] + dummy_input = self.tokenize(dummy_text) + if format == 'pytorch': + torch.save(self._model, path) + elif format == 'torchscript': + try: + try: + jit_model = torch.jit.script(self._model) + except Exception: + jit_model = torch.jit.trace(self._model, dummy_input, strict=False) + torch.jit.save(jit_model, path) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onnx': + new_model = ConvertModel(self._model) + input_names = list(dummy_input.keys()) + dynamic_axes = {} + for i_n, i_v in dummy_input.items(): + if len(i_v.shape) == 1: + dynamic_axes[i_n] = {0: 'batch_size'} + else: + dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} + dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} + try: + torch.onnx.export(new_model, + tuple(dummy_input.values()), + path, + input_names=input_names, + output_names=['output_0'], + opset_version=13, + dynamic_axes=dynamic_axes, + do_constant_folding=True + ) + except Exception as e: + log.error(f'Fail to save as onnx: {e}.') + raise RuntimeError(f'Fail to save as onnx: {e}.') + # todo: elif format == 'tensorrt': + else: + log.error(f'Unsupported format "{format}".') + return Path(path).resolve() + @staticmethod def supported_model_names(format: str = None): import requests @@ -78,7 +208,10 @@ class STransformers(NNOperator): to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) + elif format == 'onnx': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) else: - raise ValueError(f'Invalid or unsupported format "{format}".') log.error(f'Invalid or unsupported format "{format}".') return model_list diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..5f259ac --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,98 @@ +from towhee import ops +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +op = ops.sentence_embedding.sbert().get_op() +# full_models = op.supported_model_names() +# checked_models = AutoTransformers.supported_model_names(format='onnx') +# models = [x for x in full_models if x not in checked_models] +models = ['all-MiniLM-L12-v2'] +test_txt = 'hello, world.' +atol = 1e-3 +log_path = 'sbert_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('transformers_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + try: + op = ops.text_embedding.sentence_transformers(model_name=name, device='cpu').get_op() + out1 = op(test_txt) + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + onnx_path = str(op.save_model('onnx')) + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + pass + try: + inputs = op.tokenize([test_txt]) + sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) + onnx_inputs = {} + for n in sess.get_inputs(): + k = n.name + if k in inputs: + onnx_inputs[k] = inputs[k].cpu().detach().numpy() + out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.') \ No newline at end of file