diff --git a/.gitattributes b/.gitattributes index ad2c207..6d34772 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,3 @@ - *.7z filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 2d38ecc..a9e5375 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,104 @@ -# actionclip +# Action Classification with ActionClip +*Author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +An action classification operator generates labels of human activities (with corresponding scores) and extracts features for the input video. +It transforms the video into frames and loads pre-trained models by model names. +This operator has implemented pre-trained models from [ActionClip](https://arxiv.org/abs/2109.08472) +and maps vectors with labels provided by datasets used for pre-training. + +
+ +## Code Example + +Use the pretrained ActionClip model to classify and generate a vector for the given video path './archery.mp4' +([download](https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4)). + + *Write the pipeline in simplified style*: + +- Predict labels (default): +```python +import towhee + +( + towhee.glob('./archery.mp4') + .video_decode.ffmpeg() + .action_classification.actionclip(model_name='clip_vit_b16') + .show() +) +``` + + + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +( + towhee.glob['path']('./archery.mp4') + .video_decode.ffmpeg['path', 'frames']() + .action_classification.actionclip['frames', ('labels', 'scores', 'features')](model_name='clip_vit_b16') + .select['path', 'labels', 'scores', 'features']() + .show(formatter={'path': 'video_path'}) +) +``` + + + +
+ +## Factory Constructor + +Create the operator via the following factory method + +***action_classification.actionclip(model_name='clip_vit_b16', skip_preprocess=False, classmap=None, topk=5)*** + +**Parameters:** + +​ ***model_name***: *str* + +​ The name of pre-trained clip model. + +​ Supported model names: +- clip_vit_b16 +- clip_vit_b32 + +​ ***skip_preprocess***: *bool* + +​ Flag to control whether to skip video transforms, defaults to False. +If set to True, the step to transform videos will be skipped. +In this case, the user should guarantee that all the input video frames are already reprocessed properly, +and thus can be fed to model directly. + +​ ***classmap***: *Dict[str: int]*: + +​ Dictionary that maps class names to one hot vectors. +If not given, the operator will load the default class map dictionary. + +​ ***topk***: *int* + +​ The topk labels & scores to present in result. The default value is 5. + +## Interface + +A video classification operator generates a list of class labels +and a corresponding vector in numpy.ndarray given a video input data. + +**Parameters:** + +​ ***frames***: *List[VideoFrame]* + +​ Video frames in towhee.types.video_frame.VideoFrame. + +**Returns**: + +​ ***labels, scores, features***: *Tuple(List[str], List[float], numpy.ndarray)* + +- labels: predicted class names. +- scores: possibility scores ranking from high to low corresponding to predicted labels. +- features: a video embedding in shape of (num_features,) representing features extracted by model. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..ed2a807 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .action_clip import ActionClip + + +def actionclip(**kwargs): + return ActionClip(**kwargs) diff --git a/action_clip.py b/action_clip.py new file mode 100644 index 0000000..938738f --- /dev/null +++ b/action_clip.py @@ -0,0 +1,123 @@ +import logging +import os +import json +from pathlib import Path +from typing import List, Union + +import torch +import numpy + +from towhee import register +from towhee.types.video_frame import VideoFrame +from towhee.operator.base import NNOperator +from towhee.models.utils.video_transforms import transform_video, get_configs +from towhee.models import action_clip + +log = logging.getLogger() + + +@register(output_schema=['label', 'vec']) +class ActionClip(NNOperator): + """ + Generate a list of class labels given a video input data. + Default labels are from [Kinetics400 Dataset](https://deepmind.com/research/open-source/kinetics). + + Args: + model_name (`str`): + Clip model name to be used in ActionClip + weights_path (`str`): + Pretrained model weights + skip_preprocess (`bool=False`): + If or not skip video transforms. + classmap (`str=None`): + Path of the json file to match class names. + topk (`int=5`): + The number of classification labels to be returned (ordered by possibility from high to low). + """ + def __init__(self, + model_name: str = 'clip_vit_b16', + weights_path: str = None, + skip_preprocess: bool = False, + classmap: dict = None, + topk: int = 5 + ): + super().__init__(framework='pytorch') + self.device = 'cpu' # todo: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model_name = model_name + self.skip_preprocess = skip_preprocess + self.topk = topk + if classmap is None: + class_file = os.path.join(str(Path(__file__).parent), 'kinetics_400.json') + with open(class_file, 'r') as f: + kinetics_classes = json.load(f) + self.classmap = {} + for k, v in kinetics_classes.items(): + self.classmap[v] = str(k).replace('"', '') + else: + self.classmap = classmap + + if weights_path is None: + weights_path = os.path.join(str(Path(__file__).parent), 'saved_model', 'action_' + model_name + '.pth') + checkpoints = torch.load(weights_path, map_location=self.device) + self.model = action_clip.create_model( + clip_model=model_name, + pretrained=True, + jit=True, + checkpoints=checkpoints + ) + + self.transform_cfgs = get_configs( + side_size=224, + crop_size=224, + num_frames=8, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + + def __call__(self, data: Union[List[VideoFrame], List[str]]): + """ + Args: + data (`Union[List[str], List[VideoFrame]]`): + Input video data or text data + + Returns: + - (labels, scores) + A tuple of lists (labels, scores). + - a video embedding + - a list of text embeddings + + """ + # Convert list of towhee.types.Image to numpy.ndarray in float32 + video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0) + assert len(video.shape) == 4 + video = video.transpose(3, 0, 1, 2) # twhc -> ctwh + + if self.skip_preprocess: + self.transform_cfgs.update(num_frames=None) + video = transform_video( + video=video, + **self.transform_cfgs + ) + video = video.to(self.device)[None, ...].transpose(1, 2) + visual_features = self.encode_video(video) + features = visual_features.to('cpu').squeeze(0).detach().numpy() + + kinetic_classes = list(self.classmap.values()) + if self.model_name in ['clip_vit_b16', 'clip_vit_b32']: + saved_text_features = os.path.join(str(Path(__file__).parent), 'kinetics400_' + self.model_name + '.npz') + text_features = torch.from_numpy(numpy.load(saved_text_features)['arr_0']) + else: + text_features = self.encode_text(kinetic_classes) + + num_text_aug = int(text_features.size(0) / len(kinetic_classes)) + similarity = action_clip.get_similarity(text_features, visual_features, num_text_augs=num_text_aug) + values_k, indices_k = similarity.topk(self.topk, dim=-1) + labels = [kinetic_classes[int(i)] for i in indices_k[0]] + scores = [round(float(x), 5) for x in values_k[0]] + return labels, scores, features + + def encode_text(self, text: List[str]): + return self.model.encode_text(text) + + def encode_video(self, video: List[VideoFrame]): + return self.model.encode_video(video) diff --git a/kinetics400_clip_vit_b16.npz b/kinetics400_clip_vit_b16.npz new file mode 100644 index 0000000..be36e77 Binary files /dev/null and b/kinetics400_clip_vit_b16.npz differ diff --git a/kinetics_400.json b/kinetics_400.json new file mode 100644 index 0000000..1d41181 --- /dev/null +++ b/kinetics_400.json @@ -0,0 +1 @@ +{"\"sharpening knives\"": 290, "\"eating ice cream\"": 115, "\"cutting nails\"": 81, "\"changing wheel\"": 53, "\"bench pressing\"": 19, "deadlifting": 88, "\"eating carrots\"": 111, "marching": 192, "\"throwing discus\"": 358, "\"playing flute\"": 231, "\"cooking on campfire\"": 72, "\"breading or breadcrumbing\"": 33, "\"playing badminton\"": 218, "\"ripping paper\"": 276, "\"playing saxophone\"": 244, "\"milking cow\"": 197, "\"juggling balls\"": 169, "\"flying kite\"": 130, "capoeira": 43, "\"making jewelry\"": 187, "drinking": 100, "\"playing cymbals\"": 228, "\"cleaning gutters\"": 61, "\"hurling (sport)\"": 161, "\"playing organ\"": 239, "\"tossing coin\"": 361, "wrestling": 395, "\"driving car\"": 103, "headbutting": 150, "\"gymnastics tumbling\"": 147, "\"making bed\"": 186, "abseiling": 0, "\"holding snake\"": 155, "\"rock climbing\"": 278, "\"cooking egg\"": 71, "\"long jump\"": 182, "\"bee keeping\"": 17, "\"trimming or shaving beard\"": 365, "\"cleaning shoes\"": 63, "\"dancing gangnam style\"": 86, "\"catching or throwing softball\"": 50, "\"ice skating\"": 164, "jogging": 168, "\"eating spaghetti\"": 116, "bobsledding": 28, "\"assembling computer\"": 8, "\"playing cricket\"": 227, "\"playing monopoly\"": 238, "\"golf putting\"": 143, "\"making pizza\"": 188, "\"javelin throw\"": 166, "\"peeling potatoes\"": 211, "clapping": 57, "\"brushing hair\"": 36, "\"flipping pancake\"": 129, "\"drinking beer\"": 101, "\"dribbling basketball\"": 99, "\"playing bagpipes\"": 219, "somersaulting": 325, "\"canoeing or kayaking\"": 42, "\"riding unicycle\"": 275, "texting": 355, "\"tasting beer\"": 352, "\"hockey stop\"": 154, "\"playing clarinet\"": 225, "\"waxing legs\"": 389, "\"curling hair\"": 80, "\"running on treadmill\"": 281, "\"tai chi\"": 346, "\"driving tractor\"": 104, "\"shaving legs\"": 293, "\"sharpening pencil\"": 291, "\"making sushi\"": 190, "\"spray painting\"": 327, "situp": 305, "\"playing kickball\"": 237, "\"sticking tongue out\"": 331, "headbanging": 149, "\"folding napkins\"": 132, "\"playing piano\"": 241, "skydiving": 312, "\"dancing charleston\"": 85, "\"ice fishing\"": 163, "tickling": 359, "bandaging": 13, "\"high jump\"": 151, "\"making a sandwich\"": 185, "\"riding mountain bike\"": 271, "\"cutting pineapple\"": 82, "\"feeding goats\"": 125, "\"dancing macarena\"": 87, "\"playing basketball\"": 220, "krumping": 179, "\"high kick\"": 152, "\"balloon blowing\"": 12, "\"playing accordion\"": 217, "\"playing chess\"": 224, "\"hula hooping\"": 159, "\"pushing wheelchair\"": 263, "\"riding camel\"": 268, "\"blowing out candles\"": 27, "\"extinguishing fire\"": 121, "\"using computer\"": 373, "\"jumpstyle dancing\"": 173, "yawning": 397, "writing": 396, "\"jumping into pool\"": 172, "\"doing laundry\"": 96, "\"egg hunting\"": 118, "\"sanding floor\"": 284, "\"moving furniture\"": 200, "\"exercising arm\"": 119, "\"sword fighting\"": 345, "\"sign language interpreting\"": 303, "\"counting money\"": 74, "bartending": 15, "\"cleaning windows\"": 65, "\"blasting sand\"": 23, "\"petting cat\"": 213, "sniffing": 320, "bowling": 31, "\"playing poker\"": 242, "\"taking a shower\"": 347, "\"washing hands\"": 382, "\"water sliding\"": 384, "\"presenting weather forecast\"": 254, "tobogganing": 360, "celebrating": 51, "\"getting a haircut\"": 138, "snorkeling": 321, "\"weaving basket\"": 390, "\"playing squash or racquetball\"": 245, "parasailing": 206, "\"news anchoring\"": 202, "\"belly dancing\"": 18, "windsurfing": 393, "\"braiding hair\"": 32, "\"crossing river\"": 78, "\"laying bricks\"": 181, "\"roller skating\"": 280, "hopscotch": 156, "\"playing trumpet\"": 248, "\"dying hair\"": 108, "\"trimming trees\"": 366, "\"pumping fist\"": 256, "\"playing keyboard\"": 236, "snowboarding": 322, "\"garbage collecting\"": 136, "\"playing controller\"": 226, "dodgeball": 94, "\"recording music\"": 266, "\"country line dancing\"": 75, "\"dancing ballet\"": 84, "gargling": 137, "ironing": 165, "\"push up\"": 260, "\"frying vegetables\"": 135, "\"ski jumping\"": 307, "\"mowing lawn\"": 201, "\"getting a tattoo\"": 139, "\"rock scissors paper\"": 279, "cheerleading": 55, "\"using remote controller (not gaming)\"": 374, "\"shaking head\"": 289, "sailing": 282, "\"training dog\"": 363, "hurdling": 160, "\"fixing hair\"": 128, "\"climbing ladder\"": 67, "\"filling eyebrows\"": 126, "\"springboard diving\"": 329, "\"eating watermelon\"": 117, "\"drumming fingers\"": 106, "\"waxing back\"": 386, "\"playing didgeridoo\"": 229, "\"swimming backstroke\"": 339, "\"biking through snow\"": 22, "\"washing feet\"": 380, "\"mopping floor\"": 198, "\"throwing ball\"": 357, "\"eating doughnuts\"": 113, "\"drinking shots\"": 102, "\"tying bow tie\"": 368, "dining": 91, "\"surfing water\"": 337, "\"sweeping floor\"": 338, "\"grooming dog\"": 145, "\"catching fish\"": 47, "\"pumping gas\"": 257, "\"riding or walking with horse\"": 273, "\"massaging person's head\"": 196, "archery": 5, "\"ice climbing\"": 162, "\"playing recorder\"": 243, "\"decorating the christmas tree\"": 89, "\"peeling apples\"": 210, "snowmobiling": 324, "\"playing ukulele\"": 249, "\"eating burger\"": 109, "\"building cabinet\"": 38, "\"stomping grapes\"": 332, "\"drop kicking\"": 105, "\"passing American football (not in game)\"": 209, "applauding": 3, "hugging": 158, "\"eating hotdog\"": 114, "\"pole vault\"": 253, "\"reading newspaper\"": 265, "\"snatch weight lifting\"": 318, "zumba": 399, "\"playing ice hockey\"": 235, "breakdancing": 34, "\"feeding fish\"": 124, "\"shredding paper\"": 300, "\"catching or throwing frisbee\"": 49, "\"exercising with an exercise ball\"": 120, "\"pushing cart\"": 262, "\"swimming butterfly stroke\"": 341, "\"riding scooter\"": 274, "spraying": 328, "\"folding paper\"": 133, "\"golf driving\"": 142, "\"robot dancing\"": 277, "\"bending back\"": 20, "testifying": 354, "\"waxing chest\"": 387, "\"carving pumpkin\"": 46, "\"hitting baseball\"": 153, "\"riding elephant\"": 269, "\"brushing teeth\"": 37, "\"pull ups\"": 255, "\"riding a bike\"": 267, "skateboarding": 306, "\"cleaning pool\"": 62, "\"playing paintball\"": 240, "\"massaging back\"": 193, "\"shoveling snow\"": 299, "\"surfing crowd\"": 336, "unboxing": 371, "faceplanting": 122, "trapezing": 364, "\"swinging legs\"": 343, "hoverboarding": 157, "\"playing violin\"": 250, "\"wrapping present\"": 394, "\"blowing nose\"": 26, "\"kicking field goal\"": 174, "\"picking fruit\"": 214, "\"swinging on something\"": 344, "\"giving or receiving award\"": 140, "\"planting trees\"": 215, "\"water skiing\"": 383, "\"washing dishes\"": 379, "\"punching bag\"": 258, "\"massaging legs\"": 195, "\"throwing axe\"": 356, "\"salsa dancing\"": 283, "bookbinding": 29, "\"tying tie\"": 370, "\"skiing crosscountry\"": 309, "\"shining shoes\"": 295, "\"making snowman\"": 189, "\"front raises\"": 134, "\"doing nails\"": 97, "\"massaging feet\"": 194, "\"playing drums\"": 230, "smoking": 316, "\"punching person (boxing)\"": 259, "cartwheeling": 45, "\"passing American football (in game)\"": 208, "\"shaking hands\"": 288, "plastering": 216, "\"watering plants\"": 385, "kissing": 176, "slapping": 314, "\"playing harmonica\"": 233, "welding": 391, "\"smoking hookah\"": 317, "\"scrambling eggs\"": 285, "\"cooking chicken\"": 70, "\"pushing car\"": 261, "\"opening bottle\"": 203, "\"cooking sausages\"": 73, "\"catching or throwing baseball\"": 48, "\"swimming breast stroke\"": 340, "digging": 90, "\"playing xylophone\"": 252, "\"doing aerobics\"": 95, "\"playing trombone\"": 247, "knitting": 178, "\"waiting in line\"": 377, "\"tossing salad\"": 362, "squat": 330, "vault": 376, "\"using segway\"": 375, "\"crawling baby\"": 77, "\"reading book\"": 264, "motorcycling": 199, "barbequing": 14, "\"cleaning floor\"": 60, "\"playing cello\"": 223, "drawing": 98, "auctioning": 9, "\"carrying baby\"": 44, "\"diving cliff\"": 93, "busking": 41, "\"cutting watermelon\"": 83, "\"scuba diving\"": 286, "\"riding mechanical bull\"": 270, "\"making tea\"": 191, "\"playing tennis\"": 246, "crying": 79, "\"dunking basketball\"": 107, "\"cracking neck\"": 76, "\"arranging flowers\"": 7, "\"building shed\"": 39, "\"golf chipping\"": 141, "\"tasting food\"": 353, "\"shaving head\"": 292, "\"answering questions\"": 2, "\"climbing tree\"": 68, "\"skipping rope\"": 311, "kitesurfing": 177, "\"juggling fire\"": 170, "laughing": 180, "paragliding": 205, "\"contact juggling\"": 69, "slacklining": 313, "\"arm wrestling\"": 6, "\"making a cake\"": 184, "\"finger snapping\"": 127, "\"grooming horse\"": 146, "\"opening present\"": 204, "\"tapping pen\"": 351, "singing": 304, "\"shot put\"": 298, "\"cleaning toilet\"": 64, "\"spinning poi\"": 326, "\"setting table\"": 287, "\"tying knot (not on a tie)\"": 369, "\"blowing glass\"": 24, "\"eating chips\"": 112, "\"tap dancing\"": 349, "\"climbing a rope\"": 66, "\"brush painting\"": 35, "\"chopping wood\"": 56, "\"stretching leg\"": 334, "\"petting animal (not cat)\"": 212, "\"baking cookies\"": 11, "\"stretching arm\"": 333, "beatboxing": 16, "jetskiing": 167, "\"bending metal\"": 21, "sneezing": 319, "\"folding clothes\"": 131, "\"sled dog racing\"": 315, "\"tapping guitar\"": 350, "\"bouncing on trampoline\"": 30, "\"waxing eyebrows\"": 388, "\"air drumming\"": 1, "\"kicking soccer ball\"": 175, "\"washing hair\"": 381, "\"riding mule\"": 272, "\"blowing leaves\"": 25, "\"strumming guitar\"": 335, "\"playing cards\"": 222, "snowkiting": 323, "\"playing bass guitar\"": 221, "\"applying cream\"": 4, "\"shooting basketball\"": 296, "\"walking the dog\"": 378, "\"triple jump\"": 367, "\"shearing sheep\"": 294, "\"clay pottery making\"": 58, "\"bungee jumping\"": 40, "\"unloading truck\"": 372, "\"shuffling cards\"": 301, "\"shooting goal (soccer)\"": 297, "\"tango dancing\"": 348, "\"side kick\"": 302, "\"grinding meat\"": 144, "yoga": 398, "\"hammer throw\"": 148, "\"changing oil\"": 52, "\"checking tires\"": 54, "parkour": 207, "\"eating cake\"": 110, "\"skiing slalom\"": 310, "\"juggling soccer ball\"": 171, "whistling": 392, "\"feeding birds\"": 123, "\"playing volleyball\"": 251, "\"swing dancing\"": 342, "\"skiing (not slalom or crosscountry)\"": 308, "lunge": 183, "\"disc golfing\"": 92, "\"clean and jerk\"": 59, "\"playing guitar\"": 232, "\"baby waking up\"": 10, "\"playing harp\"": 234} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5b806a0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +# towhee +# towhee.models +torch +torchvision +scipy diff --git a/result1.png b/result1.png new file mode 100644 index 0000000..e56d125 Binary files /dev/null and b/result1.png differ diff --git a/result2.png b/result2.png new file mode 100644 index 0000000..98a3040 Binary files /dev/null and b/result2.png differ diff --git a/saved_model/action_clip_vit_b16.pth b/saved_model/action_clip_vit_b16.pth new file mode 100644 index 0000000..720ed87 --- /dev/null +++ b/saved_model/action_clip_vit_b16.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ee06d4810ccfede4c8b586ee494983577027697b3842ff61c3259efa074a6ca +size 75841034