Browse Source
Fix run.py
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
1 additions and
2 deletions
-
benchmark/run.py
|
@ -118,7 +118,6 @@ elif args.format == 'onnx': |
|
|
collection_name = collection_name + '_onnx' |
|
|
collection_name = collection_name + '_onnx' |
|
|
if not os.path.exists(onnx_path): |
|
|
if not os.path.exists(onnx_path): |
|
|
onnx_path = op.save_model(format='onnx') |
|
|
onnx_path = op.save_model(format='onnx') |
|
|
|
|
|
|
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
@ -126,7 +125,7 @@ elif args.format == 'onnx': |
|
|
@towhee.register |
|
|
@towhee.register |
|
|
def run_onnx(img): |
|
|
def run_onnx(img): |
|
|
img = img.convert('RGB') |
|
|
img = img.convert('RGB') |
|
|
img = op.tfms(img).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
img = op.tfms(img).unsqueeze(0).cpu().detach().numpy() |
|
|
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] |
|
|
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] |
|
|
if len(features.shape) == 4: |
|
|
if len(features.shape) == 4: |
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|