From c6eabd756e7e4a7798aed0af6f1d9cc1fc619e35 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 21 Jun 2022 14:54:44 +0800 Subject: [PATCH] Add list of supported_model_names(format='onnx') Signed-off-by: Jael Gu --- auto_transformers.py | 52 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/auto_transformers.py b/auto_transformers.py index 95832d6..c3f1120 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -321,6 +321,56 @@ class AutoTransformers(NNOperator): ] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) - else: # todo: format in {'onnx', 'tensorrt'} + elif format == 'onnx': + to_remove = [ + 'albert-xlarge-v1', + 'albert-xlarge-v2', + 'albert-xxlarge-v1', + 'albert-xxlarge-v2', + 'allenai/led-base-16384', + 'ctrl', + 'distilgpt2', + 'EleutherAI/gpt-j-6B', + 'EleutherAI/gpt-neo-1.3B', + 'funnel-transformer/intermediate', + 'funnel-transformer/large', + 'funnel-transformer/medium', + 'funnel-transformer/small', + 'funnel-transformer/xlarge', + 'google/bigbird-pegasus-large-arxiv', + 'google/bigbird-pegasus-large-bigpatent', + 'google/bigbird-pegasus-large-pubmed', + 'google/canine-c', + 'google/canine-s', + 'google/fnet-base', + 'google/fnet-large', + 'google/reformer-crime-and-punishment', + 'gpt2', + 'gpt2-large', + 'gpt2-medium', + 'gpt2-xl', + 'microsoft/deberta-v2-xlarge', + 'microsoft/deberta-v2-xlarge-mnli', + 'microsoft/deberta-v2-xxlarge', + 'microsoft/deberta-v2-xxlarge-mnli', + 'microsoft/deberta-xlarge', + 'microsoft/deberta-xlarge-mnli', + 'openai-gpt', + 'transfo-xl-wt103', + 'uw-madison/yoso-4096', + 'xlm-mlm-100-1280', + 'xlm-mlm-17-1280', + 'xlm-mlm-en-2048', + 'xlm-roberta-large', + 'xlm-roberta-large-finetuned-conll02-dutch', + 'xlm-roberta-large-finetuned-conll02-spanish', + 'xlm-roberta-large-finetuned-conll03-english', + 'xlm-roberta-large-finetuned-conll03-german', + 'xlnet-base-cased', + 'xlnet-large-cased' + ] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + else: # todo: format in {'tensorrt'} log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list