From 2697c15870d88156b312b95985929a729b9af3a8 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 18 Jan 2023 14:29:14 +0800 Subject: [PATCH] Update model list Signed-off-by: Jael Gu --- auto_transformers.py | 111 ++++++++++++++++++------------------------- test_onnx.py | 4 +- 2 files changed, 48 insertions(+), 67 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 763be88..995b14e 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -86,7 +86,10 @@ class AutoTransformers(NNOperator): self.device = device else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model_name = self.map_model_names(model_name) + if model_name in s_list: + self.model_name = 'sentence-transformers/' + model_name + else: + self.model_name = model_name self.norm = norm self.checkpoint_path = checkpoint_path @@ -225,62 +228,20 @@ class AutoTransformers(NNOperator): @staticmethod def supported_model_names(format: str = None): - full_list = [ - 'sentence-t5-xxl', - 'sentence-t5-xl', - 'sentence-t5-large', - 'text-embedding-ada-002', - 'text-similarity-davinci-001', - 'text-similarity-babbage-001', - 'text-similarity-curie-001', - 'paraphrase-mpnet-base-v2', - 'text-similarity-ada-001', - 'gtr-t5-xxl', - 'gtr-t5-large', - 'gtr-t5-xl', - 'paraphrase-multilingual-mpnet-base-v2', - 'paraphrase-distilroberta-base-v2', - 'all-mpnet-base-v1', - 'all-roberta-large-v1', - 'all-mpnet-base-v2', - 'all-MiniLM-L12-v2', - 'all-distilroberta-v1', - 'all-MiniLM-L12-v1', - 'gtr-t5-base', - 'paraphrase-multilingual-MiniLM-L12-v2', - 'paraphrase-MiniLM-L12-v2', - 'all-MiniLM-L6-v1', - 'paraphrase-TinyBERT-L6-v2', - 'all-MiniLM-L6-v2', - 'paraphrase-albert-small-v2', - 'multi-qa-mpnet-base-cos-v1', - 'paraphrase-MiniLM-L3-v2', - 'multi-qa-distilbert-cos-v1', - 'msmarco-distilbert-base-v4', - 'multi-qa-mpnet-base-dot-v1', - 'msmarco-distilbert-base-tas-b', - 'distiluse-base-multilingual-cased-v2', - 'multi-qa-distilbert-dot-v1', - 'multi-qa-MiniLM-L6-cos-v1', - 'distiluse-base-multilingual-cased-v1', - 'msmarco-bert-base-dot-v5', - 'paraphrase-MiniLM-L6-v2', - 'multi-qa-MiniLM-L6-dot-v1', - 'msmarco-distilbert-dot-v5', - 'bert-base-nli-mean-tokens', + add_models = [ + 'bert-base-uncased', + 'bert-large-uncased', 'bert-large-uncased-whole-word-masking', - 'average_word_embeddings_komninos', 'distilbert-base-uncased', - 'average_word_embeddings_glove.6B.300d', - 'dpr-ctx_encoder-multiset-base', - 'dpr-ctx_encoder-single-nq-base', - 'microsoft/deberta-xlarge', + 'facebook/dpr-ctx_encoder-multiset-base', + 'facebook/dpr-ctx_encoder-single-nq-base', 'facebook/bart-large', - 'bert-base-uncased', - 'microsoft/deberta-xlarge-mnli', 'gpt2-xl', - 'bert-large-uncased' + 'microsoft/deberta-xlarge', + 'microsoft/deberta-xlarge-mnli', + 'msmarco-distilbert-base-v4', ] + full_list = s_list + add_models full_list.sort() if format is None: model_list = full_list @@ -301,16 +262,36 @@ class AutoTransformers(NNOperator): log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list - @staticmethod - def map_model_names(name): - req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") - data = req.text - default_sbert = [] - for line in data.split('\r\n'): - line = line.replace(' ', '') - if line.startswith('"name":'): - n = line.split(':')[-1].replace('"', '').replace(',', '') - default_sbert.append(n) - if name in default_sbert: - name = 'sentence-transformers/' + name - return name + +s_list = [ + 'paraphrase-MiniLM-L3-v2', + 'paraphrase-MiniLM-L6-v2', + 'paraphrase-MiniLM-L12-v2', + 'paraphrase-distilroberta-base-v2', + 'paraphrase-TinyBERT-L6-v2', + 'paraphrase-mpnet-base-v2', + 'paraphrase-albert-small-v2', + 'paraphrase-multilingual-mpnet-base-v2', + 'paraphrase-multilingual-MiniLM-L12-v2', + 'distiluse-base-multilingual-cased-v1', + 'distiluse-base-multilingual-cased-v2', + 'all-distilroberta-v1', + 'all-MiniLM-L6-v1', + 'all-MiniLM-L6-v2', + 'all-MiniLM-L12-v1', + 'all-MiniLM-L12-v2', + 'all-mpnet-base-v1', + 'all-mpnet-base-v2', + 'all-roberta-large-v1', + 'multi-qa-MiniLM-L6-dot-v1', + 'multi-qa-MiniLM-L6-cos-v1', + 'multi-qa-distilbert-dot-v1', + 'multi-qa-distilbert-cos-v1', + 'multi-qa-mpnet-base-dot-v1', + 'multi-qa-mpnet-base-cos-v1', + 'msmarco-distilbert-dot-v5', + 'msmarco-bert-base-dot-v5', + 'msmarco-distilbert-base-tas-b', + 'bert-base-nli-mean-tokens', + 'msmarco-distilbert-base-v4' +] diff --git a/test_onnx.py b/test_onnx.py index 87dc20e..c209feb 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -49,8 +49,6 @@ logger.debug(f'cpu: {psutil.cpu_count()}') status = None for name in models: logger.info(f'***{name}***') - saved_name = name.replace('/', '-') - onnx_path = f'saved/onnx/{saved_name}.onnx' if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 @@ -62,6 +60,8 @@ for name in models: except Exception as e: logger.error(f'FAIL TO LOAD OP: {e}') continue + saved_name = op.model_name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' try: op.save_model(model_type='onnx') logger.info('ONNX SAVED.')