logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

157 lines
5.8 KiB

from typing import Generator, NamedTuple
from functools import partial, reduce
import math
import logging
import av
import numpy as np
from towhee.types.video_frame import VideoFrame
from towhee.operator.base import PyOperator
logger = logging.getLogger()
class SAMPLE_TYPE:
UNIFORM_TEMPORAL_SUBSAMPLE = 'uniform_temporal_subsample'
TIME_STEP_SAMPLE = 'time_step_sample'
class VideoDecoder(PyOperator):
'''
VideoDecoder
Return images with RGB format.
'''
def __init__(self, start_time=None, end_time=None, sample_type=None, args=None) -> None:
super().__init__()
self._start_time = start_time if start_time is not None else 0
self._end_time = end_time
self._sample_type = sample_type
self._args = args if args is not None else {}
def get_sample(self, stream, duration):
if self._sample_type is None:
return self._no_sample
elif self._sample_type.lower() == SAMPLE_TYPE.UNIFORM_TEMPORAL_SUBSAMPLE:
end_time = self._end_time if self._end_time is not None and self._end_time <= duration else duration
start_time = self._start_time if self._start_time is not None else 0
nums = int(stream.rate * (end_time - start_time))
return partial(self._uniform_temporal_subsample, total_frames=nums)
elif self._sample_type.lower() == SAMPLE_TYPE.TIME_STEP_SAMPLE:
start_time = self._start_time if self._start_time is not None else 0
end_time = self._end_time if self._end_time is not None and self._end_time <= duration else duration
return partial(self._time_step_sample, start_time=start_time, end_time=end_time)
else:
raise RuntimeError('Unkown sample type: %s' % self._sample_type)
def _no_sample(self, frame_iter):
if self._end_time is None:
yield from frame_iter
else:
for frame in frame_iter:
frame.time < self._end_time
yield frame
def _time_step_sample(self, frame_iter, start_time, end_time):
time_step = self._args.get('time_step')
if time_step is None:
raise RuntimeError('time_step_sample sample lost args time_step')
time_index = start_time
for frame in frame_iter:
if time_index >= self._end_time:
break
if frame.time >= time_index:
time_index += time_step
yield frame
def _uniform_temporal_subsample(self, frame_iter, total_frames):
num_samples = self._args.get('num_samples')
if num_samples is None:
raise RuntimeError('uniform_temporal_subsample lost args num_samples')
indexs = np.linspace(0, total_frames - 1, num_samples).astype('int')
cur_index = 0
count = 0
for frame in frame_iter:
if cur_index >= len(indexs):
return
while cur_index < len(indexs) and indexs[cur_index] <= count:
cur_index += 1
yield frame
count += 1
@staticmethod
def _decdoe(video, container, start_time):
if start_time is not None:
start_offset = int(math.floor(start_time * (1 / video.time_base)))
else:
start_offset = 0
seek_offset = start_offset
seek_offset = max(seek_offset - 1, 0)
try:
container.seek(seek_offset, any_frame=False, backward=True, stream=video)
except av.AVError as e:
logger.error('Seek to start_time: %s sec failed, the offset is %s, errors: %s' % (start_time, seek_offset, str(e)))
raise RuntimeError from e
for frame in container.decode(video):
if frame.time < start_time:
continue
yield frame
# @staticmethod
# def get_video_duration(video):
# print(video)
# if video.duration is not None:
# return float(video.duration * video.time_base)
# elif video.metadata.get('DURATION') is not None:
# time_str = video.metadata['DURATION']
# return reduce(lambda x, y: float(x) * 60 + float(y), time_str.split(':'))
# else:
# return None
def __call__(self, video_path: str) -> Generator:
with av.open(video_path) as container:
stream = container.streams.video[0]
duration = float(container.duration) / 1000000
image_format = 'RGB'
frame_gen = VideoDecoder._decdoe(stream, container, self._start_time)
sample_function = self.get_sample(stream, duration)
for frame in sample_function(frame_gen):
timestamp = int(frame.time * 1000)
ndarray = frame.to_ndarray(format='rgb24')
img = VideoFrame(ndarray, image_format, timestamp, frame.key_frame)
yield img
# if __name__ == '__main__':
# video_path = "/home/junjie.jiangjjj/workspace/video/[The Rock] [1996] [Trailer] [#2]-16-l-rO5B64.mkv"
# video_path1 = "/home/junjie.jiangjjj/workspace/video/'Eagle Eye' Trailer (2008)-_wkqo_Rd3_Q.mp4"
# video_path2 = "/home/junjie.jiangjjj/workspace/video/2001 - A Space Odyssey - Trailer [1968] HD-Z2UWOeBcsJI.webm"
# # video_path3 = "/home/zhangchen/zhangchen_workspace/dataset/MSRVTT/msrvtt_data/MSRVTT_Videos/video9991.mp4"
# video_path3 = "/home/junjie.jiangjjj/e2adc784b83446ae775f698b9d17c9fd392b2f75.flv"
# def d(video_path):
# d = VideoDecoder(10, 17, 'time_step_sample', {'time_step': 1})
# fs = d(video_path)
# for f in fs:
# print(f.mode, f.key_frame, f.timestamp)
# d(video_path)
# # print('#' * 100)
# # with av.open(video_path) as container:
# # print(container.duration)
# # stream = container.streams.video[0]
# # print(stream.time_base)