logo
Browse Source

Fix run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
fa1e717980
  1. 1
      benchmark/run.py

1
benchmark/run.py

@ -129,6 +129,7 @@ elif args.format == 'onnx':
img = op.tfms(img).unsqueeze(0).cpu().detach().numpy()
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0]
if len(features.shape) == 4:
features = torch.from_numpy(features)
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
outs = features.squeeze(0)

Loading…
Cancel
Save