diff --git a/benchmark/run.py b/benchmark/run.py index 69ed532..4b26e99 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -129,6 +129,7 @@ elif args.format == 'onnx': img = op.tfms(img).unsqueeze(0).cpu().detach().numpy() features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] if len(features.shape) == 4: + features = torch.from_numpy(features) global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) outs = features.squeeze(0)