From b56e5e1fa45206bc748a7a6eb52572ec0b649929 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 12 Jan 2023 13:27:43 +0800 Subject: [PATCH] Add files Signed-off-by: Jael Gu --- README.md | 149 ++++++++++++++++++++++++- __init__.py | 19 ++++ auto_transformers.py | 256 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 7 ++ result.png | Bin 0 -> 5855 bytes test_onnx.py | 105 ++++++++++++++++++ 6 files changed, 535 insertions(+), 1 deletion(-) create mode 100644 __init__.py create mode 100644 auto_transformers.py create mode 100644 requirements.txt create mode 100644 result.png create mode 100644 test_onnx.py diff --git a/README.md b/README.md index e2cfae6..0c5d656 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,149 @@ -# transformers +# Sentence Embedding with Transformers +*author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +A sentence embedding operator generates one embedding vector in ndarray for each input text. +The embedding represents the semantic information of the whole input text as one vector. +This operator is implemented with pre-trained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). + +
+ +## Code Example + +Use the pre-trained model 'sentence-transformers/paraphrase-albert-small-v2' +to generate an embedding for the sentence "Hello, world.". + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +- **option 1 (towhee>=0.9.0):** +```python +from towhee.dc2 import pipe, ops, DataCollection + +p = ( + pipe.input('text') + .map('text', 'vec', + ops.sentence_embedding.transformers(model_name='sentence-transformers/paraphrase-albert-small-v2')) + .output('text', 'vec') +) + +DataCollection(p('Hello, world.')).show() +``` + + + +- **option 2:** + +```python +import towhee + +( + towhee.dc['text'](['Hello, world.']) + .sentence_embedding.transformers['text', 'vec']( + model_name='sentence-transformers/paraphrase-albert-small-v2') + .show() +) +``` + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***sentence_embedding.transformers(model_name=None)*** + +**Parameters:** + +***model_name***: *str* + +The model name in string, defaults to None. +If None, the operator will be initialized without specified model. + +Supported model names: refer to `supported_model_names` below. + +***checkpoint_path***: *str* + +The path to local checkpoint, defaults to None. +If None, the operator will download and load pretrained model by `model_name` from Huggingface transformers. + +
+ +***tokenizer***: *object* + +The method to tokenize input text, defaults to None. +If None, the operator will use default tokenizer by `model_name` from Huggingface transformers. + +
+ +## Interface + +The operator takes a piece of text in string as input. +It loads tokenizer and pre-trained model using model name, +and then return a text emabedding in numpy.ndarray. + +***\_\_call\_\_(txt)*** + +**Parameters:** + +***data***: *Union[str, list]* + +​ The text in string or a list of texts. + +**Returns**: + +*numpy.ndarray or list* + +​ The text embedding (or token embeddings) extracted by model. +If `data` is string, the operator returns an embedding in numpy.ndarray with shape of (dim,). +If `data` is a list, the operator returns a list of embedding(s) with length of input list. + +
+ +***save_model(format='pytorch', path='default')*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format to export model as, such as 'pytorch', 'torchscript', 'onnx', +defaults to 'pytorch'. + +***path***: *str* + +​ The path where exported model is saved to. +By default, it will save model to `saved` directory under the operator cache. + +```python +from towhee import ops + +op = ops.sentence_embedding.transformers(model_name='sentence-transformers/paraphrase-albert-small-v2').get_op() +op.save_model('onnx', 'test.onnx') +``` +PosixPath('/Home/.towhee/operators/sentence-embedding/transformers/main/test.onnx') + +
+ +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', 'torchscript', 'onnx'. + +```python +from towhee import ops + + +op = ops.sentence_embedding.transformers().get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..2cc07d3 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .auto_transformers import AutoTransformers + + +def transformers(*args, **kwargs): + return AutoTransformers(*args, **kwargs) diff --git a/auto_transformers.py b/auto_transformers.py new file mode 100644 index 0000000..0a601a2 --- /dev/null +++ b/auto_transformers.py @@ -0,0 +1,256 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import os +import torch +import shutil +from pathlib import Path +from typing import Union +from collections import OrderedDict + +from transformers import AutoModel + +from towhee.operator import NNOperator +from towhee import register +# from towhee.dc2 import accelerate + +import warnings +import logging +from transformers import logging as t_logging + +log = logging.getLogger('run_op') +warnings.filterwarnings('ignore') +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +t_logging.set_verbosity_error() + + +# @accelerate +class Model: + def __init__(self, model): + self.model = model + + def __call__(self, *args, **kwargs): + outs = self.model(*args, **kwargs, return_dict=True) + return outs['last_hidden_state'] + + +@register(output_schema=['vec']) +class AutoTransformers(NNOperator): + """ + NLP embedding operator that uses the pretrained transformers model gathered by huggingface. + Args: + model_name (`str`): + The model name to load a pretrained model from transformers. + checkpoint_path (`str`): + The local checkpoint path. + tokenizer (`object`): + The tokenizer to tokenize input text as model inputs. + """ + + def __init__(self, + model_name: str = None, + checkpoint_path: str = None, + tokenizer: object = None, + device: str = None, + norm: bool = False + ): + super().__init__() + self._device = device + self.model_name = model_name + self.user_tokenizer = tokenizer + self.norm = norm + self.checkpoint_path = checkpoint_path + + if self.model_name: + model_list = self.supported_model_names() + # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" + self.model = Model(self._model) + else: + log.warning('The operator is initialized without specified model.') + pass + + def __call__(self, data: Union[str, list]) -> numpy.ndarray: + if isinstance(data, str): + txt = [data] + else: + txt = data + try: + inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device) + except Exception as e: + log.error(f'Fail to tokenize inputs: {e}') + raise e + try: + outs = self.model(**inputs) + except Exception as e: + log.error(f'Invalid input for the model: {self.model_name}') + raise e + outs = self.post_proc(outs, inputs) + if self.norm: + outs = torch.nn.functional.normalize(outs, ) + features = outs.cpu().detach().numpy() + if isinstance(data, str): + features = features.squeeze(0) + else: + features = list(features) + return features + + @property + def _model(self): + model = AutoModel.from_pretrained(self.model_name).to(self.device) + if hasattr(model, 'pooler') and model.pooler: + model.pooler = None + if self.checkpoint_path: + try: + state_dict = torch.load(self.checkpoint_path, map_location=self.device) + model.load_state_dict(state_dict) + except Exception: + log.error(f'Fail to load weights from {self.checkpoint_path}') + model.eval() + return model + + @property + def device(self): + if self._device is None: + if self._device_id < 0: + self._device = torch.device('cpu') + else: + self._device = torch.device(self._device_id) + return self._device + + @property + def model_config(self): + from transformers import AutoConfig + configs = AutoConfig.from_pretrained(self.model_name) + return configs + + @property + def onnx_config(self): + from transformers.onnx.features import FeaturesManager + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + self._model, feature='default') + old_config = model_onnx_config(self.model_config) + onnx_config = { + 'inputs': dict(old_config.inputs), + 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} + } + return onnx_config + + @property + def tokenizer(self): + from transformers import AutoTokenizer + try: + if self.user_tokenizer: + t = tokenizer + else: + t = AutoTokenizer.from_pretrained(self.model_name) + if not t.pad_token: + t.pad_token = '[PAD]' + except Exception as e: + log.error(f'Fail to load tokenizer.') + raise e + return t + + def post_proc(self, token_embeddings, inputs): + token_embeddings = token_embeddings.to(self.device) + attention_mask = inputs['attention_mask'].to(self.device) + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sentence_embs = torch.sum( + token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sentence_embs + + def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): + if output_file == 'default': + output_file = str(Path(__file__).parent) + output_file = os.path.join(output_file, 'saved', model_type) + os.makedirs(output_file, exist_ok=True) + name = self.model_name.replace('/', '-') + output_file = os.path.join(output_file, name) + if model_type in ['pytorch', 'torchscript']: + output_file = output_file + '.pt' + elif model_type == 'onnx': + output_file = output_file + '.onnx' + else: + raise AttributeError('Unsupported model_type.') + + dummy_input = 'test sentence' + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary + if model_type == 'pytorch': + torch.save(self._model, output_file) + elif model_type == 'torchscript': + inputs = list(inputs.values()) + try: + try: + jit_model = torch.jit.script(self._model) + except Exception: + jit_model = torch.jit.trace(self._model, inputs, strict=False) + torch.jit.save(jit_model, output_file) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif model_type == 'onnx': + dynamic_axes = {} + for k, v in self.onnx_config['inputs'].items(): + dynamic_axes[k] = v + for k, v in self.onnx_config['outputs'].items(): + dynamic_axes[k] = v + torch.onnx.export( + self._model, + tuple(inputs.values()), + output_file, + input_names=list(self.onnx_config['inputs'].keys()), + output_names=list(self.onnx_config['outputs'].keys()), + dynamic_axes=dynamic_axes, + opset_version=torch.onnx.constant_folding_opset_versions[-1], + do_constant_folding=True, + ) + # todo: elif format == 'tensorrt': + else: + log.error(f'Unsupported format "{format}".') + return Path(output_file).resolve() + + @property + def supported_formats(self): + onnxes = self.supported_model_names(format='onnx') + if self.model_name in onnxes: + return ['onnx'] + else: + return ['pytorch'] + + @staticmethod + def supported_model_names(format: str = None): + full_list = [ + + ] + full_list.sort() + if format is None: + model_list = full_list + elif format == 'pytorch': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + elif format == 'torchscript': + to_remove = [ + ] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + elif format == 'onnx': + to_remove = [ + ] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + # todo: elif format == 'tensorrt': + else: + log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') + return model_list diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5e12bcc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +numpy +transformers +sentencepiece +protobuf + +towhee +torch diff --git a/result.png b/result.png new file mode 100644 index 0000000000000000000000000000000000000000..7924280d1663c2a0d800539fa84fb70253cde06d GIT binary patch literal 5855 zcmZ8lbyO72x2G1qgdhzfrG(Vdf=Efo!qO?Rgpx`LOE)N~#1c#QN{1jIOSiC6OLv0{ z2+M*r`1JdG=e>8{A9v2ZcRn-c%zVzhGv^bdtF3m2>>e2&9^M^wkcvJY9wG3L5BulO zfEOv-n6+DoaZOoiv}onE#eSuU1G82yM!8eJcJk5Cve zpFD600L6#mLlhau{x}dtIKCy6`9JQzng5-_`yUMBzu@Nmc|vCwC%V2P*1g^*dPSdI z79K~4=DDQ3TIjELS!gs}yh;#V;V)?(JwLGi9orULDo}Ad<0%hk4!ON>V5c0<+`)vW!Z z+jzINM&eT`){@*+Sa84RX#J9&I4%(oyWtW0td1~1*0-Cqsa=MW@eT03u^WxZRbWMu zHs0fsSRFzuJ}y+%a?X9>>IML7Vp?pEM{JU0lUdapd)mBF)e1)}t;kq0RR!&mTeK*n z)KvIH60MR3zh_*SXv`ECPW^s#Yo>nm@sq3ldWZVS@*G&YAd#gWNNTC-_>qN=cbLBa z>0FKb;#b;T`D;%)OB?pSl)RU)$#za<=xx#MPYpys*?t*TthNG~`sH$t;kTXwY=j{N z|2^5mQzOj`B^l9{J)PvT-fho}!ZG4zgSSE=<0_HVPC(a2fC)bQHLcOrBc9|RBOE`9 z4Wr}a;G=YxLMcjpZluT;9~o4#;+`&rf2^E#s;haKwznAT?8;8j+m{PAc(bn+J`|@p zvP|m+tS$jq9mR~K2hoeG?6_F4?RZp3PE-0#6q6opuQQmrEBIuZX3m+29%8QFp*Aeo zKKJKxzsPyfa74YcKGI}e1rRzl{Z1mYvWk<2yOJG?|Eg0w)vkEkJ>u*6mYYG=rt3V2 zL3}=38}!M1Af7d(t1lb;1RK{q1KO%l_DJ@e%6l0|TH-9@YM{it2e+ArzF6An2^CQ;2f;^i?HjxV^ z<64}UU!F6&06Bos%%eqcKn|2v1ikq#AYd)N#HX^~maBy6d4}v|AEOC|XG>+>)5%Ga z=kpHHo5b%?QWZUBr`gq5j@rK84uY{$#y{T)!VD;6-9p5}Ujzb^_jbC`d-pUlD2_GT z@p&=>0kWH(G3zqPryy<7u4*3RL*>GwW|7NNWOJh5AyLJ&Evb7EdI==ZXG^m0VzBBN zv-&C-^9TDa9b7eEL>|+>^I&U_QFuk^BP(q?M*J1;6?~~ZaB$b-foF4H%INi4!d%OF zT#(Es{JXLE9J?zjq-mvinj!|oJQQQwJ92=g#nF8+s#yTT3SghW8~?AhghD!j){K9x zmVEaAr)$CX!&56H8UKy*wvqoAKlbfk_{rzLE71S^yAQ>CNc`-fKg2xy*RNVno_xKW zJ@(+)nVu0FmPUMHCJsewj=Myp?>IO?kBF04__u+L#ro#uQwLY=WBxZjMrH^;%yJ}M zj2G40JN=SOMdM87lJVxJgWnHVj)#5haayCPy#XzYTQ(?M`%ddlP{)q5ssH0ZZpd=2 zjm2-W$8JpNrjwDUray4L7o!wl^jLMW!6kP_EYM&n2}s+ z)d5xU=?Op8j@8<-!NeyG^=|rUi(|PrQgj|chXv;R=q86p4&kxyR9Un223AK53}Sh3 z?hkiGvUeZR0JP#Si zudVfN8foK23zsEEt39~sCqHn4ANX{n0=eNckCIP&jBW%%-HdF7jOxd>TtiVqE+=@m z{#xA5Q4h&LFkT=^X*s!TSWc2EKQJm=S~=i;{u&f&K@{PIcO>MIftBX$Ie(3JSk@Zh z@~MA_4FF9gG^G0GZj(Wk@-uBxF>DHUno;bRLzW)gN9cQkER2>}jFzE*%TA+eTb2oT zNgjB%9N7xAl5@0!@r&M&`tmZ4UoLo1S~nNI=u)~l4(ti^qjQ8yH1m-9vh%MFX%JXu z3rZ0D?hkj?;QXSC&%+Ud>IUi%=Rtrha zJ9(bS#!qoS9VHeVQop1SQDRFw5xsldNiILWMdTN+9>?dN$4AA!nqTZJl!&LF9Q$0a z%41Jtli{V;9J#OJORfh#Ux-KJ0oXumW$ae5TuRGBvygylO5xvEm;a3)SAu<2&M9+S*MpeZAs(fSCa0Y;_9v1PEiRy zZCtLMy;>dQMu@lLj-;({Jspag*}T00g$nLTJ6hS~@g(kA)gi^}^yeK%>B2ctn@)Be zp)%@obe31M=QG>&D7E9tD4-S(D0^(!xHNH+EQERd;}*8HFvdC6JZg=w61XGBf9epG z_kITK8M1%TAyx^@ptDrdmLqMY#ysJ4Pj=lq~uwHVn z79cf1?_np4{aI|(o>_mrWdDRJ5_zO<)a#MiRbprcd2CQdU>>SVDAmL|&sE0iozg2u z0mp6FR&oI|_r?U(Te?wkODq6Ld0A@l+xk{2<*c9QDVCw~EzW1)9B;!?amD@TEKQB}|SWgXk4nLutvV5(zBMf!AdXr(&TxX70 zY?_s54?0?(cplMrs3C*vl&z`ctFby>Up~od7}SL)QyJ4-fe$U9d_|V{Rbv$ocI328 zA;$Ge0HnUcU)E+ebQV(UtGQgM!E{){`@opE?G_FZj;=hqy4xM0wD(9~Z{N2Esi)lq z*93NkaX)FY-tiXG8wzkNO+T)deO1Pxc1S}W#}1QSHt9^xcPup9nx;-H%SZ6kfYJA$ zNEdS_*NM?6nkHl`dC_#M25S^M9ai?k{PaozaNtmD-J~^hWKZnN1MTYrPcrGTl#|v? zg_AWY(!vfifVwpiv0e zK|=dP0@ruNklC_sW$%dBVUzzHf{;#PGszBf(Oe1f{L$tRd0w5leXqeT6xoj`Z~AY> zM?R*I*R$^X$sIlx#$oB>`CMnFEt(X@#Rvq_@W<`IZ)2A#ixd3P)qhVsDBzF##DAAb7JU0@~RQf3fT zNJRaOsp?K}xC820BqC9=v8&R~Sc;nro!ExYUTy7$T9}-Io$ck+41Y3h6iX@OJL#lnyG5tc>*g zsb>C?=UV~&<@Z%nx^yb1hKP;}$E&~>PEhq#pX?k|%{8ed}_?swy_Ov*A&5=r_`FEnBwyve9 zD|b|`;a4vjneM*&zQ1Qzzx#Xp`x?pWdqu{tA!i5nj%maN9z979D^`Pw$d9z-R)C?Y z7j4!1GAH*9%Tvhi**b+8z$Ee*7ly)AvfBEDewXj*<@rsJWkgDC=Y#QIDT&{Dsb^@s zw%_$2X0Yc%yy^wddn|pgqr!edqzdZtgeeH}O<2D^%|F#R->>Uk=4urofdG|SC~CDi z)mF1BfeQG`g|Dp<6=0|xtK$Wx=Y8S_B7*5g)Y-z3rK3$^^AnEQSePK;> zvL;^g%X89h_P~L3>N$J{x7YGr(A!fE^fqCfJ9rRSa(u5=N_s`PWEo;wteT9i#eKM3 ze3KBs=4T0P9ey`NN*`koGVrG?f9*%si$ncrwWE6WU|`l?Xf=txA4xQ=HOjSo7Sh7Z zkVrePLB)Ckp6j_i+{T9|?R+R-W)}M0x;wcj^&7s~!`F9{vaas7TB>5ikOchwhW8i; zgHi^u+drX4`24^!--t|UjyvHUUj8Q#4rX5+oUZFSEl88knkuj;@gq^I04l%DUXA7t z)A3t^sZ)KEUmQU*)14ktGgEfR`_l<_!PqdA8EHsc`LbW9iT2)3DR`Rqvkp47sDuC_ zTOYa7m+U{?z{I&Bk+Bw>A39N6jeSwZe_Lo4d#$S6_gIcfnGD-`^-^qsJ}MZq7>KLP zSVIx$dyZ}ShUB`f2|(V(_J{hP2S)5uNIoEsWA^Z!ele) zcExm=)F-m5fh*VRDb2Am^ibx#QPd(;!qa;yo0OA^ttNWN{rH~nEEvm9-J(>bfoCZS z)ij~nT+pB^{PVra!0g9iHVO^5SJ?8>$4HCBVyc23wuU^tq@$`q>?e`4Fp`8FL!}zG zXuH3y?}#8`is(tYOH5ots*lQoZ z+bsA!xk{k27?){*y`w}g$jK@oX=67`NKEreP_Gm<1B0Sdn(&ONw9g(Z){{xlwoItq zXvaEd7J!8$hc=Ee5?Zf~%36@+YLF3MPR`*L;>7S~L+;+_VRU z#*%c=g9w!6s^(CWSZ~jop?7YK)Ae3^y!YSlxwd@2_Is0INya@@Q)JinyMPZ<^I(op` z=VsR+peHKB59TlH``WOBt6s>;2v}(3%DO~XW<=tA=g@OL+d_@82fCtzrd|6?nUmJ2 zH2=4^C3&b#OX3~rnZUK~hU{h)ZG_YzDNwpyOJ*9*MO)zJY(Wb)_}a%PzEm5(txI5F z@j06||7P&@(7L1%FrshOf6Bio3wn1n9h9lzcdFO~=9ThpQj@OzN3 zgq{kNjnvFI_Y&G)3!l^&R3%hGKc=blx2jRIq7x+*VC$9&a4cEDxS34eybZ~Y7su9cE-Qcdu zZ4oMCRvci;#P#UT6q)axb|0_tz-#KKd2@Sp3`lM}*n6*H3m&TbBeo61D<6(IT327M zZlBH{s{O!G@9RdhFb+$WHVG@p$4_oKpZ8u9QM}>%Y<|l2+s@_l&fyNW|NC8_I+}Qv zTRnG5?RB@*3?@&OO}93N9LYNV=Y#xhv&|;5)4{WhZslVO(-chEB5Keclm6VGT#E7)z{L~+!V?wxd9Ll+X zgN0wXO=!-_cH|gup1u;c%6NTtR6WmR0$&WA&7PPQ&=m?IB_GjZ=gOH zPN*x?X?Bz2-(RRQNp)ZmWP(waST_N?%o|Ht^r{5pELb+Csb7YacQ9vGzu7hQo%~zb z+<{MEdpOYM2k+6_QgS{fsJRuooO84y=_Zpks|cl4+RX7$0go^q7^^2Q$jTq(KS7jc zF>On3m_F22BQCQ*1VOr&gOFI|fFRe($F}R|=AJUSp#noZYa32lvgAIDL#S@BPi34j zzWnO1r^u7OFlk-M2Ykdu{5Y9jZb+}?M*nMP_kG(kp#PXFdwKSqC2dAIN?%0}3w>H> z%8$YZ#qT8T$rWK)S%r#OEIttrFtD#$oBuOIPbspEvRGQ;7=6RNx z-_ou8opzGoCa#~47)rK9Fi)Jw0YgP*CjMV-KA1;U#2GKB8x?~+St0&-;baMbHjR1a z#dGQ;Oos_%Y9D`2Rd_(fG@p>bya?yefS@H<)4Pgu7N+ zFncTZtUsVfySIjMsO6s=pX#`FO+1N{%%^oJ7ycp>StvPBrOWdda9N5 zSjtsc70vxt2wVJs35kMsD2s{eL|IXKjq5(duc!g(+t>4nHfZ7J@#48*gqYi_30qYe z>9L$AgRmI!jA~fzS^Fh0w z)83MCPS{1G~;spCh?+!7U0 zTLI!FE)}Sl$N?R4T zhHJldT8dx02^z<@Imi+J4v)3PZFapuvS!>?lbgOTZV7PZ+eC+nixaktZqf}lD61_x zsf==5x-ZETHaGmnIh%PE?SP(hZj?}dc>ZPi%WI%pf1NT3Rd(F`+Hetx8=!%ei4@8_ zgT19?e!oe4`kY-$XpgR|tf+7Cdq(xKjqBVE?BM0knwzFe5afu@9LrNaGj z@$;UhpgU0ZGAjp3ZAhU^(p?`Gc`^A5Q&L{r{i)aMJ|Iw5op-zxcA_rL`-eemQ&|l?R NT~%A9TFE;6zW|LyarOWJ literal 0 HcmV?d00001 diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..87dc20e --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,105 @@ +from towhee import ops +import torch +import numpy +import onnx +import onnxruntime + +import os +from pathlib import Path +import logging +import platform +import psutil + +import warnings +from transformers import logging as t_logging + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +warnings.filterwarnings("ignore") +t_logging.set_verbosity_error() + +# full_models = AutoTransformers.supported_model_names() +# checked_models = AutoTransformers.supported_model_names(format='onnx') +# models = [x for x in full_models if x not in checked_models] +models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2'] +test_txt = 'hello, world.' +atol = 1e-3 +log_path = 'transformers_onnx.log' +f = open('onnx.csv', 'w+') +f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') + +logger = logging.getLogger('transformers_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) + +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') + + +status = None +for name in models: + logger.info(f'***{name}***') + saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + if status: + f.write(','.join(status) + '\n') + status = [name] + ['fail'] * 5 + try: + op = ops.sentence_embedding.transformers(model_name=name).get_op() + out1 = op(test_txt) + logger.info('OP LOADED.') + status[1] = 'success' + except Exception as e: + logger.error(f'FAIL TO LOAD OP: {e}') + continue + try: + op.save_model(model_type='onnx') + logger.info('ONNX SAVED.') + status[2] = 'success' + except Exception as e: + logger.error(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(onnx_path, load_external_data=False) + onnx.checker.check_model(saved_onnx) + logger.info('ONNX CHECKED.') + status[3] = 'success' + except Exception as e: + logger.error(f'FAIL TO CHECK ONNX: {e}') + continue + try: + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + inputs = op.tokenizer(test_txt, return_tensors='np') + out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] + new_inputs = op.tokenizer(test_txt, return_tensors='pt') + out2 = op.post_proc(torch.from_numpy(out2), new_inputs) + logger.info('ONNX WORKED.') + status[4] = 'success' + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') + status[5] = 'success' + else: + logger.info(f'Check accuracy: atol is larger than {atol}.') + except Exception as e: + logger.error(f'FAIL TO RUN ONNX: {e}') + continue + +if status: + f.write(','.join(status) + '\n') + +print('Finished.')