logo
Browse Source

modify

Signed-off-by: xujinling <jinling.xu@zilliz.com>
main
xujinling 2 years ago
parent
commit
8fd7fd6112
  1. 2
      README.md
  2. 74
      get_configs.py
  3. 20
      video_swin_transformer.py

2
README.md

@ -68,8 +68,6 @@ model_name='swin_tiny_patch244_window877_kinetics400_1k', skip_preprocess=False,
- swin_base_patch244_window877_kinetics400_1k - swin_base_patch244_window877_kinetics400_1k
- swin_small_patch244_window877_kinetics400_1k - swin_small_patch244_window877_kinetics400_1k
- swin_base_patch244_window877_kinetics400_22k - swin_base_patch244_window877_kinetics400_22k
- swin_base_patch244_window877_kinetics600_22k
- swin_base_patch244_window1677_sthv2
***skip_preprocess***: *bool* ***skip_preprocess***: *bool*

74
get_configs.py

@ -1,74 +0,0 @@
def configs(model_name):
args = {
'swin_base_patch244_window877_kinetics400_1k':
{'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7), 'drop_path_rate': 0.4, 'patch_norm': True},
'swin_small_patch244_window877_kinetics400_1k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_small_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
'swin_tiny_patch244_window877_kinetics400_1k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_tiny_patch244_window877_kinetics400_1k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.1,
'patch_norm': True},
'swin_base_patch244_window877_kinetics400_22k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics400_22k.pth',
'num_classes': 400,
'labels_file_name': 'kinetics_400.json',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
'swin_base_patch244_window877_kinetics600_22k':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window877_kinetics600_22k.pth',
'num_classes': 600,
'labels_file_name': '',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (8, 7, 7), 'drop_path_rate': 0.4, 'patch_norm': True},
'swin_base_patch244_window1677_sthv2':
{
'pretrained': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window1677_sthv2.pth',
'num_classes': 174,
'labels_file_name': '',
'embed_dim': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_size': (2, 4, 4),
'window_size': (16, 7, 7),
'drop_path_rate': 0.4,
'patch_norm': True},
}
return args[model_name]

20
video_swin_transformer.py

@ -10,7 +10,6 @@ from towhee.operator.base import NNOperator
from towhee.types.video_frame import VideoFrame from towhee.types.video_frame import VideoFrame
from towhee.models.utils.video_transforms import transform_video, get_configs from towhee.models.utils.video_transforms import transform_video, get_configs
from towhee.models.video_swin_transformer import video_swin_transformer from towhee.models.video_swin_transformer import video_swin_transformer
from .get_configs import configs
log = logging.getLogger() log = logging.getLogger()
@ -42,7 +41,6 @@ class VideoSwinTransformer(NNOperator):
self.model_name = model_name self.model_name = model_name
self.skip_preprocess = skip_preprocess self.skip_preprocess = skip_preprocess
self.topk = topk self.topk = topk
self.model_configs = configs(model_name=self.model_name)
if classmap is None: if classmap is None:
class_file = os.path.join(str(Path(__file__).parent), self.model_configs['labels_file_name']) class_file = os.path.join(str(Path(__file__).parent), self.model_configs['labels_file_name'])
with open(class_file, 'r') as f: with open(class_file, 'r') as f:
@ -54,24 +52,16 @@ class VideoSwinTransformer(NNOperator):
self.classmap = classmap self.classmap = classmap
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = video_swin_transformer.VideoSwinTransformer(
pretrained=self.model_configs['pretrained'],
num_classes=self.model_configs['num_classes'],
embed_dim=self.model_configs['embed_dim'],
depths=self.model_configs['depths'],
num_heads=self.model_configs['num_heads'],
patch_size=self.model_configs['patch_size'],
window_size=self.model_configs['window_size'],
drop_path_rate=self.model_configs['drop_path_rate'],
patch_norm=self.model_configs['patch_norm'],
self.model = video_swin_transformer.create_model(model_name=self.model_name,
pretrained=True,
device=self.device) device=self.device)
self.transform_cfgs = get_configs( self.transform_cfgs = get_configs(
side_size=224, side_size=224,
crop_size=224, crop_size=224,
num_frames=4,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
num_frames=32,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
) )
def decoder_video(self, data: List[VideoFrame]): def decoder_video(self, data: List[VideoFrame]):

Loading…
Cancel
Save