diff --git a/README.md b/README.md index d608772..3b74229 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,84 @@ -# tsm +# Video Classification with TSM +*Author: [Xinyu Ge](https://github.com/gexy185)* + +
+ +## Description + +A video classification operator generates labels (and corresponding scores) and extracts features for the input video. +It transforms the video into frames and loads pre-trained models by model names. +This operator has implemented pre-trained models from [TSM](https://arxiv.org/abs/1811.08383) +and maps vectors with labels provided by datasets used for pre-training. + +
+ +## Code Example + +Use the pretrained ActionClip model to classify and generate a vector for the given video path './archery.mp4' +([download](https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4)). + + *Write the pipeline in simplified style*: + +- Predict labels (default): +```python +import towhee + +( + towhee.glob('./archery.mp4') + .video_decode.ffmpeg() + .video_classification.tsm( + model_name='tsm_k400_r50_seg8', topk=5) + .show() +) +``` +
+ +## Factory Constructor + +Create the operator via the following factory method + +***video_classification.tsm( +model_name='tsm_k400_r50_seg8', skip_preprocess=False, classmap=None, topk=5)*** + +**Parameters:** + +​ ***model_name***: *str* + +​ The name of pre-trained clip model. + +​ Supported model names: +- tsm_k400_r50_seg8 + +​ ***skip_preprocess***: *bool* + +​ Flag to control whether to skip video transforms, defaults to False. +If set to True, the step to transform videos will be skipped. +In this case, the user should guarantee that all the input video frames are already reprocessed properly, +and thus can be fed to model directly. + +​ ***classmap***: *Dict[str: int]*: + +​ Dictionary that maps class names to one hot vectors. +If not given, the operator will load the default class map dictionary. + +​ ***topk***: *int* + +​ The topk labels & scores to present in result. The default value is 5. + +## Interface + +A video classification operator generates a list of class labels +and a corresponding vector in numpy.ndarray given a video input data. + +**Parameters:** + +​ ***video***: *Union[str, numpy.ndarray]* + +​ Input video data using local path in string or video frames in ndarray. + + +**Returns**: *(list, list)* + +​ A tuple of (labels, scores), +which contains lists of predicted class names and corresponding scores. diff --git a/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth b/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth new file mode 100644 index 0000000..24047ce Binary files /dev/null and b/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth differ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..6ffc213 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tsm import Tsm + + +def tsm(**kwargs): + return Tsm(**kwargs) diff --git a/archery.mp4 b/archery.mp4 new file mode 100644 index 0000000..4a724d6 Binary files /dev/null and b/archery.mp4 differ diff --git a/kinetics_400.json b/kinetics_400.json new file mode 100644 index 0000000..1d41181 --- /dev/null +++ b/kinetics_400.json @@ -0,0 +1 @@ +{"\"sharpening knives\"": 290, "\"eating ice cream\"": 115, "\"cutting nails\"": 81, "\"changing wheel\"": 53, "\"bench pressing\"": 19, "deadlifting": 88, "\"eating carrots\"": 111, "marching": 192, "\"throwing discus\"": 358, "\"playing flute\"": 231, "\"cooking on campfire\"": 72, "\"breading or breadcrumbing\"": 33, "\"playing badminton\"": 218, "\"ripping paper\"": 276, "\"playing saxophone\"": 244, "\"milking cow\"": 197, "\"juggling balls\"": 169, "\"flying kite\"": 130, "capoeira": 43, "\"making jewelry\"": 187, "drinking": 100, "\"playing cymbals\"": 228, "\"cleaning gutters\"": 61, "\"hurling (sport)\"": 161, "\"playing organ\"": 239, "\"tossing coin\"": 361, "wrestling": 395, "\"driving car\"": 103, "headbutting": 150, "\"gymnastics tumbling\"": 147, "\"making bed\"": 186, "abseiling": 0, "\"holding snake\"": 155, "\"rock climbing\"": 278, "\"cooking egg\"": 71, "\"long jump\"": 182, "\"bee keeping\"": 17, "\"trimming or shaving beard\"": 365, "\"cleaning shoes\"": 63, "\"dancing gangnam style\"": 86, "\"catching or throwing softball\"": 50, "\"ice skating\"": 164, "jogging": 168, "\"eating spaghetti\"": 116, "bobsledding": 28, "\"assembling computer\"": 8, "\"playing cricket\"": 227, "\"playing monopoly\"": 238, "\"golf putting\"": 143, "\"making pizza\"": 188, "\"javelin throw\"": 166, "\"peeling potatoes\"": 211, "clapping": 57, "\"brushing hair\"": 36, "\"flipping pancake\"": 129, "\"drinking beer\"": 101, "\"dribbling basketball\"": 99, "\"playing bagpipes\"": 219, "somersaulting": 325, "\"canoeing or kayaking\"": 42, "\"riding unicycle\"": 275, "texting": 355, "\"tasting beer\"": 352, "\"hockey stop\"": 154, "\"playing clarinet\"": 225, "\"waxing legs\"": 389, "\"curling hair\"": 80, "\"running on treadmill\"": 281, "\"tai chi\"": 346, "\"driving tractor\"": 104, "\"shaving legs\"": 293, "\"sharpening pencil\"": 291, "\"making sushi\"": 190, "\"spray painting\"": 327, "situp": 305, "\"playing kickball\"": 237, "\"sticking tongue out\"": 331, "headbanging": 149, "\"folding napkins\"": 132, "\"playing piano\"": 241, "skydiving": 312, "\"dancing charleston\"": 85, "\"ice fishing\"": 163, "tickling": 359, "bandaging": 13, "\"high jump\"": 151, "\"making a sandwich\"": 185, "\"riding mountain bike\"": 271, "\"cutting pineapple\"": 82, "\"feeding goats\"": 125, "\"dancing macarena\"": 87, "\"playing basketball\"": 220, "krumping": 179, "\"high kick\"": 152, "\"balloon blowing\"": 12, "\"playing accordion\"": 217, "\"playing chess\"": 224, "\"hula hooping\"": 159, "\"pushing wheelchair\"": 263, "\"riding camel\"": 268, "\"blowing out candles\"": 27, "\"extinguishing fire\"": 121, "\"using computer\"": 373, "\"jumpstyle dancing\"": 173, "yawning": 397, "writing": 396, "\"jumping into pool\"": 172, "\"doing laundry\"": 96, "\"egg hunting\"": 118, "\"sanding floor\"": 284, "\"moving furniture\"": 200, "\"exercising arm\"": 119, "\"sword fighting\"": 345, "\"sign language interpreting\"": 303, "\"counting money\"": 74, "bartending": 15, "\"cleaning windows\"": 65, "\"blasting sand\"": 23, "\"petting cat\"": 213, "sniffing": 320, "bowling": 31, "\"playing poker\"": 242, "\"taking a shower\"": 347, "\"washing hands\"": 382, "\"water sliding\"": 384, "\"presenting weather forecast\"": 254, "tobogganing": 360, "celebrating": 51, "\"getting a haircut\"": 138, "snorkeling": 321, "\"weaving basket\"": 390, "\"playing squash or racquetball\"": 245, "parasailing": 206, "\"news anchoring\"": 202, "\"belly dancing\"": 18, "windsurfing": 393, "\"braiding hair\"": 32, "\"crossing river\"": 78, "\"laying bricks\"": 181, "\"roller skating\"": 280, "hopscotch": 156, "\"playing trumpet\"": 248, "\"dying hair\"": 108, "\"trimming trees\"": 366, "\"pumping fist\"": 256, "\"playing keyboard\"": 236, "snowboarding": 322, "\"garbage collecting\"": 136, "\"playing controller\"": 226, "dodgeball": 94, "\"recording music\"": 266, "\"country line dancing\"": 75, "\"dancing ballet\"": 84, "gargling": 137, "ironing": 165, "\"push up\"": 260, "\"frying vegetables\"": 135, "\"ski jumping\"": 307, "\"mowing lawn\"": 201, "\"getting a tattoo\"": 139, "\"rock scissors paper\"": 279, "cheerleading": 55, "\"using remote controller (not gaming)\"": 374, "\"shaking head\"": 289, "sailing": 282, "\"training dog\"": 363, "hurdling": 160, "\"fixing hair\"": 128, "\"climbing ladder\"": 67, "\"filling eyebrows\"": 126, "\"springboard diving\"": 329, "\"eating watermelon\"": 117, "\"drumming fingers\"": 106, "\"waxing back\"": 386, "\"playing didgeridoo\"": 229, "\"swimming backstroke\"": 339, "\"biking through snow\"": 22, "\"washing feet\"": 380, "\"mopping floor\"": 198, "\"throwing ball\"": 357, "\"eating doughnuts\"": 113, "\"drinking shots\"": 102, "\"tying bow tie\"": 368, "dining": 91, "\"surfing water\"": 337, "\"sweeping floor\"": 338, "\"grooming dog\"": 145, "\"catching fish\"": 47, "\"pumping gas\"": 257, "\"riding or walking with horse\"": 273, "\"massaging person's head\"": 196, "archery": 5, "\"ice climbing\"": 162, "\"playing recorder\"": 243, "\"decorating the christmas tree\"": 89, "\"peeling apples\"": 210, "snowmobiling": 324, "\"playing ukulele\"": 249, "\"eating burger\"": 109, "\"building cabinet\"": 38, "\"stomping grapes\"": 332, "\"drop kicking\"": 105, "\"passing American football (not in game)\"": 209, "applauding": 3, "hugging": 158, "\"eating hotdog\"": 114, "\"pole vault\"": 253, "\"reading newspaper\"": 265, "\"snatch weight lifting\"": 318, "zumba": 399, "\"playing ice hockey\"": 235, "breakdancing": 34, "\"feeding fish\"": 124, "\"shredding paper\"": 300, "\"catching or throwing frisbee\"": 49, "\"exercising with an exercise ball\"": 120, "\"pushing cart\"": 262, "\"swimming butterfly stroke\"": 341, "\"riding scooter\"": 274, "spraying": 328, "\"folding paper\"": 133, "\"golf driving\"": 142, "\"robot dancing\"": 277, "\"bending back\"": 20, "testifying": 354, "\"waxing chest\"": 387, "\"carving pumpkin\"": 46, "\"hitting baseball\"": 153, "\"riding elephant\"": 269, "\"brushing teeth\"": 37, "\"pull ups\"": 255, "\"riding a bike\"": 267, "skateboarding": 306, "\"cleaning pool\"": 62, "\"playing paintball\"": 240, "\"massaging back\"": 193, "\"shoveling snow\"": 299, "\"surfing crowd\"": 336, "unboxing": 371, "faceplanting": 122, "trapezing": 364, "\"swinging legs\"": 343, "hoverboarding": 157, "\"playing violin\"": 250, "\"wrapping present\"": 394, "\"blowing nose\"": 26, "\"kicking field goal\"": 174, "\"picking fruit\"": 214, "\"swinging on something\"": 344, "\"giving or receiving award\"": 140, "\"planting trees\"": 215, "\"water skiing\"": 383, "\"washing dishes\"": 379, "\"punching bag\"": 258, "\"massaging legs\"": 195, "\"throwing axe\"": 356, "\"salsa dancing\"": 283, "bookbinding": 29, "\"tying tie\"": 370, "\"skiing crosscountry\"": 309, "\"shining shoes\"": 295, "\"making snowman\"": 189, "\"front raises\"": 134, "\"doing nails\"": 97, "\"massaging feet\"": 194, "\"playing drums\"": 230, "smoking": 316, "\"punching person (boxing)\"": 259, "cartwheeling": 45, "\"passing American football (in game)\"": 208, "\"shaking hands\"": 288, "plastering": 216, "\"watering plants\"": 385, "kissing": 176, "slapping": 314, "\"playing harmonica\"": 233, "welding": 391, "\"smoking hookah\"": 317, "\"scrambling eggs\"": 285, "\"cooking chicken\"": 70, "\"pushing car\"": 261, "\"opening bottle\"": 203, "\"cooking sausages\"": 73, "\"catching or throwing baseball\"": 48, "\"swimming breast stroke\"": 340, "digging": 90, "\"playing xylophone\"": 252, "\"doing aerobics\"": 95, "\"playing trombone\"": 247, "knitting": 178, "\"waiting in line\"": 377, "\"tossing salad\"": 362, "squat": 330, "vault": 376, "\"using segway\"": 375, "\"crawling baby\"": 77, "\"reading book\"": 264, "motorcycling": 199, "barbequing": 14, "\"cleaning floor\"": 60, "\"playing cello\"": 223, "drawing": 98, "auctioning": 9, "\"carrying baby\"": 44, "\"diving cliff\"": 93, "busking": 41, "\"cutting watermelon\"": 83, "\"scuba diving\"": 286, "\"riding mechanical bull\"": 270, "\"making tea\"": 191, "\"playing tennis\"": 246, "crying": 79, "\"dunking basketball\"": 107, "\"cracking neck\"": 76, "\"arranging flowers\"": 7, "\"building shed\"": 39, "\"golf chipping\"": 141, "\"tasting food\"": 353, "\"shaving head\"": 292, "\"answering questions\"": 2, "\"climbing tree\"": 68, "\"skipping rope\"": 311, "kitesurfing": 177, "\"juggling fire\"": 170, "laughing": 180, "paragliding": 205, "\"contact juggling\"": 69, "slacklining": 313, "\"arm wrestling\"": 6, "\"making a cake\"": 184, "\"finger snapping\"": 127, "\"grooming horse\"": 146, "\"opening present\"": 204, "\"tapping pen\"": 351, "singing": 304, "\"shot put\"": 298, "\"cleaning toilet\"": 64, "\"spinning poi\"": 326, "\"setting table\"": 287, "\"tying knot (not on a tie)\"": 369, "\"blowing glass\"": 24, "\"eating chips\"": 112, "\"tap dancing\"": 349, "\"climbing a rope\"": 66, "\"brush painting\"": 35, "\"chopping wood\"": 56, "\"stretching leg\"": 334, "\"petting animal (not cat)\"": 212, "\"baking cookies\"": 11, "\"stretching arm\"": 333, "beatboxing": 16, "jetskiing": 167, "\"bending metal\"": 21, "sneezing": 319, "\"folding clothes\"": 131, "\"sled dog racing\"": 315, "\"tapping guitar\"": 350, "\"bouncing on trampoline\"": 30, "\"waxing eyebrows\"": 388, "\"air drumming\"": 1, "\"kicking soccer ball\"": 175, "\"washing hair\"": 381, "\"riding mule\"": 272, "\"blowing leaves\"": 25, "\"strumming guitar\"": 335, "\"playing cards\"": 222, "snowkiting": 323, "\"playing bass guitar\"": 221, "\"applying cream\"": 4, "\"shooting basketball\"": 296, "\"walking the dog\"": 378, "\"triple jump\"": 367, "\"shearing sheep\"": 294, "\"clay pottery making\"": 58, "\"bungee jumping\"": 40, "\"unloading truck\"": 372, "\"shuffling cards\"": 301, "\"shooting goal (soccer)\"": 297, "\"tango dancing\"": 348, "\"side kick\"": 302, "\"grinding meat\"": 144, "yoga": 398, "\"hammer throw\"": 148, "\"changing oil\"": 52, "\"checking tires\"": 54, "parkour": 207, "\"eating cake\"": 110, "\"skiing slalom\"": 310, "\"juggling soccer ball\"": 171, "whistling": 392, "\"feeding birds\"": 123, "\"playing volleyball\"": 251, "\"swing dancing\"": 342, "\"skiing (not slalom or crosscountry)\"": 308, "lunge": 183, "\"disc golfing\"": 92, "\"clean and jerk\"": 59, "\"playing guitar\"": 232, "\"baby waking up\"": 10, "\"playing harp\"": 234} \ No newline at end of file diff --git a/tsm.py b/tsm.py new file mode 100644 index 0000000..cf9de9c --- /dev/null +++ b/tsm.py @@ -0,0 +1,108 @@ +import logging +import os +import json +from pathlib import Path +from typing import List + +import torch +import numpy + +from towhee import register +from towhee.operator.base import NNOperator +from towhee.types.video_frame import VideoFrame +from towhee.models.utils.video_transforms import get_configs, transform_video +from towhee.models.tsm.tsm import create_model + +log = logging.getLogger() + + +@register(output_schema=['labels', 'scores', 'features']) +class Tsm(NNOperator): + """ + Generate a list of class labels given a video input data. + Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics). + Args: + model_name (`str`): + Supported model names: + - tsm_k400_r50_seg8 + skip_preprocess (`str`): + Flag to skip video transforms. + predict (`bool`): + Flag to control whether predict labels. If False, then return video embedding. + classmap (`str=None`): + Path of the json file to match class names. + topk (`int=5`): + The number of classification labels to be returned (ordered by possibility from high to low). + """ + def __init__(self, + model_name: str = 'tsm_k400_r50_seg8', + framework: str = 'pytorch', + skip_preprocess: bool = False, + classmap: str = None, + topk: int = 5, + ): + super().__init__(framework=framework) + self.model_name = model_name + self.skip_preprocess = skip_preprocess + self.topk = topk + if 'k400' in model_name: + self.dataset_name = 'kinetics_400' + if classmap is None: + class_file = os.path.join(str(Path(__file__).parent), self.dataset_name+'.json') + with open(class_file, "r") as f: + kinetics_classes = json.load(f) + self.classmap = {} + for k, v in kinetics_classes.items(): + self.classmap[v] = str(k).replace('"', '') + else: + self.classmap = classmap + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if model_name == 'tsm_k400_r50_seg8': + self.weights_path = os.path.join(str(Path(__file__).parent), 'TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth') + self.model = create_model(model_name=model_name, pretrained=True, weights_path=self.weights_path, device=self.device) + self.transform_cfgs = get_configs( + side_size=224, + crop_size=224, + num_frames=8, + mean=self.model.input_mean, + std=self.model.input_std, + ) + + def __call__(self, video: List[VideoFrame]): + """ + Args: + video (`List[VideoFrame]`): + Video path in string. + + Returns: + (labels, scores) + A tuple of lists (labels, scores). + OR emb + Video embedding. + """ + # Convert list of towhee.types.Image to numpy.ndarray in float32 + video = numpy.stack([img.astype(numpy.float32)/255. for img in video], axis=0) + assert len(video.shape) == 4 + video = video.transpose(3, 0, 1, 2) # twhc -> ctwh + + # Transform video data given configs + if self.skip_preprocess: + self.cfg.update(num_frames=None) + + data = transform_video( + video=video, + **self.transform_cfgs + ) + inputs = data.to(self.device)[None, ...] + + feats = self.model.forward_features(inputs) + features = feats.to('cpu').squeeze(0).detach().numpy() + + outs = self.model(feats) + post_act = torch.nn.Softmax(dim=1) + preds = post_act(outs) + pred_scores, pred_classes = preds.topk(k=self.topk) + labels = [self.classmap[int(i)] for i in pred_classes[0]] + scores = [round(float(x), 5) for x in pred_scores[0]] + + return labels, scores, features