|
@ -1,27 +1,72 @@ |
|
|
|
|
|
import os |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
from typing import List |
|
|
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
from torch import nn |
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
try: |
|
|
|
|
|
from towhee import accelerate |
|
|
|
|
|
except: |
|
|
|
|
|
def accelerate(func): |
|
|
|
|
|
return func |
|
|
from towhee.operator import NNOperator |
|
|
from towhee.operator import NNOperator |
|
|
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@accelerate |
|
|
|
|
|
class Model: |
|
|
|
|
|
def __init__(self, model_name, config, device): |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.config = config |
|
|
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
if self.config.num_labels == 1: |
|
|
|
|
|
self.activation_fct = torch.sigmoid |
|
|
|
|
|
else: |
|
|
|
|
|
self.activation_fct = partial(torch.softmax, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, **features): |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
logits = self.model(**features, return_dict=True).logits |
|
|
|
|
|
scores = self.activation_fct(logits) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReRank(NNOperator): |
|
|
class ReRank(NNOperator): |
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self._model_name = model_name |
|
|
self._model_name = model_name |
|
|
self._model = CrossEncoder(self._model_name, device=device) |
|
|
|
|
|
if self._model.config.num_labels == 1: |
|
|
|
|
|
self._model.default_activation_function = nn.Sigmoid() |
|
|
|
|
|
|
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.model = Model(model_name, self.config, device) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
self.max_length = max_length |
|
|
self._threshold = threshold |
|
|
self._threshold = threshold |
|
|
|
|
|
|
|
|
def __call__(self, query: str, docs: List): |
|
|
def __call__(self, query: str, docs: List): |
|
|
if len(docs) == 0: |
|
|
if len(docs) == 0: |
|
|
return [], [] |
|
|
return [], [] |
|
|
if self._model.config.num_labels > 1: |
|
|
|
|
|
scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch = [(query, doc) for doc in docs] |
|
|
|
|
|
texts = [[] for _ in range(len(batch[0]))] |
|
|
|
|
|
|
|
|
|
|
|
for example in batch: |
|
|
|
|
|
for idx, text in enumerate(example): |
|
|
|
|
|
texts[idx].append(text.strip()) |
|
|
|
|
|
|
|
|
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) |
|
|
|
|
|
|
|
|
|
|
|
for name in tokenized: |
|
|
|
|
|
tokenized[name] = tokenized[name].to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
scores = self.model(**tokenized).detach().cpu().numpy() |
|
|
|
|
|
if self.config.num_labels == 1: |
|
|
|
|
|
scores = [score[0] for score in scores] |
|
|
else: |
|
|
else: |
|
|
scores = self._model.predict([(query, doc) for doc in docs]) |
|
|
|
|
|
|
|
|
scores = scores[:, 1] |
|
|
|
|
|
scores = [score for score in scores] |
|
|
|
|
|
|
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) |
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) |
|
|
if self._threshold is None: |
|
|
if self._threshold is None: |
|
|
re_docs = [docs[i] for i in re_ids] |
|
|
re_docs = [docs[i] for i in re_ids] |
|
@ -30,3 +75,42 @@ class ReRank(NNOperator): |
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
return re_docs, re_scores |
|
|
return re_docs, re_scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def _model(self): |
|
|
|
|
|
return self.model.model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
|
|
if path == 'default': |
|
|
|
|
|
path = str(Path(__file__).parent) |
|
|
|
|
|
path = os.path.join(path, 'saved', format) |
|
|
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
name = self._model_name.replace('/', '-') |
|
|
|
|
|
path = os.path.join(path, name) |
|
|
|
|
|
if format in ['pytorch',]: |
|
|
|
|
|
path = path + '.pt' |
|
|
|
|
|
elif format == 'onnx': |
|
|
|
|
|
path = path + '.onnx' |
|
|
|
|
|
else: |
|
|
|
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if format == 'pytorch': |
|
|
|
|
|
torch.save(self._model, path) |
|
|
|
|
|
|
|
|
|
|
|
elif format == 'onnx': |
|
|
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
|
|
from transformers.onnx import export |
|
|
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
|
|
self._model, feature='default') |
|
|
|
|
|
onnx_config = model_onnx_config(self._model.config) |
|
|
|
|
|
onnx_inputs, onnx_outputs = export( |
|
|
|
|
|
self.tokenizer, |
|
|
|
|
|
self._model, |
|
|
|
|
|
config=onnx_config, |
|
|
|
|
|
opset=13, |
|
|
|
|
|
output=Path(path) |
|
|
|
|
|
) |
|
|
|
|
|
return Path(path).resolve() |
|
|