|
|
@ -86,7 +86,10 @@ class AutoTransformers(NNOperator): |
|
|
|
self.device = device |
|
|
|
else: |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model_name = self.map_model_names(model_name) |
|
|
|
if model_name in s_list: |
|
|
|
self.model_name = 'sentence-transformers/' + model_name |
|
|
|
else: |
|
|
|
self.model_name = model_name |
|
|
|
self.norm = norm |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
@ -225,62 +228,20 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
full_list = [ |
|
|
|
'sentence-t5-xxl', |
|
|
|
'sentence-t5-xl', |
|
|
|
'sentence-t5-large', |
|
|
|
'text-embedding-ada-002', |
|
|
|
'text-similarity-davinci-001', |
|
|
|
'text-similarity-babbage-001', |
|
|
|
'text-similarity-curie-001', |
|
|
|
'paraphrase-mpnet-base-v2', |
|
|
|
'text-similarity-ada-001', |
|
|
|
'gtr-t5-xxl', |
|
|
|
'gtr-t5-large', |
|
|
|
'gtr-t5-xl', |
|
|
|
'paraphrase-multilingual-mpnet-base-v2', |
|
|
|
'paraphrase-distilroberta-base-v2', |
|
|
|
'all-mpnet-base-v1', |
|
|
|
'all-roberta-large-v1', |
|
|
|
'all-mpnet-base-v2', |
|
|
|
'all-MiniLM-L12-v2', |
|
|
|
'all-distilroberta-v1', |
|
|
|
'all-MiniLM-L12-v1', |
|
|
|
'gtr-t5-base', |
|
|
|
'paraphrase-multilingual-MiniLM-L12-v2', |
|
|
|
'paraphrase-MiniLM-L12-v2', |
|
|
|
'all-MiniLM-L6-v1', |
|
|
|
'paraphrase-TinyBERT-L6-v2', |
|
|
|
'all-MiniLM-L6-v2', |
|
|
|
'paraphrase-albert-small-v2', |
|
|
|
'multi-qa-mpnet-base-cos-v1', |
|
|
|
'paraphrase-MiniLM-L3-v2', |
|
|
|
'multi-qa-distilbert-cos-v1', |
|
|
|
'msmarco-distilbert-base-v4', |
|
|
|
'multi-qa-mpnet-base-dot-v1', |
|
|
|
'msmarco-distilbert-base-tas-b', |
|
|
|
'distiluse-base-multilingual-cased-v2', |
|
|
|
'multi-qa-distilbert-dot-v1', |
|
|
|
'multi-qa-MiniLM-L6-cos-v1', |
|
|
|
'distiluse-base-multilingual-cased-v1', |
|
|
|
'msmarco-bert-base-dot-v5', |
|
|
|
'paraphrase-MiniLM-L6-v2', |
|
|
|
'multi-qa-MiniLM-L6-dot-v1', |
|
|
|
'msmarco-distilbert-dot-v5', |
|
|
|
'bert-base-nli-mean-tokens', |
|
|
|
add_models = [ |
|
|
|
'bert-base-uncased', |
|
|
|
'bert-large-uncased', |
|
|
|
'bert-large-uncased-whole-word-masking', |
|
|
|
'average_word_embeddings_komninos', |
|
|
|
'distilbert-base-uncased', |
|
|
|
'average_word_embeddings_glove.6B.300d', |
|
|
|
'dpr-ctx_encoder-multiset-base', |
|
|
|
'dpr-ctx_encoder-single-nq-base', |
|
|
|
'microsoft/deberta-xlarge', |
|
|
|
'facebook/dpr-ctx_encoder-multiset-base', |
|
|
|
'facebook/dpr-ctx_encoder-single-nq-base', |
|
|
|
'facebook/bart-large', |
|
|
|
'bert-base-uncased', |
|
|
|
'microsoft/deberta-xlarge-mnli', |
|
|
|
'gpt2-xl', |
|
|
|
'bert-large-uncased' |
|
|
|
'microsoft/deberta-xlarge', |
|
|
|
'microsoft/deberta-xlarge-mnli', |
|
|
|
'msmarco-distilbert-base-v4', |
|
|
|
] |
|
|
|
full_list = s_list + add_models |
|
|
|
full_list.sort() |
|
|
|
if format is None: |
|
|
|
model_list = full_list |
|
|
@ -301,16 +262,36 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') |
|
|
|
return model_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def map_model_names(name): |
|
|
|
req = requests.get("https://www.sbert.net/_static/html/models_en_sentence_embeddings.html") |
|
|
|
data = req.text |
|
|
|
default_sbert = [] |
|
|
|
for line in data.split('\r\n'): |
|
|
|
line = line.replace(' ', '') |
|
|
|
if line.startswith('"name":'): |
|
|
|
n = line.split(':')[-1].replace('"', '').replace(',', '') |
|
|
|
default_sbert.append(n) |
|
|
|
if name in default_sbert: |
|
|
|
name = 'sentence-transformers/' + name |
|
|
|
return name |
|
|
|
|
|
|
|
s_list = [ |
|
|
|
'paraphrase-MiniLM-L3-v2', |
|
|
|
'paraphrase-MiniLM-L6-v2', |
|
|
|
'paraphrase-MiniLM-L12-v2', |
|
|
|
'paraphrase-distilroberta-base-v2', |
|
|
|
'paraphrase-TinyBERT-L6-v2', |
|
|
|
'paraphrase-mpnet-base-v2', |
|
|
|
'paraphrase-albert-small-v2', |
|
|
|
'paraphrase-multilingual-mpnet-base-v2', |
|
|
|
'paraphrase-multilingual-MiniLM-L12-v2', |
|
|
|
'distiluse-base-multilingual-cased-v1', |
|
|
|
'distiluse-base-multilingual-cased-v2', |
|
|
|
'all-distilroberta-v1', |
|
|
|
'all-MiniLM-L6-v1', |
|
|
|
'all-MiniLM-L6-v2', |
|
|
|
'all-MiniLM-L12-v1', |
|
|
|
'all-MiniLM-L12-v2', |
|
|
|
'all-mpnet-base-v1', |
|
|
|
'all-mpnet-base-v2', |
|
|
|
'all-roberta-large-v1', |
|
|
|
'multi-qa-MiniLM-L6-dot-v1', |
|
|
|
'multi-qa-MiniLM-L6-cos-v1', |
|
|
|
'multi-qa-distilbert-dot-v1', |
|
|
|
'multi-qa-distilbert-cos-v1', |
|
|
|
'multi-qa-mpnet-base-dot-v1', |
|
|
|
'multi-qa-mpnet-base-cos-v1', |
|
|
|
'msmarco-distilbert-dot-v5', |
|
|
|
'msmarco-bert-base-dot-v5', |
|
|
|
'msmarco-distilbert-base-tas-b', |
|
|
|
'bert-base-nli-mean-tokens', |
|
|
|
'msmarco-distilbert-base-v4' |
|
|
|
] |
|
|
|