# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import warnings import os import numpy from typing import Union import torch import torchaudio from panns_inference import AudioTagging, labels from towhee.operator.base import NNOperator from towhee import register warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['label', 'vec']) class Panns(NNOperator): """ Built on top of [panns_inference](https://github.com/qiuqiangkong/panns_inference). """ def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: super().__init__(framework=framework) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tagger = AudioTagging(checkpoint_path=weights_path, device=self.device) self.model = self.tagger.model # self.model.eval() # self.model.to(self.device) def __call__(self, audio: Union[str, numpy.ndarray], sample_rate: int = None, top_k: int = 5) -> numpy.ndarray: if isinstance(audio, str): source = os.path.abspath(audio) audio_wav, sr = torchaudio.load(source) elif isinstance(audio, numpy.ndarray): sr = sample_rate audio_wav = torch.tensor(audio).to(torch.float32) if audio_wav.shape[0] == 2: audio_wav = torch.mean(audio_wav, dim=0) elif audio_wav.shape[0] == 1: audio_wav = audio_wav.squeeze(0) _sr = 32000 if sr != _sr: transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=_sr) audio_tensors = transform(audio_wav) audio_tensors = audio_tensors[None, :] clipwise_output, embedding = self.tagger.inference(audio_tensors) sorted_indexes = numpy.argsort(clipwise_output[0])[::-1] tags = [] for k in range(top_k): tag = numpy.array(labels)[sorted_indexes[k]] score = clipwise_output[0][sorted_indexes[k]] tags.append((tag, round(score, 2))) return tags, embedding.squeeze(0) # if __name__ == '__main__': # encoder = Panns() # # audio_path = '/audio/path/or/link' # tags, vecs = encoder(audio_path) # # # audio_data = numpy.zeros((2, 441344)) # # sample_rate = 44100 # # tags, vecs = encoder(audio_data, sample_rate) # print(tags, vecs.shape)