|
|
@ -80,10 +80,11 @@ class AutoTransformers(NNOperator): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name + '.pt') |
|
|
|
path = os.path.join(path, name) |
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') |
|
|
|
inputs = list(inputs.values()) |
|
|
|
if format == 'torchscript': |
|
|
|
path = path + '.pt' |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
@ -93,127 +94,139 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onxx': |
|
|
|
pass |
|
|
|
else: |
|
|
|
torch.save(self.model, path) |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(): |
|
|
|
full_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
|
"bert-large-cased", |
|
|
|
"bert-base-multilingual-uncased", |
|
|
|
"bert-base-multilingual-cased", |
|
|
|
"bert-base-chinese", |
|
|
|
"bert-base-german-cased", |
|
|
|
"bert-large-uncased-whole-word-masking", |
|
|
|
"bert-large-cased-whole-word-masking", |
|
|
|
"bert-large-uncased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-large-cased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-base-cased-finetuned-mrpc", |
|
|
|
"bert-base-german-dbmdz-cased", |
|
|
|
"bert-base-german-dbmdz-uncased", |
|
|
|
"cl-tohoku/bert-base-japanese-whole-word-masking", |
|
|
|
"cl-tohoku/bert-base-japanese-char", |
|
|
|
"cl-tohoku/bert-base-japanese-char-whole-word-masking", |
|
|
|
"TurkuNLP/bert-base-finnish-cased-v1", |
|
|
|
"TurkuNLP/bert-base-finnish-uncased-v1", |
|
|
|
"wietsedv/bert-base-dutch-cased", |
|
|
|
"google/bigbird-roberta-base", |
|
|
|
"google/bigbird-roberta-large", |
|
|
|
"google/bigbird-base-trivia-itc", |
|
|
|
"albert-base-v1", |
|
|
|
"albert-large-v1", |
|
|
|
"albert-xlarge-v1", |
|
|
|
"albert-xxlarge-v1", |
|
|
|
"albert-base-v2", |
|
|
|
"albert-large-v2", |
|
|
|
"albert-xlarge-v2", |
|
|
|
"albert-xxlarge-v2", |
|
|
|
"facebook/bart-large", |
|
|
|
"google/bert_for_seq_generation_L-24_bbc_encoder", |
|
|
|
"google/bigbird-pegasus-large-arxiv", |
|
|
|
"google/bigbird-pegasus-large-pubmed", |
|
|
|
"google/bigbird-pegasus-large-bigpatent", |
|
|
|
"google/canine-s", |
|
|
|
"google/canine-c", |
|
|
|
"YituTech/conv-bert-base", |
|
|
|
"YituTech/conv-bert-medium-small", |
|
|
|
"YituTech/conv-bert-small", |
|
|
|
"ctrl", |
|
|
|
"microsoft/deberta-base", |
|
|
|
"microsoft/deberta-large", |
|
|
|
"microsoft/deberta-xlarge", |
|
|
|
"microsoft/deberta-base-mnli", |
|
|
|
"microsoft/deberta-large-mnli", |
|
|
|
"microsoft/deberta-xlarge-mnli", |
|
|
|
"distilbert-base-uncased", |
|
|
|
"distilbert-base-uncased-distilled-squad", |
|
|
|
"distilbert-base-cased", |
|
|
|
"distilbert-base-cased-distilled-squad", |
|
|
|
"distilbert-base-german-cased", |
|
|
|
"distilbert-base-multilingual-cased", |
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english", |
|
|
|
"google/electra-small-generator", |
|
|
|
"google/electra-base-generator", |
|
|
|
"google/electra-large-generator", |
|
|
|
"google/electra-small-discriminator", |
|
|
|
"google/electra-base-discriminator", |
|
|
|
"google/electra-large-discriminator", |
|
|
|
"google/fnet-base", |
|
|
|
"google/fnet-large", |
|
|
|
"facebook/wmt19-ru-en", |
|
|
|
"funnel-transformer/small", |
|
|
|
"funnel-transformer/small-base", |
|
|
|
"funnel-transformer/medium", |
|
|
|
"funnel-transformer/medium-base", |
|
|
|
"funnel-transformer/intermediate", |
|
|
|
"funnel-transformer/intermediate-base", |
|
|
|
"funnel-transformer/large", |
|
|
|
"funnel-transformer/large-base", |
|
|
|
"funnel-transformer/xlarge-base", |
|
|
|
"funnel-transformer/xlarge", |
|
|
|
"gpt2", |
|
|
|
"gpt2-medium", |
|
|
|
"gpt2-large", |
|
|
|
"gpt2-xl", |
|
|
|
"distilgpt2", |
|
|
|
"EleutherAI/gpt-neo-1.3B", |
|
|
|
"EleutherAI/gpt-j-6B", |
|
|
|
"kssteven/ibert-roberta-base", |
|
|
|
"allenai/led-base-16384", |
|
|
|
"google/mobilebert-uncased", |
|
|
|
"microsoft/mpnet-base", |
|
|
|
"uw-madison/nystromformer-512", |
|
|
|
"openai-gpt", |
|
|
|
"google/reformer-crime-and-punishment", |
|
|
|
"tau/splinter-base", |
|
|
|
"tau/splinter-base-qass", |
|
|
|
"tau/splinter-large", |
|
|
|
"tau/splinter-large-qass", |
|
|
|
"squeezebert/squeezebert-uncased", |
|
|
|
"squeezebert/squeezebert-mnli", |
|
|
|
"squeezebert/squeezebert-mnli-headless", |
|
|
|
"transfo-xl-wt103", |
|
|
|
"xlm-mlm-en-2048", |
|
|
|
"xlm-mlm-ende-1024", |
|
|
|
"xlm-mlm-enfr-1024", |
|
|
|
"xlm-mlm-enro-1024", |
|
|
|
"xlm-mlm-tlm-xnli15-1024", |
|
|
|
"xlm-mlm-xnli15-1024", |
|
|
|
"xlm-clm-enfr-1024", |
|
|
|
"xlm-clm-ende-1024", |
|
|
|
"xlm-mlm-17-1280", |
|
|
|
"xlm-mlm-100-1280", |
|
|
|
"xlm-roberta-base", |
|
|
|
"xlm-roberta-large", |
|
|
|
"xlm-roberta-large-finetuned-conll02-dutch", |
|
|
|
"xlm-roberta-large-finetuned-conll02-spanish", |
|
|
|
"xlm-roberta-large-finetuned-conll03-english", |
|
|
|
"xlm-roberta-large-finetuned-conll03-german", |
|
|
|
"xlnet-base-cased", |
|
|
|
"xlnet-large-cased", |
|
|
|
"uw-madison/yoso-4096", |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
return full_list |
|
|
|
def get_model_name(self): |
|
|
|
full_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
|
"bert-large-cased", |
|
|
|
"bert-base-multilingual-uncased", |
|
|
|
"bert-base-multilingual-cased", |
|
|
|
"bert-base-chinese", |
|
|
|
"bert-base-german-cased", |
|
|
|
"bert-large-uncased-whole-word-masking", |
|
|
|
"bert-large-cased-whole-word-masking", |
|
|
|
"bert-large-uncased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-large-cased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-base-cased-finetuned-mrpc", |
|
|
|
"bert-base-german-dbmdz-cased", |
|
|
|
"bert-base-german-dbmdz-uncased", |
|
|
|
"cl-tohoku/bert-base-japanese-whole-word-masking", |
|
|
|
"cl-tohoku/bert-base-japanese-char", |
|
|
|
"cl-tohoku/bert-base-japanese-char-whole-word-masking", |
|
|
|
"TurkuNLP/bert-base-finnish-cased-v1", |
|
|
|
"TurkuNLP/bert-base-finnish-uncased-v1", |
|
|
|
"wietsedv/bert-base-dutch-cased", |
|
|
|
"google/bigbird-roberta-base", |
|
|
|
"google/bigbird-roberta-large", |
|
|
|
"google/bigbird-base-trivia-itc", |
|
|
|
"albert-base-v1", |
|
|
|
"albert-large-v1", |
|
|
|
"albert-xlarge-v1", |
|
|
|
"albert-xxlarge-v1", |
|
|
|
"albert-base-v2", |
|
|
|
"albert-large-v2", |
|
|
|
"albert-xlarge-v2", |
|
|
|
"albert-xxlarge-v2", |
|
|
|
"facebook/bart-large", |
|
|
|
"google/bert_for_seq_generation_L-24_bbc_encoder", |
|
|
|
"google/bigbird-pegasus-large-arxiv", |
|
|
|
"google/bigbird-pegasus-large-pubmed", |
|
|
|
"google/bigbird-pegasus-large-bigpatent", |
|
|
|
"google/canine-s", |
|
|
|
"google/canine-c", |
|
|
|
"YituTech/conv-bert-base", |
|
|
|
"YituTech/conv-bert-medium-small", |
|
|
|
"YituTech/conv-bert-small", |
|
|
|
"ctrl", |
|
|
|
"microsoft/deberta-base", |
|
|
|
"microsoft/deberta-large", |
|
|
|
"microsoft/deberta-xlarge", |
|
|
|
"microsoft/deberta-base-mnli", |
|
|
|
"microsoft/deberta-large-mnli", |
|
|
|
"microsoft/deberta-xlarge-mnli", |
|
|
|
"distilbert-base-uncased", |
|
|
|
"distilbert-base-uncased-distilled-squad", |
|
|
|
"distilbert-base-cased", |
|
|
|
"distilbert-base-cased-distilled-squad", |
|
|
|
"distilbert-base-german-cased", |
|
|
|
"distilbert-base-multilingual-cased", |
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english", |
|
|
|
"google/electra-small-generator", |
|
|
|
"google/electra-base-generator", |
|
|
|
"google/electra-large-generator", |
|
|
|
"google/electra-small-discriminator", |
|
|
|
"google/electra-base-discriminator", |
|
|
|
"google/electra-large-discriminator", |
|
|
|
"google/fnet-base", |
|
|
|
"google/fnet-large", |
|
|
|
"facebook/wmt19-ru-en", |
|
|
|
"funnel-transformer/small", |
|
|
|
"funnel-transformer/small-base", |
|
|
|
"funnel-transformer/medium", |
|
|
|
"funnel-transformer/medium-base", |
|
|
|
"funnel-transformer/intermediate", |
|
|
|
"funnel-transformer/intermediate-base", |
|
|
|
"funnel-transformer/large", |
|
|
|
"funnel-transformer/large-base", |
|
|
|
"funnel-transformer/xlarge-base", |
|
|
|
"funnel-transformer/xlarge", |
|
|
|
"gpt2", |
|
|
|
"gpt2-medium", |
|
|
|
"gpt2-large", |
|
|
|
"gpt2-xl", |
|
|
|
"distilgpt2", |
|
|
|
"EleutherAI/gpt-neo-1.3B", |
|
|
|
"EleutherAI/gpt-j-6B", |
|
|
|
"kssteven/ibert-roberta-base", |
|
|
|
"allenai/led-base-16384", |
|
|
|
"google/mobilebert-uncased", |
|
|
|
"microsoft/mpnet-base", |
|
|
|
"uw-madison/nystromformer-512", |
|
|
|
"openai-gpt", |
|
|
|
"google/reformer-crime-and-punishment", |
|
|
|
"tau/splinter-base", |
|
|
|
"tau/splinter-base-qass", |
|
|
|
"tau/splinter-large", |
|
|
|
"tau/splinter-large-qass", |
|
|
|
"squeezebert/squeezebert-uncased", |
|
|
|
"squeezebert/squeezebert-mnli", |
|
|
|
"squeezebert/squeezebert-mnli-headless", |
|
|
|
"transfo-xl-wt103", |
|
|
|
"xlm-mlm-en-2048", |
|
|
|
"xlm-mlm-ende-1024", |
|
|
|
"xlm-mlm-enfr-1024", |
|
|
|
"xlm-mlm-enro-1024", |
|
|
|
"xlm-mlm-tlm-xnli15-1024", |
|
|
|
"xlm-mlm-xnli15-1024", |
|
|
|
"xlm-clm-enfr-1024", |
|
|
|
"xlm-clm-ende-1024", |
|
|
|
"xlm-mlm-17-1280", |
|
|
|
"xlm-mlm-100-1280", |
|
|
|
"xlm-roberta-base", |
|
|
|
"xlm-roberta-large", |
|
|
|
"xlm-roberta-large-finetuned-conll02-dutch", |
|
|
|
"xlm-roberta-large-finetuned-conll02-spanish", |
|
|
|
"xlm-roberta-large-finetuned-conll03-english", |
|
|
|
"xlm-roberta-large-finetuned-conll03-german", |
|
|
|
"xlnet-base-cased", |
|
|
|
"xlnet-large-cased", |
|
|
|
"uw-madison/yoso-4096", |
|
|
|
"microsoft/deberta-v2-xlarge", |
|
|
|
"microsoft/deberta-v2-xxlarge", |
|
|
|
"microsoft/deberta-v2-xlarge-mnli", |
|
|
|
"microsoft/deberta-v2-xxlarge-mnli", |
|
|
|
"flaubert/flaubert_small_cased", |
|
|
|
"flaubert/flaubert_base_uncased", |
|
|
|
"flaubert/flaubert_base_cased", |
|
|
|
"flaubert/flaubert_large_cased", |
|
|
|
"camembert-base", |
|
|
|
"Musixmatch/umberto-commoncrawl-cased-v1", |
|
|
|
"Musixmatch/umberto-wikipedia-uncased-v1", |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
return full_list |
|
|
|