logo
rerank
repo-copy-icon

copied

Browse Source

support local checkpoint_path

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi 2 years ago
parent
commit
862b980643
  1. 9
      rerank.py

9
rerank.py

@ -15,10 +15,13 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto
@accelerate @accelerate
class Model: class Model:
def __init__(self, model_name, config, device):
def __init__(self, model_name, checkpoint_path, config, device):
self.device = device self.device = device
self.config = config self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
if checkpoint_path:
state_dict = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(state_dict)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -35,12 +38,12 @@ class Model:
class ReRank(NNOperator): class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None):
super().__init__() super().__init__()
self._model_name = model_name self._model_name = model_name
self.config = AutoConfig.from_pretrained(model_name) self.config = AutoConfig.from_pretrained(model_name)
self.device = device self.device = device
self.model = Model(model_name, self.config, device)
self.model = Model(model_name, checkpoint_path, self.config, device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length self.max_length = max_length
self._threshold = threshold self._threshold = threshold

Loading…
Cancel
Save