diff --git a/benchmark/run.py b/benchmark/run.py index b28f436..49cf48f 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -49,7 +49,7 @@ query_data = query_data[:query_size] if query_size else query_data # Warm up print('Warming up...') op = ops.image_embedding.timm(model_name=model_name, device=device).get_op() -dummy_input = numpy.random.randn(1, 3, op.config['input_size'][-2], op.config['input_size'][-1]) +dummy_input = numpy.random.randn(op.config['input_size'][0]) dim = op(dummy_input).shape[0] print(f'output dim: {dim}') @@ -120,6 +120,7 @@ elif args.format == 'onnx': @towhee.register def run_onnx(img): + img = img.convert('RGB') img = op.tfms(img).cpu().detach().numpy() features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] if len(features.shape) == 4: