logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

125 lines
4.6 KiB

import logging
import os
import json
from pathlib import Path
from typing import List, Union
import torch
import numpy
from towhee import register
from towhee.types.video_frame import VideoFrame
from towhee.operator.base import NNOperator
from towhee.models.utils.video_transforms import transform_video, get_configs
from towhee.models import action_clip
log = logging.getLogger()
@register(output_schema=['label', 'vec'])
class ActionClip(NNOperator):
"""
Generate a list of class labels given a video input data.
Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics).
Args:
model_name (`str`):
Clip model name to be used in ActionClip
weights_path (`str`):
Pretrained model weights
skip_preprocess (`bool=False`):
If or not skip video transforms.
classmap (`str=None`):
Path of the json file to match class names.
topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low).
"""
def __init__(self,
model_name: str = 'clip_vit_b16',
weights_path: str = None,
skip_preprocess: bool = False,
classmap: dict = None,
topk: int = 5
):
super().__init__(framework='pytorch')
self.device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
self.model_name = model_name
self.skip_preprocess = skip_preprocess
self.topk = topk
if classmap is None:
class_file = os.path.join(str(Path(__file__).parent), 'kinetics_400.json')
with open(class_file, 'r') as f:
kinetics_classes = json.load(f)
self.classmap = {}
for k, v in kinetics_classes.items():
self.classmap[v] = str(k).replace('"', '')
else:
self.classmap = classmap
if weights_path is None:
weights_path = os.path.join(str(Path(__file__).parent), 'saved_model', 'action_' + model_name + '.pth')
checkpoints = torch.load(weights_path, map_location=self.device)
self.model = action_clip.create_model(
clip_model=model_name,
pretrained=True,
jit=True,
checkpoints=checkpoints,
device=self.device
)
self.transform_cfgs = get_configs(
side_size=224,
crop_size=224,
num_frames=8,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
def __call__(self, data: Union[List[VideoFrame], List[str]]):
"""
Args:
data (`Union[List[str], List[VideoFrame]]`):
Input video data or text data
Returns:
- (labels, scores)
A tuple of lists (labels, scores).
- a video embedding
- a list of text embeddings
"""
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
if self.skip_preprocess:
self.transform_cfgs.update(num_frames=None)
video = transform_video(
video=video,
**self.transform_cfgs
)
video = video.to(self.device)[None, ...].transpose(1, 2)
visual_features = self.encode_video(video).float()
features = visual_features.to('cpu').squeeze(0).detach().numpy()
kinetic_classes = list(self.classmap.values())
if self.model_name in ['clip_vit_b16', 'clip_vit_b32']:
saved_text_features = os.path.join(str(Path(__file__).parent), 'kinetics400_' + self.model_name + '.npz')
text_features = torch.from_numpy(numpy.load(saved_text_features)['arr_0'])
else:
text_features = self.encode_text(kinetic_classes)
text_features = text_features.float().to(self.device)
num_text_aug = int(text_features.size(0) / len(kinetic_classes))
similarity = action_clip.get_similarity(text_features, visual_features, num_text_augs=num_text_aug)
values_k, indices_k = similarity.topk(self.topk, dim=-1)
labels = [kinetic_classes[int(i)] for i in indices_k[0]]
scores = [round(float(x), 5) for x in values_k[0]]
return labels, scores, features
def encode_text(self, text: List[str]):
return self.model.encode_text(text)
def encode_video(self, video: List[VideoFrame]):
return self.model.encode_video(video)