diff --git a/test.py b/test.py index 62ab8a4..1a3419d 100644 --- a/test.py +++ b/test.py @@ -7,6 +7,7 @@ import numpy import onnx import onnxruntime +device = 'cuda' if torch.cuda.is_available() else 'cpu' def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() @@ -14,7 +15,7 @@ def to_numpy(tensor): # decode = ops.audio_decode.ffmpeg() # audio = [x[0] for x in decode('path/to/audio.wav')] -audio = torch.rand(10, 256, 32) +audio = torch.rand(10, 256, 32).to(device) op = ops.audio_embedding.nnfp() out0 = op.get_op().model(audio)