You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
116 lines
4.0 KiB
116 lines
4.0 KiB
import os
|
|
from pathlib import Path
|
|
from typing import List
|
|
from functools import partial
|
|
|
|
import torch
|
|
try:
|
|
from towhee import accelerate
|
|
except:
|
|
def accelerate(func):
|
|
return func
|
|
from towhee.operator import NNOperator
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
|
|
|
|
|
@accelerate
|
|
class Model:
|
|
def __init__(self, model_name, config, device):
|
|
self.device = device
|
|
self.config = config
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
if self.config.num_labels == 1:
|
|
self.activation_fct = torch.sigmoid
|
|
else:
|
|
self.activation_fct = partial(torch.softmax, dim=1)
|
|
|
|
def __call__(self, **features):
|
|
with torch.no_grad():
|
|
logits = self.model(**features, return_dict=True).logits
|
|
scores = self.activation_fct(logits)
|
|
return scores
|
|
|
|
|
|
class ReRank(NNOperator):
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
|
|
super().__init__()
|
|
self._model_name = model_name
|
|
self.config = AutoConfig.from_pretrained(model_name)
|
|
self.device = device
|
|
self.model = Model(model_name, self.config, device)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.max_length = max_length
|
|
self._threshold = threshold
|
|
|
|
def __call__(self, query: str, docs: List):
|
|
if len(docs) == 0:
|
|
return [], []
|
|
|
|
batch = [(query, doc) for doc in docs]
|
|
texts = [[] for _ in range(len(batch[0]))]
|
|
|
|
for example in batch:
|
|
for idx, text in enumerate(example):
|
|
texts[idx].append(text.strip())
|
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
|
|
|
|
for name in tokenized:
|
|
tokenized[name] = tokenized[name].to(self.device)
|
|
|
|
scores = self.model(**tokenized).detach().cpu().numpy()
|
|
if self.config.num_labels == 1:
|
|
scores = [score[0] for score in scores]
|
|
else:
|
|
scores = scores[:, 1]
|
|
scores = [score for score in scores]
|
|
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
|
|
if self._threshold is None:
|
|
re_docs = [docs[i] for i in re_ids]
|
|
re_scores = [scores[i] for i in re_ids]
|
|
else:
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
|
|
return re_docs, re_scores
|
|
|
|
|
|
@property
|
|
def _model(self):
|
|
return self.model.model
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'):
|
|
if path == 'default':
|
|
path = str(Path(__file__).parent)
|
|
path = os.path.join(path, 'saved', format)
|
|
os.makedirs(path, exist_ok=True)
|
|
name = self._model_name.replace('/', '-')
|
|
path = os.path.join(path, name)
|
|
if format in ['pytorch',]:
|
|
path = path + '.pt'
|
|
elif format == 'onnx':
|
|
path = path + '.onnx'
|
|
else:
|
|
raise AttributeError(f'Invalid format {format}.')
|
|
|
|
|
|
if format == 'pytorch':
|
|
torch.save(self._model, path)
|
|
|
|
elif format == 'onnx':
|
|
from transformers.onnx.features import FeaturesManager
|
|
from transformers.onnx import export
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
|
|
self._model, feature='default')
|
|
onnx_config = model_onnx_config(self._model.config)
|
|
onnx_inputs, onnx_outputs = export(
|
|
self.tokenizer,
|
|
self._model,
|
|
config=onnx_config,
|
|
opset=13,
|
|
output=Path(path)
|
|
)
|
|
return Path(path).resolve()
|