sbert
copied
4 changed files with 245 additions and 0 deletions
@ -0,0 +1,33 @@ |
|||||
|
# Evaluation |
||||
|
|
||||
|
## Model performance in sentence similarity |
||||
|
|
||||
|
1. Download SentEval & test data |
||||
|
```bash |
||||
|
git clone https://github.com/facebookresearch/SentEval.git |
||||
|
cd SentEval/data/downstream |
||||
|
./get_transfer_data.bash |
||||
|
``` |
||||
|
|
||||
|
2. Run test script |
||||
|
```bash |
||||
|
python evaluate.py MODEL_NAME |
||||
|
``` |
||||
|
|
||||
|
## QPS Test |
||||
|
|
||||
|
Please note that `qps_test.py` uses: |
||||
|
- `localhost:8000`: to connect triton client |
||||
|
- `'Hello, world.''`: as test sentence |
||||
|
|
||||
|
```bash |
||||
|
python qps_test --model sentence-t5-base --pipe --onnx --triton --num 100 |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
- `--model`: mandatory, string, model name |
||||
|
- `--pipe`: optional, on/off flag to enable qps test for pipe |
||||
|
- `--onnx`: optional, on/off flag to enable qps test for onnx |
||||
|
- `--triton`: optional, on/off flag to enable qps for triton (please make sure that triton client is ready) |
||||
|
- `--num`: optional, integer, defaults to 100, batch size in each loop (10 loops in total) |
||||
|
- `--device`: optional, int, defaults to -1, cuda index or use cpu when -1 |
@ -0,0 +1,69 @@ |
|||||
|
# Copyright (c) 2017-present, Facebook, Inc. |
||||
|
# All rights reserved. |
||||
|
# |
||||
|
# This source code is licensed under the license found in the |
||||
|
# LICENSE file in the root directory of this source tree. |
||||
|
# |
||||
|
|
||||
|
""" |
||||
|
Clone repo here: https://github.com/facebookresearch/SentEval.git |
||||
|
""" |
||||
|
|
||||
|
from __future__ import absolute_import, division, unicode_literals |
||||
|
|
||||
|
import sys |
||||
|
import logging |
||||
|
import numpy as np |
||||
|
from towhee import ops |
||||
|
from statistics import mean |
||||
|
|
||||
|
import os |
||||
|
import warnings |
||||
|
from transformers import logging as t_logging |
||||
|
|
||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
||||
|
warnings.filterwarnings("ignore") |
||||
|
t_logging.set_verbosity_error() |
||||
|
|
||||
|
model_name = sys.argv[-1] |
||||
|
op = ops.sentence_embedding.sbert(model_name=model_name, device='cpu').get_op() |
||||
|
# op = ops.text_embedding.sentence_transformers(model_name=model_name, device='cuda:3').get_op() |
||||
|
|
||||
|
# Set PATHs |
||||
|
PATH_TO_SENTEVAL = '../' |
||||
|
PATH_TO_DATA = '../data' |
||||
|
|
||||
|
# import SentEval |
||||
|
sys.path.insert(0, PATH_TO_SENTEVAL) |
||||
|
import senteval |
||||
|
|
||||
|
# SentEval prepare and batcher |
||||
|
def prepare(params, samples): |
||||
|
return |
||||
|
|
||||
|
def batcher(params, batch): |
||||
|
batch = [' '.join(sent) if sent != [] else '.' for sent in batch] |
||||
|
embeddings = op(batch) |
||||
|
return np.vstack(embeddings) |
||||
|
|
||||
|
# Set params for SentEval |
||||
|
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} |
||||
|
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, |
||||
|
'tenacity': 5, 'epoch_size': 4} |
||||
|
|
||||
|
# Set up logger |
||||
|
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
se = senteval.engine.SE(params_senteval, batcher, prepare) |
||||
|
# transfer_tasks = ['STSBenchmark'] |
||||
|
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16'] |
||||
|
results = se.eval(transfer_tasks) |
||||
|
print(results) |
||||
|
p = [] |
||||
|
s = [] |
||||
|
for t in transfer_tasks: |
||||
|
p.append(res['pearson']['mean']) |
||||
|
s.append(res['spearman']['mean']) |
||||
|
print('pearson:', mean(p)) |
||||
|
print('spearman:', mean(s)) |
@ -0,0 +1,123 @@ |
|||||
|
from towhee import AutoPipes, AutoConfig, pipe, ops, triton_client |
||||
|
|
||||
|
import onnxruntime |
||||
|
import numpy |
||||
|
import torch |
||||
|
from statistics import mean |
||||
|
|
||||
|
import time |
||||
|
import argparse |
||||
|
|
||||
|
import os |
||||
|
import re |
||||
|
import warnings |
||||
|
import logging |
||||
|
from transformers import logging as t_logging |
||||
|
|
||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
||||
|
warnings.filterwarnings("ignore") |
||||
|
t_logging.set_verbosity_error() |
||||
|
|
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--model', required=True, type=str) |
||||
|
parser.add_argument('--pipe', action='store_true') |
||||
|
parser.add_argument('--triton', action='store_true') |
||||
|
parser.add_argument('--onnx', action='store_true') |
||||
|
parser.add_argument('--atol', type=float, default=1e-3) |
||||
|
parser.add_argument('--num', type=int, default=100) |
||||
|
parser.add_argument('--device', type=int, default=-1) |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
device = 'cuda:' + str(args.device) if args.device >= 0 else 'cpu' |
||||
|
model_name = args.model |
||||
|
|
||||
|
p = ( |
||||
|
pipe.input('text') |
||||
|
.map('text', 'vec', ops.sentence_embedding.sbert(model_name=model_name, device=device)) |
||||
|
.output('vec') |
||||
|
) |
||||
|
|
||||
|
# conf = AutoConfig.load_config('sentence_embedding') |
||||
|
# conf.model = model_name |
||||
|
# conf.device = args.device |
||||
|
# p = AutoPipes.pipeline('sentence_embedding', conf) |
||||
|
|
||||
|
text = 'Hello, world.' |
||||
|
out1 = p(text).get()[0] |
||||
|
print('Pipe: OK') |
||||
|
|
||||
|
if args.num and args.pipe: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
# p([text] * args.num) |
||||
|
p.batch([text] * args.num) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Pipe qps:', mean(qps)) |
||||
|
|
||||
|
if args.triton: |
||||
|
client = triton_client.Client(url='localhost:8000') |
||||
|
out2 = client(text)[0][0] |
||||
|
print('Triton: OK') |
||||
|
|
||||
|
if numpy.allclose(out1, out2, atol=args.atol): |
||||
|
print('Check accuracy: OK') |
||||
|
else: |
||||
|
max_diff = numpy.abs(out1 - out2).max() |
||||
|
min_diff = numpy.abs(out1 - out2).min() |
||||
|
mean_diff = numpy.abs(out1 - out2).mean() |
||||
|
print(f'Check accuracy: atol is larger than {args.atol}.') |
||||
|
print(f'Maximum absolute difference is {max_diff}.') |
||||
|
print(f'Minimum absolute difference is {min_diff}.') |
||||
|
print(f'Mean difference is {mean_diff}.') |
||||
|
|
||||
|
if args.num: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
client.batch([text] * args.num) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Triton qps:', mean(qps)) |
||||
|
|
||||
|
if args.onnx: |
||||
|
op = ops.sentence_embedding.sbert(model_name=model_name, device=device).get_op() |
||||
|
# if not os.path.exists('test.onnx'): |
||||
|
op.save_model('onnx', 'test.onnx') |
||||
|
if device == 'cpu': |
||||
|
providers = ['CPUExecutionProvider'] |
||||
|
sess = onnxruntime.InferenceSession('test.onnx', |
||||
|
providers=providers) |
||||
|
if device != 'cpu': |
||||
|
sess.set_providers(['CUDAExecutionProvider'], [{'device_id': args.device}]) |
||||
|
inputs = op.tokenize([text]) |
||||
|
for k, v in inputs.items(): |
||||
|
inputs[k] = v.cpu().detach().numpy() |
||||
|
out3 = sess.run(None, input_feed=inputs) |
||||
|
print('Onnx: OK') |
||||
|
if numpy.allclose(out1, out3, atol=args.atol): |
||||
|
print('Check accuracy: OK') |
||||
|
else: |
||||
|
max_diff = numpy.abs(out1 - out3).max() |
||||
|
min_diff = numpy.abs(out1 - out3).min() |
||||
|
mean_diff = numpy.abs(out1 - out3).mean() |
||||
|
print(f'Check accuracy: atol is larger than {args.atol}.') |
||||
|
print(f'Maximum absolute difference is {max_diff}.') |
||||
|
print(f'Minimum absolute difference is {min_diff}.') |
||||
|
print(f'Mean difference is {mean_diff}.') |
||||
|
|
||||
|
if args.num: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
for _ in range(args.num): |
||||
|
tokens = op.tokenize([text]) |
||||
|
outs = sess.run(None, input_feed=inputs) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Onnx qps:', mean(qps)) |
@ -0,0 +1,20 @@ |
|||||
|
from towhee import triton_client |
||||
|
import sys |
||||
|
import time |
||||
|
|
||||
|
num = int(sys.argv[-1]) |
||||
|
data = 'Hello, world.' |
||||
|
client = triton_client.Client('localhost:8000') |
||||
|
|
||||
|
# warm up |
||||
|
client.batch([data]) |
||||
|
print('client: ok') |
||||
|
|
||||
|
time.sleep(5) |
||||
|
|
||||
|
print('test...') |
||||
|
start = time.time() |
||||
|
client.batch([data] * num, batch_size=8) |
||||
|
end = time.time() |
||||
|
print(f'duration: {end - start}') |
||||
|
print(f'qps: {num / (end - start)}') |
Loading…
Reference in new issue