logo
Browse Source

Fix run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ea091310ca
  1. 3
      benchmark/run.py

3
benchmark/run.py

@ -118,7 +118,6 @@ elif args.format == 'onnx':
collection_name = collection_name + '_onnx' collection_name = collection_name + '_onnx'
if not os.path.exists(onnx_path): if not os.path.exists(onnx_path):
onnx_path = op.save_model(format='onnx') onnx_path = op.save_model(format='onnx')
sess = onnxruntime.InferenceSession(onnx_path, sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())
@ -126,7 +125,7 @@ elif args.format == 'onnx':
@towhee.register @towhee.register
def run_onnx(img): def run_onnx(img):
img = img.convert('RGB') img = img.convert('RGB')
img = op.tfms(img).cpu().detach().numpy()
img = op.tfms(img).unsqueeze(0).cpu().detach().numpy()
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0]
if len(features.shape) == 4: if len(features.shape) == 4:
global_pool = nn.AdaptiveAvgPool2d(1) global_pool = nn.AdaptiveAvgPool2d(1)

Loading…
Cancel
Save