logo
Browse Source

Fix onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
cf57a2215c
  1. 11
      test_onnx.py
  2. 4
      timm_image.py

11
test_onnx.py

@ -43,12 +43,19 @@ for name in models:
logger.info(f'***{name}***') logger.info(f'***{name}***')
saved_name = name.replace('/', '-') saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx' onnx_path = f'saved/onnx/{saved_name}.onnx'
try:
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
except Exception as e:
print(f'***Please re-download model {name}.***')
logger.error(f'Fail to call model: {e}')
continue
if status: if status:
f.write(','.join(status) + '\n') f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
try: try:
out1 = op.model.forward_features(data).detach().numpy() out1 = op.model.forward_features(data).detach().numpy()
logger.info('OP LOADED.') logger.info('OP LOADED.')

4
timm_image.py

@ -65,7 +65,7 @@ class TimmImage(NNOperator):
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
@ -135,7 +135,7 @@ class TimmImage(NNOperator):
path, path,
input_names=['input_0'], input_names=['input_0'],
output_names=['output_0'], output_names=['output_0'],
opset_version=12,
opset_version=14,
dynamic_axes={ dynamic_axes={
'input_0': {0: 'batch_size'}, 'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'} 'output_0': {0: 'batch_size'}

Loading…
Cancel
Save