|
|
@ -45,7 +45,7 @@ class AutoTransformers(NNOperator): |
|
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
self.model.eval() |
|
|
|
except Exception as e: |
|
|
|
model_list = self.get_model_name() |
|
|
|
model_list = self.supported_model_names() |
|
|
|
if model_name not in model_list: |
|
|
|
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") |
|
|
|
else: |
|
|
@ -103,7 +103,7 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
model_list = [ |
|
|
|
full_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
|
"bert-large-cased", |
|
|
@ -231,13 +231,15 @@ class AutoTransformers(NNOperator): |
|
|
|
"Musixmatch/umberto-commoncrawl-cased-v1", |
|
|
|
"Musixmatch/umberto-wikipedia-uncased-v1", |
|
|
|
] |
|
|
|
model_list.sort() |
|
|
|
full_list.sort() |
|
|
|
if format is None: |
|
|
|
pass |
|
|
|
model_list = full_list |
|
|
|
elif format == 'pytorch': |
|
|
|
pass |
|
|
|
to_remove = [] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'torchscript': |
|
|
|
model_list.remove( |
|
|
|
to_remove = [ |
|
|
|
'EleutherAI/gpt-j-6B', |
|
|
|
'EleutherAI/gpt-neo-1.3B', |
|
|
|
'allenai/led-base-16384', |
|
|
@ -283,7 +285,9 @@ class AutoTransformers(NNOperator): |
|
|
|
'flaubert/flaubert_base_uncased', |
|
|
|
'flaubert/flaubert_base_cased', |
|
|
|
'flaubert/flaubert_large_cased' |
|
|
|
) |
|
|
|
] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
else: # todo: format in {'onnx', 'tensorrt'} |
|
|
|
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') |
|
|
|
return model_list |
|
|
|