Skip to content

Commit cbe1113

Browse files
committed
adding method to update nonant bounds; call that method where appropriate
1 parent 0e680a9 commit cbe1113

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

Diff for: mpisppy/cylinders/fwph_spoke.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def sync(self):
2727
# Tell the hub about the most recent bound
2828
self.bound = self.opt._local_bound
2929

30+
# Update the nonant bounds, if possible
31+
self.update_nonant_bounds()
32+
3033
def finalize(self):
3134
# The FWPH spoke can call "finalize" before it
3235
# even starts doing anything, so its possible

Diff for: mpisppy/cylinders/hub.py

+3
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def sync(self):
311311
self.update_receive_buffers()
312312
self.update_outerbounds()
313313
self.update_innerbounds()
314+
self.update_nonant_bounds()
314315
if self.opt.extensions is not None:
315316
self.opt.extobject.sync_with_spokes()
316317

@@ -321,6 +322,7 @@ def sync_bounds(self):
321322
self.update_receive_buffers()
322323
self.update_outerbounds()
323324
self.update_innerbounds()
325+
self.update_nonant_bounds()
324326
self.send_boundsout()
325327

326328
def sync_extensions(self):
@@ -424,6 +426,7 @@ def sync(self, send_nonants=True):
424426
self.update_receive_buffers()
425427
self.update_innerbounds()
426428
self.update_outerbounds()
429+
self.update_nonant_bounds()
427430
# in case LShaped ever gets extensions
428431
if getattr(self.opt, "extensions", None) is not None:
429432
self.opt.extobject.sync_with_spokes()

Diff for: mpisppy/cylinders/lagrangian_bounder.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def lagrangian_prep(self):
1919
self.opt._create_solvers()
2020

2121
def lagrangian(self, need_solution=True):
22+
# update the nonant bounds, if possible, for a tighter relaxation
23+
self.update_nonant_bounds()
2224
verbose = self.opt.options['verbose']
2325
# This is sort of a hack, but might help folks:
2426
if "ipopt" in self.opt.options["solver_name"]:

Diff for: mpisppy/cylinders/spcommunicator.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class SPCommunicator:
116116
or expects to receive from another SPCommunicator object.
117117
"""
118118
send_fields = ()
119-
receive_fields = ()
119+
receive_fields = (Field.NONANT_LOWER_BOUNDS, Field.NONANT_UPPER_BOUNDS,)
120120

121121
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
122122
self.fullcomm = fullcomm
@@ -138,7 +138,7 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
138138
# Common fields for spokes and hubs
139139
self.receive_buffers = {}
140140
self.send_buffers = {}
141-
# key: Field, value: list of (strata_rank, SPComm) with that Field
141+
# key: Field, value: list of (strata_rank, SPComm, buffer) with that Field
142142
self.receive_field_spcomms = {}
143143

144144
# setup FieldLengths which calculates
@@ -378,3 +378,27 @@ def update_receive_buffers(self):
378378
self.get_receive_buffer(recv_buf, field, rank)
379379
## End for
380380
return
381+
382+
def update_nonant_bounds(self):
383+
""" update the bounds on the nonanticipative variables based on
384+
Field.NONANT_LOWER_BOUNDS and Field.NONANT_UPPER_BOUNDS. The lower and
385+
upper bound buffers should be up-to-date, which can be done by calling
386+
`SPCommunicator.update_receive_buffers`.
387+
"""
388+
_INF = float("inf")
389+
for _, _, recv_buf in self.receive_field_spcomms[Field.NONANT_LOWER_BOUNDS]:
390+
for s in self.opt.local_scenarios.items():
391+
for ci, (ndn_i, xvar) in enumerate(s._mpisppy_data.nonant_indices.items()):
392+
xvarlb = xvar.lb
393+
if xvarlb is None:
394+
xvarlb = -_INF
395+
if recv_buf[ci] > xvarlb:
396+
xvar.lb = recv_buf[ci]
397+
for _, _, recv_buf in self.receive_field_spcomms[Field.NONANT_UPPER_BOUNDS]:
398+
for s in self.opt.local_scenarios.items():
399+
for ci, (ndn_i, xvar) in enumerate(s._mpisppy_data.nonant_indices.items()):
400+
xvarub = xvar.ub
401+
if xvarub is None:
402+
xvarub = _INF
403+
if recv_buf[ci] < xvarub:
404+
xvar.ub = recv_buf[ci]

0 commit comments

Comments
 (0)