From 257075a76bbc8b3d1f587a52d5096a581a348624 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 9 Jun 2022 14:19:42 +0800 Subject: [PATCH] Add files Signed-off-by: Jael Gu --- README.md | 112 ++++++++++++++++++++++++++++- __init__.py | 19 +++++ kinetics_400.json | 1 + pytorchvideo.py | 179 ++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 ++ 5 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 __init__.py create mode 100644 kinetics_400.json create mode 100644 pytorchvideo.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index d2bf495..afc91ef 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,112 @@ -# pytorchvideo +# Video Classification with Pytorchvideo +*Author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +A video classification operator is able to predict labels (and corresponding scores) +and extracts features given the input video. +It preprocesses video frames with video transforms and then loads pre-trained models by model names. +This operator has implemented pre-trained models from [Pytorchvideo](https://github.com/facebookresearch/pytorchvideo) +and maps vectors with labels provided by the [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics). + +
+ +## Code Example + +Use the pretrained SLOWFAST model ('slowfast_r50') +to classify and generate a vector for the given video path './archery.mp4' ([download](https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4)). + + *Write the pipeline in simplified style*: + +```python +import towhee + +( + towhee.glob('./archery.mp4') + .video_decode.ffmpeg() + .video_classification.pytorchvideo(model_name='slowfast_r50') + .to_list() +) +``` + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +( + towhee.glob['path']('./archery.mp4') + .video_decode.ffmpeg['path', 'frames']() + .video_classification.pytorchvideo['frames', ('labels', 'scores', 'features')]( + model_name='slowfast_r50') + .select['labels', 'scores', 'features']() + .show() +) +``` + + +
+ +## Factory Constructor + +Create the operator via the following factory method + +***video_classification.pytorchvideo( +model_name='x3d_xs', skip_preprocess=False, classmap=None, topk=5)*** + +**Parameters:** + +​ ***model_name***: *str* + +​ The name of pre-trained model from pytorchvideo hub. + +​ Supported model names: +- c2d_r50 +- i3d_r50 +- slow_r50 +- slowfast_r50 +- slowfast_r101 +- x3d_xs +- x3d_s +- x3d_m +- mvit_base_16x4 +- mvit_base_32x3 + +​ ***skip_preprocess***: *bool* + +​ Flag to control whether to skip UniformTemporalSubsample in video transforms, defaults to False. +If set to True, the step of UniformTemporalSubsample will be skipped. +In this case, the user should guarantee that all the input video frames are already reprocessed properly, +and thus can be fed to model directly. + +​ ***classmap***: *Dict[str: int]*: + +​ Dictionary that maps class names to one hot vectors. +If not given, the operator will load the default class map dictionary. + +​ ***topk***: *int* + +​ The topk labels & scores to present in result. The default value is 5. + +## Interface + +Given a video data, the video classification operator predicts a list of class labels +and generates a video embedding in numpy.ndarray. + +**Parameters:** + +​ ***frames***: *List[VideoFrame]* + +​ Video frames in towhee.types.video_frame.VideoFrame. + + +**Returns**: + +​ ***labels, scores, features***: *Tuple(List[str], List[float], numpy.ndarray)* + +- labels: predicted class names. +- scores: possibility scores ranking from high to low corresponding to predicted labels. +- features: a video embedding in shape of (num_features,) representing features extracted by model. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..9d4bb48 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pytorchvideo import PytorchVideo + + +def pytorchvideo(**kwargs): + return PytorchVideo(**kwargs) diff --git a/kinetics_400.json b/kinetics_400.json new file mode 100644 index 0000000..1d41181 --- /dev/null +++ b/kinetics_400.json @@ -0,0 +1 @@ +{"\"sharpening knives\"": 290, "\"eating ice cream\"": 115, "\"cutting nails\"": 81, "\"changing wheel\"": 53, "\"bench pressing\"": 19, "deadlifting": 88, "\"eating carrots\"": 111, "marching": 192, "\"throwing discus\"": 358, "\"playing flute\"": 231, "\"cooking on campfire\"": 72, "\"breading or breadcrumbing\"": 33, "\"playing badminton\"": 218, "\"ripping paper\"": 276, "\"playing saxophone\"": 244, "\"milking cow\"": 197, "\"juggling balls\"": 169, "\"flying kite\"": 130, "capoeira": 43, "\"making jewelry\"": 187, "drinking": 100, "\"playing cymbals\"": 228, "\"cleaning gutters\"": 61, "\"hurling (sport)\"": 161, "\"playing organ\"": 239, "\"tossing coin\"": 361, "wrestling": 395, "\"driving car\"": 103, "headbutting": 150, "\"gymnastics tumbling\"": 147, "\"making bed\"": 186, "abseiling": 0, "\"holding snake\"": 155, "\"rock climbing\"": 278, "\"cooking egg\"": 71, "\"long jump\"": 182, "\"bee keeping\"": 17, "\"trimming or shaving beard\"": 365, "\"cleaning shoes\"": 63, "\"dancing gangnam style\"": 86, "\"catching or throwing softball\"": 50, "\"ice skating\"": 164, "jogging": 168, "\"eating spaghetti\"": 116, "bobsledding": 28, "\"assembling computer\"": 8, "\"playing cricket\"": 227, "\"playing monopoly\"": 238, "\"golf putting\"": 143, "\"making pizza\"": 188, "\"javelin throw\"": 166, "\"peeling potatoes\"": 211, "clapping": 57, "\"brushing hair\"": 36, "\"flipping pancake\"": 129, "\"drinking beer\"": 101, "\"dribbling basketball\"": 99, "\"playing bagpipes\"": 219, "somersaulting": 325, "\"canoeing or kayaking\"": 42, "\"riding unicycle\"": 275, "texting": 355, "\"tasting beer\"": 352, "\"hockey stop\"": 154, "\"playing clarinet\"": 225, "\"waxing legs\"": 389, "\"curling hair\"": 80, "\"running on treadmill\"": 281, "\"tai chi\"": 346, "\"driving tractor\"": 104, "\"shaving legs\"": 293, "\"sharpening pencil\"": 291, "\"making sushi\"": 190, "\"spray painting\"": 327, "situp": 305, "\"playing kickball\"": 237, "\"sticking tongue out\"": 331, "headbanging": 149, "\"folding napkins\"": 132, "\"playing piano\"": 241, "skydiving": 312, "\"dancing charleston\"": 85, "\"ice fishing\"": 163, "tickling": 359, "bandaging": 13, "\"high jump\"": 151, "\"making a sandwich\"": 185, "\"riding mountain bike\"": 271, "\"cutting pineapple\"": 82, "\"feeding goats\"": 125, "\"dancing macarena\"": 87, "\"playing basketball\"": 220, "krumping": 179, "\"high kick\"": 152, "\"balloon blowing\"": 12, "\"playing accordion\"": 217, "\"playing chess\"": 224, "\"hula hooping\"": 159, "\"pushing wheelchair\"": 263, "\"riding camel\"": 268, "\"blowing out candles\"": 27, "\"extinguishing fire\"": 121, "\"using computer\"": 373, "\"jumpstyle dancing\"": 173, "yawning": 397, "writing": 396, "\"jumping into pool\"": 172, "\"doing laundry\"": 96, "\"egg hunting\"": 118, "\"sanding floor\"": 284, "\"moving furniture\"": 200, "\"exercising arm\"": 119, "\"sword fighting\"": 345, "\"sign language interpreting\"": 303, "\"counting money\"": 74, "bartending": 15, "\"cleaning windows\"": 65, "\"blasting sand\"": 23, "\"petting cat\"": 213, "sniffing": 320, "bowling": 31, "\"playing poker\"": 242, "\"taking a shower\"": 347, "\"washing hands\"": 382, "\"water sliding\"": 384, "\"presenting weather forecast\"": 254, "tobogganing": 360, "celebrating": 51, "\"getting a haircut\"": 138, "snorkeling": 321, "\"weaving basket\"": 390, "\"playing squash or racquetball\"": 245, "parasailing": 206, "\"news anchoring\"": 202, "\"belly dancing\"": 18, "windsurfing": 393, "\"braiding hair\"": 32, "\"crossing river\"": 78, "\"laying bricks\"": 181, "\"roller skating\"": 280, "hopscotch": 156, "\"playing trumpet\"": 248, "\"dying hair\"": 108, "\"trimming trees\"": 366, "\"pumping fist\"": 256, "\"playing keyboard\"": 236, "snowboarding": 322, "\"garbage collecting\"": 136, "\"playing controller\"": 226, "dodgeball": 94, "\"recording music\"": 266, "\"country line dancing\"": 75, "\"dancing ballet\"": 84, "gargling": 137, "ironing": 165, "\"push up\"": 260, "\"frying vegetables\"": 135, "\"ski jumping\"": 307, "\"mowing lawn\"": 201, "\"getting a tattoo\"": 139, "\"rock scissors paper\"": 279, "cheerleading": 55, "\"using remote controller (not gaming)\"": 374, "\"shaking head\"": 289, "sailing": 282, "\"training dog\"": 363, "hurdling": 160, "\"fixing hair\"": 128, "\"climbing ladder\"": 67, "\"filling eyebrows\"": 126, "\"springboard diving\"": 329, "\"eating watermelon\"": 117, "\"drumming fingers\"": 106, "\"waxing back\"": 386, "\"playing didgeridoo\"": 229, "\"swimming backstroke\"": 339, "\"biking through snow\"": 22, "\"washing feet\"": 380, "\"mopping floor\"": 198, "\"throwing ball\"": 357, "\"eating doughnuts\"": 113, "\"drinking shots\"": 102, "\"tying bow tie\"": 368, "dining": 91, "\"surfing water\"": 337, "\"sweeping floor\"": 338, "\"grooming dog\"": 145, "\"catching fish\"": 47, "\"pumping gas\"": 257, "\"riding or walking with horse\"": 273, "\"massaging person's head\"": 196, "archery": 5, "\"ice climbing\"": 162, "\"playing recorder\"": 243, "\"decorating the christmas tree\"": 89, "\"peeling apples\"": 210, "snowmobiling": 324, "\"playing ukulele\"": 249, "\"eating burger\"": 109, "\"building cabinet\"": 38, "\"stomping grapes\"": 332, "\"drop kicking\"": 105, "\"passing American football (not in game)\"": 209, "applauding": 3, "hugging": 158, "\"eating hotdog\"": 114, "\"pole vault\"": 253, "\"reading newspaper\"": 265, "\"snatch weight lifting\"": 318, "zumba": 399, "\"playing ice hockey\"": 235, "breakdancing": 34, "\"feeding fish\"": 124, "\"shredding paper\"": 300, "\"catching or throwing frisbee\"": 49, "\"exercising with an exercise ball\"": 120, "\"pushing cart\"": 262, "\"swimming butterfly stroke\"": 341, "\"riding scooter\"": 274, "spraying": 328, "\"folding paper\"": 133, "\"golf driving\"": 142, "\"robot dancing\"": 277, "\"bending back\"": 20, "testifying": 354, "\"waxing chest\"": 387, "\"carving pumpkin\"": 46, "\"hitting baseball\"": 153, "\"riding elephant\"": 269, "\"brushing teeth\"": 37, "\"pull ups\"": 255, "\"riding a bike\"": 267, "skateboarding": 306, "\"cleaning pool\"": 62, "\"playing paintball\"": 240, "\"massaging back\"": 193, "\"shoveling snow\"": 299, "\"surfing crowd\"": 336, "unboxing": 371, "faceplanting": 122, "trapezing": 364, "\"swinging legs\"": 343, "hoverboarding": 157, "\"playing violin\"": 250, "\"wrapping present\"": 394, "\"blowing nose\"": 26, "\"kicking field goal\"": 174, "\"picking fruit\"": 214, "\"swinging on something\"": 344, "\"giving or receiving award\"": 140, "\"planting trees\"": 215, "\"water skiing\"": 383, "\"washing dishes\"": 379, "\"punching bag\"": 258, "\"massaging legs\"": 195, "\"throwing axe\"": 356, "\"salsa dancing\"": 283, "bookbinding": 29, "\"tying tie\"": 370, "\"skiing crosscountry\"": 309, "\"shining shoes\"": 295, "\"making snowman\"": 189, "\"front raises\"": 134, "\"doing nails\"": 97, "\"massaging feet\"": 194, "\"playing drums\"": 230, "smoking": 316, "\"punching person (boxing)\"": 259, "cartwheeling": 45, "\"passing American football (in game)\"": 208, "\"shaking hands\"": 288, "plastering": 216, "\"watering plants\"": 385, "kissing": 176, "slapping": 314, "\"playing harmonica\"": 233, "welding": 391, "\"smoking hookah\"": 317, "\"scrambling eggs\"": 285, "\"cooking chicken\"": 70, "\"pushing car\"": 261, "\"opening bottle\"": 203, "\"cooking sausages\"": 73, "\"catching or throwing baseball\"": 48, "\"swimming breast stroke\"": 340, "digging": 90, "\"playing xylophone\"": 252, "\"doing aerobics\"": 95, "\"playing trombone\"": 247, "knitting": 178, "\"waiting in line\"": 377, "\"tossing salad\"": 362, "squat": 330, "vault": 376, "\"using segway\"": 375, "\"crawling baby\"": 77, "\"reading book\"": 264, "motorcycling": 199, "barbequing": 14, "\"cleaning floor\"": 60, "\"playing cello\"": 223, "drawing": 98, "auctioning": 9, "\"carrying baby\"": 44, "\"diving cliff\"": 93, "busking": 41, "\"cutting watermelon\"": 83, "\"scuba diving\"": 286, "\"riding mechanical bull\"": 270, "\"making tea\"": 191, "\"playing tennis\"": 246, "crying": 79, "\"dunking basketball\"": 107, "\"cracking neck\"": 76, "\"arranging flowers\"": 7, "\"building shed\"": 39, "\"golf chipping\"": 141, "\"tasting food\"": 353, "\"shaving head\"": 292, "\"answering questions\"": 2, "\"climbing tree\"": 68, "\"skipping rope\"": 311, "kitesurfing": 177, "\"juggling fire\"": 170, "laughing": 180, "paragliding": 205, "\"contact juggling\"": 69, "slacklining": 313, "\"arm wrestling\"": 6, "\"making a cake\"": 184, "\"finger snapping\"": 127, "\"grooming horse\"": 146, "\"opening present\"": 204, "\"tapping pen\"": 351, "singing": 304, "\"shot put\"": 298, "\"cleaning toilet\"": 64, "\"spinning poi\"": 326, "\"setting table\"": 287, "\"tying knot (not on a tie)\"": 369, "\"blowing glass\"": 24, "\"eating chips\"": 112, "\"tap dancing\"": 349, "\"climbing a rope\"": 66, "\"brush painting\"": 35, "\"chopping wood\"": 56, "\"stretching leg\"": 334, "\"petting animal (not cat)\"": 212, "\"baking cookies\"": 11, "\"stretching arm\"": 333, "beatboxing": 16, "jetskiing": 167, "\"bending metal\"": 21, "sneezing": 319, "\"folding clothes\"": 131, "\"sled dog racing\"": 315, "\"tapping guitar\"": 350, "\"bouncing on trampoline\"": 30, "\"waxing eyebrows\"": 388, "\"air drumming\"": 1, "\"kicking soccer ball\"": 175, "\"washing hair\"": 381, "\"riding mule\"": 272, "\"blowing leaves\"": 25, "\"strumming guitar\"": 335, "\"playing cards\"": 222, "snowkiting": 323, "\"playing bass guitar\"": 221, "\"applying cream\"": 4, "\"shooting basketball\"": 296, "\"walking the dog\"": 378, "\"triple jump\"": 367, "\"shearing sheep\"": 294, "\"clay pottery making\"": 58, "\"bungee jumping\"": 40, "\"unloading truck\"": 372, "\"shuffling cards\"": 301, "\"shooting goal (soccer)\"": 297, "\"tango dancing\"": 348, "\"side kick\"": 302, "\"grinding meat\"": 144, "yoga": 398, "\"hammer throw\"": 148, "\"changing oil\"": 52, "\"checking tires\"": 54, "parkour": 207, "\"eating cake\"": 110, "\"skiing slalom\"": 310, "\"juggling soccer ball\"": 171, "whistling": 392, "\"feeding birds\"": 123, "\"playing volleyball\"": 251, "\"swing dancing\"": 342, "\"skiing (not slalom or crosscountry)\"": 308, "lunge": 183, "\"disc golfing\"": 92, "\"clean and jerk\"": 59, "\"playing guitar\"": 232, "\"baby waking up\"": 10, "\"playing harp\"": 234} \ No newline at end of file diff --git a/pytorchvideo.py b/pytorchvideo.py new file mode 100644 index 0000000..69ce409 --- /dev/null +++ b/pytorchvideo.py @@ -0,0 +1,179 @@ +import logging +import os +import json +from pathlib import Path +from typing import List, Union, Iterable, Callable + +import torch +from torch import nn +import numpy + +from towhee import register +from towhee.types import VideoFrame +from towhee.operator.base import NNOperator +from towhee.models.utils.video_transforms import transform_video + +log = logging.getLogger() + + +@register(output_schema=['labels', 'scores', 'features']) +class PytorchVideo(NNOperator): + """ + Generate a list of class labels given a video input data. + Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics). + + Args: + model_name (`str`): + The pretrained model name from torch hub. + Supported model names: + - c2d_r50 + - i3d_r50 + - slow_r50 + - slowfast_r50 + - slowfast_r101 + - x3d_xs + - x3d_s + - x3d_m + - mvit_base_16x4 + - mvit_base_32x3 + skip_preprocess (`str`): + Flag to skip video transforms. + classmap (`str=None`): + Path of the json file to match class names. + topk (`int=5`): + The number of classification labels to be returned (ordered by possibility from high to low). + """ + + def __init__( + self, + model_name: str = 'x3d_xs', + framework: str = 'pytorch', + skip_preprocess: bool = False, + classmap: str = None, + topk: int = 5, + ) -> None: + super().__init__(framework=framework) + self.model_name = model_name + self.skip_preprocess = skip_preprocess + self.topk = topk + if classmap is None: + class_file = os.path.join(str(Path(__file__).parent), 'kinetics_400.json') + with open(class_file, 'r') as f: + kinetics_classes = json.load(f) + self.classmap = {} + for k, v in kinetics_classes.items(): + self.classmap[v] = str(k).replace('"', '') + else: + self.classmap = classmap + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + self.model = torch.hub.load('facebookresearch/pytorchvideo', model=model_name, pretrained=True) + self.model.eval() + self.model.to(self.device) + + def __call__(self, frames: List[VideoFrame]): + """ + Args: + frames (`List[VideoFrame]`): + Video frames in towhee.types.video_frame.VideoFrame. + + Returns: + labels, scores: + A tuple of lists (labels, scores). + video embedding: + A video embedding in numpy.ndarray. + """ + # Convert list of towhee.types.Image to numpy.ndarray in float32 + video = numpy.stack([img.astype(numpy.float32) / 255. for img in frames], axis=0) + assert len(video.shape) == 4 + video = video.transpose(3, 0, 1, 2) # twhc -> ctwh + + if self.skip_preprocess: + data = transform_video( + video=video, + model_name=self.model_name, + num_frames=None + ) + else: + data = transform_video( + video=video, + model_name=self.model_name + ) + if self.model_name.startswith('slowfast'): + inputs = [data[0].to(self.device)[None, ...], data[1].to(self.device)[None, ...]] + else: + inputs = data.to(self.device)[None, ...] + + feats, outs = self.new_forward(inputs) + features = feats.to('cpu').squeeze(0).detach().numpy() + + post_act = torch.nn.Softmax(dim=1) + preds = post_act(outs) + pred_scores, pred_classes = preds.topk(k=self.topk) + labels = [self.classmap[int(i)] for i in pred_classes[0]] + scores = [round(float(x), 5) for x in pred_scores[0]] + return labels, scores, features + + def new_forward(self, x: Union[torch.Tensor, list]): + """ + Generate embeddings returned by the second last hidden layer. + + Args: + x (`Union[torch.Tensor, list]`): + tensor or list of input video after transforms + + Returns: + Tensor of layer outputs. + """ + blocks = list(self.model.children()) + if len(blocks) == 1: + blocks = blocks[0] + if self.model_name.startswith('x3d'): + sub_blocks = list(blocks[-1].children()) + extractor = FeatureExtractor(self.model, sub_blocks, layer=0) + elif self.model_name.startswith('mvit'): + sub_blocks = list(blocks[-1].children()) + extractor = FeatureExtractor(self.model, sub_blocks, layer=0) + else: + extractor = FeatureExtractor(self.model, blocks, layer=-2) + features, outs = extractor(x) + if features.dim() == 5: + global_pool = nn.AdaptiveAvgPool3d(1) + features = global_pool(features) + return features.flatten(), outs + + def get_model_name(self): + full_list = [ + 'c2d_r50', + 'i3d_r50', + 'slow_r50', + 'slowfast_r50', + 'slowfast_r101', + 'x3d_xs', + 'x3d_s', + 'x3d_m', + 'mvit_base_16x4', + 'mvit_base_32x3' + ] + full_list.sort() + return full_list + + +class FeatureExtractor(nn.Module): + def __init__(self, model: nn.Module, blocks: List[nn.Module], layer: int): + super().__init__() + self.model = model + self.features = None + + target_layer = blocks[layer] + self.handler = target_layer.register_forward_hook(self.save_outputs_hook()) + + def save_outputs_hook(self) -> Callable: + def fn(_, __, output): + self.features = output + return fn + + def forward(self, x): + outs = self.model(x) + self.handler.remove() + return self.features, outs diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8fef7f2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +# torch>=1.8.0 +# torchvision>=0.9.0 +# pytorchvideo +# towhee>=0.6.0