logo
rerank
repo-copy-icon

copied

You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

156 lines
5.3 KiB

import os
from pathlib import Path
from typing import List
from functools import partial
import torch
try:
from towhee import accelerate
except:
def accelerate(func):
return func
from towhee.operator import NNOperator
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
@accelerate
class Model:
def __init__(self, model_name, config, device):
self.device = device
self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()
def __call__(self, *args, **kwargs):
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs.logits
class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
super().__init__()
self._model_name = model_name
self.config = AutoConfig.from_pretrained(model_name)
self.device = device
self.model = Model(model_name, self.config, device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self._threshold = threshold
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, query: str, docs: List):
if len(docs) == 0:
return [], []
batch = [(query, doc) for doc in docs]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
for name in tokenized:
tokenized[name] = tokenized[name].to(self.device)
logits = self.model(**tokenized)
scores = self.post_proc(logits)
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None:
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
else:
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores
def post_proc(self, logits):
scores = self.activation_fct(logits).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [float(score[0]) for score in scores]
else:
scores = scores[:, 1]
scores = [float(score) for score in scores]
return scores
@property
def _model(self):
return self.model.model
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self._model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch',]:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self._model,
config=onnx_config,
opset=13,
output=Path(path)
)
return Path(path).resolve()
if __name__ == '__main__':
model_name_list = [
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-4-v2',
'cross-encoder/ms-marco-MiniLM-L-6-v2',
'cross-encoder/ms-marco-MiniLM-L-12-v2',
'cross-encoder/ms-marco-TinyBERT-L-2',
'cross-encoder/ms-marco-TinyBERT-L-4',
'cross-encoder/ms-marco-TinyBERT-L-6',
'cross-encoder/ms-marco-electra-base',
'nboost/pt-tinybert-msmarco',
'nboost/pt-bert-base-uncased-msmarco',
'nboost/pt-bert-large-msmarco',
'Capreolus/electra-base-msmarco',
'amberoad/bert-multilingual-passage-reranking-msmarco',
]
for model_name in model_name_list:
print('\n' + model_name)
op = ReRank(model_name, threshold=0)
res = op('abc', ['123', 'ABC', 'ABCabc'])
print(res)
op.save_model('onnx')