diff --git a/auto_transformers.py b/auto_transformers.py index aca2eba..9ae436d 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -239,7 +239,6 @@ class AutoTransformers(NNOperator): 'gpt2-xl', 'microsoft/deberta-xlarge', 'microsoft/deberta-xlarge-mnli', - 'msmarco-distilbert-base-v4', ] full_list = s_list + add_models full_list.sort() @@ -254,7 +253,7 @@ class AutoTransformers(NNOperator): assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': - to_remove = [] + to_remove = ['gpt2-xl'] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) # todo: elif format == 'tensorrt': diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000..d10880b --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,15 @@ +# Evaluation + +## Model performance in sentence similarity + +1. Download SentEval & test data +```bash +git clone https://github.com/facebookresearch/SentEval.git +cd SentEval/data/downstream +./get_transfer_data.bash +``` + +2. Run test script +```bash +python transformers_test.py MODEL_NAME +``` \ No newline at end of file diff --git a/benchmark/evaluate.py b/benchmark/evaluate.py new file mode 100644 index 0000000..d9489a0 --- /dev/null +++ b/benchmark/evaluate.py @@ -0,0 +1,71 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Clone GenSen repo here: https://github.com/Maluuba/gensen.git +And follow instructions for loading the model used in batcher +""" + +from __future__ import absolute_import, division, unicode_literals + +import sys +import logging +import numpy as np +from towhee import ops +from statistics import mean + +import os +import warnings +from transformers import logging as t_logging + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +warnings.filterwarnings("ignore") +t_logging.set_verbosity_error() + +model_name = sys.argv[-1] +op = ops.sentence_embedding.transformers(model_name=model_name, device='cuda:3').get_op() +# op = ops.text_embedding.sentence_transformers(model_name=model_name, device='cuda:3').get_op() + +# Set PATHs +PATH_TO_SENTEVAL = '../' +PATH_TO_DATA = '../data' + +# import SentEval +sys.path.insert(0, PATH_TO_SENTEVAL) +import senteval + +# SentEval prepare and batcher +def prepare(params, samples): + return + +def batcher(params, batch): + batch = [' '.join(sent) if sent != [] else '.' for sent in batch] + embeddings = op(batch) + return np.vstack(embeddings) + +# Set params for SentEval +params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} +params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, + 'tenacity': 5, 'epoch_size': 4} + +# Set up logger +logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) + +if __name__ == "__main__": + se = senteval.engine.SE(params_senteval, batcher, prepare) + # transfer_tasks = ['STSBenchmark'] + # transfer_tasks = ['STS16'] + transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16'] + results = se.eval(transfer_tasks) + p = [] + s = [] + for t in transfer_tasks: + res = results[t]['all'] + p.append(res['pearson']['mean']) + s.append(res['spearman']['mean']) + print('pearson:', mean(p)) + print('spearman:', mean(s)) diff --git a/benchmark/qps_test.py b/benchmark/qps_test.py new file mode 100644 index 0000000..2221b27 --- /dev/null +++ b/benchmark/qps_test.py @@ -0,0 +1,132 @@ +import towhee +from towhee.dc2 import AutoPipes, AutoConfig +from towhee import triton_client, ops + +import onnxruntime +import numpy +import torch +from statistics import mean + +import time +import argparse + +import os +import re +import warnings +import logging +from transformers import logging as t_logging + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +warnings.filterwarnings("ignore") +t_logging.set_verbosity_error() + +parser = argparse.ArgumentParser() +parser.add_argument('--model', required=True, type=str) +parser.add_argument('--pipe', action='store_true') +parser.add_argument('--triton', action='store_true') +parser.add_argument('--onnx', action='store_true') +parser.add_argument('--atol', type=float, default=1e-3) +parser.add_argument('--num', type=int, default=100) +parser.add_argument('--device', type=int, default=-1) + +args = parser.parse_args() + +device = 'cuda:' + str(args.device) if args.device >= 0 else 'cpu' +model_name = args.model +# model_name = 'paraphrase-albert-small-v2' +# model_name = 'all-MiniLM-L6-v2' +# model_name = 'all-mpnet-base-v2' +# model_name = 'distilbert-base-uncased' + +# p = ( +# pipe.input('text') +# .map('text', 'vec', ops.sentence_embedding.transformers(model_name=model_name, device='cuda:3')) +# .output('vec') +# ) + +conf = AutoConfig.load_config('sentence_embedding') +conf.model = model_name +conf.device = args.device +p = AutoPipes.pipeline('sentence_embedding', conf) + +text = 'Hello, world.' +out1 = p(text).get()[0] +print('Pipe: OK') + +if args.num and args.pipe: + qps = [] + for _ in range(10): + start = time.time() + # p([text] * args.num) + p.batch([text] * args.num) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Pipe qps:', mean(qps)) + +if args.triton: + client = triton_client.Client(url='172.16.70.4:8101') + out2 = client(text)[0][0] + print('Triton: OK') + + if numpy.allclose(out1, out2, atol=args.atol): + print('Check accuracy: OK') + else: + max_diff = numpy.abs(out1 - out2).max() + min_diff = numpy.abs(out1 - out2).min() + mean_diff = numpy.abs(out1 - out2).mean() + print(f'Check accuracy: atol is larger than {args.atol}.') + print(f'Maximum absolute difference is {max_diff}.') + print(f'Minimum absolute difference is {min_diff}.') + print(f'Mean difference is {mean_diff}.') + + if args.num: + qps = [] + for _ in range(10): + start = time.time() + client.batch([text] * args.num) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Triton qps:', mean(qps)) + +if args.onnx: + op = ops.sentence_embedding.transformers(model_name=model_name, device='cpu').get_op() + # if not os.path.exists('test.onnx'): + op.save_model('onnx', 'test.onnx') + sess = onnxruntime.InferenceSession('test.onnx', + providers=['CUDAExecutionProvider']) + inputs = op.tokenizer([text], padding=True, truncation=True, return_tensors='np') + # inputs = {} + # for k, v in tokens.items(): + # if k in op.onnx_config['inputs'].keys(): + # inputs[k] = v + out3 = sess.run(None, input_feed=dict(inputs)) + torch_inputs = {} + for k, v in inputs.items(): + torch_inputs[k] = torch.from_numpy(v).to(device) + out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy() + print('Onnx: OK') + if numpy.allclose(out1, out3, atol=args.atol): + print('Check accuracy: OK') + else: + max_diff = numpy.abs(out1 - out3).max() + min_diff = numpy.abs(out1 - out3).min() + mean_diff = numpy.abs(out1 - out3).mean() + print(f'Check accuracy: atol is larger than {args.atol}.') + print(f'Maximum absolute difference is {max_diff}.') + print(f'Minimum absolute difference is {min_diff}.') + print(f'Mean difference is {mean_diff}.') + + if args.num: + qps = [] + for _ in range(10): + start = time.time() + for _ in range(args.num): + tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt') + outs = sess.run(None, input_feed=dict(inputs)) + op.post_proc(torch.from_numpy(outs[0]).to(device), torch_inputs) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Onnx qps:', mean(qps)) diff --git a/benchmark/test_client.py b/benchmark/test_client.py new file mode 100644 index 0000000..bb6aa3b --- /dev/null +++ b/benchmark/test_client.py @@ -0,0 +1,24 @@ +from towhee import triton_client +import sys +import time + +num = int(sys.argv[-1]) +client = triton_client.Client(url='172.16.70.4:8101') + +data = 'hello' +# data = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.' + +res = client(data) +print(res[0][0].shape) + +res = client([data]) +print(len(res[0][0]), res[0][0][0].shape) + +# Or run data with batch +start = time.time() +res = client.batch([data] * num, batch_size=8) +end = time.time() +print(num / (end-start)) +# print(len(res), res[0][0][0].shape) + +client.close()