logo
Browse Source

Add supported model names

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
860c1d47de
  1. 4
      README.md
  2. 83
      auto_transformers.py

4
README.md

@ -63,7 +63,9 @@ Create the operator via the following factory method:
The model name in string, defaults to None.
If None, the operator will be initialized without specified model.
Supported model names: refer to `supported_model_names` below.
Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models).
Please note that only models listed in `supported_model_names` are tested.
You can refer to [Towhee Pipeline]() for benchmark.
***checkpoint_path***: *str*

83
auto_transformers.py

@ -73,19 +73,19 @@ class AutoTransformers(NNOperator):
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = self.map_model_names(model_name)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
self.norm = norm
self.checkpoint_path = checkpoint_path
if self.model_name:
model_list = self.supported_model_names()
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(self._model)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
else:
log.warning('The operator is initialized without specified model.')
pass
@ -205,14 +205,6 @@ class AutoTransformers(NNOperator):
log.error(f'Unsupported format "{format}".')
return Path(output_file).resolve()
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@property
def supported_formats(self):
return ['onnx']
@ -220,7 +212,60 @@ class AutoTransformers(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
full_list = [
'sentence-t5-xxl',
'sentence-t5-xl',
'sentence-t5-large',
'text-embedding-ada-002',
'text-similarity-davinci-001',
'text-similarity-babbage-001',
'text-similarity-curie-001',
'paraphrase-mpnet-base-v2',
'text-similarity-ada-001',
'gtr-t5-xxl',
'gtr-t5-large',
'gtr-t5-xl',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-distilroberta-base-v2',
'all-mpnet-base-v1',
'all-roberta-large-v1',
'all-mpnet-base-v2',
'all-MiniLM-L12-v2',
'all-distilroberta-v1',
'all-MiniLM-L12-v1',
'gtr-t5-base',
'paraphrase-multilingual-MiniLM-L12-v2',
'paraphrase-MiniLM-L12-v2',
'all-MiniLM-L6-v1',
'paraphrase-TinyBERT-L6-v2',
'all-MiniLM-L6-v2',
'paraphrase-albert-small-v2',
'multi-qa-mpnet-base-cos-v1',
'paraphrase-MiniLM-L3-v2',
'multi-qa-distilbert-cos-v1',
'msmarco-distilbert-base-v4',
'multi-qa-mpnet-base-dot-v1',
'msmarco-distilbert-base-tas-b',
'distiluse-base-multilingual-cased-v2',
'multi-qa-distilbert-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'distiluse-base-multilingual-cased-v1',
'msmarco-bert-base-dot-v5',
'paraphrase-MiniLM-L6-v2',
'multi-qa-MiniLM-L6-dot-v1',
'msmarco-distilbert-dot-v5',
'bert-base-nli-mean-tokens',
'bert-large-uncased-whole-word-masking',
'average_word_embeddings_komninos',
'distilbert-base-uncased',
'average_word_embeddings_glove.6B.300d',
'dpr-ctx_encoder-multiset-base',
'dpr-ctx_encoder-single-nq-base',
'microsoft/deberta-xlarge',
'facebook/bart-large',
'bert-base-uncased',
'microsoft/deberta-xlarge-mnli',
'gpt2-xl',
'bert-large-uncased'
]
full_list.sort()
if format is None:
@ -230,13 +275,11 @@ class AutoTransformers(NNOperator):
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'torchscript':
to_remove = [
]
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = [
]
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'tensorrt':

Loading…
Cancel
Save