From 0bd1bcc127536ca47ee8f2863b7a6711f4d51765 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 9 Jun 2022 15:35:41 +0800 Subject: [PATCH] Update model list Signed-off-by: Jael Gu --- README.md | 2 + auto_transformers.py | 259 +++++++++++++++++++++++-------------------- 2 files changed, 138 insertions(+), 123 deletions(-) diff --git a/README.md b/README.md index d2c459c..e901cea 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Supported model names: - bert-base-cased - bert-large-cased + - bert-large-uncased - bert-base-multilingual-uncased - bert-base-multilingual-cased - bert-base-chinese @@ -181,6 +182,7 @@ Supported model names:
Funnel + - funnel-transformer/small - funnel-transformer/small-base - funnel-transformer/medium diff --git a/auto_transformers.py b/auto_transformers.py index 46ab490..8f51d30 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -80,10 +80,11 @@ class AutoTransformers(NNOperator): if path == 'default': path = str(Path(__file__).parent) name = self.model_name.replace('/', '-') - path = os.path.join(path, name + '.pt') + path = os.path.join(path, name) inputs = self.tokenizer('[CLS]', return_tensors='pt') - inputs = list(inputs.values()) if format == 'torchscript': + path = path + '.pt' + inputs = list(inputs.values()) try: try: jit_model = torch.jit.script(self.model) @@ -93,127 +94,139 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onxx': + pass else: torch.save(self.model, path) - -def get_model_list(): - full_list = [ - "bert-large-uncased", - "bert-base-cased", - "bert-large-cased", - "bert-base-multilingual-uncased", - "bert-base-multilingual-cased", - "bert-base-chinese", - "bert-base-german-cased", - "bert-large-uncased-whole-word-masking", - "bert-large-cased-whole-word-masking", - "bert-large-uncased-whole-word-masking-finetuned-squad", - "bert-large-cased-whole-word-masking-finetuned-squad", - "bert-base-cased-finetuned-mrpc", - "bert-base-german-dbmdz-cased", - "bert-base-german-dbmdz-uncased", - "cl-tohoku/bert-base-japanese-whole-word-masking", - "cl-tohoku/bert-base-japanese-char", - "cl-tohoku/bert-base-japanese-char-whole-word-masking", - "TurkuNLP/bert-base-finnish-cased-v1", - "TurkuNLP/bert-base-finnish-uncased-v1", - "wietsedv/bert-base-dutch-cased", - "google/bigbird-roberta-base", - "google/bigbird-roberta-large", - "google/bigbird-base-trivia-itc", - "albert-base-v1", - "albert-large-v1", - "albert-xlarge-v1", - "albert-xxlarge-v1", - "albert-base-v2", - "albert-large-v2", - "albert-xlarge-v2", - "albert-xxlarge-v2", - "facebook/bart-large", - "google/bert_for_seq_generation_L-24_bbc_encoder", - "google/bigbird-pegasus-large-arxiv", - "google/bigbird-pegasus-large-pubmed", - "google/bigbird-pegasus-large-bigpatent", - "google/canine-s", - "google/canine-c", - "YituTech/conv-bert-base", - "YituTech/conv-bert-medium-small", - "YituTech/conv-bert-small", - "ctrl", - "microsoft/deberta-base", - "microsoft/deberta-large", - "microsoft/deberta-xlarge", - "microsoft/deberta-base-mnli", - "microsoft/deberta-large-mnli", - "microsoft/deberta-xlarge-mnli", - "distilbert-base-uncased", - "distilbert-base-uncased-distilled-squad", - "distilbert-base-cased", - "distilbert-base-cased-distilled-squad", - "distilbert-base-german-cased", - "distilbert-base-multilingual-cased", - "distilbert-base-uncased-finetuned-sst-2-english", - "google/electra-small-generator", - "google/electra-base-generator", - "google/electra-large-generator", - "google/electra-small-discriminator", - "google/electra-base-discriminator", - "google/electra-large-discriminator", - "google/fnet-base", - "google/fnet-large", - "facebook/wmt19-ru-en", - "funnel-transformer/small", - "funnel-transformer/small-base", - "funnel-transformer/medium", - "funnel-transformer/medium-base", - "funnel-transformer/intermediate", - "funnel-transformer/intermediate-base", - "funnel-transformer/large", - "funnel-transformer/large-base", - "funnel-transformer/xlarge-base", - "funnel-transformer/xlarge", - "gpt2", - "gpt2-medium", - "gpt2-large", - "gpt2-xl", - "distilgpt2", - "EleutherAI/gpt-neo-1.3B", - "EleutherAI/gpt-j-6B", - "kssteven/ibert-roberta-base", - "allenai/led-base-16384", - "google/mobilebert-uncased", - "microsoft/mpnet-base", - "uw-madison/nystromformer-512", - "openai-gpt", - "google/reformer-crime-and-punishment", - "tau/splinter-base", - "tau/splinter-base-qass", - "tau/splinter-large", - "tau/splinter-large-qass", - "squeezebert/squeezebert-uncased", - "squeezebert/squeezebert-mnli", - "squeezebert/squeezebert-mnli-headless", - "transfo-xl-wt103", - "xlm-mlm-en-2048", - "xlm-mlm-ende-1024", - "xlm-mlm-enfr-1024", - "xlm-mlm-enro-1024", - "xlm-mlm-tlm-xnli15-1024", - "xlm-mlm-xnli15-1024", - "xlm-clm-enfr-1024", - "xlm-clm-ende-1024", - "xlm-mlm-17-1280", - "xlm-mlm-100-1280", - "xlm-roberta-base", - "xlm-roberta-large", - "xlm-roberta-large-finetuned-conll02-dutch", - "xlm-roberta-large-finetuned-conll02-spanish", - "xlm-roberta-large-finetuned-conll03-english", - "xlm-roberta-large-finetuned-conll03-german", - "xlnet-base-cased", - "xlnet-large-cased", - "uw-madison/yoso-4096", - ] - full_list.sort() - return full_list + def get_model_name(self): + full_list = [ + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + "google/bigbird-roberta-base", + "google/bigbird-roberta-large", + "google/bigbird-base-trivia-itc", + "albert-base-v1", + "albert-large-v1", + "albert-xlarge-v1", + "albert-xxlarge-v1", + "albert-base-v2", + "albert-large-v2", + "albert-xlarge-v2", + "albert-xxlarge-v2", + "facebook/bart-large", + "google/bert_for_seq_generation_L-24_bbc_encoder", + "google/bigbird-pegasus-large-arxiv", + "google/bigbird-pegasus-large-pubmed", + "google/bigbird-pegasus-large-bigpatent", + "google/canine-s", + "google/canine-c", + "YituTech/conv-bert-base", + "YituTech/conv-bert-medium-small", + "YituTech/conv-bert-small", + "ctrl", + "microsoft/deberta-base", + "microsoft/deberta-large", + "microsoft/deberta-xlarge", + "microsoft/deberta-base-mnli", + "microsoft/deberta-large-mnli", + "microsoft/deberta-xlarge-mnli", + "distilbert-base-uncased", + "distilbert-base-uncased-distilled-squad", + "distilbert-base-cased", + "distilbert-base-cased-distilled-squad", + "distilbert-base-german-cased", + "distilbert-base-multilingual-cased", + "distilbert-base-uncased-finetuned-sst-2-english", + "google/electra-small-generator", + "google/electra-base-generator", + "google/electra-large-generator", + "google/electra-small-discriminator", + "google/electra-base-discriminator", + "google/electra-large-discriminator", + "google/fnet-base", + "google/fnet-large", + "facebook/wmt19-ru-en", + "funnel-transformer/small", + "funnel-transformer/small-base", + "funnel-transformer/medium", + "funnel-transformer/medium-base", + "funnel-transformer/intermediate", + "funnel-transformer/intermediate-base", + "funnel-transformer/large", + "funnel-transformer/large-base", + "funnel-transformer/xlarge-base", + "funnel-transformer/xlarge", + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + "EleutherAI/gpt-neo-1.3B", + "EleutherAI/gpt-j-6B", + "kssteven/ibert-roberta-base", + "allenai/led-base-16384", + "google/mobilebert-uncased", + "microsoft/mpnet-base", + "uw-madison/nystromformer-512", + "openai-gpt", + "google/reformer-crime-and-punishment", + "tau/splinter-base", + "tau/splinter-base-qass", + "tau/splinter-large", + "tau/splinter-large-qass", + "squeezebert/squeezebert-uncased", + "squeezebert/squeezebert-mnli", + "squeezebert/squeezebert-mnli-headless", + "transfo-xl-wt103", + "xlm-mlm-en-2048", + "xlm-mlm-ende-1024", + "xlm-mlm-enfr-1024", + "xlm-mlm-enro-1024", + "xlm-mlm-tlm-xnli15-1024", + "xlm-mlm-xnli15-1024", + "xlm-clm-enfr-1024", + "xlm-clm-ende-1024", + "xlm-mlm-17-1280", + "xlm-mlm-100-1280", + "xlm-roberta-base", + "xlm-roberta-large", + "xlm-roberta-large-finetuned-conll02-dutch", + "xlm-roberta-large-finetuned-conll02-spanish", + "xlm-roberta-large-finetuned-conll03-english", + "xlm-roberta-large-finetuned-conll03-german", + "xlnet-base-cased", + "xlnet-large-cased", + "uw-madison/yoso-4096", + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", + "flaubert/flaubert_small_cased", + "flaubert/flaubert_base_uncased", + "flaubert/flaubert_base_cased", + "flaubert/flaubert_large_cased", + "camembert-base", + "Musixmatch/umberto-commoncrawl-cased-v1", + "Musixmatch/umberto-wikipedia-uncased-v1", + ] + full_list.sort() + return full_list