diff --git a/panns.py b/panns.py index 76635bb..e79c671 100644 --- a/panns.py +++ b/panns.py @@ -74,7 +74,8 @@ class Panns(NNOperator): resampler = torchaudio.transforms.Resample(sr, self.sample_rate, dtype=audio.dtype) audio = resampler(audio) - audio = audio[None, :] + if len(audio.shape) == 1: + audio = audio[None, :] clipwise_output, embedding = self.tagger.inference(audio) sorted_indexes = numpy.argsort(clipwise_output[0])[::-1]