|
|
@ -21,16 +21,17 @@ class Model: |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
|
self.model.to(self.device) |
|
|
|
self.model.eval() |
|
|
|
if self.config.num_labels == 1: |
|
|
|
self.activation_fct = torch.sigmoid |
|
|
|
else: |
|
|
|
self.activation_fct = partial(torch.softmax, dim=1) |
|
|
|
|
|
|
|
def __call__(self, **features): |
|
|
|
with torch.no_grad(): |
|
|
|
logits = self.model(**features, return_dict=True).logits |
|
|
|
scores = self.activation_fct(logits) |
|
|
|
return scores |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
new_args = [] |
|
|
|
for x in args: |
|
|
|
new_args.append(x.to(self.device)) |
|
|
|
new_kwargs = {} |
|
|
|
for k, v in kwargs.items(): |
|
|
|
new_kwargs[k] = v.to(self.device) |
|
|
|
outs = self.model(*new_args, **new_kwargs, return_dict=True) |
|
|
|
return outs.logits |
|
|
|
|
|
|
|
|
|
|
|
class ReRank(NNOperator): |
|
|
@ -43,6 +44,10 @@ class ReRank(NNOperator): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
self.max_length = max_length |
|
|
|
self._threshold = threshold |
|
|
|
if self.config.num_labels == 1: |
|
|
|
self.activation_fct = torch.sigmoid |
|
|
|
else: |
|
|
|
self.activation_fct = partial(torch.softmax, dim=1) |
|
|
|
|
|
|
|
def __call__(self, query: str, docs: List): |
|
|
|
if len(docs) == 0: |
|
|
@ -60,12 +65,8 @@ class ReRank(NNOperator): |
|
|
|
for name in tokenized: |
|
|
|
tokenized[name] = tokenized[name].to(self.device) |
|
|
|
|
|
|
|
scores = self.model(**tokenized).detach().cpu().numpy() |
|
|
|
if self.config.num_labels == 1: |
|
|
|
scores = [score[0] for score in scores] |
|
|
|
else: |
|
|
|
scores = scores[:, 1] |
|
|
|
scores = [score for score in scores] |
|
|
|
logits = self.model(**tokenized) |
|
|
|
scores = self.post_proc(logits) |
|
|
|
|
|
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) |
|
|
|
if self._threshold is None: |
|
|
@ -77,6 +78,16 @@ class ReRank(NNOperator): |
|
|
|
return re_docs, re_scores |
|
|
|
|
|
|
|
|
|
|
|
def post_proc(self, logits): |
|
|
|
scores = self.activation_fct(logits).detach().cpu().numpy() |
|
|
|
if self.config.num_labels == 1: |
|
|
|
scores = [score[0] for score in scores] |
|
|
|
else: |
|
|
|
scores = scores[:, 1] |
|
|
|
scores = [score for score in scores] |
|
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
@ -114,3 +125,29 @@ class ReRank(NNOperator): |
|
|
|
output=Path(path) |
|
|
|
) |
|
|
|
return Path(path).resolve() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
model_name_list = [ |
|
|
|
'cross-encoder/ms-marco-TinyBERT-L-2-v2', |
|
|
|
'cross-encoder/ms-marco-MiniLM-L-2-v2', |
|
|
|
'cross-encoder/ms-marco-MiniLM-L-4-v2', |
|
|
|
'cross-encoder/ms-marco-MiniLM-L-6-v2', |
|
|
|
'cross-encoder/ms-marco-MiniLM-L-12-v2', |
|
|
|
'cross-encoder/ms-marco-TinyBERT-L-2', |
|
|
|
'cross-encoder/ms-marco-TinyBERT-L-4', |
|
|
|
'cross-encoder/ms-marco-TinyBERT-L-6', |
|
|
|
'cross-encoder/ms-marco-electra-base', |
|
|
|
'nboost/pt-tinybert-msmarco', |
|
|
|
'nboost/pt-bert-base-uncased-msmarco', |
|
|
|
'nboost/pt-bert-large-msmarco', |
|
|
|
'Capreolus/electra-base-msmarco', |
|
|
|
'amberoad/bert-multilingual-passage-reranking-msmarco', |
|
|
|
] |
|
|
|
for model_name in model_name_list: |
|
|
|
print('\n' + model_name) |
|
|
|
op = ReRank(model_name, threshold=0) |
|
|
|
res = op('abc', ['123', 'ABC', 'ABCabc']) |
|
|
|
print(res) |
|
|
|
op.save_model('onnx') |