panns
copied
3 changed files with 111 additions and 0 deletions
@ -0,0 +1,19 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .panns import Panns |
|||
|
|||
|
|||
def panns(weights_path: str = None,): |
|||
return Panns(weights_path=weights_path) |
@ -0,0 +1,89 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import logging |
|||
import warnings |
|||
|
|||
import os |
|||
import numpy |
|||
from typing import Union |
|||
|
|||
import torch |
|||
import torchaudio |
|||
|
|||
from panns_inference import AudioTagging, labels |
|||
|
|||
from towhee.operator.base import NNOperator |
|||
from towhee import register |
|||
|
|||
|
|||
warnings.filterwarnings('ignore') |
|||
log = logging.getLogger() |
|||
|
|||
|
|||
@register(output_schema=['label', 'vec']) |
|||
class Panns(NNOperator): |
|||
""" |
|||
Built on top of [panns_inference](https://github.com/qiuqiangkong/panns_inference). |
|||
""" |
|||
|
|||
def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: |
|||
super().__init__(framework=framework) |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
self.tagger = AudioTagging(checkpoint_path=weights_path, device=self.device) |
|||
self.model = self.tagger.model |
|||
# self.model.eval() |
|||
# self.model.to(self.device) |
|||
|
|||
def __call__(self, audio: Union[str, numpy.ndarray], sample_rate: int = None, top_k: int = 5) -> numpy.ndarray: |
|||
if isinstance(audio, str): |
|||
source = os.path.abspath(audio) |
|||
audio_wav, sr = torchaudio.load(source) |
|||
elif isinstance(audio, numpy.ndarray): |
|||
sr = sample_rate |
|||
audio_wav = torch.tensor(audio).to(torch.float32) |
|||
|
|||
if audio_wav.shape[0] == 2: |
|||
audio_wav = torch.mean(audio_wav, dim=0) |
|||
elif audio_wav.shape[0] == 1: |
|||
audio_wav = audio_wav.squeeze(0) |
|||
|
|||
_sr = 32000 |
|||
if sr != _sr: |
|||
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=_sr) |
|||
audio_tensors = transform(audio_wav) |
|||
|
|||
audio_tensors = audio_tensors[None, :] |
|||
clipwise_output, embedding = self.tagger.inference(audio_tensors) |
|||
|
|||
sorted_indexes = numpy.argsort(clipwise_output[0])[::-1] |
|||
tags = [] |
|||
for k in range(top_k): |
|||
tag = numpy.array(labels)[sorted_indexes[k]] |
|||
score = clipwise_output[0][sorted_indexes[k]] |
|||
tags.append((tag, round(score, 2))) |
|||
|
|||
return tags, embedding.squeeze(0) |
|||
|
|||
|
|||
# if __name__ == '__main__': |
|||
# encoder = Panns() |
|||
# |
|||
# audio_path = '/audio/path/or/link' |
|||
# tags, vecs = encoder(audio_path) |
|||
# |
|||
# # audio_data = numpy.zeros((2, 441344)) |
|||
# # sample_rate = 44100 |
|||
# # tags, vecs = encoder(audio_data, sample_rate) |
|||
# print(tags, vecs.shape) |
@ -0,0 +1,3 @@ |
|||
panns_inference |
|||
torchaudio |
|||
torch |
Loading…
Reference in new issue