logo
Browse Source

Add supported model names

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
860c1d47de
  1. 4
      README.md
  2. 83
      auto_transformers.py

4
README.md

@ -63,7 +63,9 @@ Create the operator via the following factory method:
The model name in string, defaults to None. The model name in string, defaults to None.
If None, the operator will be initialized without specified model. If None, the operator will be initialized without specified model.
Supported model names: refer to `supported_model_names` below.
Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models).
Please note that only models listed in `supported_model_names` are tested.
You can refer to [Towhee Pipeline]() for benchmark.
***checkpoint_path***: *str* ***checkpoint_path***: *str*

83
auto_transformers.py

@ -73,19 +73,19 @@ class AutoTransformers(NNOperator):
else: else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = self.map_model_names(model_name) self.model_name = self.map_model_names(model_name)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
self.norm = norm self.norm = norm
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
if self.model_name: if self.model_name:
model_list = self.supported_model_names()
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(self._model) self.model = Model(self._model)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
else: else:
log.warning('The operator is initialized without specified model.') log.warning('The operator is initialized without specified model.')
pass pass
@ -205,14 +205,6 @@ class AutoTransformers(NNOperator):
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')
return Path(output_file).resolve() return Path(output_file).resolve()
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@property @property
def supported_formats(self): def supported_formats(self):
return ['onnx'] return ['onnx']
@ -220,7 +212,60 @@ class AutoTransformers(NNOperator):
@staticmethod @staticmethod
def supported_model_names(format: str = None): def supported_model_names(format: str = None):
full_list = [ full_list = [
'sentence-t5-xxl',
'sentence-t5-xl',
'sentence-t5-large',
'text-embedding-ada-002',
'text-similarity-davinci-001',
'text-similarity-babbage-001',
'text-similarity-curie-001',
'paraphrase-mpnet-base-v2',
'text-similarity-ada-001',
'gtr-t5-xxl',
'gtr-t5-large',
'gtr-t5-xl',
'paraphrase-multilingual-mpnet-base-v2',
'paraphrase-distilroberta-base-v2',
'all-mpnet-base-v1',
'all-roberta-large-v1',
'all-mpnet-base-v2',
'all-MiniLM-L12-v2',
'all-distilroberta-v1',
'all-MiniLM-L12-v1',
'gtr-t5-base',
'paraphrase-multilingual-MiniLM-L12-v2',
'paraphrase-MiniLM-L12-v2',
'all-MiniLM-L6-v1',
'paraphrase-TinyBERT-L6-v2',
'all-MiniLM-L6-v2',
'paraphrase-albert-small-v2',
'multi-qa-mpnet-base-cos-v1',
'paraphrase-MiniLM-L3-v2',
'multi-qa-distilbert-cos-v1',
'msmarco-distilbert-base-v4',
'multi-qa-mpnet-base-dot-v1',
'msmarco-distilbert-base-tas-b',
'distiluse-base-multilingual-cased-v2',
'multi-qa-distilbert-dot-v1',
'multi-qa-MiniLM-L6-cos-v1',
'distiluse-base-multilingual-cased-v1',
'msmarco-bert-base-dot-v5',
'paraphrase-MiniLM-L6-v2',
'multi-qa-MiniLM-L6-dot-v1',
'msmarco-distilbert-dot-v5',
'bert-base-nli-mean-tokens',
'bert-large-uncased-whole-word-masking',
'average_word_embeddings_komninos',
'distilbert-base-uncased',
'average_word_embeddings_glove.6B.300d',
'dpr-ctx_encoder-multiset-base',
'dpr-ctx_encoder-single-nq-base',
'microsoft/deberta-xlarge',
'facebook/bart-large',
'bert-base-uncased',
'microsoft/deberta-xlarge-mnli',
'gpt2-xl',
'bert-large-uncased'
] ]
full_list.sort() full_list.sort()
if format is None: if format is None:
@ -230,13 +275,11 @@ class AutoTransformers(NNOperator):
assert set(to_remove).issubset(set(full_list)) assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove)) model_list = list(set(full_list) - set(to_remove))
elif format == 'torchscript': elif format == 'torchscript':
to_remove = [
]
to_remove = []
assert set(to_remove).issubset(set(full_list)) assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove)) model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx': elif format == 'onnx':
to_remove = [
]
to_remove = []
assert set(to_remove).issubset(set(full_list)) assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove)) model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':

Loading…
Cancel
Save