Skip to content

Commit 60f5833

Browse files
committed
add greenmim infer
1 parent c951bb6 commit 60f5833

File tree

11 files changed

+1265
-4
lines changed

11 files changed

+1265
-4
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# model settings
2+
img_size = 224
3+
patch_size = 4
4+
5+
model = dict(
6+
type='GreenMIM',
7+
data_preprocessor=dict(
8+
mean=[123.675, 116.28, 103.53],
9+
std=[58.395, 57.12, 57.375],
10+
bgr_to_rgb=True),
11+
backbone=dict(
12+
type='GreenMIMSwinTransformer',
13+
arch='B',
14+
img_size=img_size,
15+
patch_size=patch_size,
16+
drop_path_rate=0.0,
17+
stage_cfgs=dict(block_cfgs=dict(window_size=7))),
18+
neck=dict(type='GreenMIMNeck', in_channels=3, encoder_stride=32, img_size=img_size, patch_size=patch_size),
19+
head=dict(
20+
type='GreenMIMHead',
21+
patch_size=patch_size,
22+
norm_pix_loss=False,
23+
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
_base_ = [
2+
'../_base_/models/greenmim_swin-base.py',
3+
'../_base_/datasets/imagenet_mae.py',
4+
'../_base_/schedules/adamw_coslr-200e_in1k.py',
5+
'../_base_/default_runtime.py',
6+
]
7+
8+
# dataset 16 GPUs x 128
9+
train_dataloader = dict(batch_size=128, num_workers=16)
10+
11+
# optimizer wrapper
12+
optimizer = dict(
13+
type='AdamW', lr=2e-4 * 2048 / 512, betas=(0.9, 0.999), eps=1e-8)
14+
optim_wrapper = dict(
15+
type='AmpOptimWrapper',
16+
optimizer=optimizer,
17+
clip_grad=dict(max_norm=5.0),
18+
paramwise_cfg=dict(
19+
custom_keys={
20+
'norm': dict(decay_mult=0.0),
21+
'bias': dict(decay_mult=0.0),
22+
'absolute_pos_embed': dict(decay_mult=0.),
23+
'relative_position_bias_table': dict(decay_mult=0.)
24+
}))
25+
26+
# learning rate scheduler
27+
param_scheduler = [
28+
dict(
29+
type='LinearLR',
30+
start_factor=1e-6 / 2e-4,
31+
by_epoch=True,
32+
begin=0,
33+
end=10,
34+
convert_to_iter_based=True),
35+
dict(
36+
type='CosineAnnealingLR',
37+
T_max=90,
38+
eta_min=1e-5 * 2048 / 512,
39+
by_epoch=True,
40+
begin=10,
41+
end=100,
42+
convert_to_iter_based=True)
43+
]
44+
45+
# schedule
46+
train_cfg = dict(max_epochs=100)
47+
48+
# runtime
49+
default_hooks = dict(logger=dict(type='LoggerHook', interval=100))

mmselfsup/models/algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from .simmim import SimMIM
2222
from .simsiam import SimSiam
2323
from .swav import SwAV
24+
from .greenmim import GreenMIM
2425

2526
__all__ = [
2627
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
2728
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
2829
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
29-
'MixMIM'
30+
'MixMIM', 'GreenMIM'
3031
]
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional, Tuple
3+
4+
import torch
5+
from mmengine.structures import BaseDataElement
6+
7+
from mmselfsup.registry import MODELS
8+
from mmselfsup.structures import SelfSupDataSample
9+
from .base import BaseModel
10+
11+
@MODELS.register_module()
12+
class GreenMIM(BaseModel):
13+
"""GreenMIM.
14+
15+
Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling
16+
<https://arxiv.org/abs/2205.13515>`_.
17+
"""
18+
19+
def extract_feat(self,
20+
inputs: List[torch.Tensor],
21+
data_samples: Optional[List[SelfSupDataSample]] = None,
22+
**kwarg) -> Tuple[torch.Tensor]:
23+
"""The forward function to extract features from neck.
24+
25+
Args:
26+
inputs (List[torch.Tensor]): The input images.
27+
28+
Returns:
29+
Tuple[torch.Tensor]: Neck outputs.
30+
"""
31+
latent, mask, ids_restore = self.backbone(inputs[0])
32+
pred = self.neck(latent, ids_restore)
33+
self.mask = mask
34+
return pred
35+
36+
def reconstruct(self,
37+
features: torch.Tensor,
38+
data_samples: Optional[List[SelfSupDataSample]] = None,
39+
**kwargs) -> SelfSupDataSample:
40+
"""The function is for image reconstruction.
41+
42+
Args:
43+
features (torch.Tensor): The input images.
44+
data_samples (List[SelfSupDataSample]): All elements required
45+
during the forward function.
46+
47+
Returns:
48+
SelfSupDataSample: The prediction from model.
49+
"""
50+
mean = kwargs['mean']
51+
std = kwargs['std']
52+
features = features * std + mean
53+
54+
pred = self.head.unpatchify(features)
55+
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
56+
57+
mask = self.mask.detach()
58+
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
59+
3) # (N, H*W, p*p*3)
60+
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
61+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
62+
63+
results = SelfSupDataSample()
64+
results.mask = BaseDataElement(**dict(value=mask))
65+
results.pred = BaseDataElement(**dict(value=pred))
66+
67+
return results
68+
69+
def patchify(self, imgs, patch_size):
70+
"""
71+
imgs: (N, 3, H, W)
72+
x: (N, L, patch_size**2 *3)
73+
"""
74+
p = patch_size
75+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
76+
77+
h = w = imgs.shape[2] // p
78+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
79+
x = torch.einsum('nchpwq->nhwpqc', x)
80+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
81+
return x
82+
83+
def loss(self, inputs: List[torch.Tensor],
84+
data_samples: List[SelfSupDataSample],
85+
**kwargs) -> Dict[str, torch.Tensor]:
86+
"""The forward function in training.
87+
88+
Args:
89+
inputs (List[torch.Tensor]): The input images.
90+
data_samples (List[SelfSupDataSample]): All elements required
91+
during the forward function.
92+
93+
Returns:
94+
Dict[str, torch.Tensor]: A dictionary of loss components.
95+
"""
96+
# ids_restore: the same as that in original repo, which is used
97+
# to recover the original order of tokens in decoder.
98+
latent, mask, ids_restore = self.backbone(inputs[0])
99+
pred = self.neck(latent, ids_restore)
100+
target = self.patchify(inputs[0], self.backbone.final_patch_size)
101+
loss = self.head(pred, target, mask)
102+
losses = dict(loss=loss)
103+
return losses

mmselfsup/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from .resnet import ResNet, ResNetSobel, ResNetV1d
1010
from .resnext import ResNeXt
1111
from .simmim_swin import SimMIMSwinTransformer
12+
from .greenmim import GreenMIMSwinTransformer
1213

1314
__all__ = [
1415
'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT',
1516
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT',
16-
'MixMIMTransformerPretrain'
17+
'MixMIMTransformerPretrain', 'GreenMIMSwinTransformer'
1718
]

0 commit comments

Comments
 (0)