|
@ -43,7 +43,7 @@ status = None |
|
|
for name in models: |
|
|
for name in models: |
|
|
logger.info(f'***{name}***') |
|
|
logger.info(f'***{name}***') |
|
|
saved_name = name.replace('/', '-') |
|
|
saved_name = name.replace('/', '-') |
|
|
onnx_path = f'saved/onnx/{saved_name}/model.onnx' |
|
|
|
|
|
|
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
if status: |
|
|
if status: |
|
|
f.write(','.join(status) + '\n') |
|
|
f.write(','.join(status) + '\n') |
|
|
status = [name] + ['fail'] * 5 |
|
|
status = [name] + ['fail'] * 5 |
|
@ -78,7 +78,7 @@ for name in models: |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
inputs = op.tokenizer(test_txt, return_tensors='np') |
|
|
inputs = op.tokenizer(test_txt, return_tensors='np') |
|
|
out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs)) |
|
|
|
|
|
|
|
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs)) |
|
|
logger.info('ONNX WORKED.') |
|
|
logger.info('ONNX WORKED.') |
|
|
status[4] = 'success' |
|
|
status[4] = 'success' |
|
|
if numpy.allclose(out1, out2, atol=atol): |
|
|
if numpy.allclose(out1, out2, atol=atol): |
|
|