Skip to content

Commit aa10918

Browse files
allow theta=pi in s2_distribute (#7)
1 parent 6c69f89 commit aa10918

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

mpipartition/spherical_partition/s2_distribute.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def s2_distribute(
6060

6161
# verify data is normalized
6262
assert np.all(data[theta_key] >= 0)
63-
assert np.all(data[theta_key] < np.pi)
63+
assert np.all(data[theta_key] <= np.pi)
6464
assert np.all(data[phi_key] >= 0)
6565
assert np.all(data[phi_key] < 2 * np.pi)
6666

@@ -76,6 +76,7 @@ def s2_distribute(
7676
(data[theta_key] - partition.theta_cap) // partition.ring_dtheta
7777
).astype(np.int32) + 1
7878
ring_idx = np.clip(ring_idx, 0, len(partition.ring_segments) + 1)
79+
ring_idx[data[theta_key] == np.pi] -= 1 # handle cases where theta == pi
7980

8081
phi_idx = np.zeros_like(ring_idx, dtype=np.int32)
8182
mask_is_on_ring = (ring_idx > 0) & (ring_idx <= len(partition.ring_segments))
@@ -156,7 +157,11 @@ def s2_distribute(
156157

157158
if validate_home:
158159
assert np.all(data_new[theta_key] >= partition.theta_extent[0])
159-
assert np.all(data_new[theta_key] < partition.theta_extent[1])
160+
if partition.theta_extent[1] < np.pi:
161+
assert np.all(data_new[theta_key] < partition.theta_extent[1])
162+
else:
163+
# bottom cap, we allow theta == pi
164+
assert np.all(data_new[theta_key] <= partition.theta_extent[1])
160165
assert np.all(data_new[phi_key] >= partition.phi_extent[0])
161166
assert np.all(data_new[phi_key] < partition.phi_extent[1])
162167

0 commit comments

Comments
 (0)