diff --git a/mmselfsup/datasets/transforms/formatting.py b/mmselfsup/datasets/transforms/formatting.py index 3ceedf3a7..a7818a4ee 100644 --- a/mmselfsup/datasets/transforms/formatting.py +++ b/mmselfsup/datasets/transforms/formatting.py @@ -68,9 +68,20 @@ def transform(self, if not isinstance(img, List): img = [img] for i, img_ in enumerate(img): - if len(img_.shape) < 3: - img_ = np.expand_dims(img_, -1) - img_ = np.ascontiguousarray(img_.transpose(2, 0, 1)) + # to handle the single channel image + img_ = np.expand_dims(img_, -1) \ + if len(img_.shape) == 2 else img_ + + if len(img_.shape) == 3: + img_ = np.ascontiguousarray(img_.transpose(2, 0, 1)) + elif len(img_.shape) == 5: + # for video data from mmaction with the shape + # (M, C, T, H, W), M = num_crops x num_clips + img_ = img_ + else: + raise ValueError(f'img should be 2, 3 or 4 dimensional, \ + instead of {len(img_.shape)} dimensional.') + img[i] = to_tensor(img_) packed_results['inputs'] = img diff --git a/mmselfsup/models/losses/reconstruction_loss.py b/mmselfsup/models/losses/reconstruction_loss.py index 7b4e045df..19bb413d1 100644 --- a/mmselfsup/models/losses/reconstruction_loss.py +++ b/mmselfsup/models/losses/reconstruction_loss.py @@ -54,14 +54,13 @@ def forward(self, """ loss = self.penalty(pred, target) - # if the dim of the loss is 3, take the average of the loss - # along the last dim - if len(loss.shape) == 3: - loss = loss.mean(dim=-1) - if mask is None: loss = loss.mean() else: + # if the dim of the loss is 3, take the average of the loss + # along the last dim + if len(loss.shape) == 3: + loss = loss.mean(dim=-1) loss = (loss * mask).sum() / mask.sum() / self.channel return loss diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index 1d6b7f4ba..d10a65d46 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -4,13 +4,15 @@ RelativeLocDataPreprocessor, RotationPredDataPreprocessor, SelfSupDataPreprocessor, - TwoNormDataPreprocessor) + TwoNormDataPreprocessor, + VideoMAEDataPreprocessor) from .ema import CosineEMA from .extractor import Extractor from .gather_layer import GatherLayer from .multi_pooling import MultiPooling from .multi_prototypes import MultiPrototypes -from .position_embedding import build_2d_sincos_position_embedding +from .position_embedding import (build_1d_sincos_position_embedding, + build_2d_sincos_position_embedding) from .sobel import Sobel from .transformer_blocks import (CAETransformerRegressorLayer, MultiheadAttention, @@ -26,9 +28,10 @@ __all__ = [ 'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes', 'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention', - 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA', - 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', - 'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm', - 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', - 'PromptTransformerEncoderLayer', 'build_clip_model' + 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder', + 'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', + 'RotationPredDataPreprocessor', 'CAEDataPreprocessor', + 'VideoMAEDataPreprocessor', 'ResLayerExtraNorm', 'NormEMAVectorQuantizer', + 'TwoNormDataPreprocessor', 'PromptTransformerEncoderLayer', + 'build_clip_model', 'build_1d_sincos_position_embedding' ] diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py index 4901cd3df..e86e336d9 100644 --- a/mmselfsup/models/utils/data_preprocessor.py +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -2,7 +2,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -from mmengine.model import ImgDataPreprocessor +from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor from mmselfsup.registry import MODELS @@ -290,3 +290,76 @@ def forward( ] return batch_inputs, batch_data_samples + + +@MODELS.register_module() +class VideoMAEDataPreprocessor(BaseDataPreprocessor): + """""" + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW' or self.format_shape == 'NCTVM': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward(self, data: dict, training: bool = False): + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # ------ To RGB ------ + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :] + for batch_input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :, :] + for batch_input in batch_inputs + ] + else: + raise ValueError(f'Invalid format shape: {self.format_shape}') + + # -- Normalization --- + if self._enable_normalize: + batch_inputs = [(batch_input - self.mean) / self.std + for batch_input in batch_inputs] + else: + batch_inputs = [ + batch_input.to(torch.float32) for batch_input in batch_inputs + ] + + return batch_inputs, batch_data_samples diff --git a/mmselfsup/models/utils/position_embedding.py b/mmselfsup/models/utils/position_embedding.py index e77cc9956..87a91f155 100644 --- a/mmselfsup/models/utils/position_embedding.py +++ b/mmselfsup/models/utils/position_embedding.py @@ -56,3 +56,32 @@ def build_2d_sincos_position_embedding( pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) return pos_emb + + +def build_1d_sincos_position_embedding( + num_patches: int, + embed_dims: int, + temperature: Optional[int] = 10000.) -> torch.Tensor: + """The function is to build 1d position embedding for model to obtain the + position information of the input patches. + + Sinusoid encoding is a kind of relative position encoding method came from + `Attention Is All You Need`_. + + Args: + num_patches (int): The number of the input patches. + embed_dims (int): The dimension of the embedding vector. + temperature (int, optional): The temperature parameter. Defaults to + 10000. + """ + vector = torch.arange(embed_dims, dtype=torch.float64) + vector = (vector - vector % 2) / embed_dims + vector = torch.pow(temperature, -vector).view(1, -1) + + sinusoid_table = torch.arange(num_patches).view(-1, 1) * vector + sinusoid_table[:, 0::2].sin_() # dim 2i + sinusoid_table[:, 1::2].cos_() # dim 2i+1 + + sinusoid_table = sinusoid_table.to(torch.float32) + + return sinusoid_table diff --git a/mmselfsup/utils/__init__.py b/mmselfsup/utils/__init__.py index ab46b5286..d7c333e2c 100644 --- a/mmselfsup/utils/__init__.py +++ b/mmselfsup/utils/__init__.py @@ -7,10 +7,11 @@ from .gather import concat_all_gather from .misc import get_model from .setup_env import register_all_modules +from .typing import * # noqa: F401, F403 __all__ = [ 'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp', 'dist_forward_collect', 'nondist_forward_collect', 'collect_env', - 'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather', - 'register_all_modules', 'get_model' + 'distributed_sinkhorn', 'concat_all_gather', 'register_all_modules', + 'get_model' ] diff --git a/mmselfsup/utils/typing.py b/mmselfsup/utils/typing.py new file mode 100644 index 000000000..15120c7f8 --- /dev/null +++ b/mmselfsup/utils/typing.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmselfsup.""" +from typing import Optional, Union + +from mmengine.config import ConfigDict + +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] diff --git a/projects/videomae/configs/_base_/datasets/k400_videomae.py b/projects/videomae/configs/_base_/datasets/k400_videomae.py new file mode 100644 index 000000000..ee0d50977 --- /dev/null +++ b/projects/videomae/configs/_base_/datasets/k400_videomae.py @@ -0,0 +1,53 @@ +# dataset settings + +dataset_type = 'mmaction.VideoDataset' +data_root = 'data/kinetics400/videos_train' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' + +file_client_args = dict( + io_backend='petrel', + path_mapping=dict( + {'data/kinetics400': 's3://openmmlab/datasets/action/Kinetics400'})) + +# file_client_args = dict(io_backend='disk') +train_pipeline = [ + dict(type='mmaction.DecordInit', **file_client_args), + dict( + type='mmaction.SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=1), + dict(type='mmaction.DecordDecode'), + dict( + type='mmaction.MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1), + dict(type='mmaction.Resize', scale=(224, 224), keep_ratio=False), + dict(type='mmaction.FormatShape', input_format='NCTHW'), + dict( + type='VideoMAEMaskGenerator', + input_size=(16, 224, 224), + patch_size=16, + tubelet_size=2, + mask_ratio=0.9, + mask_mode='tube'), + dict( + type='PackSelfSupInputs', + key='imgs', + algorithm_keys=['mask'], + meta_keys=['img_shape', 'label']) +] + +train_dataloader = dict( + batch_size=32, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) diff --git a/projects/videomae/configs/_base_/models/videomae_vit-small-p16.py b/projects/videomae/configs/_base_/models/videomae_vit-small-p16.py new file mode 100644 index 000000000..0cedcfc2e --- /dev/null +++ b/projects/videomae/configs/_base_/models/videomae_vit-small-p16.py @@ -0,0 +1,40 @@ +# model settings +model = dict( + type='VideoMAE', + data_preprocessor=dict( + type='VideoMAEDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCTHW'), + backbone=dict( + type='VideoMAEViT', + img_size=224, + embed_dims=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + num_frames=16, + norm_cfg=dict(type='LN', eps=1e-6), + patch_size=16, + mask_ratio=0.9), + neck=dict( + type='VideoMAEPretrainDecoder', + img_size=224, + num_frames=16, + num_classes=1536, + num_heads=3, + input_dims=384, + embed_dims=192, + patch_size=16, + depth=4, + ), + head=dict( + type='VideoMAEPretrainHead', + norm_pix=True, + patch_size=16, + loss=dict(type='PixelReconstructionLoss', criterion='L2')), + init_cfg=[ + dict(type='Xavier', distribution='uniform', layer='Linear'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) diff --git a/projects/videomae/configs/_base_/schedules/adamw_coslr-100e_in1k.py b/projects/videomae/configs/_base_/schedules/adamw_coslr-100e_in1k.py new file mode 100644 index 000000000..7ab03a869 --- /dev/null +++ b/projects/videomae/configs/_base_/schedules/adamw_coslr-100e_in1k.py @@ -0,0 +1,19 @@ +# optimizer_wrapper +optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', T_max=160, by_epoch=True, begin=40, end=200) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200) diff --git a/projects/videomae/configs/videomae/README.md b/projects/videomae/configs/videomae/README.md new file mode 100644 index 000000000..e943b9d45 --- /dev/null +++ b/projects/videomae/configs/videomae/README.md @@ -0,0 +1,5 @@ +# VideoMAE + +> [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) + + diff --git a/projects/videomae/configs/videomae/recognition/mmaction2_default_runtime.py b/projects/videomae/configs/videomae/recognition/mmaction2_default_runtime.py new file mode 100644 index 000000000..5aac5a67e --- /dev/null +++ b/projects/videomae/configs/videomae/recognition/mmaction2_default_runtime.py @@ -0,0 +1,26 @@ +# default_scope = 'mmaction' + +default_hooks = dict( + runtime_info=dict(type='mmaction.RuntimeInfoHook'), + timer=dict(type='mmaction.IterTimerHook'), + logger=dict(type='mmaction.LoggerHook', interval=20, ignore_last=False), + param_scheduler=dict(type='mmaction.ParamSchedulerHook'), + checkpoint=dict( + type='mmaction.CheckpointHook', interval=1, save_best='auto'), + sampler_seed=dict(type='mmaction.DistSamplerSeedHook'), + sync_buffers=dict(type='mmaction.SyncBuffersHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) + +log_processor = dict( + type='mmaction.LogProcessor', window_size=20, by_epoch=True) + +vis_backends = [dict(type='mmaction.LocalVisBackend')] +visualizer = dict(type='mmaction.ActionVisualizer', vis_backends=vis_backends) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-ft_16x4x1_kinetics-400.py b/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-ft_16x4x1_kinetics-400.py new file mode 100644 index 000000000..08501abae --- /dev/null +++ b/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-ft_16x4x1_kinetics-400.py @@ -0,0 +1,149 @@ +_base_ = ['./vit-base-p16_videomae-k400-pre_16x4x1_kinetics-400.py'] + +custom_imports = dict( + imports=['mmaction'], + # imports=['mmaction.datasets.transforms'], + allow_failed_imports=False) + +# model settings +model = dict( + type='mmaction.Recognizer3D', + backbone=dict( + type='mmaction.VisionTransformer', + img_size=224, + patch_size=16, + embed_dims=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + num_frames=16, + norm_cfg=dict(type='LN', eps=1e-6)), + cls_head=dict( + type='mmaction.TimeSformerHead', + num_classes=400, + in_channels=384, + average_clips='prob', + loss_cls=dict(type='mmaction.CrossEntropyLoss')), + data_preprocessor=dict( + type='mmaction.ActionDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCTHW')) + +# dataset settings +dataset_type = 'mmaction.VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict( + io_backend='petrel', + path_mapping=dict( + {'data/kinetics400': 's3://openmmlab/datasets/action/Kinetics400'})) + +train_pipeline = [ + dict(type='mmaction.DecordInit', **file_client_args), + dict( + type='mmaction.SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=1), + dict(type='mmaction.DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='RandomResizedCrop'), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackActionInputs') +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='mmaction.DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_pipeline = [ + dict(type='mmaction.DecordInit', **file_client_args), + dict( + type='mmaction.SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=5, + test_mode=True), + dict(type='mmaction.DecordDecode'), + dict(type='mmaction.Resize', scale=(-1, 224)), + dict(type='mmaction.ThreeCrop', crop_size=224), + dict(type='mmaction.FormatShape', input_format='NCTHW'), + dict(type='mmaction.PackActionInputs') +] + +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='mmaction.DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='mmaction.DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +val_evaluator = dict(type='mmaction.AccMetric') +test_evaluator = dict(type='mmaction.AccMetric') + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=30, val_begin=1, val_interval=3) +val_cfg = dict(type='mmaction.ValLoop') +test_cfg = dict(type='mmaction.TestLoop') + +optim_wrapper = dict( + optimizer=dict( + type='SGD', lr=0.005, momentum=0.9, weight_decay=1e-4, nesterov=True), + paramwise_cfg=dict( + custom_keys={ + '.backbone.cls_token': dict(decay_mult=0.0), + '.backbone.pos_embed': dict(decay_mult=0.0), + '.backbone.time_embed': dict(decay_mult=0.0) + }), + clip_grad=dict(max_norm=40, norm_type=2)) + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=15, + by_epoch=True, + milestones=[5, 10], + gamma=0.1) +] + +default_hooks = dict( + checkpoint=dict(interval=3, max_keep_ckpts=5), logger=dict(interval=100)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (8 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=64) diff --git a/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-pre_16x4x1_kinetics-400.py b/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-pre_16x4x1_kinetics-400.py new file mode 100644 index 000000000..73414697f --- /dev/null +++ b/projects/videomae/configs/videomae/recognition/vit-base-p16_videomae-k400-pre_16x4x1_kinetics-400.py @@ -0,0 +1,75 @@ +_base_ = ['./mmaction2_default_runtime.py'] + +# _base_ = ['mmaction::_base_/default_runtime.py'] + +# custom_imports = dict( +# imports=['mmaction.datasets.transforms'], +# allow_failed_imports=False) + +custom_imports = dict(imports=['mmaction'], allow_failed_imports=False) + +# model settings +model = dict( + type='mmaction.Recognizer3D', + backbone=dict( + type='mmaction.VisionTransformer', + img_size=224, + patch_size=16, + embed_dims=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + num_frames=16, + norm_cfg=dict(type='LN', eps=1e-6)), + cls_head=dict( + type='mmaction.TimeSformerHead', + num_classes=400, + in_channels=768, + average_clips='prob', + loss_cls=dict(type='mmaction.CrossEntropyLoss')), + data_preprocessor=dict( + type='mmaction.ActionDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCTHW')) + +# dataset settings +dataset_type = 'mmaction.VideoDataset' +data_root_val = 'data/kinetics400/videos_val' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict( + io_backend='petrel', + path_mapping=dict( + {'data/kinetics400': 's3://openmmlab/datasets/action/Kinetics400'})) + +test_pipeline = [ + dict(type='mmaction.DecordInit', **file_client_args), + dict( + type='mmaction.SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=5, + test_mode=True), + dict(type='mmaction.DecordDecode'), + dict(type='mmaction.Resize', scale=(-1, 224)), + dict(type='mmaction.ThreeCrop', crop_size=224), + dict(type='mmaction.FormatShape', input_format='NCTHW'), + dict(type='mmaction.PackActionInputs') +] + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='mmaction.DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +test_evaluator = dict(type='mmaction.AccMetric') +test_cfg = dict(type='mmaction.TestLoop') diff --git a/projects/videomae/configs/videomae/videomae_vit-small-p16_16xb64-amp-coslr-800e_k400.py b/projects/videomae/configs/videomae/videomae_vit-small-p16_16xb64-amp-coslr-800e_k400.py new file mode 100644 index 000000000..ef8d05e1a --- /dev/null +++ b/projects/videomae/configs/videomae/videomae_vit-small-p16_16xb64-amp-coslr-800e_k400.py @@ -0,0 +1,56 @@ +_base_ = [ + '../_base_/models/videomae_vit-small-p16.py', + '../_base_/datasets/k400_videomae.py', + '../_base_/schedules/adamw_coslr-100e_in1k.py', + 'mmselfsup::selfsup/_base_/default_runtime.py', +] + +custom_imports = dict( + imports=['models', 'datasets', 'mmaction.datasets.transforms'], + allow_failed_imports=False) + +# dataset 2 * 8 * 64 = 1024 +train_dataloader = dict(batch_size=64, num_workers=8) +# optimizer wrapper +optimizer = dict( + type='AdamW', lr=1.5e-4 * 1024 / 256, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=optimizer, + paramwise_cfg=dict( + custom_keys={ + # 'ln': dict(decay_mult=0.0), + # 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'decoder_pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True) +] + +train_cfg = dict(max_epochs=800) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) +resume = True diff --git a/projects/videomae/configs/videomae/videomae_vit-small-p16_8xb64-amp-coslr-200e_k400.py b/projects/videomae/configs/videomae/videomae_vit-small-p16_8xb64-amp-coslr-200e_k400.py new file mode 100644 index 000000000..5610e744e --- /dev/null +++ b/projects/videomae/configs/videomae/videomae_vit-small-p16_8xb64-amp-coslr-200e_k400.py @@ -0,0 +1,56 @@ +_base_ = [ + '../_base_/models/videomae_vit-small-p16.py', + '../_base_/datasets/k400_videomae.py', + '../_base_/schedules/adamw_coslr-100e_in1k.py', + 'mmselfsup::selfsup/_base_/default_runtime.py', +] + +custom_imports = dict( + imports=['models', 'datasets', 'mmaction.datasets.transforms'], + allow_failed_imports=False) + +# dataset 8 x 128 = 1024 +train_dataloader = dict(batch_size=128, num_workers=8) +# optimizer wrapper +optimizer = dict( + type='AdamW', lr=1.5e-4 * 4096 / 256, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=optimizer, + paramwise_cfg=dict( + custom_keys={ + # 'ln': dict(decay_mult=0.0), + # 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'decoder_pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=160, + by_epoch=True, + begin=40, + end=200, + convert_to_iter_based=True) +] + +train_cfg = dict(max_epochs=200) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) +resume = True diff --git a/projects/videomae/datasets/__init__.py b/projects/videomae/datasets/__init__.py new file mode 100644 index 000000000..bab072ae0 --- /dev/null +++ b/projects/videomae/datasets/__init__.py @@ -0,0 +1 @@ +from .transforms import * # noqa: F401,F403 diff --git a/projects/videomae/datasets/transforms/__init__.py b/projects/videomae/datasets/transforms/__init__.py new file mode 100644 index 000000000..32c4b261f --- /dev/null +++ b/projects/videomae/datasets/transforms/__init__.py @@ -0,0 +1,3 @@ +from .processing import VideoMAEMaskGenerator + +__all__ = ['VideoMAEMaskGenerator'] diff --git a/projects/videomae/datasets/transforms/processing.py b/projects/videomae/datasets/transforms/processing.py new file mode 100644 index 000000000..135d3af6c --- /dev/null +++ b/projects/videomae/datasets/transforms/processing.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmcv.transforms import BaseTransform + +from mmselfsup.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class VideoMAEMaskGenerator(BaseTransform): + """Generate mask for VideoMAE.""" + + def __init__(self, + input_size: tuple, + mask_ratio: float = 0.75, + patch_size: int = 16, + tubelet_size: int = 2, + mask_mode: str = 'tube') -> None: + self.input_size = input_size + self.mask_ratio = mask_ratio + self.mask_mode = mask_mode + + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + num_frames, height, width = input_size + + self.num_tubelets = num_frames // tubelet_size + + self.num_patches_frame = (height // patch_size) * (width // patch_size) + self.num_patches_video = self.num_tubelets * self.num_patches_frame + + self.num_masks_frame = int(mask_ratio * self.num_patches_frame) + self.num_masks_video = self.num_tubelets * self.num_patches_frame + + def transform(self, results: dict) -> dict: + if self.mask_mode == 'random': + # TODO: add random mask + pass + elif self.mask_mode == 'tube': + mask_frame = np.hstack([ + np.zeros(self.num_patches_frame - self.num_masks_frame), + np.ones(self.num_masks_frame) + ]) + + np.random.shuffle(mask_frame) + mask_video = np.tile(mask_frame, (self.num_tubelets, 1)).flatten() + + else: + raise NotImplementedError + + results.update({'mask': mask_video}) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(input_size={self.input_size}, ' + repr_str += f'mask_ratio={self.mask_ratio}, ' + repr_str += f'mask_mode={self.mask_mode}, ' + repr_str += f'patch_size={self.patch_size}, ' + repr_str += f'tubelet_size={self.tubelet_size},)' + return repr_str diff --git a/projects/videomae/models/__init__.py b/projects/videomae/models/__init__.py new file mode 100644 index 000000000..27ab31fb0 --- /dev/null +++ b/projects/videomae/models/__init__.py @@ -0,0 +1,4 @@ +from .algorithms import * # noqa: F401,F403 +from .backbones import * # noqa: F401,F403 +from .head import * # noqa: F401,F403 +from .neck import * # noqa: F401,F403 diff --git a/projects/videomae/models/algorithms/__init__.py b/projects/videomae/models/algorithms/__init__.py new file mode 100644 index 000000000..df573f273 --- /dev/null +++ b/projects/videomae/models/algorithms/__init__.py @@ -0,0 +1,3 @@ +from .videomae import VideoMAE + +__all__ = ['VideoMAE'] diff --git a/projects/videomae/models/algorithms/videomae.py b/projects/videomae/models/algorithms/videomae.py new file mode 100644 index 000000000..593c85c93 --- /dev/null +++ b/projects/videomae/models/algorithms/videomae.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch + +from mmselfsup.models.algorithms import BaseModel +from mmselfsup.registry import MODELS +from mmselfsup.structures import SelfSupDataSample + + +@MODELS.register_module() +class VideoMAE(BaseModel): + """VideoMAE algorithm. + + Implementation of `VideoMAE: Masked Autoencoders are Data-Efficient + Learners for Self-Supervised Video Pre-Training + `_ + """ + + def extract_feat(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwarg) -> Tuple[torch.Tensor]: + """""" + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + video_latent = self.backbone(inputs[0], mask) + feat = self.neck(video_latent[0]) + return feat + + def loss(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training.""" + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + # normalized with RGB mean and std + video = inputs[0].squeeze(1) + + # change the mask from the float to bool type + mask = mask.to(torch.bool) + # encoder part + video = self.backbone(video, mask) + # decoder part + video_rec = self.neck(video, mask) + # criterion part + # recover the unnormlized video to [0, 1] + target = inputs[0].squeeze( + 1) * self.data_preprocessor.std + self.data_preprocessor.mean + target = target / 255.0 + loss = self.head(video_rec, target, mask) + losses = dict(loss=loss) + + return losses diff --git a/projects/videomae/models/backbones/__init__.py b/projects/videomae/models/backbones/__init__.py new file mode 100644 index 000000000..783032671 --- /dev/null +++ b/projects/videomae/models/backbones/__init__.py @@ -0,0 +1,3 @@ +from .videomae_vit import VideoMAEViT + +__all__ = ['VideoMAEViT'] diff --git a/projects/videomae/models/backbones/videomae_vit.py b/projects/videomae/models/backbones/videomae_vit.py new file mode 100644 index 000000000..19b19fcb1 --- /dev/null +++ b/projects/videomae/models/backbones/videomae_vit.py @@ -0,0 +1,363 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmengine.model import BaseModule, ModuleList +from mmengine.utils import to_2tuple +from torch import Tensor, nn + +from mmselfsup.models.utils import build_1d_sincos_position_embedding +from mmselfsup.registry import MODELS +from mmselfsup.utils import ConfigType, OptConfigType + + +class Attention(BaseModule): + """Multi-head Self-attention. + + Args: + embed_dims (int): Dimensions of embedding. + num_heads (int): Number of parallel attention heads. + qkv_bias (bool): If True, add a learnable bias to q and v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0. + drop_rate (float): Dropout ratio of output. Defaults to 0. + init_cfg (dict or ConfigDict, optional): The Config + for initialization. Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop_rate: float = 0., + drop_rate: float = 0., + init_cfg: OptConfigType = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + + self.scale = qk_scale or head_embed_dims**-0.5 + + if qkv_bias: + self._init_qv_bias() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(drop_rate) + + def _init_qv_bias(self) -> None: + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def forward(self, x: Tensor) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (Tensor): The input data with size of (B, N, C). + Returns: + Tensor: The output of the attention block, same size as inputs. + """ + B, N, C = x.shape + + if hasattr(self, 'q_bias'): + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class VideoMAEBlock(BaseModule): + """The basic block in the Vision Transformer. + + Args: + embed_dims (int): Dimensions of embedding. + num_heads (int): Number of parallel attention heads. + mlp_ratio (int): The ratio between the hidden layer and the + input layer in the FFN. Defaults to 4. + qkv_bias (bool): If True, add a learnable bias to q and v. + Defaults to True. + qk_scale (float): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + drop_rate (float): Dropout ratio of output. Defaults to 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0. + drop_path_rate (float): Dropout ratio of the residual branch. + Defaults to 0. + init_values (float): Value to init the multiplier of the + residual branch. Defaults to 0. + act_cfg (dict or ConfigDict): Config for activation layer in FFN. + Defaults to `dict(type='GELU')`. + norm_cfg (dict or ConfigDict): Config for norm layers. + Defaults to `dict(type='LN', eps=1e-6)`. + init_cfg (dict or ConfigDict, optional): The Config + for initialization. Defaults to None. + """ + + def __init__(self, + embed_dims: int, + num_heads: int, + mlp_ratio: int = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: float = 0.0, + act_cfg: ConfigType = dict(type='GELU'), + norm_cfg: ConfigType = dict(type='LN', eps=1e-6), + init_cfg: OptConfigType = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = Attention( + embed_dims, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate) + + self.drop_path = nn.Identity() + if drop_path_rate > 0.: + self.drop_path = DropPath(drop_path_rate) + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + mlp_hidden_dim = int(embed_dims * mlp_ratio) + self.mlp = FFN( + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate, + add_identity=False) + + self._init_gammas(init_values, embed_dims) + + def _init_gammas(self, init_values: float, dim: int) -> None: + if type(init_values) == float and init_values > 0: + self.gamma_1 = nn.Parameter( + init_values * torch.ones(dim), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones(dim), requires_grad=True) + + def forward(self, x: Tensor) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (Tensor): The input data with size of (B, N, C). + Returns: + Tensor: The output of the transformer block, same size as inputs. + """ + if hasattr(self, 'gamma_1'): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +@MODELS.register_module() +class VideoMAEViT(BaseModule): + """Vision Transformer with support for patch or hybrid CNN input stage. An + impl of `VideoMAE: Masked Autoencoders are Data-Efficient Learners for + Self-Supervised Video Pre-Training `_ + + Args: + img_size (int or tuple): Size of input image. + Defaults to 224. + patch_size (int): Spatial size of one patch. Defaults to 16. + in_channels (int): The number of channels of he input. + Defaults to 3. + embed_dims (int): Dimensions of embedding. Defaults to 768. + depth (int): number of blocks in the transformer. + Defaults to 12. + num_heads (int): Number of parallel attention heads in + TransformerCoder. Defaults to 12. + mlp_ratio (int): The ratio between the hidden layer and the + input layer in the FFN. Defaults to 4. + qkv_bias (bool): If True, add a learnable bias to q and v. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + drop_rate (float): Dropout ratio of output. Defaults to 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Defaults to 0. + drop_path_rate (float): Dropout ratio of the residual branch. + Defaults to 0. + norm_cfg (dict or Configdict): Config for norm layers. + Defaults to `dict(type='LN', eps=1e-6)`. + init_values (float): Value to init the multiplier of the residual + branch. Defaults to 0. + use_learnable_pos_emb (bool): If True, use learnable positional + embedding, othersize use sinusoid encoding. Defaults to False. + num_frames (int): Number of frames in the video. Defaults to 16. + tubelet_size (int): Temporal size of one patch. Defaults to 2. + use_mean_pooling (bool): If True, take the mean pooling over all + positions. Defaults to True. + pretrained (str, optional): Name of pretrained model. Default: None. + mask_rato (float): The ratio of masked tokens. Defaults to 0.75. + mask_type (str): The type of masked tokens. + Defaults to 'random'. choices=['random', 'tube'] + init_cfg (dict or list[dict]): Initialization config dict. Defaults to + ``[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ]``. + """ + + def __init__(self, + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dims: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: int = 4., + qkv_bias: bool = True, + qk_scale: int = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: ConfigType = dict(type='LN', eps=1e-6), + init_values: int = 0., + use_learnable_pos_emb: bool = False, + num_frames: int = 16, + tubelet_size: int = 2, + pretrained: Optional[str] = None, + mask_ratio: float = 0.75, + mask_type: str = 'random', + init_cfg: Optional[Union[Dict, List[Dict]]] = [ + dict( + type='TruncNormal', layer='Linear', std=0.02, + bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + **kwargs) -> None: + + if pretrained: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + super().__init__(init_cfg=init_cfg) + + patch_size = to_2tuple(patch_size) + img_size = to_2tuple(img_size) + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv3d', + kernel_size=(tubelet_size, ) + patch_size, + stride=(tubelet_size, ) + patch_size, + padding=(0, 0, 0), + dilation=(1, 1, 1)) + + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) * \ + (num_frames // tubelet_size) + + if use_learnable_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dims)) + nn.init.trunc_normal_(self.pos_embed, std=.02) + else: + # sine-cosine positional embeddings is on the way + pos_embed = build_1d_sincos_position_embedding( + num_patches, embed_dims) + self.register_buffer('pos_embed', pos_embed) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = ModuleList([ + VideoMAEBlock( + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + init_values=init_values) for i in range(depth) + ]) + + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + self.mask_ratio = mask_ratio + self.mask_type = mask_type + + def video_masking(self, x: torch.Tensor, mask_ratio: float, + mask_type: str) -> torch.Tensor: + """Mask the video feature. + + Args: + x (torch.Tensor): The video feature. + mask_ratio (float): The ratio of masked tokens. + mask_type (str): The type of masked tokens. + choices=['random', 'tube'] + Returns: + torch.Tensor: The masked video feature. + """ + assert mask_type in ['random', 'tube'], \ + f"mask_type must be one of ['random', 'tube'], but got {mask_type}" + if mask_type == 'random': + return self.random_masking(x, mask_ratio) + else: + return self.tube_masking(x, mask_ratio) + + def forward(self, x: Tensor, mask: Tensor) -> Tensor: + """Defines the computation performed at every call. + + Args: + x (Tensor): The input data. + Returns: + Tensor: The feature of the input + samples extracted by the backbone. + """ + import pdb + pdb.set_trace() + x = self.patch_embed(x)[0] + + x = x + self.pos_embed + x = self.pos_drop(x) + + B, _, C = x.shape + assert mask is not None + + # use ~mask to indicate the visible tokens + x = x[~mask].reshape(B, -1, C) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x diff --git a/projects/videomae/models/head/__init__.py b/projects/videomae/models/head/__init__.py new file mode 100644 index 000000000..1df56d936 --- /dev/null +++ b/projects/videomae/models/head/__init__.py @@ -0,0 +1,3 @@ +from .videomae_head import VideoMAEPretrainHead + +__all__ = ['VideoMAEPretrainHead'] diff --git a/projects/videomae/models/head/videomae_head.py b/projects/videomae/models/head/videomae_head.py new file mode 100644 index 000000000..525c05d9b --- /dev/null +++ b/projects/videomae/models/head/videomae_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class VideoMAEPretrainHead(BaseModule): + + def __init__(self, + loss: dict, + norm_pix: bool = False, + tubelet_size: int = 2, + patch_size: int = 16) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.tubelet_size = tubelet_size + self.loss = MODELS.build(loss) + + def patchify(self, video: torch.Tensor) -> torch.Tensor: + """ + + Args: + video (torch.Tensor): A batch of videos, of shape + B x C x T x H x W. C is the channel, T is the temporal length + + Returns: + torch.Tensor: Patchified videos. The shape is B x T x L x D. + """ + p = self.patch_size + ts = self.tubelet_size + B, C, T, H, W = video.shape + assert H == W and H % p == 0 and C == 3 + # number of patches in height and width + h = w = H // p + # number of tubelet in temporal dimension + t = T // ts + + # video shape (B, 3, T, H, W) + x = video.reshape(shape=(B, 3, t, ts, h, p, w, p)) + # 'b c ts hp wq->b t hw spq c' + x = torch.einsum('bctshpwq->bthwspqc', x) + # (B, num_token, num_pixel_per_token, 3) + x = x.reshape(shape=(B, t * h * w, ts * p * p, 3)) + return x + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + target = self.patchify(target) + if self.norm_pix: + # normalize the target video, different from the mae + mean = target.mean(dim=-2, keepdim=True) + var = target.var(dim=-2, unbiased=True, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + B, T, L, C = target.shape + target = target.view(B, T, L * C) + return target + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + target = self.construct_target(target) + B, _, C = target.shape + target = target[mask].reshape(B, -1, C) + loss = self.loss(pred, target) + + return loss diff --git a/projects/videomae/models/neck/__init__.py b/projects/videomae/models/neck/__init__.py new file mode 100644 index 000000000..6079ffcab --- /dev/null +++ b/projects/videomae/models/neck/__init__.py @@ -0,0 +1,3 @@ +from .videomae_neck import VideoMAEPretrainDecoder + +__all__ = ['VideoMAEPretrainDecoder'] diff --git a/projects/videomae/models/neck/videomae_neck.py b/projects/videomae/models/neck/videomae_neck.py new file mode 100644 index 000000000..cff6d4bb8 --- /dev/null +++ b/projects/videomae/models/neck/videomae_neck.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import to_2tuple + +from mmselfsup.models.utils import build_1d_sincos_position_embedding +from mmselfsup.registry import MODELS +from ..backbones.videomae_vit import VideoMAEBlock + + +@MODELS.register_module() +class VideoMAEPretrainDecoder(BaseModule): + """Decoder for VideoMAE Pre-training. + + Some of the code is borrowed from ``. # noqa + """ + + def __init__(self, + num_patches: int = 196, + patch_size: int = 16, + img_size: int = 224, + num_classes: int = 768, + input_dims: int = 768, + embed_dims: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: int = 4, + qkv_bias: bool = True, + qkv_scale: Optional[float] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_value: Optional[float] = None, + tubelet_size: int = 2, + num_frames: int = 16, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + + patch_sizes = to_2tuple(patch_size) + img_sizes = to_2tuple(img_size) + + num_patches = (img_sizes[1] // patch_sizes[1]) * \ + (img_sizes[0] // patch_sizes[0]) * \ + (num_frames // tubelet_size) + + # used to convert the dim of features from encoder to the dim + # compatible with that of decoder + self.decoder_embed_layer = nn.Linear( + input_dims, embed_dims, bias=False) + + decoder_pos_embed = build_1d_sincos_position_embedding( + num_patches, embed_dims) + self.register_buffer('decoder_pos_embed', decoder_pos_embed) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_classes = num_classes + assert num_classes == 3 * tubelet_size * patch_size**2 + self.embed_dims = embed_dims + self.patch_size = patch_size + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList([ + VideoMAEBlock( + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qkv_scale=qkv_scale, + drop=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + init_values=init_value) for i in range(depth) + ]) + + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + self.head = nn.Linear(embed_dims, num_classes) \ + if num_classes > 0 else nn.Identity() + + def init_weights(self): + super().init_weights() + trunc_normal_(self.mask_token, std=.02) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Forward function.""" + # linear transformation to adapt the channel dimension + x = self.decoder_embed_layer(x) + B, _, C = x.shape + + # expand the position embedding to the size of B + expand_pos_embed = self.decoder_pos_embed.expand(x.shape[0], -1, -1) + # position embedding of visible token + visible_pe = expand_pos_embed[~mask].reshape(B, -1, C) + # position embedding of masked token + masked_pe = expand_pos_embed[mask].reshape(B, -1, C) + return_token_num = masked_pe.shape[1] + + x = torch.cat([x + visible_pe, self.mask_token + masked_pe], dim=1) + + for blk in self.blocks: + x = blk(x) + + # only conduct the pixel prediction on the masked token + x = x[:, -return_token_num:] if return_token_num > 0 else x + + x = self.head(self.norm(x)) + + return x diff --git a/projects/videomae/tools/ckpt_tree.py b/projects/videomae/tools/ckpt_tree.py new file mode 100644 index 000000000..b78500bd4 --- /dev/null +++ b/projects/videomae/tools/ckpt_tree.py @@ -0,0 +1,185 @@ +import argparse +import math +from pathlib import Path + +import torch +from rich.console import Console + +console = Console() + +prog_description = """\ +Draw the state dict tree. +""" + + +def parse_args(): + parser = argparse.ArgumentParser(description=prog_description) + parser.add_argument( + 'path', + type=Path, + help='The path of the checkpoint or model config to draw.') + parser.add_argument('--depth', type=int, help='The max depth to draw.') + parser.add_argument( + '--full-name', + action='store_true', + help='Whether to print the full name of the key.') + parser.add_argument( + '--shape', + action='store_true', + help='Whether to print the shape of the parameter.') + parser.add_argument( + '--state-key', + type=str, + help='The key of the state dict in the checkpoint.') + parser.add_argument( + '--number', + action='store_true', + help='Mark all parameters and their index number.') + parser.add_argument( + '--node', + type=str, + help='Show the sub-tree of a node, like "backbone.layers".') + args = parser.parse_args() + return args + + +def ckpt_to_state_dict(checkpoint, key=None): + if key is not None: + state_dict = checkpoint[key] + elif 'state_dict' in checkpoint: + # try mmcls style + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + elif isinstance(next(iter(checkpoint.values())), torch.Tensor): + # try native style + state_dict = checkpoint + else: + raise KeyError('Please specify the key of state ' + f'dict from {list(checkpoint.keys())}.') + return state_dict + + +class StateDictTree: + + def __init__(self, key='', value=None): + self.children = {} + self.key: str = key + self.value = value + + def add_parameter(self, key, value): + keys = key.split('.', 1) + if len(keys) == 1: + self.children[key] = StateDictTree(key, value) + elif keys[0] in self.children: + self.children[keys[0]].add_parameter(keys[1], value) + else: + node = StateDictTree(keys[0]) + node.add_parameter(keys[1], value) + self.children[keys[0]] = node + + def __getitem__(self, key: str): + return self.children[key] + + def __repr__(self) -> str: + with console.capture() as capture: + for line in self.iter_tree(): + console.print(line) + return capture.get() + + def __len__(self): + return len(self.children) + + def draw_tree(self, + max_depth=None, + full_name=False, + with_shape=False, + with_value=False): + for line in self.iter_tree( + max_depth=max_depth, + full_name=full_name, + with_shape=with_shape, + with_value=with_value, + ): + console.print(line, highlight=False) + + def iter_tree( + self, + lead='', + prefix='', + max_depth=None, + full_name=False, + with_shape=False, + with_value=False, + ): + if self.value is None: + key_str = f'[blue]{self.key}[/]' + elif with_shape: + key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}' + elif with_value: + key_str = f'[green]{self.key}[/] {self.value}' + else: + key_str = f'[green]{self.key}[/]' + + yield lead + prefix + key_str + + lead = lead.replace('├─', '│ ') + lead = lead.replace('└─', ' ') + if self.key and full_name: + prefix = f'{prefix}{self.key}.' + + if max_depth == 0: + return + elif max_depth is not None: + max_depth -= 1 + + for i, child in enumerate(self.children.values()): + level_lead = '├─' if i < len(self.children) - 1 else '└─' + yield from child.iter_tree( + lead=f'{lead}{level_lead} ', + prefix=prefix, + max_depth=max_depth, + full_name=full_name, + with_shape=with_shape, + with_value=with_value) + + +def main(): + args = parse_args() + if args.path.suffix in ['.json', '.py', '.yml']: + from mmcls.apis import init_model + from mmengine.runner import get_state_dict + model = init_model(args.path, device='cpu') + state_dict = get_state_dict(model) + else: + ckpt = torch.load(args.path, map_location='cpu') + state_dict = ckpt_to_state_dict(ckpt, args.state_key) + + root = StateDictTree() + for k, v in state_dict.items(): + root.add_parameter(k, v) + + para_index = 0 + mark_width = math.floor(math.log(len(state_dict), 10) + 1) + if args.node is not None: + for key in args.node.split('.'): + root = root[key] + + for line in root.iter_tree( + max_depth=args.depth, + full_name=args.full_name, + with_shape=args.shape, + ): + if not args.number: + mark = '' + # A hack method to determine whether a line is parameter. + elif '[green]' in line: + mark = f'[red]({str(para_index).ljust(mark_width)})[/]' + para_index += 1 + else: + mark = ' ' * (mark_width + 2) + console.print(mark + line, highlight=False) + + +if __name__ == '__main__': + main() diff --git a/projects/videomae/tools/videomae_to_mmselfsup.py b/projects/videomae/tools/videomae_to_mmselfsup.py new file mode 100644 index 000000000..aa5c9bd28 --- /dev/null +++ b/projects/videomae/tools/videomae_to_mmselfsup.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_videomae(ckpt): + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('encoder.'): + new_k = k.replace('encoder.', 'backbone.') + elif k.startswith('decoder'): + new_k = k.replace('decoder.', 'neck.') + elif k.startswith('encoder_to_decoder'): + new_k = k.replace('encoder_to_decoder.', + 'neck.decoder_embed_layer.') + elif k.startswith('mask_token'): + new_k = 'neck.' + k + + # second round + if 'patch_embed.proj.' in new_k: + new_k = new_k.replace('patch_embed.proj.', + 'patch_embed.projection.') + + if 'mlp.fc1' in new_k: + new_k = new_k.replace('mlp.fc1', 'mlp.layers.0.0') + + if 'mlp.fc2' in new_k: + new_k = new_k.replace('mlp.fc2', 'mlp.layers.1') + + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained clip models to mmcls style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + weight = convert_videomae(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main() diff --git a/tools/benchmarks/mmaction2/dist_test.sh b/tools/benchmarks/mmaction2/dist_test.sh new file mode 100644 index 000000000..936652602 --- /dev/null +++ b/tools/benchmarks/mmaction2/dist_test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -x + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +# Arguments starting from the forth one are captured by ${@:4} +python -m torch.distributed.launch --nnodes=$NNODES --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/benchmarks/mmaction2/mim_dist_test.sh b/tools/benchmarks/mmaction2/mim_dist_test.sh new file mode 100644 index 000000000..870eded99 --- /dev/null +++ b/tools/benchmarks/mmaction2/mim_dist_test.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -x + +CFG=$1 +CHECKPOINT=$2 +GPUS=${GPUS:-8} +PY_ARGS=${@:3} + + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +mim test mmaction \ + $CFG \ + --checkpoint $CHECKPOINT \ + --launcher pytorch \ + -G $GPUS \ + --cfg-options $PY_ARGS diff --git a/tools/benchmarks/mmaction2/mim_dist_train.sh b/tools/benchmarks/mmaction2/mim_dist_train.sh new file mode 100644 index 000000000..a91c8698d --- /dev/null +++ b/tools/benchmarks/mmaction2/mim_dist_train.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -x + +CFG=$1 +PRETRAIN=$2 # pretrained model +GPUS=$3 +PY_ARGS=${@:4} + +WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)" + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ + +mim train mmaction $CFG \ + --launcher pytorch -G $GPUS \ + --work-dir $WORK_DIR \ + --cfg-options model.backbone.init_cfg.type=Pretrained \ + model.backbone.init_cfg.checkpoint=$PRETRAIN \ + model.backbone.init_cfg.prefix="backbone." \ + $PY_ARGS diff --git a/tools/benchmarks/mmaction2/test.py b/tools/benchmarks/mmaction2/test.py new file mode 100644 index 000000000..a6428a929 --- /dev/null +++ b/tools/benchmarks/mmaction2/test.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmaction.utils import register_all_modules +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +print('mmaction2 test ') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMAction2 test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--dump', + type=str, + help='dump predictions to a pickle file for offline evaluation') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--show-dir', + help='directory where the visualization images will be saved.') + parser.add_argument( + '--show', + action='store_true', + help='whether to display the prediction results in a window.') + parser.add_argument( + '--interval', + type=int, + default=1, + help='visualize per interval samples.') + parser.add_argument( + '--wait-time', + type=float, + default=2, + help='display time of every window. (second)') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def merge_args(cfg, args): + """Merge CLI arguments to config.""" + # -------------------- visualization -------------------- + if args.show or (args.show_dir is not None): + assert 'visualization' in cfg.default_hooks, \ + 'VisualizationHook is not set in the `default_hooks` field of ' \ + 'config. Please set `visualization=dict(type="VisualizationHook")`' + + cfg.default_hooks.visualization.enable = True + cfg.default_hooks.visualization.show = args.show + cfg.default_hooks.visualization.wait_time = args.wait_time + cfg.default_hooks.visualization.out_dir = args.show_dir + cfg.default_hooks.visualization.interval = args.interval + + # -------------------- Dump predictions -------------------- + if args.dump is not None: + assert args.dump.endswith(('.pkl', '.pickle')), \ + 'The dump file must be a pkl file.' + dump_metric = dict(type='DumpResults', out_file_path=args.dump) + if isinstance(cfg.test_evaluator, (list, tuple)): + cfg.test_evaluator = list(cfg.test_evaluator) + cfg.test_evaluator.append(dump_metric) + else: + cfg.test_evaluator = [cfg.test_evaluator, dump_metric] + + return cfg + + +def main(): + args = parse_args() + + # register all modules in mmaction2 into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main()