Skip to content

Commit ee3e55a

Browse files
committed
fixing welford combine for spectra calculation
1 parent 2577771 commit ee3e55a

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

data_process/data_process_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,13 @@ def welford_combine(stats1, stats2):
132132
if (s_b["counts"].ndim != 0) and (s_a["counts"].ndim != s_a["values"].ndim):
133133
n_a = s_a["counts"][None, :, None, None]
134134
n_b = s_b["counts"][None, :, None, None]
135+
reshape = True
135136
else:
136137
n_a = s_a["counts"]
137138
n_b = s_b["counts"]
139+
reshape = False
140+
141+
# combined counts
138142
n_ab = n_a + n_b
139143

140144
if s_a["type"] == "min":
@@ -157,7 +161,10 @@ def welford_combine(stats1, stats2):
157161
], dim=0
158162
).contiguous()
159163

160-
stats[k] = {"counts": n_ab.reshape(-1),
164+
if reshape:
165+
n_ab = n_ab.reshape(-1)
166+
167+
stats[k] = {"counts": n_ab,
161168
"type": s_a["type"],
162169
"values": values}
163170

data_process/get_spectra.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
from wb2_helpers import DistributedProgressBar
3737
from data_process_helpers import welford_combine, collective_reduce, binary_reduce
3838

39+
@torch.compile(fullgraph=True)
3940
def compute_powerspectrum(x, sht):
40-
coeffs = sht(x).abs().pow(2)
41+
coeffs = torch.square(torch.abs(sht(x)))
4142
coeffs[..., 1:] *= 2.0
42-
power_spectrum = coeffs.sum(dim=-1)
43+
power_spectrum = torch.sum(coeffs, dim=-1)
4344

4445
return power_spectrum
4546

0 commit comments

Comments
 (0)