Browse Source
Fix onnx
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
11 additions and
4 deletions
-
test_onnx.py
-
timm_image.py
|
@ -43,12 +43,19 @@ for name in models: |
|
|
logger.info(f'***{name}***') |
|
|
logger.info(f'***{name}***') |
|
|
saved_name = name.replace('/', '-') |
|
|
saved_name = name.replace('/', '-') |
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
|
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f'***Please re-download model {name}.***') |
|
|
|
|
|
logger.error(f'Fail to call model: {e}') |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
if status: |
|
|
if status: |
|
|
f.write(','.join(status) + '\n') |
|
|
f.write(','.join(status) + '\n') |
|
|
status = [name] + ['fail'] * 5 |
|
|
status = [name] + ['fail'] * 5 |
|
|
|
|
|
|
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
|
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
|
|
|
try: |
|
|
try: |
|
|
out1 = op.model.forward_features(data).detach().numpy() |
|
|
out1 = op.model.forward_features(data).detach().numpy() |
|
|
logger.info('OP LOADED.') |
|
|
logger.info('OP LOADED.') |
|
|
|
@ -65,7 +65,7 @@ class TimmImage(NNOperator): |
|
|
self.device = device |
|
|
self.device = device |
|
|
self.model_name = model_name |
|
|
self.model_name = model_name |
|
|
if self.model_name: |
|
|
if self.model_name: |
|
|
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) |
|
|
|
|
|
|
|
|
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
self.model.to(self.device) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
@ -135,7 +135,7 @@ class TimmImage(NNOperator): |
|
|
path, |
|
|
path, |
|
|
input_names=['input_0'], |
|
|
input_names=['input_0'], |
|
|
output_names=['output_0'], |
|
|
output_names=['output_0'], |
|
|
opset_version=12, |
|
|
|
|
|
|
|
|
opset_version=14, |
|
|
dynamic_axes={ |
|
|
dynamic_axes={ |
|
|
'input_0': {0: 'batch_size'}, |
|
|
'input_0': {0: 'batch_size'}, |
|
|
'output_0': {0: 'batch_size'} |
|
|
'output_0': {0: 'batch_size'} |
|
|