From b2519607766daac8d038931150a283192c3e5e08 Mon Sep 17 00:00:00 2001
From: Jael Gu <mengjia.gu@zilliz.com>
Date: Mon, 6 Feb 2023 19:09:32 +0800
Subject: [PATCH] Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
---
 README.md        | 32 +++++++++-----------------------
 panns.py         |  7 ++++---
 requirements.txt |  2 +-
 3 files changed, 14 insertions(+), 27 deletions(-)

diff --git a/README.md b/README.md
index e9b0963..b6130ec 100644
--- a/README.md
+++ b/README.md
@@ -15,33 +15,19 @@ The pre-trained model used here is from the paper **PANNs: Large-Scale Pretraine
 
 Predict labels and generate embeddings given the audio path "test.wav".
 
- *Write the pipeline in simplified style*:
+*Write a pipeline with explicit inputs/outputs name specifications:*
 
 ```python
-import towhee
-
-(
-    towhee.glob('test.wav')
-          .audio_decode.ffmpeg()
-          .runas_op(func=lambda x:[y[0] for y in x])
-          .audio_classification.panns()
-          .show()
-)
-```
-
-*Write a same pipeline with explicit inputs/outputs name specifications:*
+from towhee.dc2 import pipe, ops, DataCollection
 
-```python
-import towhee
-
-(
-    towhee.glob['path']('test.wav')
-          .audio_decode.ffmpeg['path', 'frames']()
-          .runas_op['frames', 'frames'](func=lambda x:[y[0] for y in x])
-          .audio_classification.panns['frames', ('labels', 'scores', 'vec')]()
-          .select['path', 'labels', 'scores', 'vec']()
-          .show()
+p = (
+    pipe.input('path')
+        .map('path', 'frame', ops.audio_decode.ffmpeg())
+        .map('frame', ('labels', 'scores', 'vec'), ops.audio_classification.panns())
+        .output('path', 'labels', 'scores', 'vec')
 )
+
+DataCollection(p('./test.wav')).show()
 ```
 <img src="./result.png" width="800px"/>
 
diff --git a/panns.py b/panns.py
index 9f05061..07e8877 100644
--- a/panns.py
+++ b/panns.py
@@ -17,10 +17,10 @@ import warnings
 
 import os
 import numpy
-import resampy
 from typing import List
 
 import torch
+import torchaudio
 
 from panns_inference import AudioTagging, labels
 
@@ -67,9 +67,10 @@ class Panns(NNOperator):
 
         audio = self.int2float(audio).astype('float32')
         if sr != self.sample_rate:
-            audio = resampy.resample(audio, sr, self.sample_rate)
+            resampler = torchaudio.transforms.Resample(sr, self.sample_rate, dtype=audio.dtype)
+            audio = resampler(audio)
 
-        audio = torch.from_numpy(audio)[None, :]
+        audio = audio[None, :]
         clipwise_output, embedding = self.tagger.inference(audio)
 
         sorted_indexes = numpy.argsort(clipwise_output[0])[::-1]
diff --git a/requirements.txt b/requirements.txt
index 8ac3f1f..2c3a4fe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
 panns_inference
-resampy
 torch
+torchaudio
 towhee>=0.7.0
\ No newline at end of file