logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

81 lines
2.9 KiB

import logging
from towhee.operator.base import PyOperator
from cpu_decode import PyAVDecode
logger = logging.getLogger()
try:
from gpu_decode import VPFDecode
except Exception:
logger.error('Import GPUDecoder failed, use CPU decode')
VPFDecode = PyAVDecode
logger = logging.getLogger()
class SAMPLE_TYPE:
UNIFORM_TEMPORAL_SUBSAMPLE = 'uniform_temporal_subsample'
TIME_STEP_SAMPLE = 'time_step_sample'
class VideoDecoder(PyOperator):
'''
VideoDecoder
Return images with RGB format.
'''
def __init__(self, gpu_id=0, start_time=None, end_time=None, sample_type=None, args=None) -> None:
super().__init__()
self._gpu_id = gpu_id
self._start_time = start_time if start_time is not None else 0
self._end_time = end_time * 1000 if end_time is not None else None
self._sample_type = sample_type.lower() if sample_type else None
self._args = args if args is not None else {}
def _gpu_decode(self, video_path):
yield from VPFDecode(video_path, self._gpu_id, self._start_time).decode()
def _cpu_decode(self, video_path):
yield from PyAVDecode(video_path, self._start_time).decode()
def _gpu_time_step_decode(self, video_path, time_step):
yield from VPFDecode(video_path, self._gpu_id, self._start_time, time_step).time_step_decode()
def _cpu_time_step_decode(self, video_path, time_step):
yield from PyAVDecode(video_path, self._start_time, time_step).time_step_decode()
def decode(self, video_path: str):
try:
yield from self._gpu_decode(video_path)
except RuntimeError:
logger.warn('GPU decode failed, only supports [h264,h265,vp9] format, will use CPU')
yield from self._cpu_decode(video_path)
def time_step_decode(self, video_path, time_step):
try:
yield from self._gpu_time_step_decode(video_path, time_step)
except RuntimeError:
logger.warn('GPU decode failed, only supports [h264,h265,vp9] format, will use CPU')
yield from self._cpu_time_step_decode(video_path, time_step)
def _filter(self, frames):
for f in frames:
if self._end_time and f.timestamp > self._end_time:
break
yield f
def __call__(self, video_path: str):
if self._sample_type is None:
yield from self._filter(self.decode(video_path))
elif self._sample_type == SAMPLE_TYPE.TIME_STEP_SAMPLE:
time_step = self._args.get('time_step')
if time_step is None:
raise RuntimeError('time_step_sample sample lost args time_step')
yield from self._filter(self.time_step_decode(video_path, time_step))
elif self._sample_type == SAMPLE_TYPE.TIME_STEP_SAMPLE:
pass
else:
raise RuntimeError('Unkown sample type, only supports: [%s|%s]' % (SAMPLE_TYPE.TIME_STEP_SAMPLE, SAMPLE_TYPE.UNIFORM_TEMPORAL_SUBSAMPLE))