logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

118 lines
4.1 KiB

import logging
import os
import csv
from pathlib import Path
from typing import List
import torch
import numpy
from towhee import register
from towhee.operator.base import NNOperator
from towhee.types.video_frame import VideoFrame
from towhee.models.utils.video_transforms import get_configs, transform_video
from towhee.models.movinet.movinet import create_model
log = logging.getLogger()
@register(output_schema=['labels', 'scores', 'features'])
class Movinet(NNOperator):
"""
Generate a list of class labels given a video input data.
Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics).
Args:
model_name (`str`):
Supported model names:
- movineta0
- movineta1
- movineta2
- movineta3
- movineta4
- movineta5
skip_preprocess (`str`):
Flag to skip video transforms.
predict (`bool`):
Flag to control whether predict labels. If False, then return video embedding.
classmap (`dict=None`):
The dictionary maps classes to integers.
topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low).
"""
def __init__(self,
model_name: str = 'movineta0',
framework: str = 'pytorch',
causal: str = False,
skip_preprocess: bool = False,
classmap: dict = None,
topk: int = 5,
):
super().__init__(framework=framework)
self.model_name = model_name
self.causal = causal
self.skip_preprocess = skip_preprocess
self.topk = topk
self.dataset_name = 'kinetics_600'
if classmap is None:
class_file = os.path.join(str(Path(__file__).parent), 'kinetics_600'+'.csv')
csvFile = open(class_file, "r")
reader = csv.reader(csvFile)
self.classmap = {}
for item in reader:
if reader.line_num == 1:
continue
self.classmap[int(item[0])] = item[1]
csvFile.close()
else:
self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = create_model(model_name=model_name, pretrained=True, causal=self.causal, device=self.device)
self.input_mean=[0.485, 0.456, 0.406]
self.input_std=[0.229, 0.224, 0.225]
self.transform_cfgs = get_configs(
side_size=176,
crop_size=176,
num_frames=50,
mean=self.input_mean,
std=self.input_std,
)
self.model.eval()
def __call__(self, video: List[VideoFrame]):
"""
Args:
video (`List[VideoFrame]`):
Video path in string.
Returns:
(labels, scores)
A tuple of lists (labels, scores).
OR emb
Video embedding.
"""
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32)/255. for img in video], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
# Transform video data given configs
if self.skip_preprocess:
self.transform_cfgs.update(num_frames=None)
data = transform_video(
video=video,
**self.transform_cfgs
)
inputs = data.to(self.device)[None, ...]
feats = self.model.forward_features(inputs)
features = feats.to('cpu').squeeze(0).detach().numpy()
outs = self.model.head(feats)
post_act = torch.nn.Softmax(dim=1)
preds = post_act(outs)
pred_scores, pred_classes = preds.topk(k=self.topk)
labels = [self.classmap[int(i)] for i in pred_classes[0]]
scores = [round(float(x), 5) for x in pred_scores[0]]
return labels, scores, features