clmr
              
                 
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            136 lines
          
        
        
          
            4.7 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            136 lines
          
        
        
          
            4.7 KiB
          
        
        
      | # Copyright 2021 Zilliz. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| #     http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright 2021 Zilliz. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| #     http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| 
 | |
| 
 | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| from typing import List | |
| 
 | |
| import torch | |
| import torchaudio | |
| import numpy | |
| 
 | |
| from towhee.operator import NNOperator | |
| from towhee import register | |
| from towhee.types.audio_frame import AudioFrame | |
| 
 | |
| sys.path.append(str(Path(__file__).parent)) | |
| from clmr_checkpoint import load_encoder_checkpoint | |
| from sample_cnn import SampleCNN | |
| 
 | |
| log = logging.getLogger() | |
| 
 | |
| 
 | |
| @register(output_schema=['vecs']) | |
| class ClmrMagnatagatune(NNOperator): | |
|     """ | |
|     Pretrained clmr | |
|     """ | |
| 
 | |
|     def __init__(self, framework="pytorch") -> None: | |
|         super().__init__(framework=framework) | |
|         self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| 
 | |
|         weight_path = os.path.join(str(Path(__file__).parent), | |
|                                    'clmr_checkpoint_10000.pt') | |
|         state_dict = load_encoder_checkpoint(weight_path, 1) | |
|         encoder = SampleCNN(strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], supervised=False, out_dim=1) | |
|         encoder.load_state_dict(state_dict) | |
| 
 | |
|         new_encoder = torch.nn.Sequential(*(list(encoder.children())[:-1])) | |
|         x = list(new_encoder[0][:10].children()) | |
|         y = torch.nn.Sequential(*list(new_encoder[0][10].children())[:-1]) | |
|         x.append(y) | |
|         self.model = torch.nn.Sequential(*x) | |
|         self.model.eval() | |
|         self.model.to(self.device) | |
| 
 | |
|     def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: | |
|         _sr = 22050 | |
|         audio_length = 59049 | |
| 
 | |
|         sr = data[0].sample_rate | |
|         layout = data[0].layout | |
|         if layout == 'stereo': | |
|             frames = [frame.reshape(-1, 2) for frame in data] | |
|             audio = numpy.vstack(frames) | |
|             audio = numpy.mean(audio, axis=1) | |
|         else: | |
|             audio = numpy.hstack(data) | |
|             if len(audio.shape) != 1: | |
|                 audio = audio.squeeze() | |
|         audio = self.int2float(audio, dtype='float32') | |
|         audio = torch.from_numpy(audio) | |
| 
 | |
|         if sr != _sr: | |
|             resampler = torchaudio.transforms.Resample(sr, _sr, dtype=audio.dtype) | |
|             audio = resampler(audio) | |
|         with torch.no_grad(): | |
|             batch = torch.split(audio, audio_length) | |
|             batch = [x for x in batch if len(x) == audio_length] | |
|             batch = torch.vstack(batch) | |
|             batch = batch.unsqueeze(dim=1).to(self.device) | |
|             features = numpy.squeeze(self.model(batch)) | |
| 
 | |
|         return features.to('cpu').detach().numpy() | |
| 
 | |
|     def int2float(self, wav: numpy.ndarray, dtype: str = 'float64'): | |
|         """ | |
|         Convert audio data from int to float. | |
|         The input dtype must be integers. | |
|         The output dtype is controlled by the parameter `dtype`, defaults to 'float64'. | |
|  | |
|         The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py | |
|         """ | |
|         dtype = numpy.dtype(dtype) | |
|         assert dtype.kind == 'f' | |
|         if wav.dtype.kind in 'iu': | |
|             # ii = numpy.iinfo(wav.dtype) | |
|             # abs_max = 2 ** (ii.bits - 1) | |
|             # offset = ii.min + abs_max | |
|             # return (wav.astype(dtype) - offset) / abs_max | |
|             if wav.dtype != 'int16': | |
|                 wav = (wav >> 16).astype(numpy.int16) | |
|             assert wav.dtype == 'int16' | |
|             wav = (wav / 32768.0).astype(dtype) | |
|             return wav | |
|         else: | |
|             return wav.astype(dtype) | |
| 
 | |
| 
 | |
| # if __name__ == "__main__": | |
| #     import towhee | |
| # | |
| #     audio_path = "path/to/audio.wav" | |
| #     frames = towhee.glob(audio_path).audio_decode.ffmpeg(99999).flatten()[0] | |
| # | |
| #     encoder = ClmrMagnatagatune() | |
| #     vec = encoder(frames) | |
| # | |
| #     print(vec.shape)
 | 
