diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 1432e82..2bb88c2 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -46,9 +46,12 @@ class NNFingerprint(NNOperator): params: dict = None, model_path: str = None, framework: str = 'pytorch', + device: str = None ): super().__init__(framework=framework) - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = device if params is None: self.params = default_params else: @@ -87,7 +90,9 @@ class NNFingerprint(NNOperator): log.info('Model is loaded.') def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: - audio_tensors = self.preprocess(data).to(self.device) + audio_tensors = self.preprocess(data) + if audio_tensors.device != self.device: + audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) if isinstance(self.model, onnxruntime.InferenceSession): audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \