You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
156 lines
5.4 KiB
156 lines
5.4 KiB
import os
|
|
from pathlib import Path
|
|
from typing import List
|
|
from functools import partial
|
|
|
|
import torch
|
|
try:
|
|
from towhee import accelerate
|
|
except:
|
|
def accelerate(func):
|
|
return func
|
|
from towhee.operator import NNOperator
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
|
|
|
|
|
@accelerate
|
|
class Model:
|
|
def __init__(self, model_name, checkpoint_path, config, device):
|
|
self.device = device
|
|
self.config = config
|
|
if checkpoint_path:
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config)
|
|
else:
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
new_args = []
|
|
for x in args:
|
|
new_args.append(x.to(self.device))
|
|
new_kwargs = {}
|
|
for k, v in kwargs.items():
|
|
new_kwargs[k] = v.to(self.device)
|
|
outs = self.model(*new_args, **new_kwargs, return_dict=True)
|
|
return outs.logits
|
|
|
|
|
|
class ReRank(NNOperator):
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None):
|
|
super().__init__()
|
|
self._model_name = model_name
|
|
self.config = AutoConfig.from_pretrained(model_name)
|
|
self.device = device
|
|
self.model = Model(model_name, checkpoint_path, self.config, device)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.max_length = max_length
|
|
self._threshold = threshold
|
|
if self.config.num_labels == 1:
|
|
self.activation_fct = torch.sigmoid
|
|
else:
|
|
self.activation_fct = partial(torch.softmax, dim=1)
|
|
|
|
def __call__(self, query: str, docs: List):
|
|
if len(docs) == 0:
|
|
return [], []
|
|
|
|
batch = [(query, doc) for doc in docs]
|
|
texts = [[] for _ in range(len(batch[0]))]
|
|
|
|
for example in batch:
|
|
for idx, text in enumerate(example):
|
|
texts[idx].append(text.strip())
|
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
|
|
|
|
logits = self.model(**tokenized)
|
|
scores = self.post_proc(logits)
|
|
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
|
|
if self._threshold is None:
|
|
re_docs = [docs[i] for i in re_ids]
|
|
re_scores = [scores[i] for i in re_ids]
|
|
else:
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
|
|
return re_docs, re_scores
|
|
|
|
|
|
def post_proc(self, logits):
|
|
scores = self.activation_fct(logits).detach().cpu().numpy()
|
|
if self.config.num_labels == 1:
|
|
scores = [float(score[0]) for score in scores]
|
|
else:
|
|
scores = scores[:, 1]
|
|
scores = [float(score) for score in scores]
|
|
return scores
|
|
|
|
|
|
@property
|
|
def _model(self):
|
|
return self.model.model
|
|
|
|
@property
|
|
def supported_formats(self):
|
|
return ['onnx']
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'):
|
|
if path == 'default':
|
|
path = str(Path(__file__).parent)
|
|
path = os.path.join(path, 'saved', format)
|
|
os.makedirs(path, exist_ok=True)
|
|
name = self._model_name.replace('/', '-')
|
|
path = os.path.join(path, name)
|
|
if format in ['pytorch',]:
|
|
path = path + '.pt'
|
|
elif format == 'onnx':
|
|
path = path + '.onnx'
|
|
else:
|
|
raise AttributeError(f'Invalid format {format}.')
|
|
|
|
|
|
if format == 'pytorch':
|
|
torch.save(self._model, path)
|
|
|
|
elif format == 'onnx':
|
|
from transformers.onnx.features import FeaturesManager
|
|
from transformers.onnx import export
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
|
|
self._model, feature='default')
|
|
onnx_config = model_onnx_config(self._model.config)
|
|
onnx_inputs, onnx_outputs = export(
|
|
self.tokenizer,
|
|
self._model.to('cpu'),
|
|
config=onnx_config,
|
|
opset=13,
|
|
output=Path(path)
|
|
)
|
|
return Path(path).resolve()
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model_name_list = [
|
|
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
|
|
'cross-encoder/ms-marco-MiniLM-L-2-v2',
|
|
'cross-encoder/ms-marco-MiniLM-L-4-v2',
|
|
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
|
'cross-encoder/ms-marco-MiniLM-L-12-v2',
|
|
'cross-encoder/ms-marco-TinyBERT-L-2',
|
|
'cross-encoder/ms-marco-TinyBERT-L-4',
|
|
'cross-encoder/ms-marco-TinyBERT-L-6',
|
|
'cross-encoder/ms-marco-electra-base',
|
|
'nboost/pt-tinybert-msmarco',
|
|
'nboost/pt-bert-base-uncased-msmarco',
|
|
'nboost/pt-bert-large-msmarco',
|
|
'Capreolus/electra-base-msmarco',
|
|
'amberoad/bert-multilingual-passage-reranking-msmarco',
|
|
]
|
|
for model_name in model_name_list:
|
|
print('\n' + model_name)
|
|
op = ReRank(model_name, threshold=0)
|
|
res = op('abc', ['123', 'ABC', 'ABCabc'])
|
|
print(res)
|
|
op.save_model('onnx')
|