logo
Browse Source

add

Signed-off-by: xujinling <jinling.xu@zilliz.com>
main
xujinling 2 years ago
parent
commit
ad8d628a5c
  1. 104
      README.md
  2. 20
      __init__.py
  3. 74
      get_configs.py
  4. 1
      kinetics_400.json
  5. 107
      video_swin_transformer.py

104
README.md

@ -1,2 +1,104 @@
# video-swin-transformer
# Action Classification with VideoSwinTransformer
Author: Jinling xu
<br />
## Description
An action classification operator generates labels of human activities (with corresponding scores)
and extracts features for the input video.
It transforms the video into frames and loads pre-trained models by model names.
This operator has implemented pre-trained models from [TimeSformer](https://arxiv.org/abs/2102.05095)
and maps vectors with labels.
<br />
## Code Example
Use the pretrained TimeSformer model ('timesformer_k400_8x224')
to classify and generate a vector for the given video path './archery.mp4' ([download](https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4)).
*Write the pipeline in simplified style*:
```python
import towhee
(
towhee.glob('./archery.mp4')
.video_decode.ffmpeg()
.action_classification.video_swin_transformer(model_name='swin_tiny_patch244_window877_kinetics400_1k')
.show()
)
```
<img src="./result1.png" width="800px"/>
<br />
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
(
towhee.glob['path']('./archery.mp4')
.video_decode.ffmpeg['path', 'frames']()
.action_classification.video_swin_transformer['frames', ('labels', 'scores', 'features')](
model_name='swin_tiny_patch244_window877_kinetics400_1k')
.select['path', 'labels', 'scores', 'features']()
.show(formatter={'path': 'video_path'})
)
```
<img src="./result2.png" width="800px"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***action_classification.timesformer(
model_name='timesformer_k400_8x224', skip_preprocess=False, classmap=None, topk=5)***
**Parameters:**
***model_name***: *str*
​ The name of pre-trained model. Supported model names:
- timesformer_k400_8x224
***skip_preprocess***: *bool*
​ Flag to control whether to skip UniformTemporalSubsample in video transforms, defaults to False.
If set to True, the step of UniformTemporalSubsample will be skipped.
In this case, the user should guarantee that all the input video frames are already reprocessed properly,
and thus can be fed to model directly.
***classmap***: *Dict[str: int]*:
​ Dictionary that maps class names to one hot vectors.
If not given, the operator will load the default class map dictionary.
***topk***: *int*
​ The topk labels & scores to present in result. The default value is 5.
## Interface
A video classification operator generates a list of class labels
and a corresponding vector in numpy.ndarray given a video input data.
**Parameters:**
***video***: *List[towhee.types.VideoFrame]*
​ Input video data should be a list of towhee.types.VideoFrame representing video frames in order.
**Returns**:
***labels, scores, features***: *Tuple(List[str], List[float], numpy.ndarray)*
- labels: predicted class names.
- scores: possibility scores ranking from high to low corresponding to predicted labels.
- features: a video embedding in shape of (768,) representing features extracted by model.

20
__init__.py

@ -0,0 +1,20 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .video_swin_transformer import VideoSwinTransformer
def video_swin_transformer(model_name: str, modality: str, **kwargs):
return VideoSwinTransformer(model_name, modality, **kwargs)

74
get_configs.py

@ -0,0 +1,74 @@
def configs(model_name):
args = {
'swin_base_patch244_window877_kinetics400_1k':
{'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7), 'drop_path_rate': 0.4, 'patch_norm': True},
'swin_small_patch244_window877_kinetics400_1k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_small_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
'swin_tiny_patch244_window877_kinetics400_1k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_tiny_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.1,
'patch_norm': True},
'swin_base_patch244_window877_kinetics400_22k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics400_22k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
'swin_base_patch244_window877_kinetics600_22k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics600_22k.pth',
'num_classes': 600,
'labels_file_name': '',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7), 'drop_path_rate': 0.4, 'patch_norm': True},
'swin_base_patch244_window1677_sthv2':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window1677_sthv2.pth',
'num_classes': 174,
'labels_file_name': '',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (16, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
}
return args[model_name]

1
kinetics_400.json

File diff suppressed because one or more lines are too long

107
video_swin_transformer.py

@ -0,0 +1,107 @@
import logging
import os
import json
from pathlib import Path
from typing import List
import torch
import numpy
from towhee import register
from towhee.operator.base import NNOperator
from towhee.types.video_frame import VideoFrame
from towhee.models.utils.video_transforms import transform_video
from towhee.models.video_swin_transformer import video_swin_transformer
from get_configs import configs
log = logging.getLogger()
@register(output_schema=['labels', 'scores', 'features'])
class VideoSwinTransformer(NNOperator):
"""
Generate a list of class labels given a video input data.
Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics).
Args:
model_name (`str`):
Supported model names:
- swin_tiny_patch244_window877_kinetics400_1k
skip_preprocess (`str`):
Flag to skip video transforms.
classmap (`str=None`):
Path of the json file to match class names.
topk (`int=5`):
The number of classification labels to be returned (ordered by possibility from high to low).
"""
def __init__(self,
model_name: str = 'swin_tiny_patch244_window877_kinetics400_1k',
framework: str = 'pytorch',
skip_preprocess: bool = False,
classmap: str = None,
topk: int = 5,
):
super().__init__(framework=framework)
self.model_name = model_name
self.skip_preprocess = skip_preprocess
self.topk = topk
self.model_configs = configs(model_name=self.model_name)
if classmap is None:
class_file = os.path.join(str(Path(__file__).parent), self.model_configs['labels_file_name'])
with open(class_file, 'r') as f:
kinetics_classes = json.load(f)
self.classmap = {}
for k, v in kinetics_classes.items():
self.classmap[v] = str(k).replace('"', '')
else:
self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = video_swin_transformer.VideoSwinTransformer(
pretrained=self.model_configs['pretrained'],
num_classes=self.model_configs['num_classes'],
embed_dim=self.model_configs['embed_dim'],
depths=self.model_configs['depths'],
num_heads=self.model_configs['num_heads'],
patch_size=self.model_configs['patch_size'],
window_size=self.model_configs['window_size'],
drop_path_rate=self.model_configs['drop_path_rate'],
patch_norm=self.model_configs['patch_norm'],
device=self.device)
def __call__(self, video: List[VideoFrame]):
"""
Args:
video (`List[VideoFrame]`):
Video path in string.
Returns:
(labels, scores)
A tuple of lists (labels, scores).
OR emb
Video embedding.
"""
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32)/255. for img in video], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
# Transform video data given configs
if self.skip_preprocess:
self.cfg.update(num_frames=None)
data = transform_video(
video=video,
**self.cfg
)
inputs = data.to(self.device)[None, ...]
feats = self.model.forward_features(inputs)
features = feats.to('cpu').squeeze(0).detach().numpy()
outs = self.model.head(feats)
post_act = torch.nn.Softmax(dim=1)
preds = post_act(outs)
pred_scores, pred_classes = preds.topk(k=self.topk)
labels = [self.classmap[int(i)] for i in pred_classes[0]]
scores = [round(float(x), 5) for x in pred_scores[0]]
return labels, scores, features
Loading…
Cancel
Save