From 93ed7f5e5e5ec10d7496f9af7606203651f4ac65 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 26 Dec 2022 16:12:34 +0800 Subject: [PATCH] Fix model name issue Signed-off-by: Jael Gu --- auto_transformers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 6fca0b8..4250a40 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -81,12 +81,11 @@ class AutoTransformers(NNOperator): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device - - model_list = self.supported_model_names() - assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" self.model_name = model_name if self.model_name: + model_list = self.supported_model_names() + assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" self.model = Model( model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) if tokenizer is None: