logo
Browse Source

Debug

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
7082875b8b
  1. 1
      README.md
  2. 18
      auto_transformers.py

1
README.md

@ -356,3 +356,4 @@ Get a list of all supported model names or supported model names for specified m
***format***: *str*
​ The model format such as 'pytorch', 'torchscript'.

18
auto_transformers.py

@ -45,7 +45,7 @@ class AutoTransformers(NNOperator):
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
except Exception as e:
model_list = self.get_model_name()
model_list = self.supported_model_names()
if model_name not in model_list:
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
else:
@ -103,7 +103,7 @@ class AutoTransformers(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
model_list = [
full_list = [
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
@ -231,13 +231,15 @@ class AutoTransformers(NNOperator):
"Musixmatch/umberto-commoncrawl-cased-v1",
"Musixmatch/umberto-wikipedia-uncased-v1",
]
model_list.sort()
full_list.sort()
if format is None:
pass
model_list = full_list
elif format == 'pytorch':
pass
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'torchscript':
model_list.remove(
to_remove = [
'EleutherAI/gpt-j-6B',
'EleutherAI/gpt-neo-1.3B',
'allenai/led-base-16384',
@ -283,7 +285,9 @@ class AutoTransformers(NNOperator):
'flaubert/flaubert_base_uncased',
'flaubert/flaubert_base_cased',
'flaubert/flaubert_large_cased'
)
]
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
else: # todo: format in {'onnx', 'tensorrt'}
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list

Loading…
Cancel
Save