Skip to content

Commit 1a0ceb8

Browse files
committed
bring back explicit window creation / distruction
1 parent a549c7d commit 1a0ceb8

File tree

3 files changed

+42
-10
lines changed

3 files changed

+42
-10
lines changed

mpisppy/cylinders/spcommunicator.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,12 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
146146
# on the problem data
147147
self._field_lengths = FieldLengths(self.opt)
148148

149+
self.window = None
150+
149151
# attach the SPCommunicator to
150152
# the SPBase object
151153
self.opt.spcomm = self
152154

153-
self.register_send_fields()
154-
155-
self._make_windows()
156-
self._create_field_rank_mappings()
157-
158-
self.register_receive_fields()
159-
160155
return
161156

162157
def _make_key(self, field: Field, origin: int):
@@ -273,13 +268,38 @@ def hub_finalize(self):
273268
def allreduce_or(self, val):
274269
return self.opt.allreduce_or(val)
275270

276-
def _make_windows(self) -> None:
271+
def make_windows(self) -> None:
272+
""" Make MPI windows: blocking call for all ranks in `strata_comm`.
273+
"""
274+
275+
if self.window is not None:
276+
return
277+
278+
self.register_send_fields()
277279

278280
window_spec = self._build_window_spec()
279281
self.window = SPWindow(window_spec, self.strata_comm)
280282

283+
self._create_field_rank_mappings()
284+
self.register_receive_fields()
285+
281286
return
282287

288+
def free_windows(self) -> None:
289+
""" Free MPI windows: blocking call for all ranks in `strata_comm`.
290+
"""
291+
292+
if self.window is None:
293+
return
294+
295+
self.receive_buffers = {}
296+
self.send_buffers = {}
297+
self.receive_field_spcomms = {}
298+
299+
self.window.free()
300+
301+
self.window = None
302+
283303
def is_send_field_registered(self, field: Field) -> bool:
284304
return field in self.send_buffers
285305

mpisppy/cylinders/spwindow.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import numpy.typing as nptyping
1414

1515
import enum
16-
import weakref
1716

1817
import pyomo.environ as pyo
1918

@@ -112,7 +111,6 @@ def __init__(self, my_fields: dict, strata_comm: MPI.Comm, field_order=None):
112111
self.buffer_length = total_buffer_length
113112
self.window = MPI.Win.Allocate(window_size_bytes, MPI.DOUBLE.size, comm=strata_comm)
114113
# ensure the memory allocated for the window is freed
115-
self._window_finalizer = weakref.finalize(self, self.window.Free)
116114
self.buff = np.ndarray(dtype="d", shape=(total_buffer_length,), buffer=self.window.tomemory())
117115
self.buff[:] = np.nan
118116

@@ -126,6 +124,17 @@ def __init__(self, my_fields: dict, strata_comm: MPI.Comm, field_order=None):
126124

127125
return
128126

127+
def free(self):
128+
if self.window is not None:
129+
self.window.Free()
130+
self.buff = None
131+
self.buffer_layout = None
132+
self.buffer_length = 0
133+
self.window = None
134+
self.strata_buffer_layouts = None
135+
self.window = None
136+
return
137+
129138
#### Functions ####
130139
def get(self, dest: nptyping.ArrayLike, strata_rank: int, field: Field):
131140

mpisppy/spin_the_wheel.py

+3
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def run(self, comm_world=None):
127127
spcomm = sp_class(opt, fullcomm, strata_comm, cylinder_comm,
128128
communicator_list, **sp_kwargs)
129129

130+
spcomm.make_windows()
131+
130132
# Run main()
131133
if strata_rank == 0:
132134
spcomm.setup_hub()
@@ -148,6 +150,7 @@ def run(self, comm_world=None):
148150

149151
## give the hub the chance to catch new values
150152
spcomm.hub_finalize()
153+
spcomm.free_windows()
151154

152155
fullcomm.Barrier()
153156
global_toc("Finalize Complete")

0 commit comments

Comments
 (0)