logo
rerank
repo-copy-icon

copied

Browse Source

refactor to support onnx

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi 2 years ago
parent
commit
87ea1727e1
  1. 104
      rerank.py

104
rerank.py

@ -1,27 +1,72 @@
import os
from pathlib import Path
from typing import List from typing import List
from functools import partial
from torch import nn
from sentence_transformers import CrossEncoder
import torch
try:
from towhee import accelerate
except:
def accelerate(func):
return func
from towhee.operator import NNOperator from towhee.operator import NNOperator
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
@accelerate
class Model:
def __init__(self, model_name, config, device):
self.device = device
self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, **features):
with torch.no_grad():
logits = self.model(**features, return_dict=True).logits
scores = self.activation_fct(logits)
return scores
class ReRank(NNOperator): class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512):
super().__init__() super().__init__()
self._model_name = model_name self._model_name = model_name
self._model = CrossEncoder(self._model_name, device=device)
if self._model.config.num_labels == 1:
self._model.default_activation_function = nn.Sigmoid()
self.config = AutoConfig.from_pretrained(model_name)
self.device = device
self.model = Model(model_name, self.config, device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self._threshold = threshold self._threshold = threshold
def __call__(self, query: str, docs: List): def __call__(self, query: str, docs: List):
if len(docs) == 0: if len(docs) == 0:
return [], [] return [], []
if self._model.config.num_labels > 1:
scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1]
batch = [(query, doc) for doc in docs]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
for name in tokenized:
tokenized[name] = tokenized[name].to(self.device)
scores = self.model(**tokenized).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
else: else:
scores = self._model.predict([(query, doc) for doc in docs])
scores = scores[:, 1]
scores = [score for score in scores]
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None: if self._threshold is None:
re_docs = [docs[i] for i in re_ids] re_docs = [docs[i] for i in re_ids]
@ -30,3 +75,42 @@ class ReRank(NNOperator):
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores return re_docs, re_scores
@property
def _model(self):
return self.model.model
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self._model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch',]:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self._model,
config=onnx_config,
opset=13,
output=Path(path)
)
return Path(path).resolve()

Loading…
Cancel
Save