diff --git a/test.py b/test.py index 1a3419d..b230402 100644 --- a/test.py +++ b/test.py @@ -23,13 +23,13 @@ out0 = op.get_op().model(audio) # Test Pytorch op.get_op().save_model(format='pytorch') -op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt') +op = ops.audio_embedding.nnfp(model_path='./saved/pytorch/nnfp.pt') out1 = op.get_op().model(audio) assert ((out0 == out1).all()) # Test Torchscript op.get_op().save_model(format='torchscript') -op = ops.audio_embedding.nnfp(checkpoint_path='./saved/torchscript/nnfp.pt') +op = ops.audio_embedding.nnfp(model_path='./saved/torchscript/nnfp.pt') out2 = op.get_op().model(audio) assert ((out0 == out2).all())