1616
1717import os
1818import unittest
19+
20+ from torch .nn .modules .container import T
1921from parameterized import parameterized
2022
2123import torch
3335
3436from makani .mpu .mappings import init_gradient_reduction_hooks
3537
38+ # layer norm imports
39+ from makani .models .common .layer_norm import GeometricInstanceNormS2
40+ from makani .mpu .layer_norm import DistributedGeometricInstanceNormS2
41+
3642from distributed_helpers import split_helper , gather_helper
3743
3844class TestDistributedLayers (unittest .TestCase ):
@@ -96,14 +102,17 @@ def _gather_helper(self, tensor, hdim=-2, wdim=-1):
96102 return tensor_gather
97103
98104
99- @parameterized .expand ([
100- [256 , 512 , 256 , 512 , 32 , 8 , 1e-5 ],
101- [181 , 360 , 181 , 360 , 1 , 10 , 1e-5 ],
102- [256 , 512 , 128 , 256 , 32 , 8 , 1e-5 ],
103- [181 , 360 , 91 , 180 , 1 , 10 , 1e-5 ],
104- [128 , 256 , 256 , 512 , 32 , 8 , 1e-5 ],
105- [ 91 , 180 , 181 , 360 , 1 , 10 , 1e-5 ],
106- ])
105+ @parameterized .expand (
106+ [
107+ [180 , 360 , 256 , 512 , 32 , 8 , 1e-5 ],
108+ [181 , 360 , 181 , 360 , 1 , 10 , 1e-5 ],
109+ [180 , 360 , 128 , 256 , 32 , 8 , 1e-5 ],
110+ [181 , 360 , 91 , 180 , 1 , 10 , 1e-5 ],
111+ [128 , 256 , 256 , 512 , 32 , 8 , 1e-5 ],
112+ [ 91 , 180 , 181 , 360 , 1 , 10 , 1e-5 ],
113+ ],
114+ skip_on_empty = True ,
115+ )
107116 def test_distributed_spectral_conv (self , nlat_in , nlon_in , nlat_out , nlon_out , batch_size , num_chan , tol , verbose = True ):
108117 B , C , Hi , Wi , Ho , Wo = batch_size , num_chan , nlat_in , nlon_in , nlat_out , nlon_out
109118
@@ -146,7 +155,7 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
146155 reduction_buffer_count = 1 ,
147156 broadcast_buffers = False ,
148157 find_unused_parameters = False ,
149- gradient_as_bucket_view = True ,
158+ gradient_as_bucket_view = True ,
150159 static_graph = True ,
151160 verbose = False ,
152161 )
@@ -158,7 +167,6 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
158167 spect_conv_dist .module .bias .copy_ (spect_conv_local .bias )
159168
160169 # input
161- self ._init_seed (444 )
162170 inp_full = torch .randn ((B , C , Hi , Wi ), dtype = torch .float32 , device = self .device )
163171
164172 #############################################################
@@ -169,7 +177,6 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
169177 out_full , _ = spect_conv_local (inp_full )
170178
171179 # create grad for backward
172- self ._init_seed (555 )
173180 with torch .no_grad ():
174181 # create full grad
175182 ograd_full = torch .randn_like (out_full )
@@ -237,6 +244,163 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b
237244 if verbose and (self .world_rank == 0 ):
238245 print (f"final relative error of bias gradients: { err .item ()} " )
239246 self .assertTrue (err .item () <= tol )
247+
248+
249+ @parameterized .expand (
250+ [
251+ [181 , 360 , 1 , 4 , 1e-5 , "equiangular" , True ],
252+ [181 , 360 , 1 , 4 , 1e-5 , "equiangular" , False ],
253+ [180 , 360 , 1 , 10 , 1e-5 , "legendre-gauss" , True ],
254+ [180 , 360 , 1 , 10 , 1e-5 , "legendre-gauss" , False ],
255+ ],
256+ skip_on_empty = True ,
257+ )
258+ def test_distributed_geometric_instance_norm_s2 (self , nlat , nlon , batch_size , num_chan , tol , grid_type , affine , verbose = True ):
259+ B , C , H , W = batch_size , num_chan , nlat , nlon
260+
261+ # set up layer norm parameters
262+ img_shape = (H , W )
263+ crop_shape = (H , W )
264+ crop_offset = (0 , 0 )
265+ pole_mask = 0
266+ eps = 1e-5
267+
268+ self ._init_seed (333 )
269+
270+ # create local (serial) layer norm
271+ norm_local = GeometricInstanceNormS2 (
272+ img_shape = img_shape ,
273+ crop_shape = crop_shape ,
274+ crop_offset = crop_offset ,
275+ grid_type = grid_type ,
276+ pole_mask = pole_mask ,
277+ num_features = C ,
278+ eps = eps ,
279+ affine = affine ,
280+ ).to (self .device )
281+
282+ # create distributed layer norm
283+ norm_dist = DistributedGeometricInstanceNormS2 (
284+ img_shape = img_shape ,
285+ crop_shape = crop_shape ,
286+ crop_offset = crop_offset ,
287+ grid_type = grid_type ,
288+ pole_mask = pole_mask ,
289+ num_features = C ,
290+ eps = eps ,
291+ affine = affine ,
292+ ).to (self .device )
293+
294+ # set up gradient reduction hooks for distributed version
295+ if affine :
296+ norm_dist = init_gradient_reduction_hooks (
297+ norm_dist ,
298+ device = self .device ,
299+ reduction_buffer_count = 1 ,
300+ broadcast_buffers = False ,
301+ find_unused_parameters = False ,
302+ gradient_as_bucket_view = True ,
303+ static_graph = True ,
304+ verbose = False ,
305+ )
306+ norm_dist_handle = norm_dist .module
307+
308+ #make sure weights are the same if affine=True
309+ if affine :
310+ with torch .no_grad ():
311+ norm_dist .module .weight .copy_ (norm_local .weight )
312+ norm_dist .module .bias .copy_ (norm_local .bias )
313+
314+ # input
315+ inp_full = torch .randn ((B , C , H , W ), dtype = torch .float32 , device = self .device )
316+
317+ #############################################################
318+ # local (serial) transform
319+ #############################################################
320+ # FWD pass
321+ inp_full .requires_grad = True
322+ out_full = norm_local (inp_full )
323+
324+ # create grad for backward
325+ with torch .no_grad ():
326+ # create full grad
327+ ograd_full = torch .randn_like (out_full )
328+
329+ # BWD pass
330+ out_full .backward (ograd_full )
331+ igrad_full = inp_full .grad .clone ()
332+
333+ if affine :
334+ wgrad_full = norm_local .weight .grad .clone ()
335+ bgrad_full = norm_local .bias .grad .clone ()
336+
337+ #############################################################
338+ # distributed transform
339+ #############################################################
340+ # FWD pass
341+ inp_local = self ._split_helper (inp_full , hdim = - 2 , wdim = - 1 )
342+ inp_local .requires_grad = True
343+ out_local = norm_dist (inp_local )
344+
345+ # BWD pass
346+ ograd_local = self ._split_helper (ograd_full , hdim = - 2 , wdim = - 1 )
347+ out_local .backward (ograd_local )
348+ igrad_local = inp_local .grad .clone ()
349+
350+ if affine :
351+ wgrad_local = norm_dist .module .weight .grad .clone ()
352+ bgrad_local = norm_dist .module .bias .grad .clone ()
353+
354+ #############################################################
355+ # evaluate FWD pass
356+ #############################################################
357+ with torch .no_grad ():
358+ out_gather_full = self ._gather_helper (out_local , hdim = - 2 , wdim = - 1 )
359+ err = fn .relative_error (out_gather_full , out_full )
360+ if verbose and (self .world_rank == 0 ):
361+ print (f"GeometricInstanceNormS2 forward relative error: { err .item ()} " )
362+ self .assertTrue (err .item () <= tol )
363+
364+ #############################################################
365+ # evaluate input grads
366+ #############################################################
367+ with torch .no_grad ():
368+ igrad_gather_full = self ._gather_helper (igrad_local , hdim = - 2 , wdim = - 1 )
369+ err = fn .relative_error (igrad_gather_full , igrad_full )
370+ if verbose and (self .world_rank == 0 ):
371+ print (f"GeometricInstanceNormS2 input grad relative error: { err .item ()} " )
372+ self .assertTrue (err .item () <= tol )
373+
374+ #############################################################
375+ # evaluate weight and bias grads
376+ #############################################################
377+ # weight gradients should be the same across all processes
378+ if affine :
379+ with torch .no_grad ():
380+ wgrad_gather_list = [torch .empty_like (wgrad_local ) for _ in range (self .world_size )]
381+ wgrad_gather_list [self .world_rank ] = wgrad_local
382+ dist .all_gather (wgrad_gather_list , wgrad_local , group = None )
383+ errs = []
384+ for wgrad_gather_full in wgrad_gather_list :
385+ errs .append (fn .relative_error (wgrad_gather_full , wgrad_full ))
386+ err = torch .mean (torch .stack (errs , dim = 0 ))
387+ if verbose and (self .world_rank == 0 ):
388+ print (f"GeometricInstanceNormS2 weight grad relative error: { err .item ()} " )
389+ self .assertTrue (err .item () <= tol )
390+
391+ # bias gradients should be the same across all processes
392+ if affine :
393+ with torch .no_grad ():
394+ bgrad_gather_list = [torch .empty_like (bgrad_local ) for _ in range (self .world_size )]
395+ bgrad_gather_list [self .world_rank ] = bgrad_local
396+ dist .all_gather (bgrad_gather_list , bgrad_local , group = None )
397+ errs = []
398+ for bgrad_gather_full in bgrad_gather_list :
399+ errs .append (fn .relative_error (bgrad_gather_full , bgrad_full ))
400+ err = torch .mean (torch .stack (errs , dim = 0 ))
401+ if verbose and (self .world_rank == 0 ):
402+ print (f"GeometricInstanceNormS2 bias grad relative error: { err .item ()} " )
403+ self .assertTrue (err .item () <= tol )
240404
241405
242406if __name__ == '__main__' :
0 commit comments