logo
Browse Source

Fix onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
cf57a2215c
  1. 11
      test_onnx.py
  2. 4
      timm_image.py

11
test_onnx.py

@ -43,12 +43,19 @@ for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
except Exception as e:
print(f'***Please re-download model {name}.***')
logger.error(f'Fail to call model: {e}')
continue
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
try:
out1 = op.model.forward_features(data).detach().numpy()
logger.info('OP LOADED.')

4
timm_image.py

@ -65,7 +65,7 @@ class TimmImage(NNOperator):
self.device = device
self.model_name = model_name
if self.model_name:
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
@ -135,7 +135,7 @@ class TimmImage(NNOperator):
path,
input_names=['input_0'],
output_names=['output_0'],
opset_version=12,
opset_version=14,
dynamic_axes={
'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}

Loading…
Cancel
Save