transformers
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
133 lines
4.4 KiB
133 lines
4.4 KiB
import towhee
|
|
from towhee.dc2 import AutoPipes, AutoConfig
|
|
from towhee import triton_client, ops
|
|
|
|
import onnxruntime
|
|
import numpy
|
|
import torch
|
|
from statistics import mean
|
|
|
|
import time
|
|
import argparse
|
|
|
|
import os
|
|
import re
|
|
import warnings
|
|
import logging
|
|
from transformers import logging as t_logging
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
warnings.filterwarnings("ignore")
|
|
t_logging.set_verbosity_error()
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model', required=True, type=str)
|
|
parser.add_argument('--pipe', action='store_true')
|
|
parser.add_argument('--triton', action='store_true')
|
|
parser.add_argument('--onnx', action='store_true')
|
|
parser.add_argument('--atol', type=float, default=1e-3)
|
|
parser.add_argument('--num', type=int, default=100)
|
|
parser.add_argument('--device', type=int, default=-1)
|
|
|
|
args = parser.parse_args()
|
|
|
|
device = 'cuda:' + str(args.device) if args.device >= 0 else 'cpu'
|
|
model_name = args.model
|
|
# model_name = 'paraphrase-albert-small-v2'
|
|
# model_name = 'all-MiniLM-L6-v2'
|
|
# model_name = 'all-mpnet-base-v2'
|
|
# model_name = 'distilbert-base-uncased'
|
|
|
|
# p = (
|
|
# pipe.input('text')
|
|
# .map('text', 'vec', ops.sentence_embedding.transformers(model_name=model_name, device='cuda:3'))
|
|
# .output('vec')
|
|
# )
|
|
|
|
conf = AutoConfig.load_config('sentence_embedding')
|
|
conf.model = model_name
|
|
conf.device = args.device
|
|
p = AutoPipes.pipeline('sentence_embedding', conf)
|
|
|
|
text = 'Hello, world.'
|
|
out1 = p(text).get()[0]
|
|
print('Pipe: OK')
|
|
|
|
if args.num and args.pipe:
|
|
qps = []
|
|
for _ in range(10):
|
|
start = time.time()
|
|
# p([text] * args.num)
|
|
p.batch([text] * args.num)
|
|
end = time.time()
|
|
q = args.num / (end - start)
|
|
qps.append(q)
|
|
print('Pipe qps:', mean(qps))
|
|
|
|
if args.triton:
|
|
client = triton_client.Client(url='localhost:8000')
|
|
out2 = client(text)[0][0]
|
|
print('Triton: OK')
|
|
|
|
if numpy.allclose(out1, out2, atol=args.atol):
|
|
print('Check accuracy: OK')
|
|
else:
|
|
max_diff = numpy.abs(out1 - out2).max()
|
|
min_diff = numpy.abs(out1 - out2).min()
|
|
mean_diff = numpy.abs(out1 - out2).mean()
|
|
print(f'Check accuracy: atol is larger than {args.atol}.')
|
|
print(f'Maximum absolute difference is {max_diff}.')
|
|
print(f'Minimum absolute difference is {min_diff}.')
|
|
print(f'Mean difference is {mean_diff}.')
|
|
|
|
if args.num:
|
|
qps = []
|
|
for _ in range(10):
|
|
start = time.time()
|
|
client.batch([text] * args.num)
|
|
end = time.time()
|
|
q = args.num / (end - start)
|
|
qps.append(q)
|
|
print('Triton qps:', mean(qps))
|
|
|
|
if args.onnx:
|
|
op = ops.sentence_embedding.transformers(model_name=model_name, device='cpu').get_op()
|
|
# if not os.path.exists('test.onnx'):
|
|
op.save_model('onnx', 'test.onnx')
|
|
sess = onnxruntime.InferenceSession('test.onnx',
|
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
|
sess.set_providers(['CUDAExecutionProvider'], [{'device_id': 0 if args.device < 0 else args.device}])
|
|
inputs = op.tokenizer([text], padding=True, truncation=True, return_tensors='np')
|
|
# inputs = {}
|
|
# for k, v in tokens.items():
|
|
# if k in op.onnx_config['inputs'].keys():
|
|
# inputs[k] = v
|
|
out3 = sess.run(None, input_feed=dict(inputs))
|
|
torch_inputs = {}
|
|
for k, v in inputs.items():
|
|
torch_inputs[k] = torch.from_numpy(v)
|
|
out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy()
|
|
print('Onnx: OK')
|
|
if numpy.allclose(out1, out3, atol=args.atol):
|
|
print('Check accuracy: OK')
|
|
else:
|
|
max_diff = numpy.abs(out1 - out3).max()
|
|
min_diff = numpy.abs(out1 - out3).min()
|
|
mean_diff = numpy.abs(out1 - out3).mean()
|
|
print(f'Check accuracy: atol is larger than {args.atol}.')
|
|
print(f'Maximum absolute difference is {max_diff}.')
|
|
print(f'Minimum absolute difference is {min_diff}.')
|
|
print(f'Mean difference is {mean_diff}.')
|
|
|
|
if args.num:
|
|
qps = []
|
|
for _ in range(10):
|
|
start = time.time()
|
|
for _ in range(args.num):
|
|
tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt')
|
|
outs = sess.run(None, input_feed=dict(inputs))
|
|
op.post_proc(torch.from_numpy(outs[0]), torch_inputs)
|
|
end = time.time()
|
|
q = args.num / (end - start)
|
|
qps.append(q)
|
|
print('Onnx qps:', mean(qps))
|