logo
Browse Source

Update device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
88fce5b9fa
  1. 14
      nn_fingerprint.py

14
nn_fingerprint.py

@ -67,7 +67,13 @@ class Model:
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
new_args = []
new_kwargs = {}
for x in new_args:
x = x.to(self.device)
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
return self.model(*new_args, **new_kwargs)
@register(output_schema=['vecs']) @register(output_schema=['vecs'])
@ -103,8 +109,6 @@ class NNFingerprint(NNOperator):
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data) audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape) # print(audio_tensors.shape)
features = self.model(audio_tensors) features = self.model(audio_tensors)
outs = features.detach().cpu().numpy() outs = features.detach().cpu().numpy()
@ -141,7 +145,7 @@ class NNFingerprint(NNOperator):
wav = preprocess_wav(audio, wav = preprocess_wav(audio,
segment_size=int(self.params['sample_rate'] * self.params['segment_size']), segment_size=int(self.params['sample_rate'] * self.params['segment_size']),
hop_size=int(self.params['sample_rate'] * self.params['hop_size']), hop_size=int(self.params['sample_rate'] * self.params['hop_size']),
frame_shift_mul=self.params['frame_shift_mul']).to(self.device)
frame_shift_mul=self.params['frame_shift_mul'])
wav = wav.to(torch.float32) wav = wav.to(torch.float32)
mel = MelSpec(sample_rate=self.params['sample_rate'], mel = MelSpec(sample_rate=self.params['sample_rate'],
window_length=self.params['window_length'], window_length=self.params['window_length'],
@ -151,7 +155,7 @@ class NNFingerprint(NNOperator):
n_mels=self.params['n_mels'], n_mels=self.params['n_mels'],
naf_mode=self.params['naf_mode'], naf_mode=self.params['naf_mode'],
mel_log=self.params['mel_log'], mel_log=self.params['mel_log'],
spec_norm=self.params['spec_norm']).to(self.device)
spec_norm=self.params['spec_norm'])
wav = mel(wav) wav = mel(wav)
return wav return wav

Loading…
Cancel
Save