diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..5a86889 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .panns import Panns + + +def panns(weights_path: str = None,): + return Panns(weights_path=weights_path) diff --git a/panns.py b/panns.py new file mode 100644 index 0000000..f98228a --- /dev/null +++ b/panns.py @@ -0,0 +1,89 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings + +import os +import numpy +from typing import Union + +import torch +import torchaudio + +from panns_inference import AudioTagging, labels + +from towhee.operator.base import NNOperator +from towhee import register + + +warnings.filterwarnings('ignore') +log = logging.getLogger() + + +@register(output_schema=['label', 'vec']) +class Panns(NNOperator): + """ + Built on top of [panns_inference](https://github.com/qiuqiangkong/panns_inference). + """ + + def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: + super().__init__(framework=framework) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.tagger = AudioTagging(checkpoint_path=weights_path, device=self.device) + self.model = self.tagger.model + # self.model.eval() + # self.model.to(self.device) + + def __call__(self, audio: Union[str, numpy.ndarray], sample_rate: int = None, top_k: int = 5) -> numpy.ndarray: + if isinstance(audio, str): + source = os.path.abspath(audio) + audio_wav, sr = torchaudio.load(source) + elif isinstance(audio, numpy.ndarray): + sr = sample_rate + audio_wav = torch.tensor(audio).to(torch.float32) + + if audio_wav.shape[0] == 2: + audio_wav = torch.mean(audio_wav, dim=0) + elif audio_wav.shape[0] == 1: + audio_wav = audio_wav.squeeze(0) + + _sr = 32000 + if sr != _sr: + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=_sr) + audio_tensors = transform(audio_wav) + + audio_tensors = audio_tensors[None, :] + clipwise_output, embedding = self.tagger.inference(audio_tensors) + + sorted_indexes = numpy.argsort(clipwise_output[0])[::-1] + tags = [] + for k in range(top_k): + tag = numpy.array(labels)[sorted_indexes[k]] + score = clipwise_output[0][sorted_indexes[k]] + tags.append((tag, round(score, 2))) + + return tags, embedding.squeeze(0) + + +# if __name__ == '__main__': +# encoder = Panns() +# +# audio_path = '/audio/path/or/link' +# tags, vecs = encoder(audio_path) +# +# # audio_data = numpy.zeros((2, 441344)) +# # sample_rate = 44100 +# # tags, vecs = encoder(audio_data, sample_rate) +# print(tags, vecs.shape) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..97a1728 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +panns_inference +torchaudio +torch \ No newline at end of file