diff --git a/rerank.py b/rerank.py index 05e4428..510e81a 100644 --- a/rerank.py +++ b/rerank.py @@ -18,10 +18,10 @@ class Model: def __init__(self, model_name, checkpoint_path, config, device): self.device = device self.config = config - self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) if checkpoint_path: - state_dict = torch.load(checkpoint_path, map_location=device) - self.model.load_state_dict(state_dict) + self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config) + else: + self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) self.model.to(self.device) self.model.eval()