|
|
@ -76,13 +76,15 @@ class AutoTransformers(NNOperator): |
|
|
|
vec = features.detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
|
def save_model(self, format: str = 'default', path: str = 'default'): |
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name) |
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') |
|
|
|
if format == 'torchscript': |
|
|
|
if format == 'pytorch': |
|
|
|
torch.save(self.model, path) |
|
|
|
elif format == 'torchscript': |
|
|
|
path = path + '.pt' |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
@ -95,12 +97,13 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onxx': |
|
|
|
pass |
|
|
|
pass # todo |
|
|
|
else: |
|
|
|
torch.save(self.model, path) |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
|
|
|
|
def get_model_name(self): |
|
|
|
full_list = [ |
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
model_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
|
"bert-large-cased", |
|
|
@ -228,5 +231,59 @@ class AutoTransformers(NNOperator): |
|
|
|
"Musixmatch/umberto-commoncrawl-cased-v1", |
|
|
|
"Musixmatch/umberto-wikipedia-uncased-v1", |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
return full_list |
|
|
|
model_list.sort() |
|
|
|
if format is None: |
|
|
|
pass |
|
|
|
elif format == 'pytorch': |
|
|
|
pass |
|
|
|
elif format == 'torchscript': |
|
|
|
model_list.remove( |
|
|
|
'EleutherAI/gpt-j-6B', |
|
|
|
'EleutherAI/gpt-neo-1.3B', |
|
|
|
'allenai/led-base-16384', |
|
|
|
'ctrl', |
|
|
|
'distilgpt2', |
|
|
|
'facebook/bart-large', |
|
|
|
'google/bigbird-pegasus-large-arxiv', |
|
|
|
'google/bigbird-pegasus-large-bigpatent', |
|
|
|
'google/bigbird-pegasus-large-pubmed', |
|
|
|
'google/canine-c', |
|
|
|
'google/canine-s', |
|
|
|
'google/reformer-crime-and-punishment', |
|
|
|
'gpt2', |
|
|
|
'gpt2-large', |
|
|
|
'gpt2-medium', |
|
|
|
'gpt2-xl', |
|
|
|
'microsoft/deberta-base', |
|
|
|
'microsoft/deberta-base-mnli', |
|
|
|
'microsoft/deberta-large', |
|
|
|
'microsoft/deberta-large-mnli', |
|
|
|
'microsoft/deberta-xlarge', |
|
|
|
'microsoft/deberta-xlarge-mnli', |
|
|
|
'openai-gpt', |
|
|
|
'transfo-xl-wt103', |
|
|
|
'uw-madison/yoso-4096', |
|
|
|
'xlm-clm-ende-1024', |
|
|
|
'xlm-clm-enfr-1024', |
|
|
|
'xlm-mlm-100-1280', |
|
|
|
'xlm-mlm-17-1280', |
|
|
|
'xlm-mlm-en-2048', |
|
|
|
'xlm-mlm-ende-1024', |
|
|
|
'xlm-mlm-enfr-1024', |
|
|
|
'xlm-mlm-enro-1024', |
|
|
|
'xlm-mlm-tlm-xnli15-1024', |
|
|
|
'xlm-mlm-xnli15-1024', |
|
|
|
'xlnet-base-cased', |
|
|
|
'xlnet-large-cased', |
|
|
|
'microsoft/deberta-v2-xlarge', |
|
|
|
'microsoft/deberta-v2-xxlarge', |
|
|
|
'microsoft/deberta-v2-xlarge-mnli', |
|
|
|
'microsoft/deberta-v2-xxlarge-mnli', |
|
|
|
'flaubert/flaubert_small_cased', |
|
|
|
'flaubert/flaubert_base_uncased', |
|
|
|
'flaubert/flaubert_base_cased', |
|
|
|
'flaubert/flaubert_large_cased' |
|
|
|
) |
|
|
|
else: # todo: format in {'onnx', 'tensorrt'} |
|
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') |
|
|
|
return model_list |
|
|
|