diff --git a/rerank.py b/rerank.py index c00ee3f..0b02fdd 100644 --- a/rerank.py +++ b/rerank.py @@ -1,27 +1,72 @@ +import os +from pathlib import Path from typing import List +from functools import partial -from torch import nn -from sentence_transformers import CrossEncoder - +import torch +try: + from towhee import accelerate +except: + def accelerate(func): + return func from towhee.operator import NNOperator +from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig + + +@accelerate +class Model: + def __init__(self, model_name, config, device): + self.device = device + self.config = config + self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) + self.model.to(self.device) + self.model.eval() + if self.config.num_labels == 1: + self.activation_fct = torch.sigmoid + else: + self.activation_fct = partial(torch.softmax, dim=1) + + def __call__(self, **features): + with torch.no_grad(): + logits = self.model(**features, return_dict=True).logits + scores = self.activation_fct(logits) + return scores class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): super().__init__() self._model_name = model_name - self._model = CrossEncoder(self._model_name, device=device) - if self._model.config.num_labels == 1: - self._model.default_activation_function = nn.Sigmoid() + self.config = AutoConfig.from_pretrained(model_name) + self.device = device + self.model = Model(model_name, self.config, device) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.max_length = max_length self._threshold = threshold def __call__(self, query: str, docs: List): if len(docs) == 0: return [], [] - if self._model.config.num_labels > 1: - scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] + + batch = [(query, doc) for doc in docs] + texts = [[] for _ in range(len(batch[0]))] + + for example in batch: + for idx, text in enumerate(example): + texts[idx].append(text.strip()) + + tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) + + for name in tokenized: + tokenized[name] = tokenized[name].to(self.device) + + scores = self.model(**tokenized).detach().cpu().numpy() + if self.config.num_labels == 1: + scores = [score[0] for score in scores] else: - scores = self._model.predict([(query, doc) for doc in docs]) + scores = scores[:, 1] + scores = [score for score in scores] + re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if self._threshold is None: re_docs = [docs[i] for i in re_ids] @@ -30,3 +75,42 @@ class ReRank(NNOperator): re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] return re_docs, re_scores + + + @property + def _model(self): + return self.model.model + + + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': + path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self._model_name.replace('/', '-') + path = os.path.join(path, name) + if format in ['pytorch',]: + path = path + '.pt' + elif format == 'onnx': + path = path + '.onnx' + else: + raise AttributeError(f'Invalid format {format}.') + + + if format == 'pytorch': + torch.save(self._model, path) + + elif format == 'onnx': + from transformers.onnx.features import FeaturesManager + from transformers.onnx import export + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + self._model, feature='default') + onnx_config = model_onnx_config(self._model.config) + onnx_inputs, onnx_outputs = export( + self.tokenizer, + self._model, + config=onnx_config, + opset=13, + output=Path(path) + ) + return Path(path).resolve()