|
|
@ -80,10 +80,11 @@ class AutoTransformers(NNOperator): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name + '.pt') |
|
|
|
path = os.path.join(path, name) |
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') |
|
|
|
inputs = list(inputs.values()) |
|
|
|
if format == 'torchscript': |
|
|
|
path = path + '.pt' |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
@ -93,11 +94,12 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onxx': |
|
|
|
pass |
|
|
|
else: |
|
|
|
torch.save(self.model, path) |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(): |
|
|
|
def get_model_name(self): |
|
|
|
full_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
@ -214,6 +216,17 @@ def get_model_list(): |
|
|
|
"xlnet-base-cased", |
|
|
|
"xlnet-large-cased", |
|
|
|
"uw-madison/yoso-4096", |
|
|
|
"microsoft/deberta-v2-xlarge", |
|
|
|
"microsoft/deberta-v2-xxlarge", |
|
|
|
"microsoft/deberta-v2-xlarge-mnli", |
|
|
|
"microsoft/deberta-v2-xxlarge-mnli", |
|
|
|
"flaubert/flaubert_small_cased", |
|
|
|
"flaubert/flaubert_base_uncased", |
|
|
|
"flaubert/flaubert_base_cased", |
|
|
|
"flaubert/flaubert_large_cased", |
|
|
|
"camembert-base", |
|
|
|
"Musixmatch/umberto-commoncrawl-cased-v1", |
|
|
|
"Musixmatch/umberto-wikipedia-uncased-v1", |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
return full_list |
|
|
|