pytorchvideo
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
180 lines
5.8 KiB
180 lines
5.8 KiB
3 years ago
|
import logging
|
||
|
import os
|
||
|
import json
|
||
|
from pathlib import Path
|
||
|
from typing import List, Union, Iterable, Callable
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import numpy
|
||
|
|
||
|
from towhee import register
|
||
|
from towhee.types import VideoFrame
|
||
|
from towhee.operator.base import NNOperator
|
||
|
from towhee.models.utils.video_transforms import transform_video
|
||
|
|
||
|
log = logging.getLogger()
|
||
|
|
||
|
|
||
|
@register(output_schema=['labels', 'scores', 'features'])
|
||
|
class PytorchVideo(NNOperator):
|
||
|
"""
|
||
|
Generate a list of class labels given a video input data.
|
||
|
Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics).
|
||
|
|
||
|
Args:
|
||
|
model_name (`str`):
|
||
|
The pretrained model name from torch hub.
|
||
|
Supported model names:
|
||
|
- c2d_r50
|
||
|
- i3d_r50
|
||
|
- slow_r50
|
||
|
- slowfast_r50
|
||
|
- slowfast_r101
|
||
|
- x3d_xs
|
||
|
- x3d_s
|
||
|
- x3d_m
|
||
|
- mvit_base_16x4
|
||
|
- mvit_base_32x3
|
||
|
skip_preprocess (`str`):
|
||
|
Flag to skip video transforms.
|
||
|
classmap (`str=None`):
|
||
|
Path of the json file to match class names.
|
||
|
topk (`int=5`):
|
||
|
The number of classification labels to be returned (ordered by possibility from high to low).
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
model_name: str = 'x3d_xs',
|
||
|
framework: str = 'pytorch',
|
||
|
skip_preprocess: bool = False,
|
||
|
classmap: str = None,
|
||
|
topk: int = 5,
|
||
|
) -> None:
|
||
|
super().__init__(framework=framework)
|
||
|
self.model_name = model_name
|
||
|
self.skip_preprocess = skip_preprocess
|
||
|
self.topk = topk
|
||
|
if classmap is None:
|
||
|
class_file = os.path.join(str(Path(__file__).parent), 'kinetics_400.json')
|
||
|
with open(class_file, 'r') as f:
|
||
|
kinetics_classes = json.load(f)
|
||
|
self.classmap = {}
|
||
|
for k, v in kinetics_classes.items():
|
||
|
self.classmap[v] = str(k).replace('"', '')
|
||
|
else:
|
||
|
self.classmap = classmap
|
||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
|
||
|
self.model = torch.hub.load('facebookresearch/pytorchvideo', model=model_name, pretrained=True)
|
||
|
self.model.eval()
|
||
|
self.model.to(self.device)
|
||
|
|
||
|
def __call__(self, frames: List[VideoFrame]):
|
||
|
"""
|
||
|
Args:
|
||
|
frames (`List[VideoFrame]`):
|
||
|
Video frames in towhee.types.video_frame.VideoFrame.
|
||
|
|
||
|
Returns:
|
||
|
labels, scores:
|
||
|
A tuple of lists (labels, scores).
|
||
|
video embedding:
|
||
|
A video embedding in numpy.ndarray.
|
||
|
"""
|
||
|
# Convert list of towhee.types.Image to numpy.ndarray in float32
|
||
|
video = numpy.stack([img.astype(numpy.float32) / 255. for img in frames], axis=0)
|
||
|
assert len(video.shape) == 4
|
||
|
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
|
||
|
|
||
|
if self.skip_preprocess:
|
||
|
data = transform_video(
|
||
|
video=video,
|
||
|
model_name=self.model_name,
|
||
|
num_frames=None
|
||
|
)
|
||
|
else:
|
||
|
data = transform_video(
|
||
|
video=video,
|
||
|
model_name=self.model_name
|
||
|
)
|
||
|
if self.model_name.startswith('slowfast'):
|
||
|
inputs = [data[0].to(self.device)[None, ...], data[1].to(self.device)[None, ...]]
|
||
|
else:
|
||
|
inputs = data.to(self.device)[None, ...]
|
||
|
|
||
|
feats, outs = self.new_forward(inputs)
|
||
|
features = feats.to('cpu').squeeze(0).detach().numpy()
|
||
|
|
||
|
post_act = torch.nn.Softmax(dim=1)
|
||
|
preds = post_act(outs)
|
||
|
pred_scores, pred_classes = preds.topk(k=self.topk)
|
||
|
labels = [self.classmap[int(i)] for i in pred_classes[0]]
|
||
|
scores = [round(float(x), 5) for x in pred_scores[0]]
|
||
|
return labels, scores, features
|
||
|
|
||
|
def new_forward(self, x: Union[torch.Tensor, list]):
|
||
|
"""
|
||
|
Generate embeddings returned by the second last hidden layer.
|
||
|
|
||
|
Args:
|
||
|
x (`Union[torch.Tensor, list]`):
|
||
|
tensor or list of input video after transforms
|
||
|
|
||
|
Returns:
|
||
|
Tensor of layer outputs.
|
||
|
"""
|
||
|
blocks = list(self.model.children())
|
||
|
if len(blocks) == 1:
|
||
|
blocks = blocks[0]
|
||
|
if self.model_name.startswith('x3d'):
|
||
|
sub_blocks = list(blocks[-1].children())
|
||
|
extractor = FeatureExtractor(self.model, sub_blocks, layer=0)
|
||
|
elif self.model_name.startswith('mvit'):
|
||
|
sub_blocks = list(blocks[-1].children())
|
||
|
extractor = FeatureExtractor(self.model, sub_blocks, layer=0)
|
||
|
else:
|
||
|
extractor = FeatureExtractor(self.model, blocks, layer=-2)
|
||
|
features, outs = extractor(x)
|
||
|
if features.dim() == 5:
|
||
|
global_pool = nn.AdaptiveAvgPool3d(1)
|
||
|
features = global_pool(features)
|
||
|
return features.flatten(), outs
|
||
|
|
||
|
def get_model_name(self):
|
||
|
full_list = [
|
||
|
'c2d_r50',
|
||
|
'i3d_r50',
|
||
|
'slow_r50',
|
||
|
'slowfast_r50',
|
||
|
'slowfast_r101',
|
||
|
'x3d_xs',
|
||
|
'x3d_s',
|
||
|
'x3d_m',
|
||
|
'mvit_base_16x4',
|
||
|
'mvit_base_32x3'
|
||
|
]
|
||
|
full_list.sort()
|
||
|
return full_list
|
||
|
|
||
|
|
||
|
class FeatureExtractor(nn.Module):
|
||
|
def __init__(self, model: nn.Module, blocks: List[nn.Module], layer: int):
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.features = None
|
||
|
|
||
|
target_layer = blocks[layer]
|
||
|
self.handler = target_layer.register_forward_hook(self.save_outputs_hook())
|
||
|
|
||
|
def save_outputs_hook(self) -> Callable:
|
||
|
def fn(_, __, output):
|
||
|
self.features = output
|
||
|
return fn
|
||
|
|
||
|
def forward(self, x):
|
||
|
outs = self.model(x)
|
||
|
self.handler.remove()
|
||
|
return self.features, outs
|