Skip to content

Commit 6dfa465

Browse files
committed
adding test for euclidian distributed instance norm
1 parent 5320369 commit 6dfa465

File tree

1 file changed

+145
-7
lines changed

1 file changed

+145
-7
lines changed

tests/distributed/tests_distributed_layers.py

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
import os
1818
import unittest
19-
20-
from torch.nn.modules.container import T
2119
from parameterized import parameterized
2220

2321
import torch
22+
import torch.nn as nn
2423
import torch.nn.functional as F
2524
import torch.distributed as dist
2625

@@ -29,15 +28,12 @@
2928

3029
from makani.utils import comm
3130
from makani.utils import functions as fn
32-
from physicsnemo.distributed.utils import split_tensor_along_dim
33-
from physicsnemo.distributed.mappings import gather_from_parallel_region, scatter_to_parallel_region, \
34-
reduce_from_parallel_region
3531

3632
from makani.mpu.mappings import init_gradient_reduction_hooks
3733

3834
# layer norm imports
3935
from makani.models.common.layer_norm import GeometricInstanceNormS2
40-
from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2
36+
from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2, DistributedInstanceNorm2d
4137

4238
from distributed_helpers import split_helper, gather_helper
4339

@@ -246,6 +242,149 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
246242
self.assertTrue(err.item() <= tol)
247243

248244

245+
@parameterized.expand(
246+
[
247+
[256, 512, 32, 8, 1e-5, True],
248+
[181, 360, 1, 10, 1e-5, True],
249+
[256, 512, 32, 8, 1e-5, False],
250+
[181, 360, 1, 10, 1e-5, False],
251+
],
252+
skip_on_empty=True,
253+
)
254+
def test_distributed_instance_norm_2d(self, nlat, nlon, batch_size, num_chan, tol, affine, verbose=True):
255+
B, C, H, W = batch_size, num_chan, nlat, nlon
256+
257+
self._init_seed(333)
258+
259+
# create local (serial) instance norm - using PyTorch's standard InstanceNorm2d
260+
norm_local = nn.InstanceNorm2d(
261+
num_features=C,
262+
eps=1e-5,
263+
affine=affine,
264+
track_running_stats=False,
265+
).to(self.device)
266+
267+
# create distributed instance norm
268+
norm_dist = DistributedInstanceNorm2d(
269+
num_features=C,
270+
eps=1e-5,
271+
affine=affine,
272+
).to(self.device)
273+
274+
# set up gradient reduction hooks for distributed version if affine=True
275+
if affine:
276+
norm_dist = init_gradient_reduction_hooks(
277+
norm_dist,
278+
device=self.device,
279+
reduction_buffer_count=1,
280+
broadcast_buffers=False,
281+
find_unused_parameters=False,
282+
gradient_as_bucket_view=True,
283+
static_graph=True,
284+
verbose=False,
285+
)
286+
norm_dist_handle = norm_dist.module
287+
else:
288+
norm_dist_handle = norm_dist
289+
290+
# make sure weights are the same if affine=True
291+
if affine:
292+
with torch.no_grad():
293+
norm_dist_handle.weight.copy_(norm_local.weight)
294+
norm_dist_handle.bias.copy_(norm_local.bias)
295+
296+
# input
297+
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
298+
299+
#############################################################
300+
# local (serial) transform
301+
#############################################################
302+
# FWD pass
303+
inp_full.requires_grad = True
304+
out_full = norm_local(inp_full)
305+
306+
# create grad for backward
307+
with torch.no_grad():
308+
# create full grad
309+
ograd_full = torch.randn_like(out_full)
310+
311+
# BWD pass
312+
out_full.backward(ograd_full)
313+
igrad_full = inp_full.grad.clone()
314+
315+
if affine:
316+
wgrad_full = norm_local.weight.grad.clone()
317+
bgrad_full = norm_local.bias.grad.clone()
318+
319+
#############################################################
320+
# distributed transform
321+
#############################################################
322+
# FWD pass
323+
inp_local = self._split_helper(inp_full, hdim=-2, wdim=-1)
324+
inp_local.requires_grad = True
325+
out_local = norm_dist(inp_local)
326+
327+
# BWD pass
328+
ograd_local = self._split_helper(ograd_full, hdim=-2, wdim=-1)
329+
out_local.backward(ograd_local)
330+
igrad_local = inp_local.grad.clone()
331+
332+
if affine:
333+
wgrad_local = norm_dist_handle.weight.grad.clone()
334+
bgrad_local = norm_dist_handle.bias.grad.clone()
335+
336+
#############################################################
337+
# evaluate FWD pass
338+
#############################################################
339+
with torch.no_grad():
340+
out_gather_full = self._gather_helper(out_local, hdim=-2, wdim=-1)
341+
err = fn.relative_error(out_gather_full, out_full)
342+
if verbose and (self.world_rank == 0):
343+
print(f"InstanceNorm2d forward relative error: {err.item()}")
344+
self.assertTrue(err.item() <= tol)
345+
346+
#############################################################
347+
# evaluate input grads
348+
#############################################################
349+
with torch.no_grad():
350+
igrad_gather_full = self._gather_helper(igrad_local, hdim=-2, wdim=-1)
351+
err = fn.relative_error(igrad_gather_full, igrad_full)
352+
if verbose and (self.world_rank == 0):
353+
print(f"InstanceNorm2d input grad relative error: {err.item()}")
354+
self.assertTrue(err.item() <= tol)
355+
356+
#############################################################
357+
# evaluate weight and bias grads
358+
#############################################################
359+
# weight gradients should be the same across all processes
360+
if affine:
361+
with torch.no_grad():
362+
wgrad_gather_list = [torch.empty_like(wgrad_local) for _ in range(self.world_size)]
363+
wgrad_gather_list[self.world_rank] = wgrad_local
364+
dist.all_gather(wgrad_gather_list, wgrad_local, group=None)
365+
errs = []
366+
for wgrad_gather_full in wgrad_gather_list:
367+
errs.append(fn.relative_error(wgrad_gather_full, wgrad_full))
368+
err = torch.mean(torch.stack(errs, dim=0))
369+
if verbose and (self.world_rank == 0):
370+
print(f"InstanceNorm2d weight grad relative error: {err.item()}")
371+
self.assertTrue(err.item() <= tol)
372+
373+
# bias gradients should be the same across all processes
374+
if affine:
375+
with torch.no_grad():
376+
bgrad_gather_list = [torch.empty_like(bgrad_local) for _ in range(self.world_size)]
377+
bgrad_gather_list[self.world_rank] = bgrad_local
378+
dist.all_gather(bgrad_gather_list, bgrad_local, group=None)
379+
errs = []
380+
for bgrad_gather_full in bgrad_gather_list:
381+
errs.append(fn.relative_error(bgrad_gather_full, bgrad_full))
382+
err = torch.mean(torch.stack(errs, dim=0))
383+
if verbose and (self.world_rank == 0):
384+
print(f"InstanceNorm2d bias grad relative error: {err.item()}")
385+
self.assertTrue(err.item() <= tol)
386+
387+
249388
@parameterized.expand(
250389
[
251390
[181, 360, 1, 4, 1e-5, "equiangular", True],
@@ -303,7 +442,6 @@ def test_distributed_geometric_instance_norm_s2(self, nlat, nlon, batch_size, nu
303442
static_graph=True,
304443
verbose=False,
305444
)
306-
norm_dist_handle = norm_dist.module
307445

308446
#make sure weights are the same if affine=True
309447
if affine:

0 commit comments

Comments
 (0)