logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

134 lines
4.8 KiB

from typing import Generator, NamedTuple
from functools import partial, reduce
import math
import logging
import av
import numpy as np
from towhee.types.image import Image
from towhee.operator.base import PyOperator
VideoOutput = NamedTuple("Outputs", [("image", Image), ("TIMESTAMP", int)])
logger = logging.getLogger()
class SAMPLE_TYPE:
UNIFORM_TEMPORAL_SUBSAMPLE = 'uniform_temporal_subsample'
class VideoDecoder(PyOperator):
'''
VideoDecoder
Return images with RGB format.
'''
def __init__(self, start_time=None, end_time=None, sample_type=None, args=None) -> None:
super().__init__()
self._start_time = start_time
self._end_time = end_time
self._sample_type = sample_type
self._args = args if args is not None else {}
def get_sample(self, stream):
if self._sample_type is None:
return self._no_smaple
elif self._sample_type.lower() == SAMPLE_TYPE.UNIFORM_TEMPORAL_SUBSAMPLE:
duration = VideoDecoder.get_video_duration(stream)
end_time = self._end_time if self._end_time is not None and self._end_time <= duration else duration
start_time = self._start_time if self._start_time is not None else 0
nums = int(stream.rate * (end_time - start_time))
return partial(self._uniform_temporal_subsample, total_frames=nums)
else:
raise RuntimeError('Unkown sample type: %s' % self._sample_type)
def _no_smaple(self, frame_iter):
if self._end_time is None:
yield from frame_iter
else:
for frame in frame_iter:
frame.time < self._end_time
yield frame
def _uniform_temporal_subsample(self, frame_iter, total_frames):
num_samples = self._args.get('num_samples')
if num_samples is None:
raise RuntimeError('uniform_temporal_subsample lost args num_samples')
indexs = np.linspace(0, total_frames - 1, num_samples).astype('int')
cur_index = 0
count = 0
for frame in frame_iter:
if cur_index >= len(indexs):
return
while cur_index < len(indexs) and indexs[cur_index] <= count:
cur_index += 1
yield frame
count += 1
@staticmethod
def _decdoe(video, container, start_time):
if start_time is not None:
start_offset = int(math.floor(start_time * (1 / video.time_base)))
else:
start_offset = 0
seek_offset = start_offset
seek_offset = max(seek_offset - 1, 0)
try:
container.seek(seek_offset, any_frame=False, backward=True, stream=video)
except av.AVError as e:
logger.error('Seek to start_time: %s sec failed, the offset is %s, errors: %s' % (start_time, seek_offset, str(e)))
raise RuntimeError from e
for frame in container.decode(video):
if frame.time < start_time:
continue
yield frame
def get_video_duration(video):
if video.duration is not None:
return float(video.duration * video.time_base)
elif video.metadata.get('DURATION') is not None:
time_str = video.metadata['DURATION']
return reduce(lambda x, y: float(x) * 60 + float(y), time_str.split(':'))
else:
return None
def __call__(self, video_path: str) -> Generator:
with av.open(video_path) as container:
stream = container.streams.video[0]
width = stream.width
height = stream.height
channel = 3
image_format = 'RGB'
frame_gen = VideoDecoder._decdoe(stream, container, self._start_time)
sample_function = self.get_sample(stream)
for frame in sample_function(frame_gen):
timestamp = int(frame.time * 1000)
ndarray = frame.to_ndarray(format='rgb24')
img = Image(ndarray.tobytes(), width, height, channel, image_format, None, key_frame=frame.key_frame)
yield VideoOutput(img, timestamp)
# if __name__ == '__main__':
# video_path = "/home/junjie.jiangjjj/workspace/video/[The Rock] [1996] [Trailer] [#2]-16-l-rO5B64.mkv"
# video_path1 = "/home/junjie.jiangjjj/workspace/video/'Eagle Eye' Trailer (2008)-_wkqo_Rd3_Q.mp4"
# video_path2 = "/home/junjie.jiangjjj/workspace/video/2001 - A Space Odyssey - Trailer [1968] HD-Z2UWOeBcsJI.webm"
# video_path3 = "/home/zhangchen/zhangchen_workspace/dataset/MSRVTT/msrvtt_data/MSRVTT_Videos/video9991.mp4"
# def d(video_path):
# d = VideoDecoder(10, 11, 'uniform_temporal_subsample', {'num_samples': 30})
# fs = d(video_path)
# for f in fs:
# print(f.TIMESTAMP)
# d(video_path1)
# print('#' * 100)
# d(video_path2)