diff --git a/README.md b/README.md index 6b45e09..52f72db 100644 --- a/README.md +++ b/README.md @@ -356,3 +356,4 @@ Get a list of all supported model names or supported model names for specified m ***format***: *str* ​ The model format such as 'pytorch', 'torchscript'. + diff --git a/auto_transformers.py b/auto_transformers.py index ce5b15c..90dc590 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -45,7 +45,7 @@ class AutoTransformers(NNOperator): self.model = AutoModel.from_pretrained(model_name) self.model.eval() except Exception as e: - model_list = self.get_model_name() + model_list = self.supported_model_names() if model_name not in model_list: log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") else: @@ -103,7 +103,7 @@ class AutoTransformers(NNOperator): @staticmethod def supported_model_names(format: str = None): - model_list = [ + full_list = [ "bert-large-uncased", "bert-base-cased", "bert-large-cased", @@ -231,13 +231,15 @@ class AutoTransformers(NNOperator): "Musixmatch/umberto-commoncrawl-cased-v1", "Musixmatch/umberto-wikipedia-uncased-v1", ] - model_list.sort() + full_list.sort() if format is None: - pass + model_list = full_list elif format == 'pytorch': - pass + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) elif format == 'torchscript': - model_list.remove( + to_remove = [ 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neo-1.3B', 'allenai/led-base-16384', @@ -283,7 +285,9 @@ class AutoTransformers(NNOperator): 'flaubert/flaubert_base_uncased', 'flaubert/flaubert_base_cased', 'flaubert/flaubert_large_cased' - ) + ] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) else: # todo: format in {'onnx', 'tensorrt'} log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list