|
16 | 16 |
|
17 | 17 | import os |
18 | 18 | import unittest |
19 | | - |
20 | | -from torch.nn.modules.container import T |
21 | 19 | from parameterized import parameterized |
22 | 20 |
|
23 | 21 | import torch |
| 22 | +import torch.nn as nn |
24 | 23 | import torch.nn.functional as F |
25 | 24 | import torch.distributed as dist |
26 | 25 |
|
|
29 | 28 |
|
30 | 29 | from makani.utils import comm |
31 | 30 | from makani.utils import functions as fn |
32 | | -from physicsnemo.distributed.utils import split_tensor_along_dim |
33 | | -from physicsnemo.distributed.mappings import gather_from_parallel_region, scatter_to_parallel_region, \ |
34 | | - reduce_from_parallel_region |
35 | 31 |
|
36 | 32 | from makani.mpu.mappings import init_gradient_reduction_hooks |
37 | 33 |
|
38 | 34 | # layer norm imports |
39 | 35 | from makani.models.common.layer_norm import GeometricInstanceNormS2 |
40 | | -from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2 |
| 36 | +from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2, DistributedInstanceNorm2d |
41 | 37 |
|
42 | 38 | from distributed_helpers import split_helper, gather_helper |
43 | 39 |
|
@@ -246,6 +242,149 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b |
246 | 242 | self.assertTrue(err.item() <= tol) |
247 | 243 |
|
248 | 244 |
|
| 245 | + @parameterized.expand( |
| 246 | + [ |
| 247 | + [256, 512, 32, 8, 1e-5, True], |
| 248 | + [181, 360, 1, 10, 1e-5, True], |
| 249 | + [256, 512, 32, 8, 1e-5, False], |
| 250 | + [181, 360, 1, 10, 1e-5, False], |
| 251 | + ], |
| 252 | + skip_on_empty=True, |
| 253 | + ) |
| 254 | + def test_distributed_instance_norm_2d(self, nlat, nlon, batch_size, num_chan, tol, affine, verbose=True): |
| 255 | + B, C, H, W = batch_size, num_chan, nlat, nlon |
| 256 | + |
| 257 | + self._init_seed(333) |
| 258 | + |
| 259 | + # create local (serial) instance norm - using PyTorch's standard InstanceNorm2d |
| 260 | + norm_local = nn.InstanceNorm2d( |
| 261 | + num_features=C, |
| 262 | + eps=1e-5, |
| 263 | + affine=affine, |
| 264 | + track_running_stats=False, |
| 265 | + ).to(self.device) |
| 266 | + |
| 267 | + # create distributed instance norm |
| 268 | + norm_dist = DistributedInstanceNorm2d( |
| 269 | + num_features=C, |
| 270 | + eps=1e-5, |
| 271 | + affine=affine, |
| 272 | + ).to(self.device) |
| 273 | + |
| 274 | + # set up gradient reduction hooks for distributed version if affine=True |
| 275 | + if affine: |
| 276 | + norm_dist = init_gradient_reduction_hooks( |
| 277 | + norm_dist, |
| 278 | + device=self.device, |
| 279 | + reduction_buffer_count=1, |
| 280 | + broadcast_buffers=False, |
| 281 | + find_unused_parameters=False, |
| 282 | + gradient_as_bucket_view=True, |
| 283 | + static_graph=True, |
| 284 | + verbose=False, |
| 285 | + ) |
| 286 | + norm_dist_handle = norm_dist.module |
| 287 | + else: |
| 288 | + norm_dist_handle = norm_dist |
| 289 | + |
| 290 | + # make sure weights are the same if affine=True |
| 291 | + if affine: |
| 292 | + with torch.no_grad(): |
| 293 | + norm_dist_handle.weight.copy_(norm_local.weight) |
| 294 | + norm_dist_handle.bias.copy_(norm_local.bias) |
| 295 | + |
| 296 | + # input |
| 297 | + inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) |
| 298 | + |
| 299 | + ############################################################# |
| 300 | + # local (serial) transform |
| 301 | + ############################################################# |
| 302 | + # FWD pass |
| 303 | + inp_full.requires_grad = True |
| 304 | + out_full = norm_local(inp_full) |
| 305 | + |
| 306 | + # create grad for backward |
| 307 | + with torch.no_grad(): |
| 308 | + # create full grad |
| 309 | + ograd_full = torch.randn_like(out_full) |
| 310 | + |
| 311 | + # BWD pass |
| 312 | + out_full.backward(ograd_full) |
| 313 | + igrad_full = inp_full.grad.clone() |
| 314 | + |
| 315 | + if affine: |
| 316 | + wgrad_full = norm_local.weight.grad.clone() |
| 317 | + bgrad_full = norm_local.bias.grad.clone() |
| 318 | + |
| 319 | + ############################################################# |
| 320 | + # distributed transform |
| 321 | + ############################################################# |
| 322 | + # FWD pass |
| 323 | + inp_local = self._split_helper(inp_full, hdim=-2, wdim=-1) |
| 324 | + inp_local.requires_grad = True |
| 325 | + out_local = norm_dist(inp_local) |
| 326 | + |
| 327 | + # BWD pass |
| 328 | + ograd_local = self._split_helper(ograd_full, hdim=-2, wdim=-1) |
| 329 | + out_local.backward(ograd_local) |
| 330 | + igrad_local = inp_local.grad.clone() |
| 331 | + |
| 332 | + if affine: |
| 333 | + wgrad_local = norm_dist_handle.weight.grad.clone() |
| 334 | + bgrad_local = norm_dist_handle.bias.grad.clone() |
| 335 | + |
| 336 | + ############################################################# |
| 337 | + # evaluate FWD pass |
| 338 | + ############################################################# |
| 339 | + with torch.no_grad(): |
| 340 | + out_gather_full = self._gather_helper(out_local, hdim=-2, wdim=-1) |
| 341 | + err = fn.relative_error(out_gather_full, out_full) |
| 342 | + if verbose and (self.world_rank == 0): |
| 343 | + print(f"InstanceNorm2d forward relative error: {err.item()}") |
| 344 | + self.assertTrue(err.item() <= tol) |
| 345 | + |
| 346 | + ############################################################# |
| 347 | + # evaluate input grads |
| 348 | + ############################################################# |
| 349 | + with torch.no_grad(): |
| 350 | + igrad_gather_full = self._gather_helper(igrad_local, hdim=-2, wdim=-1) |
| 351 | + err = fn.relative_error(igrad_gather_full, igrad_full) |
| 352 | + if verbose and (self.world_rank == 0): |
| 353 | + print(f"InstanceNorm2d input grad relative error: {err.item()}") |
| 354 | + self.assertTrue(err.item() <= tol) |
| 355 | + |
| 356 | + ############################################################# |
| 357 | + # evaluate weight and bias grads |
| 358 | + ############################################################# |
| 359 | + # weight gradients should be the same across all processes |
| 360 | + if affine: |
| 361 | + with torch.no_grad(): |
| 362 | + wgrad_gather_list = [torch.empty_like(wgrad_local) for _ in range(self.world_size)] |
| 363 | + wgrad_gather_list[self.world_rank] = wgrad_local |
| 364 | + dist.all_gather(wgrad_gather_list, wgrad_local, group=None) |
| 365 | + errs = [] |
| 366 | + for wgrad_gather_full in wgrad_gather_list: |
| 367 | + errs.append(fn.relative_error(wgrad_gather_full, wgrad_full)) |
| 368 | + err = torch.mean(torch.stack(errs, dim=0)) |
| 369 | + if verbose and (self.world_rank == 0): |
| 370 | + print(f"InstanceNorm2d weight grad relative error: {err.item()}") |
| 371 | + self.assertTrue(err.item() <= tol) |
| 372 | + |
| 373 | + # bias gradients should be the same across all processes |
| 374 | + if affine: |
| 375 | + with torch.no_grad(): |
| 376 | + bgrad_gather_list = [torch.empty_like(bgrad_local) for _ in range(self.world_size)] |
| 377 | + bgrad_gather_list[self.world_rank] = bgrad_local |
| 378 | + dist.all_gather(bgrad_gather_list, bgrad_local, group=None) |
| 379 | + errs = [] |
| 380 | + for bgrad_gather_full in bgrad_gather_list: |
| 381 | + errs.append(fn.relative_error(bgrad_gather_full, bgrad_full)) |
| 382 | + err = torch.mean(torch.stack(errs, dim=0)) |
| 383 | + if verbose and (self.world_rank == 0): |
| 384 | + print(f"InstanceNorm2d bias grad relative error: {err.item()}") |
| 385 | + self.assertTrue(err.item() <= tol) |
| 386 | + |
| 387 | + |
249 | 388 | @parameterized.expand( |
250 | 389 | [ |
251 | 390 | [181, 360, 1, 4, 1e-5, "equiangular", True], |
@@ -303,7 +442,6 @@ def test_distributed_geometric_instance_norm_s2(self, nlat, nlon, batch_size, nu |
303 | 442 | static_graph=True, |
304 | 443 | verbose=False, |
305 | 444 | ) |
306 | | - norm_dist_handle = norm_dist.module |
307 | 445 |
|
308 | 446 | #make sure weights are the same if affine=True |
309 | 447 | if affine: |
|
0 commit comments