timm
copied
Jael Gu
2 years ago
4 changed files with 179 additions and 10 deletions
@ -0,0 +1,131 @@ |
|||||
|
import towhee |
||||
|
from towhee.dc2 import pipe, ops |
||||
|
from towhee import triton_client |
||||
|
|
||||
|
import onnxruntime |
||||
|
import numpy |
||||
|
import torch |
||||
|
from statistics import mean |
||||
|
|
||||
|
import time |
||||
|
import argparse |
||||
|
|
||||
|
import os |
||||
|
import re |
||||
|
import warnings |
||||
|
import logging |
||||
|
from transformers import logging as t_logging |
||||
|
|
||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
||||
|
warnings.filterwarnings("ignore") |
||||
|
t_logging.set_verbosity_error() |
||||
|
|
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--model', required=True, type=str) |
||||
|
parser.add_argument('--pipe', action='store_true') |
||||
|
parser.add_argument('--triton', action='store_true') |
||||
|
parser.add_argument('--onnx', action='store_true') |
||||
|
parser.add_argument('--atol', type=float, default=1e-3) |
||||
|
parser.add_argument('--num', type=int, default=100) |
||||
|
parser.add_argument('--device', type=str, default='cpu') |
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
model_name = args.model |
||||
|
# model_name = 'resnet50' |
||||
|
# model_name = 'vgg16' |
||||
|
# model_name = 'deit3_base_patch16_224' |
||||
|
# model_name = 'deit_tiny_patch16_224' |
||||
|
# model_name = 'deit_base_distilled_patch16_224' |
||||
|
# model_name = 'convnext_base' |
||||
|
# model_name = 'vit_base_patch16_224' |
||||
|
# model_name = 'tf_efficientnet_b5' |
||||
|
|
||||
|
|
||||
|
p = ( |
||||
|
pipe.input('url') |
||||
|
.map('url', 'img', ops.image_decode.cv2_rgb()) |
||||
|
.map('img', 'vec', ops.image_embedding.timm(model_name=model_name, device=args.device)) |
||||
|
.output('vec') |
||||
|
) |
||||
|
|
||||
|
data = '../towhee.jpeg' |
||||
|
out1 = p(data).get()[0] |
||||
|
print('Pipe: OK') |
||||
|
|
||||
|
if args.num and args.pipe: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
p.batch([data] * args.num) |
||||
|
# for _ in range(args.num): |
||||
|
# p(data) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Pipe qps:', mean(qps)) |
||||
|
|
||||
|
if args.triton: |
||||
|
client = triton_client.Client(url='localhost:8000') |
||||
|
out2 = client(data)[0][0][0] |
||||
|
print('Triton: OK') |
||||
|
|
||||
|
if numpy.allclose(out1, out2, atol=args.atol): |
||||
|
print('Check accuracy: OK') |
||||
|
else: |
||||
|
max_diff = numpy.abs(out1 - out2).max() |
||||
|
min_diff = numpy.abs(out1 - out2).min() |
||||
|
mean_diff = numpy.abs(out1 - out2).mean() |
||||
|
print(f'Check accuracy: atol is larger than {args.atol}.') |
||||
|
print(f'Maximum absolute difference is {max_diff}.') |
||||
|
print(f'Minimum absolute difference is {min_diff}.') |
||||
|
print(f'Mean difference is {mean_diff}.') |
||||
|
|
||||
|
if args.num: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
client.batch([data] * args.num) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Triton qps:', mean(qps)) |
||||
|
|
||||
|
if args.onnx: |
||||
|
op = ops.image_embedding.timm(model_name=model_name, device='cpu').get_op() |
||||
|
decoder = ops.image_decode.cv2_rgb().get_op() |
||||
|
# if not os.path.exists('test.onnx'): |
||||
|
op.save_model('onnx', 'test.onnx') |
||||
|
sess = onnxruntime.InferenceSession('test.onnx', |
||||
|
providers=['CUDAExecutionProvider']) |
||||
|
inputs = decoder(data) |
||||
|
inputs = op.convert_img(inputs) |
||||
|
inputs = op.tfms(inputs).unsqueeze(0) |
||||
|
out3 = sess.run(None, input_feed={'input_0': inputs.cpu().detach().numpy()})[0] |
||||
|
op.device = 'cuda' if args.device != 'cpu' else 'cpu' |
||||
|
out3 = op.post_proc(torch.from_numpy(out3)).cpu().detach().numpy() |
||||
|
print('Onnx: OK') |
||||
|
if numpy.allclose(out1, out3, atol=args.atol): |
||||
|
print('Check accuracy: OK') |
||||
|
else: |
||||
|
max_diff = numpy.abs(out1 - out3).max() |
||||
|
min_diff = numpy.abs(out1 - out3).min() |
||||
|
mean_diff = numpy.abs(out1 - out3).mean() |
||||
|
print(f'Check accuracy: atol is larger than {args.atol}.') |
||||
|
print(f'Maximum absolute difference is {max_diff}.') |
||||
|
print(f'Minimum absolute difference is {min_diff}.') |
||||
|
print(f'Mean difference is {mean_diff}.') |
||||
|
|
||||
|
if args.num: |
||||
|
qps = [] |
||||
|
for _ in range(10): |
||||
|
start = time.time() |
||||
|
for _ in range(args.num): |
||||
|
inputs = decoder(data) |
||||
|
inputs = op.convert_img(inputs) |
||||
|
inputs = op.tfms(inputs).unsqueeze(0) |
||||
|
outs = sess.run(None, input_feed={'input_0': inputs.cpu().detach().numpy()})[0] |
||||
|
outs = op.post_proc(torch.from_numpy(outs)) |
||||
|
end = time.time() |
||||
|
q = args.num / (end - start) |
||||
|
qps.append(q) |
||||
|
print('Onnx qps:', mean(qps)) |
Loading…
Reference in new issue