logo
rerank
repo-copy-icon

copied

Browse Source

refactor to support trt accelerate

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi 2 years ago
parent
commit
5feecc6f6d
  1. 67
      rerank.py

67
rerank.py

@ -21,16 +21,17 @@ class Model:
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, **features):
with torch.no_grad():
logits = self.model(**features, return_dict=True).logits
scores = self.activation_fct(logits)
return scores
def __call__(self, *args, **kwargs):
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs.logits
class ReRank(NNOperator):
@ -43,6 +44,10 @@ class ReRank(NNOperator):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length
self._threshold = threshold
if self.config.num_labels == 1:
self.activation_fct = torch.sigmoid
else:
self.activation_fct = partial(torch.softmax, dim=1)
def __call__(self, query: str, docs: List):
if len(docs) == 0:
@ -60,12 +65,8 @@ class ReRank(NNOperator):
for name in tokenized:
tokenized[name] = tokenized[name].to(self.device)
scores = self.model(**tokenized).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
else:
scores = scores[:, 1]
scores = [score for score in scores]
logits = self.model(**tokenized)
scores = self.post_proc(logits)
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None:
@ -77,6 +78,16 @@ class ReRank(NNOperator):
return re_docs, re_scores
def post_proc(self, logits):
scores = self.activation_fct(logits).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
else:
scores = scores[:, 1]
scores = [score for score in scores]
return scores
@property
def _model(self):
return self.model.model
@ -114,3 +125,29 @@ class ReRank(NNOperator):
output=Path(path)
)
return Path(path).resolve()
if __name__ == '__main__':
model_name_list = [
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-4-v2',
'cross-encoder/ms-marco-MiniLM-L-6-v2',
'cross-encoder/ms-marco-MiniLM-L-12-v2',
'cross-encoder/ms-marco-TinyBERT-L-2',
'cross-encoder/ms-marco-TinyBERT-L-4',
'cross-encoder/ms-marco-TinyBERT-L-6',
'cross-encoder/ms-marco-electra-base',
'nboost/pt-tinybert-msmarco',
'nboost/pt-bert-base-uncased-msmarco',
'nboost/pt-bert-large-msmarco',
'Capreolus/electra-base-msmarco',
'amberoad/bert-multilingual-passage-reranking-msmarco',
]
for model_name in model_name_list:
print('\n' + model_name)
op = ReRank(model_name, threshold=0)
res = op('abc', ['123', 'ABC', 'ABCabc'])
print(res)
op.save_model('onnx')
Loading…
Cancel
Save