logo
Browse Source

Add supported_model_names

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
2bc56f2cd6
  1. 29
      README.md
  2. 73
      auto_transformers.py

29
README.md

@ -316,10 +316,11 @@ The operator takes a piece of text in string as input.
It loads tokenizer and pre-trained model using model name.
and then return text embedding in ndarray.
***__call__(txt)***
**Parameters:**
***text***: *str*
***txt***: *str*
​ The text in string.
@ -329,3 +330,29 @@ and then return text embedding in ndarray.
*numpy.ndarray*
​ The text embedding extracted by model.
***save_model(format='pytorch', path='default')***
Save model to local with specified format.
**Parameters:**
***format***: *str*
​ The format of saved model, defaults to 'pytorch'.
***format***: *path*
​ The path where model is saved to. By default, it will save model to the operator directory.
***supported_model_names(format=None)***
Get a list of all supported model names or supported model names for specified model format.
**Parameters:**
***format***: *str*
​ The model format such as 'pytorch', 'torchscript'.

73
auto_transformers.py

@ -76,13 +76,15 @@ class AutoTransformers(NNOperator):
vec = features.detach().numpy()
return vec
def save_model(self, format: str = 'default', path: str = 'default'):
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt')
if format == 'torchscript':
if format == 'pytorch':
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
inputs = list(inputs.values())
try:
@ -95,12 +97,13 @@ class AutoTransformers(NNOperator):
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onxx':
pass
pass # todo
else:
torch.save(self.model, path)
log.error(f'Unsupported format "{format}".')
def get_model_name(self):
full_list = [
@staticmethod
def supported_model_names(format: str = None):
model_list = [
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
@ -228,5 +231,59 @@ class AutoTransformers(NNOperator):
"Musixmatch/umberto-commoncrawl-cased-v1",
"Musixmatch/umberto-wikipedia-uncased-v1",
]
full_list.sort()
return full_list
model_list.sort()
if format is None:
pass
elif format == 'pytorch':
pass
elif format == 'torchscript':
model_list.remove(
'EleutherAI/gpt-j-6B',
'EleutherAI/gpt-neo-1.3B',
'allenai/led-base-16384',
'ctrl',
'distilgpt2',
'facebook/bart-large',
'google/bigbird-pegasus-large-arxiv',
'google/bigbird-pegasus-large-bigpatent',
'google/bigbird-pegasus-large-pubmed',
'google/canine-c',
'google/canine-s',
'google/reformer-crime-and-punishment',
'gpt2',
'gpt2-large',
'gpt2-medium',
'gpt2-xl',
'microsoft/deberta-base',
'microsoft/deberta-base-mnli',
'microsoft/deberta-large',
'microsoft/deberta-large-mnli',
'microsoft/deberta-xlarge',
'microsoft/deberta-xlarge-mnli',
'openai-gpt',
'transfo-xl-wt103',
'uw-madison/yoso-4096',
'xlm-clm-ende-1024',
'xlm-clm-enfr-1024',
'xlm-mlm-100-1280',
'xlm-mlm-17-1280',
'xlm-mlm-en-2048',
'xlm-mlm-ende-1024',
'xlm-mlm-enfr-1024',
'xlm-mlm-enro-1024',
'xlm-mlm-tlm-xnli15-1024',
'xlm-mlm-xnli15-1024',
'xlnet-base-cased',
'xlnet-large-cased',
'microsoft/deberta-v2-xlarge',
'microsoft/deberta-v2-xxlarge',
'microsoft/deberta-v2-xlarge-mnli',
'microsoft/deberta-v2-xxlarge-mnli',
'flaubert/flaubert_small_cased',
'flaubert/flaubert_base_uncased',
'flaubert/flaubert_base_cased',
'flaubert/flaubert_large_cased'
)
else: # todo: format in {'onnx', 'tensorrt'}
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list

Loading…
Cancel
Save