Skip to content

Commit 6e8d699

Browse files
Add typing and data-consistency checks when distributing/overloading arrays (#8)
* Add all2all_iterations parameter to distribute functions for improved data handling * add type checking and linting to github actions * github actions: update action versions * add tests for multiple iterations * assert consistent data across ranks in distribute * assert consistent data for overload and exchange
1 parent 2893f00 commit 6e8d699

23 files changed

+394
-294
lines changed

.flake8

Lines changed: 0 additions & 6 deletions
This file was deleted.

.github/workflows/pypi.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ jobs:
1414
steps:
1515

1616
- name: check out
17-
uses: actions/checkout@v2
17+
uses: actions/checkout@v4
1818
with:
1919
fetch-depth: 0
2020
lfs: true
2121

2222
- name: Setup Python
23-
uses: actions/setup-python@v2
23+
uses: actions/setup-python@v5
2424
with:
25-
python-version: 3.8
25+
python-version: 3.11
2626

2727
- name: Install MPI
2828
uses: mpi4py/setup-mpi@v1
@@ -38,7 +38,7 @@ jobs:
3838

3939
- name: Load cached venv
4040
id: cached-poetry-dependencies
41-
uses: actions/cache@v2
41+
uses: actions/cache@v4
4242
with:
4343
path: .venv
4444
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}

.github/workflows/sphinx.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ jobs:
1313
steps:
1414

1515
- name: check out
16-
uses: actions/checkout@v2
16+
uses: actions/checkout@v4
1717
with:
1818
fetch-depth: 0
1919
lfs: true
2020

2121
- name: Setup Python
22-
uses: actions/setup-python@v2
22+
uses: actions/setup-python@v5
2323
with:
24-
python-version: 3.8
24+
python-version: 3.11
2525

2626
- name: Install MPI
2727
uses: mpi4py/setup-mpi@v1
@@ -37,7 +37,7 @@ jobs:
3737

3838
- name: Load cached venv
3939
id: cached-poetry-dependencies
40-
uses: actions/cache@v2
40+
uses: actions/cache@v4
4141
with:
4242
path: .venv
4343
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}

.github/workflows/tests.yml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- name: Setup Python
1414
uses: actions/setup-python@v5
1515
with:
16-
python-version: "3.9"
16+
python-version: "3.11"
1717

1818
- name: Install MPI
1919
uses: mpi4py/setup-mpi@v1
@@ -29,17 +29,23 @@ jobs:
2929

3030
- name: Load cached venv
3131
id: cached-poetry-dependencies
32-
uses: actions/cache@v2
32+
uses: actions/cache@v4
3333
with:
3434
path: .venv
3535
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}
3636

3737
- name: Install dependencies
3838
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
39-
run: poetry install --no-interaction --no-root
39+
run: poetry install --no-interaction --no-root --all-extras
4040

4141
- name: Install library
42-
run: poetry install --no-interaction
42+
run: poetry install --no-interaction --all-extras
43+
44+
- name: MyPy Type checking
45+
run: poetry run mypy --config-file mypy.ini --show-error-codes
46+
47+
- name: Ruff Linting
48+
run: poetry run ruff check
4349

4450
- name: Run MPI tests 2 ranks
4551
run: |

docs/conf.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
# relative to the documentation root, use os.path.abspath to make it
1818
# absolute, like shown here.
1919
#
20-
import os, sys, shutil, subprocess
20+
import os
21+
import sys
2122
import re
2223
from pathlib import Path
24+
from typing import TYPE_CHECKING
2325

2426
DIR = Path(__file__).parent.resolve()
2527
sys.path.insert(0, os.path.abspath(".."))
2628

27-
import mpipartition
29+
import mpipartition # noqa: E402
2830

2931
# -- General configuration ---------------------------------------------
3032

@@ -128,7 +130,7 @@
128130

129131
# -- Options for LaTeX output ------------------------------------------
130132

131-
latex_elements = {
133+
latex_elements: dict[str, str] = {
132134
# The paper size ('letterpaper' or 'a4paper').
133135
#
134136
# 'papersize': 'letterpaper',
@@ -181,8 +183,11 @@
181183
),
182184
]
183185

186+
if TYPE_CHECKING:
187+
from sphinx.application import Sphinx
184188

185-
def prepare(app):
189+
190+
def prepare(app: Sphinx) -> None:
186191
with open(DIR.parent / "README.rst") as f:
187192
contents = f.read()
188193

@@ -194,11 +199,11 @@ def prepare(app):
194199
f.write(contents)
195200

196201

197-
def clean_up(app, exception):
202+
def clean_up(app: Sphinx, exception: Exception) -> None:
198203
(DIR / "readme.rst").unlink()
199204

200205

201-
def setup(app):
206+
def setup(app: Sphinx) -> None:
202207
app.add_css_file("css/custom.css")
203208
# Copy the readme in
204209
app.connect("builder-inited", prepare)

mpipartition/_send_home.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from typing import TYPE_CHECKING
4+
from mpi4py import MPI
5+
import sys
6+
7+
if TYPE_CHECKING:
8+
from .partition import Partition
9+
from .spherical_partition import S2Partition
10+
11+
ParticleDataT = dict[str, np.ndarray]
12+
13+
14+
def distribute_dataset_by_home(
15+
partition: Partition | S2Partition,
16+
data: ParticleDataT,
17+
home_idx: np.ndarray,
18+
*,
19+
verbose: int = 0,
20+
verify_count: bool = True,
21+
all2all_iterations: int = 1,
22+
) -> ParticleDataT:
23+
total_to_send = len(home_idx)
24+
nperiteration = total_to_send // all2all_iterations
25+
data_new_list: list[ParticleDataT] = []
26+
27+
# Some general assertions that every rank has valid data
28+
keys = list(data.keys())
29+
keys_0 = partition.comm.bcast(keys, root=0)
30+
assert len(keys) == len(keys_0)
31+
assert all(k in keys_0 for k in keys)
32+
dtype_string = "".join(data[k].dtype.char for k in keys_0)
33+
dtype_string_0 = partition.comm.bcast(dtype_string, root=0)
34+
assert dtype_string == dtype_string_0
35+
36+
for i in range(all2all_iterations):
37+
start_idx = i * nperiteration
38+
end_idx = (
39+
(i + 1) * nperiteration if i < all2all_iterations - 1 else total_to_send
40+
)
41+
if partition.rank == 0 and verbose > 0:
42+
print(f" - Distributing particles iteration {i + 1}/{all2all_iterations}")
43+
_data = {k: v[start_idx:end_idx] for k, v in data.items()}
44+
_home_idx = home_idx[start_idx:end_idx]
45+
_data = _distribute_dataset_by_home(
46+
partition,
47+
_data,
48+
_home_idx,
49+
keys_0,
50+
verbose=verbose,
51+
verify_count=verify_count,
52+
)
53+
data_new_list.append(_data)
54+
# concatenate the data
55+
data_new = {k: np.concatenate([d[k] for d in data_new_list]) for k in data.keys()}
56+
return data_new
57+
58+
59+
def _distribute_dataset_by_home(
60+
partition: Partition | S2Partition,
61+
data: ParticleDataT,
62+
home_idx: np.ndarray,
63+
keys: list[str],
64+
*,
65+
verbose: int = 0,
66+
verify_count: bool = True,
67+
) -> ParticleDataT:
68+
total_to_send = len(home_idx)
69+
for d in data.values():
70+
assert len(d) == total_to_send, "All data arrays must have the same length"
71+
72+
# sort by rank
73+
s = np.argsort(home_idx)
74+
home_idx = home_idx[s]
75+
76+
# offsets and counts
77+
send_displacements = np.searchsorted(home_idx, np.arange(partition.nranks))
78+
send_displacements = send_displacements.astype(np.int32)
79+
send_counts = np.append(send_displacements[1:], total_to_send) - send_displacements
80+
send_counts = send_counts.astype(np.int32)
81+
82+
# announce to each rank how many objects will be sent
83+
recv_counts = np.empty_like(send_counts)
84+
partition.comm.Alltoall(send_counts, recv_counts)
85+
recv_displacements = np.insert(np.cumsum(recv_counts)[:-1], 0, 0)
86+
87+
# number of objects that this rank will receive
88+
total_to_receive = np.sum(recv_counts)
89+
90+
# debug message
91+
if verbose > 1:
92+
for i in range(partition.nranks):
93+
if partition.rank == i:
94+
print(f"Distribute Debug Rank {i}")
95+
print(f" - rank has {total_to_send} particles")
96+
print(f" - rank receives {total_to_receive} particles")
97+
print(f" - send_counts: {send_counts}")
98+
print(f" - send_displacements: {send_displacements}")
99+
print(f" - recv_counts: {recv_counts}")
100+
print(f" - recv_displacements: {recv_displacements}")
101+
print("", flush=True)
102+
partition.comm.Barrier()
103+
104+
# send data all-to-all, each array individually
105+
data_new = {k: np.empty(total_to_receive, dtype=data[k].dtype) for k in data.keys()}
106+
107+
for k in keys:
108+
d = data[k][s]
109+
s_msg = [d, (send_counts, send_displacements), d.dtype.char]
110+
r_msg = [data_new[k], (recv_counts, recv_displacements), d.dtype.char]
111+
partition.comm.Alltoallv(s_msg, r_msg)
112+
113+
if verify_count:
114+
key0 = keys[0]
115+
local_counts = np.array([len(data[key0]), len(data_new[key0])], dtype=np.int64)
116+
global_counts = np.empty_like(local_counts)
117+
partition.comm.Reduce(local_counts, global_counts, op=MPI.SUM, root=0)
118+
if partition.rank == 0 and global_counts[0] != global_counts[1]:
119+
print(
120+
f"Error in distribute: particle count during distribute was not "
121+
f"maintained ({global_counts[0]} -> {global_counts[1]})",
122+
file=sys.stderr,
123+
flush=True,
124+
)
125+
partition.comm.Abort()
126+
127+
return data_new

mpipartition/distribute.py

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import sys
2-
from typing import List, Mapping, Union
2+
from typing import List, Union
33

44
import numpy as np
55

6-
from .partition import MPI, Partition
6+
from .partition import Partition
7+
from ._send_home import distribute_dataset_by_home
78

8-
ParticleDataT = Mapping[str, np.ndarray]
9+
ParticleDataT = dict[str, np.ndarray]
910

1011

1112
def distribute(
@@ -16,6 +17,7 @@ def distribute(
1617
*,
1718
verbose: Union[bool, int] = False,
1819
verify_count: bool = True,
20+
all2all_iterations: int = 1,
1921
) -> ParticleDataT:
2022
"""Distribute data among MPI ranks according to data position and volume partition
2123
@@ -46,6 +48,10 @@ def distribute(
4648
verify_count:
4749
If True, make sure that total number of objects is conserved
4850
51+
all2all_iterations:
52+
The number of iterations to use for the all-to-all communication.
53+
This is useful for large datasets, where MPI_Alltoallv may fail
54+
4955
Returns
5056
-------
5157
data: ParticleDataT
@@ -59,7 +65,7 @@ def distribute(
5965
if nranks == 1:
6066
return data
6167

62-
rank = partition.rank
68+
# rank = partition.rank
6369
comm = partition.comm
6470
dimensions = partition.dimensions
6571
ranklist = np.array(partition.ranklist)
@@ -92,59 +98,13 @@ def distribute(
9298
# there are no particles on this rank
9399
home_idx = np.empty(0, dtype=np.int32)
94100

95-
# sort by rank
96-
s = np.argsort(home_idx)
97-
home_idx = home_idx[s]
98-
99-
# offsets and counts
100-
send_displacements = np.searchsorted(home_idx, np.arange(nranks))
101-
send_displacements = send_displacements.astype(np.int32)
102-
send_counts = np.append(send_displacements[1:], total_to_send) - send_displacements
103-
send_counts = send_counts.astype(np.int32)
104-
105-
# announce to each rank how many objects will be sent
106-
recv_counts = np.empty_like(send_counts)
107-
comm.Alltoall(send_counts, recv_counts)
108-
recv_displacements = np.insert(np.cumsum(recv_counts)[:-1], 0, 0)
109-
110-
# number of objects that this rank will receive
111-
total_to_receive = np.sum(recv_counts)
112-
113-
# debug message
114-
if verbose > 1:
115-
for i in range(nranks):
116-
if rank == i:
117-
print(f"Distribute Debug Rank {i}")
118-
print(f" - rank has {total_to_send} particles")
119-
print(f" - rank receives {total_to_receive} particles")
120-
print(f" - send_counts: {send_counts}")
121-
print(f" - send_displacements: {send_displacements}")
122-
print(f" - recv_counts: {recv_counts}")
123-
print(f" - recv_displacements: {recv_displacements}")
124-
print(f"", flush=True)
125-
comm.Barrier()
126-
127-
# send data all-to-all, each array individually
128-
data_new = {k: np.empty(total_to_receive, dtype=data[k].dtype) for k in data.keys()}
129-
130-
for k in data.keys():
131-
d = data[k][s]
132-
s_msg = [d, (send_counts, send_displacements), d.dtype.char]
133-
r_msg = [data_new[k], (recv_counts, recv_displacements), d.dtype.char]
134-
comm.Alltoallv(s_msg, r_msg)
135-
136-
if verify_count:
137-
local_counts = np.array(
138-
[len(data[coord_keys[0]]), len(data_new[coord_keys[0]])], dtype=np.int64
139-
)
140-
global_counts = np.empty_like(local_counts)
141-
comm.Reduce(local_counts, global_counts, op=MPI.SUM, root=0)
142-
if rank == 0 and global_counts[0] != global_counts[1]:
143-
print(
144-
f"Error in distribute: particle count during distribute was not maintained ({global_counts[0]} -> {global_counts[1]})",
145-
file=sys.stderr,
146-
flush=True,
147-
)
148-
comm.Abort()
101+
data_new = distribute_dataset_by_home(
102+
partition,
103+
data,
104+
home_idx=home_idx,
105+
verbose=verbose,
106+
verify_count=verify_count,
107+
all2all_iterations=all2all_iterations,
108+
)
149109

150110
return data_new

0 commit comments

Comments
 (0)