logo
rerank
repo-copy-icon

copied

You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

116 lines
4.0 KiB

import os
from pathlib import Path
from typing import List
from functools import partial
import torch
try:
from towhee import accelerate
except:
def accelerate(func):
return func
from towhee.operator import NNOperator
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
@accelerate
class Model:
def __init__(self, model_name, config, device):
self.device = device
self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, **features):
with torch.no_grad():
logits = self.model(**features, return_dict=True).logits
scores = self.activation_fct(logits)
return scores
class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
super().__init__()
self._model_name = model_name
self.config = AutoConfig.from_pretrained(model_name)
self.device = device
self.model = Model(model_name, self.config, device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self._threshold = threshold
def __call__(self, query: str, docs: List):
if len(docs) == 0:
return [], []
batch = [(query, doc) for doc in docs]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
for name in tokenized:
tokenized[name] = tokenized[name].to(self.device)
scores = self.model(**tokenized).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
else:
scores = scores[:, 1]
scores = [score for score in scores]
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None:
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
else:
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores
@property
def _model(self):
return self.model.model
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self._model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch',]:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self._model,
config=onnx_config,
opset=13,
output=Path(path)
)
return Path(path).resolve()