diff --git a/__init__.py b/__init__.py index 0cf02a5..26c8e44 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,5 @@ from .audio_decoder_ffmpeg import AudioDecoderFFmpeg -def ffmpeg(batch_size=-1): - return AudioDecoderFFmpeg(batch_size) +def ffmpeg(**kwargs): + return AudioDecoderFFmpeg(**kwargs) diff --git a/audio_decoder_ffmpeg.py b/audio_decoder_ffmpeg.py index d8237bd..40e2e47 100644 --- a/audio_decoder_ffmpeg.py +++ b/audio_decoder_ffmpeg.py @@ -11,16 +11,29 @@ class AudioDecoderFFmpeg(PyOperator): """ """ - def __init__(self, batch_size=-1) -> None: + def __init__(self, batch_size=-1, sample_rate=None, layout=None) -> None: super().__init__() self._batch_size = batch_size + self._sample_rate = sample_rate + self._layout = layout def __call__(self, audio_path: str): frames = [] in_container = av.open(audio_path) - stream = in_container.streams.get(audio=0)[0] + stream = in_container.streams.get(audio=0)[0] + if self._sample_rate or self._layout: + resampler = av.AudioResampler( + format=av.AudioFormat(stream.format.name).packed, + layout=self._layout if self._layout else 'mono', + rate=self._sample_rate if self._sample_rate else 8000 + ) + else: + resampler = None + if self._batch_size <= 0: for frame in in_container.decode(stream): + if resampler: + frame = resampler.resample(frame)[0] timestamp = int(frame.time * 1000) sample_rate = frame.sample_rate layout = frame.layout.name @@ -28,6 +41,8 @@ class AudioDecoderFFmpeg(PyOperator): yield AudioFrame(ndarray, sample_rate, timestamp, layout) else: for frame in in_container.decode(stream): + if resampler: + frame = resampler.resample(frame)[0] timestamp = int(frame.time * 1000) sample_rate = frame.sample_rate layout = frame.layout.name