-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmae.py
58 lines (41 loc) · 1.67 KB
/
mae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
### refer to https://github.com/facebookresearch/mae ###
import copy
import math
import random
import warnings
from argparse import Namespace
from functools import wraps
import torch
import torch.nn.functional as F
from torch import nn
from ssl.base import BaseSelfSupervisedModel
def _get_module_device(module):
return next(module.parameters()).device
class MAE(BaseSelfSupervisedModel):
def __init__(self, backbone: nn.Module, params: Namespace):
super().__init__(backbone, params)
self.norm_pix_loss = params.norm_pix_loss
backbone.set_mask_ratio(mask_ratio=params.mask_ratio)
self.online_encoder = backbone # for consistency
# get device of network and make wrapper same device
device = _get_module_device(self.online_encoder)
self.to(device)
def _data_parallel(self):
self.online_encoder = nn.DataParallel(self.online_encoder)
def compute_ssl_loss(self, x, _, return_features=False):
pred, target, mask = self.online_encoder(x, pretrain=True) # pred: [N, L, p*p*3]
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
# MSE loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
# mask: [N, L], 0 is keep, 1 is remove
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward_features(self, x):
""" Only used in train_selfsup_sampling.py
"""
output = self.backbone(x, global_pool=True)
return output