Skip to content

Commit b179910

Browse files
authored
fixing distributed instance norm (#33)
* fixing distributed instance norm * adding test for euclidian distributed instance norm * cleaning up imports
1 parent 1302047 commit b179910

File tree

4 files changed

+351
-43
lines changed

4 files changed

+351
-43
lines changed

makani/models/common/layer_norm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,23 @@ def __init__(
5858

5959
# we only need the weights
6060
self.quadrature = GridQuadrature(
61-
quadrature_rule, img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, normalize=True, pole_mask=pole_mask, distributed=False
61+
quadrature_rule,
62+
img_shape=img_shape,
63+
crop_shape=crop_shape,
64+
crop_offset=crop_offset,
65+
normalize=True,
66+
pole_mask=pole_mask,
67+
distributed=False
6268
)
6369

6470
def forward(self, x: torch.Tensor) -> torch.Tensor:
6571

6672
# extract shapes
6773
B, C, H, W = x.shape
6874

75+
xtype = x.dtype
6976
with amp.autocast(device_type="cuda", enabled=False):
70-
dtype = x.dtype
71-
x = x.float()
77+
x = x.to(torch.float32)
7278

7379
# compute var and mean
7480
mean = self.quadrature(x)
@@ -79,9 +85,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7985
mean = mean.reshape(B, C, 1, 1)
8086

8187
# convert types
82-
x = x.to(dtype)
83-
mean = mean.to(dtype)
84-
var = var.to(dtype)
88+
x = x.to(xtype)
89+
mean = mean.to(xtype)
90+
var = var.to(xtype)
8591

8692
# apply the normalization
8793
if self.affine:

makani/mpu/layer_norm.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ def _welford_kernel(vars: torch.Tensor, means: torch.Tensor, counts: torch.Tenso
6666
# use Welford's algorithm to accumulate them into a single mean and variance
6767
for i in range(1, means.shape[0]):
6868
delta = means[i, ...] - mean
69+
mean = mean + delta * counts[i, ...] / (count + counts[i, ...])
6970
m2 = m2 + m2s[i, ...] + delta**2 * count * counts[i, ...] / (count + counts[i, ...])
70-
if i == 1:
71-
mean = (mean * count + means[i, ...] * counts[i, ...]) / (count + counts[i, ...])
72-
else:
73-
mean = mean + delta * counts[i, ...] / (count + counts[i, ...])
7471

7572
# update the current count
7673
count = count + counts[i, ...]
@@ -122,7 +119,7 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
122119
"""Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
123120

124121
# extract shapes
125-
B, C, H, W = x.shape
122+
B, C, _, _ = x.shape
126123

127124
# those have the shapes [B, C]
128125
var, mean = torch.var_mean(x, dim=(-2, -1), unbiased=False, keepdim=False)
@@ -141,9 +138,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
141138

142139
def forward(self, x: torch.Tensor) -> torch.Tensor:
143140

141+
xtype = x.dtype
144142
with amp.autocast(device_type="cuda", enabled=False):
145-
dtype = x.dtype
146-
x = x.float()
143+
x = x.to(torch.float32)
147144

148145
# start by computing std and mean
149146
var, mean = self._stats_welford(x)
@@ -152,9 +149,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152149
mean = copy_to_parallel_region(mean, "spatial")
153150
var = copy_to_parallel_region(var, "spatial")
154151

155-
x = x.to(dtype)
156-
mean = mean.to(dtype)
157-
var = var.to(dtype)
152+
x = x.to(xtype)
153+
mean = mean.to(xtype)
154+
var = var.to(xtype)
158155

159156
# apply the normalization
160157
if self.affine:
@@ -188,7 +185,13 @@ def __init__(
188185

189186
# we only need the weights
190187
quad_weight = GridQuadrature(
191-
quadrature_rule, img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, normalize=True, pole_mask=pole_mask, distributed=True
188+
quadrature_rule,
189+
img_shape=img_shape,
190+
crop_shape=crop_shape,
191+
crop_offset=crop_offset,
192+
normalize=True,
193+
pole_mask=pole_mask,
194+
distributed=True
192195
).quad_weight
193196

194197
self.register_buffer("quad_weight", quad_weight, persistent=False)
@@ -197,12 +200,12 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
197200
"""Computes the statistics locally, then uses the Welford online algorithm to reduce them"""
198201

199202
# extract shapes
200-
B, C, H, W = x.shape
203+
B, C, _, _ = x.shape
201204

202205
# compute var, mean locally: those have the shapes [B, C]
203-
mean = torch.sum(x * self.quad_weight, dim=(-2, -1), keepdim=False)
204-
var = torch.sum(torch.square(x - mean.reshape(B, C, 1, 1)) * self.quad_weight, dim=(-2, -1), keepdim=False)
205206
count = torch.tile(torch.sum(self.quad_weight, dim=(-2, -1), keepdim=False), (B, C))
207+
mean = torch.sum(x * self.quad_weight, dim=(-2, -1), keepdim=False) / count
208+
var = torch.sum(torch.square(x - mean.reshape(B, C, 1, 1)) * self.quad_weight, dim=(-2, -1), keepdim=False) / count
206209

207210
# compute welford variance
208211
var, mean, _ = distributed_welford_variance(var, mean, count, "spatial")
@@ -215,9 +218,9 @@ def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
215218

216219
def forward(self, x: torch.Tensor) -> torch.Tensor:
217220

221+
xtype = x.dtype
218222
with amp.autocast(device_type="cuda", enabled=False):
219-
dtype = x.dtype
220-
x = x.float()
223+
x = x.to(torch.float32)
221224

222225
# start by computing std and mean
223226
var, mean = self._stats_welford(x)
@@ -226,9 +229,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226229
mean = copy_to_parallel_region(mean, "spatial")
227230
var = copy_to_parallel_region(var, "spatial")
228231

229-
x = x.to(dtype)
230-
mean = mean.to(dtype)
231-
var = var.to(dtype)
232+
x = x.to(xtype)
233+
mean = mean.to(xtype)
234+
var = var.to(xtype)
232235

233236
# apply the normalization
234237
if self.affine:

makani/utils/metrics/functions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional, Tuple, List
17-
from dataclasses import dataclass
16+
from typing import Optional, Tuple
1817

19-
import math
2018
import torch
2119

22-
from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature
2320
from makani.utils import comm
2421
from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region
2522
from physicsnemo.distributed.utils import split_tensor_along_dim

0 commit comments

Comments
 (0)