From 2604d10a698b1d2e56565693b47c8285fbed12d6 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 16 Aug 2022 11:42:19 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- nn_fingerprint.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 1432e82..2bb88c2 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -46,9 +46,12 @@ class NNFingerprint(NNOperator): params: dict = None, model_path: str = None, framework: str = 'pytorch', + device: str = None ): super().__init__(framework=framework) - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device if params is None: self.params = default_params else: @@ -87,7 +90,9 @@ class NNFingerprint(NNOperator): log.info('Model is loaded.') def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: - audio_tensors = self.preprocess(data).to(self.device) + audio_tensors = self.preprocess(data) + if audio_tensors.device != self.device: + audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) if isinstance(self.model, onnxruntime.InferenceSession): audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \