logo
Browse Source

Update model list

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
2697c15870
  1. 111
      auto_transformers.py
  2. 4
      test_onnx.py

111
auto_transformers.py

@ -86,7 +86,10 @@ class AutoTransformers(NNOperator):
self.device = device self.device = device
else: else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = self.map_model_names(model_name)
if model_name in s_list:
self.model_name = 'sentence-transformers/' + model_name
else:
self.model_name = model_name
self.norm = norm self.norm = norm
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
@ -225,62 +228,20 @@ class AutoTransformers(NNOperator):
@staticmethod @staticmethod
def supported_model_names(format: str = None): def supported_model_names(format: str = None):
full_list = [
'sentence-t5-xxl',
'sentence-t5-xl',
'sentence-t5-large',
'text-embedding-ada-002',
'text-similarity-davinci-001',
'text-similarity-babbage-001',
'text-similarity-curie-001',
'paraphrase-mpnet-base-v2',
'text-similarity-ada-001',
'gtr-t5-xxl',
'gtr-t5-large',
'gtr-t5-xl',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-distilroberta-base-v2',
'all-mpnet-base-v1',
'all-roberta-large-v1',
'all-mpnet-base-v2',
'all-MiniLM-L12-v2',
'all-distilroberta-v1',
'all-MiniLM-L12-v1',
'gtr-t5-base',
'paraphrase-multilingual-MiniLM-L12-v2',
'paraphrase-MiniLM-L12-v2',
'all-MiniLM-L6-v1',
'paraphrase-TinyBERT-L6-v2',
'all-MiniLM-L6-v2',
'paraphrase-albert-small-v2',
'multi-qa-mpnet-base-cos-v1',
'paraphrase-MiniLM-L3-v2',
'multi-qa-distilbert-cos-v1',
'msmarco-distilbert-base-v4',
'multi-qa-mpnet-base-dot-v1',
'msmarco-distilbert-base-tas-b',
'distiluse-base-multilingual-cased-v2',
'multi-qa-distilbert-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'distiluse-base-multilingual-cased-v1',
'msmarco-bert-base-dot-v5',
'paraphrase-MiniLM-L6-v2',
'multi-qa-MiniLM-L6-dot-v1',
'msmarco-distilbert-dot-v5',
'bert-base-nli-mean-tokens',
add_models = [
'bert-base-uncased',
'bert-large-uncased',
'bert-large-uncased-whole-word-masking', 'bert-large-uncased-whole-word-masking',
'average_word_embeddings_komninos',
'distilbert-base-uncased', 'distilbert-base-uncased',
'average_word_embeddings_glove.6B.300d',
'dpr-ctx_encoder-multiset-base',
'dpr-ctx_encoder-single-nq-base',
'microsoft/deberta-xlarge',
'facebook/dpr-ctx_encoder-multiset-base',
'facebook/dpr-ctx_encoder-single-nq-base',
'facebook/bart-large', 'facebook/bart-large',
'bert-base-uncased',
'microsoft/deberta-xlarge-mnli',
'gpt2-xl', 'gpt2-xl',
'bert-large-uncased'
'microsoft/deberta-xlarge',
'microsoft/deberta-xlarge-mnli',
'msmarco-distilbert-base-v4',
] ]
full_list = s_list + add_models
full_list.sort() full_list.sort()
if format is None: if format is None:
model_list = full_list model_list = full_list
@ -301,16 +262,36 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list return model_list
@staticmethod
def map_model_names(name):
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html")
data = req.text
default_sbert = []
for line in data.split('\r\n'):
line = line.replace(' ', '')
if line.startswith('"name":'):
n = line.split(':')[-1].replace('"', '').replace(',', '')
default_sbert.append(n)
if name in default_sbert:
name = 'sentence-transformers/' + name
return name
s_list = [
'paraphrase-MiniLM-L3-v2',
'paraphrase-MiniLM-L6-v2',
'paraphrase-MiniLM-L12-v2',
'paraphrase-distilroberta-base-v2',
'paraphrase-TinyBERT-L6-v2',
'paraphrase-mpnet-base-v2',
'paraphrase-albert-small-v2',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-multilingual-MiniLM-L12-v2',
'distiluse-base-multilingual-cased-v1',
'distiluse-base-multilingual-cased-v2',
'all-distilroberta-v1',
'all-MiniLM-L6-v1',
'all-MiniLM-L6-v2',
'all-MiniLM-L12-v1',
'all-MiniLM-L12-v2',
'all-mpnet-base-v1',
'all-mpnet-base-v2',
'all-roberta-large-v1',
'multi-qa-MiniLM-L6-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'multi-qa-distilbert-dot-v1',
'multi-qa-distilbert-cos-v1',
'multi-qa-mpnet-base-dot-v1',
'multi-qa-mpnet-base-cos-v1',
'msmarco-distilbert-dot-v5',
'msmarco-bert-base-dot-v5',
'msmarco-distilbert-base-tas-b',
'bert-base-nli-mean-tokens',
'msmarco-distilbert-base-v4'
]

4
test_onnx.py

@ -49,8 +49,6 @@ logger.debug(f'cpu: {psutil.cpu_count()}')
status = None status = None
for name in models: for name in models:
logger.info(f'***{name}***') logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if status: if status:
f.write(','.join(status) + '\n') f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
@ -62,6 +60,8 @@ for name in models:
except Exception as e: except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}') logger.error(f'FAIL TO LOAD OP: {e}')
continue continue
saved_name = op.model_name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
try: try:
op.save_model(model_type='onnx') op.save_model(model_type='onnx')
logger.info('ONNX SAVED.') logger.info('ONNX SAVED.')

Loading…
Cancel
Save