|
@ -15,10 +15,13 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto |
|
|
|
|
|
|
|
|
@accelerate |
|
|
@accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, config, device): |
|
|
|
|
|
|
|
|
def __init__(self, model_name, checkpoint_path, config, device): |
|
|
self.device = device |
|
|
self.device = device |
|
|
self.config = config |
|
|
self.config = config |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
|
|
|
if checkpoint_path: |
|
|
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.to(self.device) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
|
|
|
|
|
@ -35,12 +38,12 @@ class Model: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReRank(NNOperator): |
|
|
class ReRank(NNOperator): |
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self._model_name = model_name |
|
|
self._model_name = model_name |
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
|
self.device = device |
|
|
self.device = device |
|
|
self.model = Model(model_name, self.config, device) |
|
|
|
|
|
|
|
|
self.model = Model(model_name, checkpoint_path, self.config, device) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.max_length = max_length |
|
|
self.max_length = max_length |
|
|
self._threshold = threshold |
|
|
self._threshold = threshold |
|
|