sbert
copied
4 changed files with 245 additions and 0 deletions
@ -0,0 +1,33 @@ |
|||
# Evaluation |
|||
|
|||
## Model performance in sentence similarity |
|||
|
|||
1. Download SentEval & test data |
|||
```bash |
|||
git clone https://github.com/facebookresearch/SentEval.git |
|||
cd SentEval/data/downstream |
|||
./get_transfer_data.bash |
|||
``` |
|||
|
|||
2. Run test script |
|||
```bash |
|||
python evaluate.py MODEL_NAME |
|||
``` |
|||
|
|||
## QPS Test |
|||
|
|||
Please note that `qps_test.py` uses: |
|||
- `localhost:8000`: to connect triton client |
|||
- `'Hello, world.''`: as test sentence |
|||
|
|||
```bash |
|||
python qps_test --model sentence-t5-base --pipe --onnx --triton --num 100 |
|||
``` |
|||
|
|||
**Args:** |
|||
- `--model`: mandatory, string, model name |
|||
- `--pipe`: optional, on/off flag to enable qps test for pipe |
|||
- `--onnx`: optional, on/off flag to enable qps test for onnx |
|||
- `--triton`: optional, on/off flag to enable qps for triton (please make sure that triton client is ready) |
|||
- `--num`: optional, integer, defaults to 100, batch size in each loop (10 loops in total) |
|||
- `--device`: optional, int, defaults to -1, cuda index or use cpu when -1 |
@ -0,0 +1,69 @@ |
|||
# Copyright (c) 2017-present, Facebook, Inc. |
|||
# All rights reserved. |
|||
# |
|||
# This source code is licensed under the license found in the |
|||
# LICENSE file in the root directory of this source tree. |
|||
# |
|||
|
|||
""" |
|||
Clone repo here: https://github.com/facebookresearch/SentEval.git |
|||
""" |
|||
|
|||
from __future__ import absolute_import, division, unicode_literals |
|||
|
|||
import sys |
|||
import logging |
|||
import numpy as np |
|||
from towhee import ops |
|||
from statistics import mean |
|||
|
|||
import os |
|||
import warnings |
|||
from transformers import logging as t_logging |
|||
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|||
warnings.filterwarnings("ignore") |
|||
t_logging.set_verbosity_error() |
|||
|
|||
model_name = sys.argv[-1] |
|||
op = ops.sentence_embedding.sbert(model_name=model_name, device='cpu').get_op() |
|||
# op = ops.text_embedding.sentence_transformers(model_name=model_name, device='cuda:3').get_op() |
|||
|
|||
# Set PATHs |
|||
PATH_TO_SENTEVAL = '../' |
|||
PATH_TO_DATA = '../data' |
|||
|
|||
# import SentEval |
|||
sys.path.insert(0, PATH_TO_SENTEVAL) |
|||
import senteval |
|||
|
|||
# SentEval prepare and batcher |
|||
def prepare(params, samples): |
|||
return |
|||
|
|||
def batcher(params, batch): |
|||
batch = [' '.join(sent) if sent != [] else '.' for sent in batch] |
|||
embeddings = op(batch) |
|||
return np.vstack(embeddings) |
|||
|
|||
# Set params for SentEval |
|||
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} |
|||
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, |
|||
'tenacity': 5, 'epoch_size': 4} |
|||
|
|||
# Set up logger |
|||
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) |
|||
|
|||
if __name__ == "__main__": |
|||
se = senteval.engine.SE(params_senteval, batcher, prepare) |
|||
# transfer_tasks = ['STSBenchmark'] |
|||
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16'] |
|||
results = se.eval(transfer_tasks) |
|||
print(results) |
|||
p = [] |
|||
s = [] |
|||
for t in transfer_tasks: |
|||
p.append(res['pearson']['mean']) |
|||
s.append(res['spearman']['mean']) |
|||
print('pearson:', mean(p)) |
|||
print('spearman:', mean(s)) |
@ -0,0 +1,123 @@ |
|||
from towhee import AutoPipes, AutoConfig, pipe, ops, triton_client |
|||
|
|||
import onnxruntime |
|||
import numpy |
|||
import torch |
|||
from statistics import mean |
|||
|
|||
import time |
|||
import argparse |
|||
|
|||
import os |
|||
import re |
|||
import warnings |
|||
import logging |
|||
from transformers import logging as t_logging |
|||
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|||
warnings.filterwarnings("ignore") |
|||
t_logging.set_verbosity_error() |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--model', required=True, type=str) |
|||
parser.add_argument('--pipe', action='store_true') |
|||
parser.add_argument('--triton', action='store_true') |
|||
parser.add_argument('--onnx', action='store_true') |
|||
parser.add_argument('--atol', type=float, default=1e-3) |
|||
parser.add_argument('--num', type=int, default=100) |
|||
parser.add_argument('--device', type=int, default=-1) |
|||
|
|||
args = parser.parse_args() |
|||
|
|||
device = 'cuda:' + str(args.device) if args.device >= 0 else 'cpu' |
|||
model_name = args.model |
|||
|
|||
p = ( |
|||
pipe.input('text') |
|||
.map('text', 'vec', ops.sentence_embedding.sbert(model_name=model_name, device=device)) |
|||
.output('vec') |
|||
) |
|||
|
|||
# conf = AutoConfig.load_config('sentence_embedding') |
|||
# conf.model = model_name |
|||
# conf.device = args.device |
|||
# p = AutoPipes.pipeline('sentence_embedding', conf) |
|||
|
|||
text = 'Hello, world.' |
|||
out1 = p(text).get()[0] |
|||
print('Pipe: OK') |
|||
|
|||
if args.num and args.pipe: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
# p([text] * args.num) |
|||
p.batch([text] * args.num) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Pipe qps:', mean(qps)) |
|||
|
|||
if args.triton: |
|||
client = triton_client.Client(url='localhost:8000') |
|||
out2 = client(text)[0][0] |
|||
print('Triton: OK') |
|||
|
|||
if numpy.allclose(out1, out2, atol=args.atol): |
|||
print('Check accuracy: OK') |
|||
else: |
|||
max_diff = numpy.abs(out1 - out2).max() |
|||
min_diff = numpy.abs(out1 - out2).min() |
|||
mean_diff = numpy.abs(out1 - out2).mean() |
|||
print(f'Check accuracy: atol is larger than {args.atol}.') |
|||
print(f'Maximum absolute difference is {max_diff}.') |
|||
print(f'Minimum absolute difference is {min_diff}.') |
|||
print(f'Mean difference is {mean_diff}.') |
|||
|
|||
if args.num: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
client.batch([text] * args.num) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Triton qps:', mean(qps)) |
|||
|
|||
if args.onnx: |
|||
op = ops.sentence_embedding.sbert(model_name=model_name, device=device).get_op() |
|||
# if not os.path.exists('test.onnx'): |
|||
op.save_model('onnx', 'test.onnx') |
|||
if device == 'cpu': |
|||
providers = ['CPUExecutionProvider'] |
|||
sess = onnxruntime.InferenceSession('test.onnx', |
|||
providers=providers) |
|||
if device != 'cpu': |
|||
sess.set_providers(['CUDAExecutionProvider'], [{'device_id': args.device}]) |
|||
inputs = op.tokenize([text]) |
|||
for k, v in inputs.items(): |
|||
inputs[k] = v.cpu().detach().numpy() |
|||
out3 = sess.run(None, input_feed=inputs) |
|||
print('Onnx: OK') |
|||
if numpy.allclose(out1, out3, atol=args.atol): |
|||
print('Check accuracy: OK') |
|||
else: |
|||
max_diff = numpy.abs(out1 - out3).max() |
|||
min_diff = numpy.abs(out1 - out3).min() |
|||
mean_diff = numpy.abs(out1 - out3).mean() |
|||
print(f'Check accuracy: atol is larger than {args.atol}.') |
|||
print(f'Maximum absolute difference is {max_diff}.') |
|||
print(f'Minimum absolute difference is {min_diff}.') |
|||
print(f'Mean difference is {mean_diff}.') |
|||
|
|||
if args.num: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
for _ in range(args.num): |
|||
tokens = op.tokenize([text]) |
|||
outs = sess.run(None, input_feed=inputs) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Onnx qps:', mean(qps)) |
@ -0,0 +1,20 @@ |
|||
from towhee import triton_client |
|||
import sys |
|||
import time |
|||
|
|||
num = int(sys.argv[-1]) |
|||
data = 'Hello, world.' |
|||
client = triton_client.Client('localhost:8000') |
|||
|
|||
# warm up |
|||
client.batch([data]) |
|||
print('client: ok') |
|||
|
|||
time.sleep(5) |
|||
|
|||
print('test...') |
|||
start = time.time() |
|||
client.batch([data] * num, batch_size=8) |
|||
end = time.time() |
|||
print(f'duration: {end - start}') |
|||
print(f'qps: {num / (end - start)}') |
Loading…
Reference in new issue