Skip to content

Commit 7f5c4c0

Browse files
authored
Merge branch 'main' into main
2 parents 3f64dd3 + 6d1f13f commit 7f5c4c0

12 files changed

+190
-103
lines changed

Diff for: .github/workflows/test_pr_and_main.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ jobs:
407407
run: |
408408
cd mpisppy/tests
409409
python test_gradient_rho.py
410-
python test_w_writer.py
410+
python test_xbar_w_reader_writer.py
411411
412412
test-headers:
413413
name: header test

Diff for: examples/farmer/farmer_cylinders.py

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def main():
116116
ph_converger=ph_converger,
117117
rho_setter = rho_setter)
118118

119+
hub_dict['opt_kwargs']['options']['cfg'] = cfg
120+
119121
if cfg.primal_dual_converger:
120122
hub_dict['opt_kwargs']['options']\
121123
['primal_dual_converger_options'] = {

Diff for: examples/generic_cylinders.bash

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
SOLVER="cplex"
55
SPB=1
66

7+
echo "^^^ hub only with w-writer (smoke) ^^^"
8+
python -m mpi4py ../mpisppy/generic_cylinders.py --module-name farmer/farmer --num-scens 3 --solver-name ${SOLVER} --max-iterations 10 --max-solver-threads 4 --default-rho 1 --W-writer --W-fname w_values.csv
9+
10+
echo "^^^ hub only with w-reader (smoke) ^^^"
11+
python -m mpi4py ../mpisppy/generic_cylinders.py --module-name farmer/farmer --num-scens 3 --solver-name ${SOLVER} --max-iterations 10 --max-solver-threads 4 --default-rho 1 --W-reader --init-W-fname w_values.csv
12+
713
echo "^^^ Multi-stage AirCond ^^^"
814
mpiexec -np 3 python -m mpi4py ../mpisppy/generic_cylinders.py --module-name mpisppy.tests.examples.aircond --branching-factors "3 3 3" --solver-name ${SOLVER} --max-iterations 10 --max-solver-threads 4 --default-rho 1 --lagrangian --xhatxbar --rel-gap 0.01 --solution-base-name aircond_nonants
915
# --xhatshuffle --stag2EFsolvern

Diff for: mpisppy/cylinders/hub.py

+20
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,26 @@ def sync(self):
603603
def sync_with_spokes(self):
604604
self.sync()
605605

606+
def sync_bounds(self):
607+
if self.has_outerbound_spokes:
608+
self.receive_outerbounds()
609+
if self.has_innerbound_spokes:
610+
self.receive_innerbounds()
611+
if self.has_bounds_only_spokes:
612+
self.send_boundsout()
613+
614+
def sync_extensions(self):
615+
if self.opt.extensions is not None:
616+
self.opt.extobject.sync_with_spokes()
617+
618+
def sync_nonants(self):
619+
if self.has_nonant_spokes:
620+
self.send_nonants()
621+
622+
def sync_Ws(self):
623+
if self.has_w_spokes:
624+
self.send_ws()
625+
606626
def is_converged(self):
607627
if self.opt.best_bound_obj_val is not None:
608628
self.BestOuterBound = self.OuterBoundUpdate(self.opt.best_bound_obj_val)

Diff for: mpisppy/generic_cylinders.py

+50-3
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,31 @@
1414
import shutil
1515
import numpy as np
1616
import pyomo.environ as pyo
17+
import pyomo.common.config as pyofig
18+
1719
from mpisppy.spin_the_wheel import WheelSpinner
20+
1821
import mpisppy.utils.cfg_vanilla as vanilla
1922
import mpisppy.utils.config as config
2023
import mpisppy.utils.sputils as sputils
24+
2125
from mpisppy.convergers.norm_rho_converger import NormRhoConverger
2226
from mpisppy.convergers.primal_dual_converger import PrimalDualConverger
27+
2328
from mpisppy.extensions.extension import MultiExtension
2429
from mpisppy.extensions.fixer import Fixer
2530
from mpisppy.extensions.mipgapper import Gapper
2631
from mpisppy.extensions.gradient_extension import Gradient_extension
2732
from mpisppy.extensions.scenario_lpfiles import Scenario_lpfiles
33+
34+
from mpisppy.utils.wxbarwriter import WXBarWriter
35+
from mpisppy.utils.wxbarreader import WXBarReader
36+
2837
import mpisppy.utils.solver_spec as solver_spec
38+
2939
from mpisppy import global_toc
3040
from mpisppy import MPI
3141

32-
3342
def _parse_args(m):
3443
# m is the model file module
3544
cfg = config.Config()
@@ -82,6 +91,18 @@ def _parse_args(m):
8291
cfg.coeff_rho_args()
8392
cfg.sensi_rho_args()
8493
cfg.reduced_costs_rho_args()
94+
95+
cfg.add_to_config("user_defined_extensions",
96+
description="Space-delimited module names for user extensions",
97+
domain=pyofig.ListOf(str),
98+
default=None)
99+
# TBD - think about adding directory for json options files
100+
101+
cfg.add_to_config("hub_and_spoke_dict_callback",
102+
description="[FOR EXPERTS ONLY] Module that contains the function hub_and_spoke_dict_callback that will be passed the hubdict and list of spokedicts prior to spin-the-wheel (last chance for intervention)",
103+
domain=str,
104+
default=None)
105+
85106
cfg.parse_command_line(f"mpi-sppy for {cfg.module_name}")
86107

87108
cfg.checker() # looks for inconsistencies
@@ -163,7 +184,11 @@ def _do_decomp(module, cfg, scenario_creator, scenario_creator_kwargs, scenario_
163184
rho_setter = rho_setter,
164185
all_nodenames = all_nodenames,
165186
)
166-
187+
188+
# the intent of the following is to transition to strictly
189+
# cfg-based option passing, as opposed to dictionary-based processing.
190+
hub_dict['opt_kwargs']['options']['cfg'] = cfg
191+
167192
# Extend and/or correct the vanilla dictionary
168193
ext_classes = list()
169194
# TBD: add cross_scenario_cuts, which also needs a cylinder
@@ -198,6 +223,24 @@ def _do_decomp(module, cfg, scenario_creator, scenario_creator_kwargs, scenario_
198223
if cfg.scenario_lpfiles:
199224
ext_classes.append(Scenario_lpfiles)
200225

226+
if cfg.W_and_xbar_reader:
227+
ext_classes.append(WXBarReader)
228+
229+
if cfg.W_and_xbar_writer:
230+
ext_classes.append(WXBarWriter)
231+
232+
if cfg.user_defined_extensions is not None:
233+
for ext_name in cfg.user_defined_extensions:
234+
module = sputils.module_name_to_module(ext_name)
235+
vanilla.extension_adder(module)
236+
# grab JSON for this module's option dictionary
237+
json_filename = ext_name+".json"
238+
if os.path.exists(json_filename):
239+
ext_options= json.load(json_filename)
240+
hub_dict['opt_kwargs']['options'][ext_name] = ext_options
241+
else:
242+
raise RuntimeError(f"JSON options file {json_filename} for user defined extension not found")
243+
201244
if cfg.sep_rho:
202245
vanilla.add_sep_rho(hub_dict, cfg)
203246

@@ -322,7 +365,11 @@ def _do_decomp(module, cfg, scenario_creator, scenario_creator_kwargs, scenario_
322365
list_of_spoke_dict.append(xhatxbar_spoke)
323366
if cfg.reduced_costs:
324367
list_of_spoke_dict.append(reduced_costs_spoke)
325-
368+
369+
# if the user dares, let them mess with the hubdict prior to solve
370+
if cfg.hub_and_spoke_dict_callback is not None:
371+
module = sputils.module_name_to_module(cfg.hub_and_spoke_dict_callback)
372+
module.hub_and_spoke_dict_callback(hub_dict, list_of_spoke_dict)
326373

327374
wheel = WheelSpinner(hub_dict, list_of_spoke_dict)
328375
wheel.spin()

Diff for: mpisppy/phbase.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ class PHBase(mpisppy.spopt.SPOpt):
239239
Function to set rho values throughout the PH algorithm.
240240
variable_probability (callable, optional):
241241
Function to set variable specific probabilities.
242-
cfg (config object, optional?) controls (mainly from user)
243-
(Maybe this should move up to spbase)
244242
245243
"""
246244
def __init__(
@@ -936,7 +934,10 @@ def _vb(msg):
936934
if self._can_update_best_bound():
937935
self.best_bound_obj_val = self.trivial_bound
938936

939-
if self.spcomm is not None:
937+
if hasattr(self.spcomm, "sync_nonants"):
938+
self.spcomm.sync_nonants()
939+
self.spcomm.sync_extensions()
940+
elif hasattr(self.spcomm, "sync"):
940941
self.spcomm.sync()
941942

942943
if have_extensions:
@@ -1003,7 +1004,7 @@ def iterk_loop(self):
10031004
self.conv = None
10041005

10051006
max_iterations = int(self.options["PHIterLimit"])
1006-
if self.spcomm is not None:
1007+
if hasattr(self.spcomm, "is_converged"):
10071008
# print a screen trace for iteration 0
10081009
if self.spcomm.is_converged():
10091010
global_toc("Cylinder convergence", self.cylinder_rank == 0)
@@ -1023,6 +1024,9 @@ def iterk_loop(self):
10231024
self.Update_W(verbose)
10241025
#global_toc('Rank: {} - After Update_W'.format(self.cylinder_rank), True)
10251026

1027+
if hasattr(self.spcomm, "sync_Ws"):
1028+
self.spcomm.sync_Ws()
1029+
10261030
if smoothed:
10271031
self.Update_z(verbose)
10281032

@@ -1067,11 +1071,15 @@ def iterk_loop(self):
10671071
if have_extensions:
10681072
self.extobject.enditer()
10691073

1070-
if self.spcomm is not None:
1071-
self.spcomm.sync()
1074+
if hasattr(self.spcomm, "sync_nonants"):
1075+
self.spcomm.sync_nonants()
1076+
self.spcomm.sync_bounds()
1077+
self.spcomm.sync_extensions()
10721078
if self.spcomm.is_converged():
10731079
global_toc("Cylinder convergence", self.cylinder_rank == 0)
10741080
break
1081+
elif hasattr(self.spcomm, "sync"):
1082+
self.spcomm.sync()
10751083

10761084
if have_extensions:
10771085
self.extobject.enditer_after_sync()

Diff for: mpisppy/tests/test_gradient_rho.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _create_ph_farmer(self):
5757
scenario_creator_kwargs = farmer.kw_creator(self.cfg)
5858
beans = (self.cfg, scenario_creator, scenario_denouement, all_scenario_names)
5959
hub_dict = vanilla.ph_hub(*beans, scenario_creator_kwargs=scenario_creator_kwargs)
60+
hub_dict['opt_kwargs']['options']['cfg'] = self.cfg
6061
list_of_spoke_dict = list()
6162
wheel = WheelSpinner(hub_dict, list_of_spoke_dict)
6263
wheel.spin()

Diff for: mpisppy/tests/test_w_writer.py renamed to mpisppy/tests/test_xbar_w_reader_writer.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,17 @@
2222
import mpisppy.tests.examples.farmer as farmer
2323
from mpisppy.spin_the_wheel import WheelSpinner
2424
from mpisppy.tests.utils import get_solver
25-
from mpisppy.utils.wxbarwriter import WXBarWriter
26-
from mpisppy.utils.wxbarreader import WXBarReader
25+
import mpisppy.utils.wxbarreader as wxbarreader
26+
import mpisppy.utils.wxbarwriter as wxbarwriter
2727

28-
29-
__version__ = 0.1
28+
__version__ = 0.2
3029

3130
solver_available,solver_name, persistent_available, persistent_solver_name= get_solver()
3231

3332
def _create_cfg():
3433
cfg = config.Config()
34+
wxbarreader.add_options_to_config(cfg)
35+
wxbarwriter.add_options_to_config(cfg)
3536
cfg.add_branching_factors()
3637
cfg.num_scens_required()
3738
cfg.popular_args()
@@ -43,7 +44,7 @@ def _create_cfg():
4344

4445
#*****************************************************************************
4546

46-
class Test_w_writer_farmer(unittest.TestCase):
47+
class Test_xbar_w_reader_writer_farmer(unittest.TestCase):
4748
""" Test the gradient code using farmer."""
4849

4950
def _create_ph_farmer(self, ph_extensions=None, max_iter=100):
@@ -59,14 +60,15 @@ def _create_ph_farmer(self, ph_extensions=None, max_iter=100):
5960
self.cfg.max_iterations = max_iter
6061
beans = (self.cfg, scenario_creator, scenario_denouement, all_scenario_names)
6162
hub_dict = vanilla.ph_hub(*beans, scenario_creator_kwargs=scenario_creator_kwargs, ph_extensions=ph_extensions)
62-
if ph_extensions==WXBarWriter: #tbd
63-
hub_dict['opt_kwargs']['options']["W_and_xbar_writer"] = {"Wcsvdir": "Wdir"}
64-
hub_dict['opt_kwargs']['options']['W_fname'] = self.temp_w_file_name
65-
hub_dict['opt_kwargs']['options']['Xbar_fname'] = self.temp_xbar_file_name
66-
if ph_extensions==WXBarReader:
67-
hub_dict['opt_kwargs']['options']["W_and_xbar_reader"] = {"Wcsvdir": "Wdir"}
68-
hub_dict['opt_kwargs']['options']['init_W_fname'] = self.w_file_name
69-
hub_dict['opt_kwargs']['options']['init_Xbar_fname'] = self.xbar_file_name
63+
hub_dict['opt_kwargs']['options']['cfg'] = self.cfg
64+
if ph_extensions==wxbarwriter.WXBarWriter:
65+
self.cfg.W_and_xbar_writer = True
66+
self.cfg.W_fname = self.temp_w_file_name
67+
self.cfg.Xbar_fname = self.temp_xbar_file_name
68+
if ph_extensions==wxbarreader.WXBarReader:
69+
self.cfg.W_and_xbar_reader = True
70+
self.cfg.init_W_fname = self.w_file_name
71+
self.cfg.init_Xbar_fname = self.xbar_file_name
7072
list_of_spoke_dict = list()
7173
wheel = WheelSpinner(hub_dict, list_of_spoke_dict)
7274
wheel.spin()
@@ -79,7 +81,7 @@ def setUp(self):
7981
self.ph_object = None
8082

8183
def test_wwriter(self):
82-
self.ph_object = self._create_ph_farmer(ph_extensions=WXBarWriter, max_iter=5)
84+
self.ph_object = self._create_ph_farmer(ph_extensions=wxbarwriter.WXBarWriter, max_iter=5)
8385
with open(self.temp_w_file_name, 'r') as f:
8486
read = csv.reader(f)
8587
rows = list(read)
@@ -88,7 +90,7 @@ def test_wwriter(self):
8890
os.remove(self.temp_w_file_name)
8991

9092
def test_xbarwriter(self):
91-
self.ph_object = self._create_ph_farmer(ph_extensions=WXBarWriter, max_iter=5)
93+
self.ph_object = self._create_ph_farmer(ph_extensions=wxbarwriter.WXBarWriter, max_iter=5)
9294
with open(self.temp_xbar_file_name, 'r') as f:
9395
read = csv.reader(f)
9496
rows = list(read)
@@ -97,15 +99,15 @@ def test_xbarwriter(self):
9799
os.remove(self.temp_xbar_file_name)
98100

99101
def test_wreader(self):
100-
self.ph_object = self._create_ph_farmer(ph_extensions=WXBarReader, max_iter=1)
102+
self.ph_object = self._create_ph_farmer(ph_extensions=wxbarreader.WXBarReader, max_iter=1)
101103
for sname, scenario in self.ph_object.local_scenarios.items():
102104
if sname == 'scen0':
103105
self.assertAlmostEqual(scenario._mpisppy_model.W[("ROOT", 1)]._value, 70.84705093609978)
104106
if sname == 'scen1':
105107
self.assertAlmostEqual(scenario._mpisppy_model.W[("ROOT", 0)]._value, -41.104251445950844)
106108

107109
def test_xbarreader(self):
108-
self.ph_object = self._create_ph_farmer(ph_extensions=WXBarReader, max_iter=1)
110+
self.ph_object = self._create_ph_farmer(ph_extensions=wxbarreader.WXBarReader, max_iter=1)
109111
for sname, scenario in self.ph_object.local_scenarios.items():
110112
if sname == 'scen0':
111113
self.assertAlmostEqual(scenario._mpisppy_model.xbars[("ROOT", 1)]._value, 274.2239371483933)

Diff for: mpisppy/utils/config.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -964,32 +964,12 @@ def tracking_args(self):
964964
domain=int,
965965
default=0)
966966

967-
968967
def wxbar_read_write_args(self):
969-
self.add_to_config("init_W_fname",
970-
description="Path of initial W file (default None)",
971-
domain=str,
972-
default=None)
973-
self.add_to_config("init_Xbar_fname",
974-
description="Path of initial Xbar file (default None)",
975-
domain=str,
976-
default=None)
977-
self.add_to_config("init_separate_W_files",
978-
description="If True, W is read from separate files (default False)",
979-
domain=bool,
980-
default=False)
981-
self.add_to_config("W_fname",
982-
description="Path of final W file (default None)",
983-
domain=str,
984-
default=None)
985-
self.add_to_config("Xbar_fname",
986-
description="Path of final Xbar file (default None)",
987-
domain=str,
988-
default=None)
989-
self.add_to_config("separate_W_files",
990-
description="If True, writes W to separate files (default False)",
991-
domain=bool,
992-
default=False)
968+
import mpisppy.utils.wxbarreader as wxbarreader
969+
wxbarreader.add_options_to_config(self)
970+
971+
import mpisppy.utils.wxbarwriter as wxbarwriter
972+
wxbarwriter.add_options_to_config(self)
993973

994974
def proper_bundle_config(self):
995975
self.add_to_config('pickle_bundles_dir',

0 commit comments

Comments
 (0)