# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import warnings import os import numpy from typing import List import torch import torchaudio from panns_inference import AudioTagging, labels from towhee.operator.base import NNOperator from towhee import register from towhee.types.audio_frame import AudioFrame warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['label', 'score', 'vec']) class Panns(NNOperator): """ Built on top of [panns_inference](https://github.com/qiuqiangkong/panns_inference). """ def __init__(self, weights_path: str = None, framework: str = 'pytorch', sample_rate: int = 32000, device: str = None, topk: int = 5): super().__init__(framework=framework) if device: self.device = device else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.sample_rate = sample_rate self.topk = topk # checkpoint_path=None will download model weights with default url # 'https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1' self.tagger = AudioTagging(checkpoint_path=weights_path, device=self.device) self.model = self.tagger.model self.model.eval() def __call__(self, data: List[AudioFrame]): sr = data[0].sample_rate layout = data[0].layout if layout == 'stereo': frames = [frame.reshape(-1, 2) for frame in data] audio = numpy.vstack(frames) audio = numpy.mean(audio, axis=1) else: audio = numpy.hstack(data) audio = self.int2float(audio).astype('float32') if sr != self.sample_rate: audio = torch.from_numpy(audio) resampler = torchaudio.transforms.Resample(sr, self.sample_rate, dtype=audio.dtype) audio = resampler(audio) if len(audio.shape) == 1: audio = audio[None, :] clipwise_output, embedding = self.tagger.inference(audio) sorted_indexes = numpy.argsort(clipwise_output[0])[::-1] tags = [] scores = [] for k in range(self.topk): tag = numpy.array(labels)[sorted_indexes[k]] score = clipwise_output[0][sorted_indexes[k]] tags.append(tag) scores.append(round(score, 4)) return tags, scores, embedding.squeeze(0) def int2float(self, wav: numpy.ndarray, dtype: str = 'float64'): """ Convert audio data from int to float. The input dtype must be integers. The output dtype is controlled by the parameter `dtype`, defaults to 'float64'. The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py """ dtype = numpy.dtype(dtype) assert dtype.kind == 'f' if wav.dtype.kind in 'iu': ii = numpy.iinfo(wav.dtype) abs_max = 2 ** (ii.bits - 1) offset = ii.min + abs_max return (wav.astype(dtype) - offset) / abs_max else: return wav.astype(dtype)