diff --git a/test_onnx.py b/test_onnx.py index 5498278..2ded7d9 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -43,12 +43,19 @@ for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') onnx_path = f'saved/onnx/{saved_name}.onnx' + + try: + op = TimmImage(model_name=name, device='cpu') + data = torch.rand((1,) + op.config['input_size']) + except Exception as e: + print(f'***Please re-download model {name}.***') + logger.error(f'Fail to call model: {e}') + continue + if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 - op = TimmImage(model_name=name, device='cpu') - data = torch.rand((1,) + op.config['input_size']) try: out1 = op.model.forward_features(data).detach().numpy() logger.info('OP LOADED.') diff --git a/timm_image.py b/timm_image.py index b56f3b2..f21f784 100644 --- a/timm_image.py +++ b/timm_image.py @@ -65,7 +65,7 @@ class TimmImage(NNOperator): self.device = device self.model_name = model_name if self.model_name: - self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) + self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) self.model.eval() self.model.to(self.device) @@ -135,7 +135,7 @@ class TimmImage(NNOperator): path, input_names=['input_0'], output_names=['output_0'], - opset_version=12, + opset_version=14, dynamic_axes={ 'input_0': {0: 'batch_size'}, 'output_0': {0: 'batch_size'}