-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathkfac_utils.py
More file actions
194 lines (155 loc) · 5.89 KB
/
kfac_utils.py
File metadata and controls
194 lines (155 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
import torch.nn as nn
import torch.nn.functional as F
def try_contiguous(x):
if not x.is_contiguous():
x = x.contiguous()
return x
def _extract_patches(x, kernel_size, stride, padding):
"""
:param x: The input feature maps. (batch_size, in_c, h, w)
:param kernel_size: the kernel size of the conv filter (tuple of two elements)
:param stride: the stride of conv operation (tuple of two elements)
:param padding: number of paddings. be a tuple of two elements
:return: (batch_size, out_h, out_w, in_c*kh*kw)
"""
if padding[0] + padding[1] > 0:
x = F.pad(x, (padding[1], padding[1], padding[0],
padding[0])).data # Actually check dims
x = x.unfold(2, kernel_size[0], stride[0])
x = x.unfold(3, kernel_size[1], stride[1])
x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
x = x.view(
x.size(0), x.size(1), x.size(2),
x.size(3) * x.size(4) * x.size(5))
return x
def update_running_stat(aa, m_aa, stat_decay):
# using inplace operation to save memory!
m_aa *= stat_decay / (1 - stat_decay)
m_aa += aa
m_aa *= (1 - stat_decay)
class ComputeMatGrad:
@classmethod
def __call__(cls, input, grad_output, layer):
if isinstance(layer, nn.Linear):
grad = cls.linear(input, grad_output, layer)
elif isinstance(layer, nn.Conv2d):
grad = cls.conv2d(input, grad_output, layer)
else:
raise NotImplementedError
return grad
@staticmethod
def linear(input, grad_output, layer):
"""
:param input: batch_size * input_dim
:param grad_output: batch_size * output_dim
:param layer: [nn.module] output_dim * input_dim
:return: batch_size * output_dim * (input_dim + [1 if with bias])
"""
with torch.no_grad():
if layer.bias is not None:
input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
input = input.unsqueeze(1)
grad_output = grad_output.unsqueeze(2)
grad = torch.bmm(grad_output, input)
return grad
@staticmethod
def conv2d(input, grad_output, layer):
"""
:param input: batch_size * in_c * in_h * in_w
:param grad_output: batch_size * out_c * h * w
:param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
:return:
"""
with torch.no_grad():
input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw
grad_output = grad_output.transpose(1, 2).transpose(2, 3)
grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
# b * hw * out_c
if layer.bias is not None:
input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw
grad = torch.einsum('abm,abn->amn', (grad_output, input))
return grad
class ComputeCovA:
@classmethod
def compute_cov_a(cls, a, layer):
return cls.__call__(a, layer)
@classmethod
def __call__(cls, a, layer):
if isinstance(layer, nn.Linear):
cov_a = cls.linear(a, layer)
elif isinstance(layer, nn.Conv2d):
cov_a = cls.conv2d(a, layer)
else:
# FIXME(CW): for extension to other layers.
# raise NotImplementedError
cov_a = None
return cov_a
@staticmethod
def conv2d(a, layer):
batch_size = a.size(0)
a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding)
spatial_size = a.size(1) * a.size(2)
a = a.view(-1, a.size(-1))
if layer.bias is not None:
a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
a = a/spatial_size
# FIXME(CW): do we need to divide the output feature map's size?
return a.t() @ (a / batch_size)
@staticmethod
def linear(a, layer):
# a: batch_size * in_dim
batch_size = a.size(0)
if layer.bias is not None:
a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
return a.t() @ (a / batch_size)
class ComputeCovG:
@classmethod
def compute_cov_g(cls, g, layer, batch_averaged=False):
"""
:param g: gradient
:param layer: the corresponding layer
:param batch_averaged: if the gradient is already averaged with the batch size?
:return:
"""
# batch_size = g.size(0)
return cls.__call__(g, layer, batch_averaged)
@classmethod
def __call__(cls, g, layer, batch_averaged):
if isinstance(layer, nn.Conv2d):
cov_g = cls.conv2d(g, layer, batch_averaged)
elif isinstance(layer, nn.Linear):
cov_g = cls.linear(g, layer, batch_averaged)
else:
cov_g = None
return cov_g
@staticmethod
def conv2d(g, layer, batch_averaged):
# g: batch_size * n_filters * out_h * out_w
# n_filters is actually the output dimension (analogous to Linear layer)
spatial_size = g.size(2) * g.size(3)
batch_size = g.shape[0]
g = g.transpose(1, 2).transpose(2, 3)
g = try_contiguous(g)
g = g.view(-1, g.size(-1))
if batch_averaged:
g = g * batch_size
g = g * spatial_size
cov_g = g.t() @ (g / g.size(0))
return cov_g
@staticmethod
def linear(g, layer, batch_averaged):
# g: batch_size * out_dim
batch_size = g.size(0)
if batch_averaged:
cov_g = g.t() @ (g * batch_size)
else:
cov_g = g.t() @ (g / batch_size)
return cov_g
if __name__ == '__main__':
def test_ComputeCovA():
pass
def test_ComputeCovG():
pass