logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

70 lines
2.0 KiB

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
"""
Clone repo here: https://github.com/facebookresearch/SentEval.git
"""
from __future__ import absolute_import, division, unicode_literals
import sys
import logging
import numpy as np
from towhee import ops
from statistics import mean
import os
import warnings
from transformers import logging as t_logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings("ignore")
t_logging.set_verbosity_error()
model_name = sys.argv[-1]
op = ops.sentence_embedding.transformers(model_name=model_name, device='cuda:3').get_op()
# op = ops.text_embedding.sentence_transformers(model_name=model_name, device='cuda:3').get_op()
# Set PATHs
PATH_TO_SENTEVAL = '../'
PATH_TO_DATA = '../data'
# import SentEval
sys.path.insert(0, PATH_TO_SENTEVAL)
import senteval
# SentEval prepare and batcher
def prepare(params, samples):
return
def batcher(params, batch):
batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
embeddings = op(batch)
return np.vstack(embeddings)
# Set params for SentEval
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
'tenacity': 5, 'epoch_size': 4}
# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
if __name__ == "__main__":
se = senteval.engine.SE(params_senteval, batcher, prepare)
# transfer_tasks = ['STSBenchmark']
# transfer_tasks = ['STS16']
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']
results = se.eval(transfer_tasks)
p = []
s = []
for t in transfer_tasks:
res = results[t]['all']
p.append(res['pearson']['mean'])
s.append(res['spearman']['mean'])
print('pearson:', mean(p))
print('spearman:', mean(s))