logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

181 lines
5.8 KiB

import logging
import os
import json
from pathlib import Path
from typing import List, Union, Iterable, Callable
import torch
from torch import nn
import numpy
from towhee import register
from towhee.types import VideoFrame
from towhee.operator.base import NNOperator
from towhee.models.utils.video_transforms import transform_video
log = logging.getLogger()
@register(output_schema=['labels', 'scores', 'features'])
class PytorchVideo(NNOperator):
"""
Generate a list of class labels given a video input data.
Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics).
Args:
model_name (`str`):
The pretrained model name from torch hub.
Supported model names:
- c2d_r50
- i3d_r50
- slow_r50
- slowfast_r50
- slowfast_r101
- x3d_xs
- x3d_s
- x3d_m
- mvit_base_16x4
- mvit_base_32x3
skip_preprocess (`str`):
Flag to skip video transforms.
classmap (`dict=None`):
The dictionary maps classes to integers.
topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low).
"""
def __init__(
self,
model_name: str = 'x3d_xs',
framework: str = 'pytorch',
skip_preprocess: bool = False,
classmap: dict = None,
topk: int = 5,
) -> None:
super().__init__(framework=framework)
self.model_name = model_name
self.skip_preprocess = skip_preprocess
self.topk = topk
if classmap is None:
class_file = os.path.join(str(Path(__file__).parent), 'kinetics_400.json')
with open(class_file, 'r') as f:
kinetics_classes = json.load(f)
self.classmap = {}
for k, v in kinetics_classes.items():
self.classmap[v] = str(k).replace('"', '')
else:
self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# self.device = 'cpu'
self.model = torch.hub.load('facebookresearch/pytorchvideo', model=model_name, pretrained=True)
self.model.eval()
self.model.to(self.device)
def __call__(self, frames: List[VideoFrame]):
"""
Args:
frames (`List[VideoFrame]`):
Video frames in towhee.types.video_frame.VideoFrame.
Returns:
labels, scores:
A tuple of lists (labels, scores).
video embedding:
A video embedding in numpy.ndarray.
"""
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32) / 255. for img in frames], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
if self.skip_preprocess:
data = transform_video(
video=video,
model_name=self.model_name,
num_frames=None
)
else:
data = transform_video(
video=video,
model_name=self.model_name
)
if self.model_name.startswith('slowfast'):
inputs = [data[0].to(self.device)[None, ...], data[1].to(self.device)[None, ...]]
else:
inputs = data.to(self.device)[None, ...]
feats, outs = self.new_forward(inputs)
features = feats.to('cpu').squeeze(0).detach().numpy()
post_act = torch.nn.Softmax(dim=1)
preds = post_act(outs)
pred_scores, pred_classes = preds.topk(k=self.topk)
labels = [self.classmap[int(i)] for i in pred_classes[0]]
scores = [round(float(x), 5) for x in pred_scores[0]]
return labels, scores, features
def new_forward(self, x: Union[torch.Tensor, list]):
"""
Generate embeddings returned by the second last hidden layer.
Args:
x (`Union[torch.Tensor, list]`):
tensor or list of input video after transforms
Returns:
Tensor of layer outputs.
"""
blocks = list(self.model.children())
if len(blocks) == 1:
blocks = blocks[0]
if self.model_name.startswith('x3d'):
sub_blocks = list(blocks[-1].children())
extractor = FeatureExtractor(self.model, sub_blocks, layer=0)
elif self.model_name.startswith('mvit'):
sub_blocks = list(blocks[-1].children())
extractor = FeatureExtractor(self.model, sub_blocks, layer=0)
else:
extractor = FeatureExtractor(self.model, blocks, layer=-2)
features, outs = extractor(x)
if features.dim() == 5:
global_pool = nn.AdaptiveAvgPool3d(1)
features = global_pool(features)
return features.flatten(), outs
def get_model_name(self):
full_list = [
'c2d_r50',
'i3d_r50',
'slow_r50',
'slowfast_r50',
'slowfast_r101',
'x3d_xs',
'x3d_s',
'x3d_m',
'mvit_base_16x4',
'mvit_base_32x3'
]
full_list.sort()
return full_list
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, blocks: List[nn.Module], layer: int):
super().__init__()
self.model = model
self.features = None
target_layer = blocks[layer]
self.handler = target_layer.register_forward_hook(self.save_outputs_hook())
def save_outputs_hook(self) -> Callable:
def fn(_, __, output):
self.features = output
return fn
def forward(self, x):
outs = self.model(x)
self.handler.remove()
return self.features, outs