logo
Browse Source

Update model list

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
0bd1bcc127
  1. 2
      README.md
  2. 21
      auto_transformers.py

2
README.md

@ -64,6 +64,7 @@ Supported model names:
- bert-base-cased - bert-base-cased
- bert-large-cased - bert-large-cased
- bert-large-uncased
- bert-base-multilingual-uncased - bert-base-multilingual-uncased
- bert-base-multilingual-cased - bert-base-multilingual-cased
- bert-base-chinese - bert-base-chinese
@ -181,6 +182,7 @@ Supported model names:
</details> </details>
<details><summary>Funnel</summary> <details><summary>Funnel</summary>
- funnel-transformer/small - funnel-transformer/small
- funnel-transformer/small-base - funnel-transformer/small-base
- funnel-transformer/medium - funnel-transformer/medium

21
auto_transformers.py

@ -80,10 +80,11 @@ class AutoTransformers(NNOperator):
if path == 'default': if path == 'default':
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name + '.pt')
path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt') inputs = self.tokenizer('[CLS]', return_tensors='pt')
inputs = list(inputs.values())
if format == 'torchscript': if format == 'torchscript':
path = path + '.pt'
inputs = list(inputs.values())
try: try:
try: try:
jit_model = torch.jit.script(self.model) jit_model = torch.jit.script(self.model)
@ -93,11 +94,12 @@ class AutoTransformers(NNOperator):
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onxx':
pass
else: else:
torch.save(self.model, path) torch.save(self.model, path)
def get_model_list():
def get_model_name(self):
full_list = [ full_list = [
"bert-large-uncased", "bert-large-uncased",
"bert-base-cased", "bert-base-cased",
@ -214,6 +216,17 @@ def get_model_list():
"xlnet-base-cased", "xlnet-base-cased",
"xlnet-large-cased", "xlnet-large-cased",
"uw-madison/yoso-4096", "uw-madison/yoso-4096",
"microsoft/deberta-v2-xlarge",
"microsoft/deberta-v2-xxlarge",
"microsoft/deberta-v2-xlarge-mnli",
"microsoft/deberta-v2-xxlarge-mnli",
"flaubert/flaubert_small_cased",
"flaubert/flaubert_base_uncased",
"flaubert/flaubert_base_cased",
"flaubert/flaubert_large_cased",
"camembert-base",
"Musixmatch/umberto-commoncrawl-cased-v1",
"Musixmatch/umberto-wikipedia-uncased-v1",
] ]
full_list.sort() full_list.sort()
return full_list return full_list

Loading…
Cancel
Save