From 862b980643be6c8391738eaca057064e5636982d Mon Sep 17 00:00:00 2001 From: ChengZi Date: Mon, 24 Jul 2023 17:58:40 +0800 Subject: [PATCH] support local checkpoint_path Signed-off-by: ChengZi --- rerank.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/rerank.py b/rerank.py index 3074313..05e4428 100644 --- a/rerank.py +++ b/rerank.py @@ -15,10 +15,13 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto @accelerate class Model: - def __init__(self, model_name, config, device): + def __init__(self, model_name, checkpoint_path, config, device): self.device = device self.config = config self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) + if checkpoint_path: + state_dict = torch.load(checkpoint_path, map_location=device) + self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() @@ -35,12 +38,12 @@ class Model: class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None): super().__init__() self._model_name = model_name self.config = AutoConfig.from_pretrained(model_name) self.device = device - self.model = Model(model_name, self.config, device) + self.model = Model(model_name, checkpoint_path, self.config, device) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_length = max_length self._threshold = threshold