diff --git a/nn_fingerprint.py b/nn_fingerprint.py index d71f6d7..46e6c7d 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -67,7 +67,13 @@ class Model: log.info('Model is loaded.') def __call__(self, *args, **kwargs): - return self.model(*args, **kwargs) + new_args = [] + new_kwargs = {} + for x in new_args: + x = x.to(self.device) + for k, v in kwargs.items(): + new_kwargs[k] = v.to(self.device) + return self.model(*new_args, **new_kwargs) @register(output_schema=['vecs']) @@ -103,8 +109,6 @@ class NNFingerprint(NNOperator): def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: audio_tensors = self.preprocess(data) - if audio_tensors.device != self.device: - audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) features = self.model(audio_tensors) outs = features.detach().cpu().numpy() @@ -141,7 +145,7 @@ class NNFingerprint(NNOperator): wav = preprocess_wav(audio, segment_size=int(self.params['sample_rate'] * self.params['segment_size']), hop_size=int(self.params['sample_rate'] * self.params['hop_size']), - frame_shift_mul=self.params['frame_shift_mul']).to(self.device) + frame_shift_mul=self.params['frame_shift_mul']) wav = wav.to(torch.float32) mel = MelSpec(sample_rate=self.params['sample_rate'], window_length=self.params['window_length'], @@ -151,7 +155,7 @@ class NNFingerprint(NNOperator): n_mels=self.params['n_mels'], naf_mode=self.params['naf_mode'], mel_log=self.params['mel_log'], - spec_norm=self.params['spec_norm']).to(self.device) + spec_norm=self.params['spec_norm']) wav = mel(wav) return wav