| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -22,6 +22,7 @@ from typing import List | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import resampy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import onnxruntime | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.operator.base import NNOperator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee import register | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -43,8 +44,9 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 params: dict = None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 checkpoint_path: str = None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 framework: str = 'pytorch'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 model_path: str = None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 framework: str = 'pytorch', | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 ): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__(framework=framework) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if params is None: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -52,29 +54,31 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.params = params | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        log.info('Loading model...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if model_path is None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = str(Path(__file__).parent) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if model_path.endswith('.onnx'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log.warning('Using onnx.') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model = onnxruntime.InferenceSession(model_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            state_dict = torch.load(model_path, map_location=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if isinstance(state_dict, torch.nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.model = state_dict | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                dim = self.params['dim'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                h = self.params['h'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                u = self.params['u'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                f_bin = self.params['n_mels'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                n_seg = int(self.params['segment_size'] * self.params['sample_rate']) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        log.info('Creating model...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                log.info('Creating model with parameters...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.model = NNFp( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    dim=dim, h=h, u=u, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    in_f=f_bin, in_t=t, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    fuller=self.params['fuller'], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    activation=self.params['activation'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ).to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        log.info('Loading weights...') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if checkpoint_path is None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = str(Path(__file__).parent) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            checkpoint_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        state_dict = torch.load(checkpoint_path, map_location=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if isinstance(state_dict, torch.nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model = state_dict | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.model.load_state_dict(state_dict) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        log.info('Model is loaded.') | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -82,8 +86,15 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        audio_tensors = self.preprocess(data).to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # print(audio_tensors.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if isinstance(self.model, onnxruntime.InferenceSession): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else audio_tensors.cpu().numpy() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ort_inputs = {self.model.get_inputs()[0].name: audio_numpy} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            outs = self.model.run(None, ort_inputs)[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            features = self.model(audio_tensors) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return features.detach().cpu().numpy() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            outs = features.detach().cpu().numpy() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return outs | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def preprocess(self, frames: List[AudioFrame]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sr = frames[0].sample_rate | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -137,7 +148,7 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return wav.astype(dtype) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def save_model(self, format: str='pytorch', path: str = 'default'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def save_model(self, format: str = 'pytorch', path: str = 'default'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if path == 'default': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = str(Path(__file__).parent) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = os.path.join(path, 'saved', format) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -156,6 +167,9 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    jit_model = torch.jit.script(self.model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                except Exception: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    log.warning( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        'Failed to directly export as torchscript.' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        'Using dummy input in shape of %s now.', dummy_input.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    jit_model = torch.jit.trace(self.model, dummy_input, strict=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.jit.save(jit_model, path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -181,3 +195,9 @@ class NNFingerprint(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # todo: elif format == 'tensorrt': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log.error(f'Unsupported format "{format}".') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def input_schema(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return [(AudioFrame, (1024,))] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def output_schema(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return [(numpy.ndarray, (-1, self.params['dim']))] | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
					 | 
				
				 | 
				
					
  |