diff --git a/README.md b/README.md
index 8abd1ee..1b3c305 100644
--- a/README.md
+++ b/README.md
@@ -1,76 +1,96 @@
# Audio Classification with PANNS
-*Author: Jael Gu*
+*Author: [Jael Gu](https://github.com/jaelgu)*
+
-## Desription
+## Description
The audio classification operator classify the given audio data with 527 labels from the large-scale [AudioSet dataset](https://research.google.com/audioset/).
The pre-trained model used here is from the paper **PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition** ([paper link](https://arxiv.org/abs/1912.10211)).
-```python
-import numpy as np
-from towhee import ops
+
+
+## Code Example
+
+Predict labels and generate embeddings given the audio path "test.wav".
+
+ *Write the pipeline in simplified style*:
-audio_classifier = ops.audio_classification.panns()
+```python
+import towhee
+
+(
+ towhee.glob('test.wav')
+ .audio_decode.ffmpeg()
+ .runas_op(func=lambda x:[y[0] for y in x])
+ .audio_classification.panns()
+ .show()
+)
+```
-# Path or url as input
-tags, audio_embedding = audio_classifier("/audio/path/or/url/")
+*Write a same pipeline with explicit inputs/outputs name specifications:*
-# Audio data as input
-audio_data = np.zeros((2, 441344))
-sample_rate = 44100
-tags, audio_embedding = audio_classifier(audio_data, sample_rate)
+```python
+import towhee
+
+(
+ towhee.glob['path']('test.wav')
+ .audio_decode.ffmpeg['path', 'frames']()
+ .runas_op['frames', 'frames'](func=lambda x:[y[0] for y in x])
+ .audio_classification.panns['frames', ('labels', 'scores', 'vec')]()
+ .show()
+)
```
+
+
## Factory Constructor
Create the operator via the following factory method
-***ops.audio_classification.panns()***
-
+***audio_classification.panns(weights_path=None, framework='pytorch',
+sample_rate=32000, topk=5)***
-## Interface
+**Parameters:**
-Given an audio (file path, link, or waveform),
-the audio classification operator generates a list of labels
-and a vector in numpy.ndarray.
+*weights_path: str*
+The path to model weights. If None, it will load default model weights.
-**Parameters:**
+*framework: str*
- None.
+The framework of model implementation.
+Default value is "pytorch" since the model is implemented in Pytorch.
+*sample_rate: int*
-**Returns**: *numpy.ndarray*
+The target sample rate of audio data after convention, defaults to 32000.
- labels [(tag, score)], audio embedding in shape (2048,).
+*topk: int*
+The number of labels & corresponding scores to be returned, sorting by possibility from high to low.
+Default value is 5.
+
-## Code Example
+## Interface
-Generate embeddings for the audio "test.wav".
+An audio embedding operator generates vectors in numpy.ndarray given towhee audio frames.
- *Write the pipeline in simplified style*:
+**Parameters:**
-```python
-from towhee import dc
+*data: List[towhee.types.audio_frame.AudioFrame]*
-dc.glob('test.wav')
- .audio_classification.panns()
- .show()
-```
+Input audio data is a list of towhee audio frames.
+The input data should represent for an audio longer than 2s.
-*Write a same pipeline with explicit inputs/outputs name specifications:*
-```python
-from towhee import dc
+**Returns**:
-dc.glob['path']('test.wav')
- .audio_classification.panns['path', 'vecs']()
- .select('vecs')
- .show()
-```
+*labels, scores, vec: Tuple(List[str], List(float), numpy.ndarray)*
+- labels: a list of topk predicted labels by model.
+- scores: a list of scores corresponding to labels, representing for possibility.
+- vec: a audio embedding generated by model, shape of which is (2048,)
diff --git a/__init__.py b/__init__.py
index 5a86889..b577e93 100644
--- a/__init__.py
+++ b/__init__.py
@@ -15,5 +15,5 @@
from .panns import Panns
-def panns(weights_path: str = None,):
- return Panns(weights_path=weights_path)
+def panns(**kwargs):
+ return Panns(**kwargs)
diff --git a/panns.py b/panns.py
index f98228a..9f05061 100644
--- a/panns.py
+++ b/panns.py
@@ -17,73 +17,86 @@ import warnings
import os
import numpy
-from typing import Union
+import resampy
+from typing import List
import torch
-import torchaudio
from panns_inference import AudioTagging, labels
from towhee.operator.base import NNOperator
from towhee import register
+from towhee.types.audio_frame import AudioFrame
warnings.filterwarnings('ignore')
log = logging.getLogger()
-@register(output_schema=['label', 'vec'])
+@register(output_schema=['label', 'score', 'vec'])
class Panns(NNOperator):
"""
Built on top of [panns_inference](https://github.com/qiuqiangkong/panns_inference).
"""
- def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None:
+ def __init__(self,
+ weights_path: str = None,
+ framework: str = 'pytorch',
+ sample_rate: int = 32000,
+ topk: int = 5):
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.sample_rate = sample_rate
+ self.topk = topk
+ # checkpoint_path=None will download model weights with default url
+ # 'https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1'
self.tagger = AudioTagging(checkpoint_path=weights_path, device=self.device)
self.model = self.tagger.model
- # self.model.eval()
- # self.model.to(self.device)
-
- def __call__(self, audio: Union[str, numpy.ndarray], sample_rate: int = None, top_k: int = 5) -> numpy.ndarray:
- if isinstance(audio, str):
- source = os.path.abspath(audio)
- audio_wav, sr = torchaudio.load(source)
- elif isinstance(audio, numpy.ndarray):
- sr = sample_rate
- audio_wav = torch.tensor(audio).to(torch.float32)
-
- if audio_wav.shape[0] == 2:
- audio_wav = torch.mean(audio_wav, dim=0)
- elif audio_wav.shape[0] == 1:
- audio_wav = audio_wav.squeeze(0)
-
- _sr = 32000
- if sr != _sr:
- transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=_sr)
- audio_tensors = transform(audio_wav)
-
- audio_tensors = audio_tensors[None, :]
- clipwise_output, embedding = self.tagger.inference(audio_tensors)
+ self.model.eval()
+ self.model.to(self.device)
+
+ def __call__(self, data: List[AudioFrame]):
+ sr = data[0].sample_rate
+ layout = data[0].layout
+ if layout == 'stereo':
+ frames = [frame.reshape(-1, 2) for frame in data]
+ audio = numpy.vstack(frames)
+ audio = numpy.mean(audio, axis=1)
+ else:
+ audio = numpy.hstack(data)
+
+ audio = self.int2float(audio).astype('float32')
+ if sr != self.sample_rate:
+ audio = resampy.resample(audio, sr, self.sample_rate)
+
+ audio = torch.from_numpy(audio)[None, :]
+ clipwise_output, embedding = self.tagger.inference(audio)
sorted_indexes = numpy.argsort(clipwise_output[0])[::-1]
tags = []
- for k in range(top_k):
+ scores = []
+ for k in range(self.topk):
tag = numpy.array(labels)[sorted_indexes[k]]
score = clipwise_output[0][sorted_indexes[k]]
- tags.append((tag, round(score, 2)))
-
- return tags, embedding.squeeze(0)
-
-
-# if __name__ == '__main__':
-# encoder = Panns()
-#
-# audio_path = '/audio/path/or/link'
-# tags, vecs = encoder(audio_path)
-#
-# # audio_data = numpy.zeros((2, 441344))
-# # sample_rate = 44100
-# # tags, vecs = encoder(audio_data, sample_rate)
-# print(tags, vecs.shape)
+ tags.append(tag)
+ scores.append(round(score, 4))
+
+ return tags, scores, embedding.squeeze(0)
+
+ def int2float(self, wav: numpy.ndarray, dtype: str = 'float64'):
+ """
+ Convert audio data from int to float.
+ The input dtype must be integers.
+ The output dtype is controlled by the parameter `dtype`, defaults to 'float64'.
+
+ The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
+ """
+ dtype = numpy.dtype(dtype)
+ assert dtype.kind == 'f'
+ if wav.dtype.kind in 'iu':
+ ii = numpy.iinfo(wav.dtype)
+ abs_max = 2 ** (ii.bits - 1)
+ offset = ii.min + abs_max
+ return (wav.astype(dtype) - offset) / abs_max
+ else:
+ return wav.astype(dtype)
diff --git a/requirements.txt b/requirements.txt
index 769e156..8ac3f1f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
panns_inference
-torchaudio
+resampy
torch
-towhee
\ No newline at end of file
+towhee>=0.7.0
\ No newline at end of file