logo
rerank
repo-copy-icon

copied

Readme
Files and versions

157 lines
5.3 KiB

import os
from pathlib import Path
from typing import List
from functools import partial
import torch
try:
from towhee import accelerate
except:
def accelerate(func):
return func
from towhee.operator import NNOperator
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
@accelerate
class Model:
def __init__(self, model_name, config, device):
self.device = device
self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()
def __call__(self, *args, **kwargs):
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs.logits
class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
super().__init__()
self._model_name = model_name
self.config = AutoConfig.from_pretrained(model_name)
self.device = device
self.model = Model(model_name, self.config, device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self._threshold = threshold
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, query: str, docs: List):
if len(docs) == 0:
return [], []
batch = [(query, doc) for doc in docs]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
for name in tokenized:
tokenized[name] = tokenized[name].to(self.device)
logits = self.model(**tokenized)
scores = self.post_proc(logits)
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None:
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
else:
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores
def post_proc(self, logits):
scores = self.activation_fct(logits).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [float(score[0]) for score in scores]
else:
scores = scores[:, 1]
scores = [float(score) for score in scores]
return scores
@property
def _model(self):
return self.model.model
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self._model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch',]:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self._model,
config=onnx_config,
opset=13,
output=Path(path)
)
return Path(path).resolve()
if __name__ == '__main__':
model_name_list = [
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-4-v2',
'cross-encoder/ms-marco-MiniLM-L-6-v2',
'cross-encoder/ms-marco-MiniLM-L-12-v2',
'cross-encoder/ms-marco-TinyBERT-L-2',
'cross-encoder/ms-marco-TinyBERT-L-4',
'cross-encoder/ms-marco-TinyBERT-L-6',
'cross-encoder/ms-marco-electra-base',
'nboost/pt-tinybert-msmarco',
'nboost/pt-bert-base-uncased-msmarco',
'nboost/pt-bert-large-msmarco',
'Capreolus/electra-base-msmarco',
'amberoad/bert-multilingual-passage-reranking-msmarco',
]
for model_name in model_name_list:
print('\n' + model_name)
op = ReRank(model_name, threshold=0)
res = op('abc', ['123', 'ABC', 'ABCabc'])
print(res)
op.save_model('onnx')