|
@ -67,7 +67,13 @@ class Model: |
|
|
log.info('Model is loaded.') |
|
|
log.info('Model is loaded.') |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
new_args = [] |
|
|
|
|
|
new_kwargs = {} |
|
|
|
|
|
for x in new_args: |
|
|
|
|
|
x = x.to(self.device) |
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
|
|
new_kwargs[k] = v.to(self.device) |
|
|
|
|
|
return self.model(*new_args, **new_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vecs']) |
|
|
@register(output_schema=['vecs']) |
|
@ -103,8 +109,6 @@ class NNFingerprint(NNOperator): |
|
|
|
|
|
|
|
|
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: |
|
|
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: |
|
|
audio_tensors = self.preprocess(data) |
|
|
audio_tensors = self.preprocess(data) |
|
|
if audio_tensors.device != self.device: |
|
|
|
|
|
audio_tensors = audio_tensors.to(self.device) |
|
|
|
|
|
# print(audio_tensors.shape) |
|
|
# print(audio_tensors.shape) |
|
|
features = self.model(audio_tensors) |
|
|
features = self.model(audio_tensors) |
|
|
outs = features.detach().cpu().numpy() |
|
|
outs = features.detach().cpu().numpy() |
|
@ -141,7 +145,7 @@ class NNFingerprint(NNOperator): |
|
|
wav = preprocess_wav(audio, |
|
|
wav = preprocess_wav(audio, |
|
|
segment_size=int(self.params['sample_rate'] * self.params['segment_size']), |
|
|
segment_size=int(self.params['sample_rate'] * self.params['segment_size']), |
|
|
hop_size=int(self.params['sample_rate'] * self.params['hop_size']), |
|
|
hop_size=int(self.params['sample_rate'] * self.params['hop_size']), |
|
|
frame_shift_mul=self.params['frame_shift_mul']).to(self.device) |
|
|
|
|
|
|
|
|
frame_shift_mul=self.params['frame_shift_mul']) |
|
|
wav = wav.to(torch.float32) |
|
|
wav = wav.to(torch.float32) |
|
|
mel = MelSpec(sample_rate=self.params['sample_rate'], |
|
|
mel = MelSpec(sample_rate=self.params['sample_rate'], |
|
|
window_length=self.params['window_length'], |
|
|
window_length=self.params['window_length'], |
|
@ -151,7 +155,7 @@ class NNFingerprint(NNOperator): |
|
|
n_mels=self.params['n_mels'], |
|
|
n_mels=self.params['n_mels'], |
|
|
naf_mode=self.params['naf_mode'], |
|
|
naf_mode=self.params['naf_mode'], |
|
|
mel_log=self.params['mel_log'], |
|
|
mel_log=self.params['mel_log'], |
|
|
spec_norm=self.params['spec_norm']).to(self.device) |
|
|
|
|
|
|
|
|
spec_norm=self.params['spec_norm']) |
|
|
wav = mel(wav) |
|
|
wav = mel(wav) |
|
|
return wav |
|
|
return wav |
|
|
|
|
|
|
|
|