@@ -66,11 +66,8 @@ def _welford_kernel(vars: torch.Tensor, means: torch.Tensor, counts: torch.Tenso
6666 # use Welford's algorithm to accumulate them into a single mean and variance
6767 for i in range (1 , means .shape [0 ]):
6868 delta = means [i , ...] - mean
69+ mean = mean + delta * counts [i , ...] / (count + counts [i , ...])
6970 m2 = m2 + m2s [i , ...] + delta ** 2 * count * counts [i , ...] / (count + counts [i , ...])
70- if i == 1 :
71- mean = (mean * count + means [i , ...] * counts [i , ...]) / (count + counts [i , ...])
72- else :
73- mean = mean + delta * counts [i , ...] / (count + counts [i , ...])
7471
7572 # update the current count
7673 count = count + counts [i , ...]
@@ -122,7 +119,7 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
122119 """Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
123120
124121 # extract shapes
125- B , C , H , W = x .shape
122+ B , C , _ , _ = x .shape
126123
127124 # those have the shapes [B, C]
128125 var , mean = torch .var_mean (x , dim = (- 2 , - 1 ), unbiased = False , keepdim = False )
@@ -141,9 +138,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
141138
142139 def forward (self , x : torch .Tensor ) -> torch .Tensor :
143140
141+ xtype = x .dtype
144142 with amp .autocast (device_type = "cuda" , enabled = False ):
145- dtype = x .dtype
146- x = x .float ()
143+ x = x .to (torch .float32 )
147144
148145 # start by computing std and mean
149146 var , mean = self ._stats_welford (x )
@@ -152,9 +149,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152149 mean = copy_to_parallel_region (mean , "spatial" )
153150 var = copy_to_parallel_region (var , "spatial" )
154151
155- x = x .to (dtype )
156- mean = mean .to (dtype )
157- var = var .to (dtype )
152+ x = x .to (xtype )
153+ mean = mean .to (xtype )
154+ var = var .to (xtype )
158155
159156 # apply the normalization
160157 if self .affine :
@@ -188,7 +185,13 @@ def __init__(
188185
189186 # we only need the weights
190187 quad_weight = GridQuadrature (
191- quadrature_rule , img_shape = img_shape , crop_shape = crop_shape , crop_offset = crop_offset , normalize = True , pole_mask = pole_mask , distributed = True
188+ quadrature_rule ,
189+ img_shape = img_shape ,
190+ crop_shape = crop_shape ,
191+ crop_offset = crop_offset ,
192+ normalize = True ,
193+ pole_mask = pole_mask ,
194+ distributed = True
192195 ).quad_weight
193196
194197 self .register_buffer ("quad_weight" , quad_weight , persistent = False )
@@ -197,12 +200,12 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
197200 """Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
198201
199202 # extract shapes
200- B , C , H , W = x .shape
203+ B , C , _ , _ = x .shape
201204
202205 # compute var, mean locally: those have the shapes [B, C]
203- mean = torch .sum (x * self .quad_weight , dim = (- 2 , - 1 ), keepdim = False )
204- var = torch .sum (torch .square (x - mean .reshape (B , C , 1 , 1 )) * self .quad_weight , dim = (- 2 , - 1 ), keepdim = False )
205206 count = torch .tile (torch .sum (self .quad_weight , dim = (- 2 , - 1 ), keepdim = False ), (B , C ))
207+ mean = torch .sum (x * self .quad_weight , dim = (- 2 , - 1 ), keepdim = False ) / count
208+ var = torch .sum (torch .square (x - mean .reshape (B , C , 1 , 1 )) * self .quad_weight , dim = (- 2 , - 1 ), keepdim = False ) / count
206209
207210 # compute welford variance
208211 var , mean , _ = distributed_welford_variance (var , mean , count , "spatial" )
@@ -215,9 +218,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
215218
216219 def forward (self , x : torch .Tensor ) -> torch .Tensor :
217220
221+ xtype = x .dtype
218222 with amp .autocast (device_type = "cuda" , enabled = False ):
219- dtype = x .dtype
220- x = x .float ()
223+ x = x .to (torch .float32 )
221224
222225 # start by computing std and mean
223226 var , mean = self ._stats_welford (x )
@@ -226,9 +229,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226229 mean = copy_to_parallel_region (mean , "spatial" )
227230 var = copy_to_parallel_region (var , "spatial" )
228231
229- x = x .to (dtype )
230- mean = mean .to (dtype )
231- var = var .to (dtype )
232+ x = x .to (xtype )
233+ mean = mean .to (xtype )
234+ var = var .to (xtype )
232235
233236 # apply the normalization
234237 if self .affine :
0 commit comments