Skip to content

Commit 5320369

Browse files
committed
fixing distributed instance norm
1 parent 1302047 commit 5320369

File tree

3 files changed

+209
-36
lines changed

3 files changed

+209
-36
lines changed

makani/models/common/layer_norm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,23 @@ def __init__(
5858

5959
# we only need the weights
6060
self.quadrature = GridQuadrature(
61-
quadrature_rule, img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, normalize=True, pole_mask=pole_mask, distributed=False
61+
quadrature_rule,
62+
img_shape=img_shape,
63+
crop_shape=crop_shape,
64+
crop_offset=crop_offset,
65+
normalize=True,
66+
pole_mask=pole_mask,
67+
distributed=False
6268
)
6369

6470
def forward(self, x: torch.Tensor) -> torch.Tensor:
6571

6672
# extract shapes
6773
B, C, H, W = x.shape
6874

75+
xtype = x.dtype
6976
with amp.autocast(device_type="cuda", enabled=False):
70-
dtype = x.dtype
71-
x = x.float()
77+
x = x.to(torch.float32)
7278

7379
# compute var and mean
7480
mean = self.quadrature(x)
@@ -79,9 +85,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7985
mean = mean.reshape(B, C, 1, 1)
8086

8187
# convert types
82-
x = x.to(dtype)
83-
mean = mean.to(dtype)
84-
var = var.to(dtype)
88+
x = x.to(xtype)
89+
mean = mean.to(xtype)
90+
var = var.to(xtype)
8591

8692
# apply the normalization
8793
if self.affine:

makani/mpu/layer_norm.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ def _welford_kernel(vars: torch.Tensor, means: torch.Tensor, counts: torch.Tenso
6666
# use Welford's algorithm to accumulate them into a single mean and variance
6767
for i in range(1, means.shape[0]):
6868
delta = means[i, ...] - mean
69+
mean = mean + delta * counts[i, ...] / (count + counts[i, ...])
6970
m2 = m2 + m2s[i, ...] + delta**2 * count * counts[i, ...] / (count + counts[i, ...])
70-
if i == 1:
71-
mean = (mean * count + means[i, ...] * counts[i, ...]) / (count + counts[i, ...])
72-
else:
73-
mean = mean + delta * counts[i, ...] / (count + counts[i, ...])
7471

7572
# update the current count
7673
count = count + counts[i, ...]
@@ -122,7 +119,7 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
122119
"""Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
123120

124121
# extract shapes
125-
B, C, H, W = x.shape
122+
B, C, _, _ = x.shape
126123

127124
# those have the shapes [B, C]
128125
var, mean = torch.var_mean(x, dim=(-2, -1), unbiased=False, keepdim=False)
@@ -141,9 +138,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
141138

142139
def forward(self, x: torch.Tensor) -> torch.Tensor:
143140

141+
xtype = x.dtype
144142
with amp.autocast(device_type="cuda", enabled=False):
145-
dtype = x.dtype
146-
x = x.float()
143+
x = x.to(torch.float32)
147144

148145
# start by computing std and mean
149146
var, mean = self._stats_welford(x)
@@ -152,9 +149,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152149
mean = copy_to_parallel_region(mean, "spatial")
153150
var = copy_to_parallel_region(var, "spatial")
154151

155-
x = x.to(dtype)
156-
mean = mean.to(dtype)
157-
var = var.to(dtype)
152+
x = x.to(xtype)
153+
mean = mean.to(xtype)
154+
var = var.to(xtype)
158155

159156
# apply the normalization
160157
if self.affine:
@@ -188,7 +185,13 @@ def __init__(
188185

189186
# we only need the weights
190187
quad_weight = GridQuadrature(
191-
quadrature_rule, img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, normalize=True, pole_mask=pole_mask, distributed=True
188+
quadrature_rule,
189+
img_shape=img_shape,
190+
crop_shape=crop_shape,
191+
crop_offset=crop_offset,
192+
normalize=True,
193+
pole_mask=pole_mask,
194+
distributed=True
192195
).quad_weight
193196

194197
self.register_buffer("quad_weight", quad_weight, persistent=False)
@@ -197,12 +200,12 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
197200
"""Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
198201

199202
# extract shapes
200-
B, C, H, W = x.shape
203+
B, C, _, _ = x.shape
201204

202205
# compute var, mean locally: those have the shapes [B, C]
203-
mean = torch.sum(x * self.quad_weight, dim=(-2, -1), keepdim=False)
204-
var = torch.sum(torch.square(x - mean.reshape(B, C, 1, 1)) * self.quad_weight, dim=(-2, -1), keepdim=False)
205206
count = torch.tile(torch.sum(self.quad_weight, dim=(-2, -1), keepdim=False), (B, C))
207+
mean = torch.sum(x * self.quad_weight, dim=(-2, -1), keepdim=False) / count
208+
var = torch.sum(torch.square(x - mean.reshape(B, C, 1, 1)) * self.quad_weight, dim=(-2, -1), keepdim=False) / count
206209

207210
# compute welford variance
208211
var, mean, _ = distributed_welford_variance(var, mean, count, "spatial")
@@ -215,9 +218,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
215218

216219
def forward(self, x: torch.Tensor) -> torch.Tensor:
217220

221+
xtype = x.dtype
218222
with amp.autocast(device_type="cuda", enabled=False):
219-
dtype = x.dtype
220-
x = x.float()
223+
x = x.to(torch.float32)
221224

222225
# start by computing std and mean
223226
var, mean = self._stats_welford(x)
@@ -226,9 +229,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226229
mean = copy_to_parallel_region(mean, "spatial")
227230
var = copy_to_parallel_region(var, "spatial")
228231

229-
x = x.to(dtype)
230-
mean = mean.to(dtype)
231-
var = var.to(dtype)
232+
x = x.to(xtype)
233+
mean = mean.to(xtype)
234+
var = var.to(xtype)
232235

233236
# apply the normalization
234237
if self.affine:

tests/distributed/tests_distributed_layers.py

Lines changed: 175 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import os
1818
import unittest
19+
20+
from torch.nn.modules.container import T
1921
from parameterized import parameterized
2022

2123
import torch
@@ -33,6 +35,10 @@
3335

3436
from makani.mpu.mappings import init_gradient_reduction_hooks
3537

38+
# layer norm imports
39+
from makani.models.common.layer_norm import GeometricInstanceNormS2
40+
from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2
41+
3642
from distributed_helpers import split_helper, gather_helper
3743

3844
class TestDistributedLayers(unittest.TestCase):
@@ -96,14 +102,17 @@ def _gather_helper(self, tensor, hdim=-2, wdim=-1):
96102
return tensor_gather
97103

98104

99-
@parameterized.expand([
100-
[256, 512, 256, 512, 32, 8, 1e-5],
101-
[181, 360, 181, 360, 1, 10, 1e-5],
102-
[256, 512, 128, 256, 32, 8, 1e-5],
103-
[181, 360, 91, 180, 1, 10, 1e-5],
104-
[128, 256, 256, 512, 32, 8, 1e-5],
105-
[ 91, 180, 181, 360, 1, 10, 1e-5],
106-
])
105+
@parameterized.expand(
106+
[
107+
[180, 360, 256, 512, 32, 8, 1e-5],
108+
[181, 360, 181, 360, 1, 10, 1e-5],
109+
[180, 360, 128, 256, 32, 8, 1e-5],
110+
[181, 360, 91, 180, 1, 10, 1e-5],
111+
[128, 256, 256, 512, 32, 8, 1e-5],
112+
[ 91, 180, 181, 360, 1, 10, 1e-5],
113+
],
114+
skip_on_empty=True,
115+
)
107116
def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, tol, verbose=True):
108117
B, C, Hi, Wi, Ho, Wo = batch_size, num_chan, nlat_in, nlon_in, nlat_out, nlon_out
109118

@@ -146,7 +155,7 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
146155
reduction_buffer_count=1,
147156
broadcast_buffers=False,
148157
find_unused_parameters=False,
149-
gradient_as_bucket_view=True,
158+
gradient_as_bucket_view=True,
150159
static_graph=True,
151160
verbose=False,
152161
)
@@ -158,7 +167,6 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
158167
spect_conv_dist.module.bias.copy_(spect_conv_local.bias)
159168

160169
# input
161-
self._init_seed(444)
162170
inp_full = torch.randn((B, C, Hi, Wi), dtype=torch.float32, device=self.device)
163171

164172
#############################################################
@@ -169,7 +177,6 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
169177
out_full, _ = spect_conv_local(inp_full)
170178

171179
# create grad for backward
172-
self._init_seed(555)
173180
with torch.no_grad():
174181
# create full grad
175182
ograd_full = torch.randn_like(out_full)
@@ -237,6 +244,163 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
237244
if verbose and (self.world_rank == 0):
238245
print(f"final relative error of bias gradients: {err.item()}")
239246
self.assertTrue(err.item() <= tol)
247+
248+
249+
@parameterized.expand(
250+
[
251+
[181, 360, 1, 4, 1e-5, "equiangular", True],
252+
[181, 360, 1, 4, 1e-5, "equiangular", False],
253+
[180, 360, 1, 10, 1e-5, "legendre-gauss", True],
254+
[180, 360, 1, 10, 1e-5, "legendre-gauss", False],
255+
],
256+
skip_on_empty=True,
257+
)
258+
def test_distributed_geometric_instance_norm_s2(self, nlat, nlon, batch_size, num_chan, tol, grid_type, affine, verbose=True):
259+
B, C, H, W = batch_size, num_chan, nlat, nlon
260+
261+
# set up layer norm parameters
262+
img_shape = (H, W)
263+
crop_shape = (H, W)
264+
crop_offset = (0, 0)
265+
pole_mask = 0
266+
eps = 1e-5
267+
268+
self._init_seed(333)
269+
270+
# create local (serial) layer norm
271+
norm_local = GeometricInstanceNormS2(
272+
img_shape=img_shape,
273+
crop_shape=crop_shape,
274+
crop_offset=crop_offset,
275+
grid_type=grid_type,
276+
pole_mask=pole_mask,
277+
num_features=C,
278+
eps=eps,
279+
affine=affine,
280+
).to(self.device)
281+
282+
# create distributed layer norm
283+
norm_dist = DistributedGeometricInstanceNormS2(
284+
img_shape=img_shape,
285+
crop_shape=crop_shape,
286+
crop_offset=crop_offset,
287+
grid_type=grid_type,
288+
pole_mask=pole_mask,
289+
num_features=C,
290+
eps=eps,
291+
affine=affine,
292+
).to(self.device)
293+
294+
# set up gradient reduction hooks for distributed version
295+
if affine:
296+
norm_dist = init_gradient_reduction_hooks(
297+
norm_dist,
298+
device=self.device,
299+
reduction_buffer_count=1,
300+
broadcast_buffers=False,
301+
find_unused_parameters=False,
302+
gradient_as_bucket_view=True,
303+
static_graph=True,
304+
verbose=False,
305+
)
306+
norm_dist_handle = norm_dist.module
307+
308+
#make sure weights are the same if affine=True
309+
if affine:
310+
with torch.no_grad():
311+
norm_dist.module.weight.copy_(norm_local.weight)
312+
norm_dist.module.bias.copy_(norm_local.bias)
313+
314+
# input
315+
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
316+
317+
#############################################################
318+
# local (serial) transform
319+
#############################################################
320+
# FWD pass
321+
inp_full.requires_grad = True
322+
out_full = norm_local(inp_full)
323+
324+
# create grad for backward
325+
with torch.no_grad():
326+
# create full grad
327+
ograd_full = torch.randn_like(out_full)
328+
329+
# BWD pass
330+
out_full.backward(ograd_full)
331+
igrad_full = inp_full.grad.clone()
332+
333+
if affine:
334+
wgrad_full = norm_local.weight.grad.clone()
335+
bgrad_full = norm_local.bias.grad.clone()
336+
337+
#############################################################
338+
# distributed transform
339+
#############################################################
340+
# FWD pass
341+
inp_local = self._split_helper(inp_full, hdim=-2, wdim=-1)
342+
inp_local.requires_grad = True
343+
out_local = norm_dist(inp_local)
344+
345+
# BWD pass
346+
ograd_local = self._split_helper(ograd_full, hdim=-2, wdim=-1)
347+
out_local.backward(ograd_local)
348+
igrad_local = inp_local.grad.clone()
349+
350+
if affine:
351+
wgrad_local = norm_dist.module.weight.grad.clone()
352+
bgrad_local = norm_dist.module.bias.grad.clone()
353+
354+
#############################################################
355+
# evaluate FWD pass
356+
#############################################################
357+
with torch.no_grad():
358+
out_gather_full = self._gather_helper(out_local, hdim=-2, wdim=-1)
359+
err = fn.relative_error(out_gather_full, out_full)
360+
if verbose and (self.world_rank == 0):
361+
print(f"GeometricInstanceNormS2 forward relative error: {err.item()}")
362+
self.assertTrue(err.item() <= tol)
363+
364+
#############################################################
365+
# evaluate input grads
366+
#############################################################
367+
with torch.no_grad():
368+
igrad_gather_full = self._gather_helper(igrad_local, hdim=-2, wdim=-1)
369+
err = fn.relative_error(igrad_gather_full, igrad_full)
370+
if verbose and (self.world_rank == 0):
371+
print(f"GeometricInstanceNormS2 input grad relative error: {err.item()}")
372+
self.assertTrue(err.item() <= tol)
373+
374+
#############################################################
375+
# evaluate weight and bias grads
376+
#############################################################
377+
# weight gradients should be the same across all processes
378+
if affine:
379+
with torch.no_grad():
380+
wgrad_gather_list = [torch.empty_like(wgrad_local) for _ in range(self.world_size)]
381+
wgrad_gather_list[self.world_rank] = wgrad_local
382+
dist.all_gather(wgrad_gather_list, wgrad_local, group=None)
383+
errs = []
384+
for wgrad_gather_full in wgrad_gather_list:
385+
errs.append(fn.relative_error(wgrad_gather_full, wgrad_full))
386+
err = torch.mean(torch.stack(errs, dim=0))
387+
if verbose and (self.world_rank == 0):
388+
print(f"GeometricInstanceNormS2 weight grad relative error: {err.item()}")
389+
self.assertTrue(err.item() <= tol)
390+
391+
# bias gradients should be the same across all processes
392+
if affine:
393+
with torch.no_grad():
394+
bgrad_gather_list = [torch.empty_like(bgrad_local) for _ in range(self.world_size)]
395+
bgrad_gather_list[self.world_rank] = bgrad_local
396+
dist.all_gather(bgrad_gather_list, bgrad_local, group=None)
397+
errs = []
398+
for bgrad_gather_full in bgrad_gather_list:
399+
errs.append(fn.relative_error(bgrad_gather_full, bgrad_full))
400+
err = torch.mean(torch.stack(errs, dim=0))
401+
if verbose and (self.world_rank == 0):
402+
print(f"GeometricInstanceNormS2 bias grad relative error: {err.item()}")
403+
self.assertTrue(err.item() <= tol)
240404

241405

242406
if __name__ == '__main__':

0 commit comments

Comments
 (0)