From 37eadfa85822be55dd5d70558cec94e9eeddc66f Mon Sep 17 00:00:00 2001 From: ChengZi Date: Tue, 25 Jul 2023 16:50:43 +0800 Subject: [PATCH] optimizer for supporting local checkpoint_path Signed-off-by: ChengZi --- rerank.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rerank.py b/rerank.py index 05e4428..510e81a 100644 --- a/rerank.py +++ b/rerank.py @@ -18,10 +18,10 @@ class Model: def __init__(self, model_name, checkpoint_path, config, device): self.device = device self.config = config - self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) if checkpoint_path: - state_dict = torch.load(checkpoint_path, map_location=device) - self.model.load_state_dict(state_dict) + self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config) + else: + self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) self.model.to(self.device) self.model.eval()