Skip to content

Commit 4851317

Browse files
committed
cleaning up nonant var cache
1 parent 4c54d73 commit 4851317

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

Diff for: mpisppy/opt/fwph.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def fw_prep(self):
8282
self.PH_Prep(attach_duals=True, attach_prox=False)
8383
self._output_header()
8484
self._attach_MIP_vars()
85+
self._cache_nonant_var_swap_mip()
8586

8687
trivial_bound = self.Iter0()
8788
secs = time.time() - self.t0
@@ -102,11 +103,9 @@ def fw_prep(self):
102103
self._attach_MIP_QP_maps()
103104
self._set_QP_objective()
104105
self._initialize_QP_var_values()
106+
self._cache_nonant_var_swap_qp()
105107
self._setup_shared_column_generation()
106108

107-
self._QP_nonants = {}
108-
self._MIP_nonants = {}
109-
110109
number_initial_column_tries = self.options.get("FW_initialization_attempts", 20)
111110
if self.FW_options["FW_iter_limit"] == 1 and number_initial_column_tries < 1:
112111
global_toc(f"{self.__class__.__name__}: Warning: FWPH needs an initial shared column if FW_iter_limit == 1. Increasing FW_iter_limit to 2 to ensure convergence")
@@ -198,6 +197,7 @@ def fwph_main(self, finalize=True):
198197
secs = time.time() - self.t0
199198
self._output(self._local_bound, best_bound, diff, secs)
200199

200+
201201
# add a shared column
202202
shared_columns = self.options.get("FWPH_shared_columns_per_iteration", 1)
203203
if shared_columns > 0:
@@ -907,6 +907,7 @@ def _set_QP_objective(self):
907907

908908
def _cache_nonant_var_swap_mip(self):
909909
""" cache the lists used for the nonant var swap """
910+
self._MIP_nonants = {}
910911

911912
# MIP nonants
912913
for k, s in self.local_scenarios.items():
@@ -923,6 +924,22 @@ def _cache_nonant_var_swap_mip(self):
923924
def _cache_nonant_var_swap_qp(self):
924925
""" cache the lists used for the nonant var swap """
925926

927+
for (name, model) in self.local_subproblems.items():
928+
scens = model.scen_list if self.bundling else [name]
929+
for scenario_name in scens:
930+
scenario = self.local_scenarios[scenario_name]
931+
num_nonant_vars = scenario._mpisppy_data.nlens
932+
node_list = scenario._mpisppy_node_list
933+
for node in node_list:
934+
node.nonant_vardata_list = [
935+
self.local_QP_subproblems[name].xr[node.name,i]
936+
if self.bundling else
937+
self.local_QP_subproblems[name].x[node.name,i]
938+
for i in range(num_nonant_vars[node.name])]
939+
self._attach_nonant_indices()
940+
941+
self._QP_nonants = {}
942+
926943
# QP nonants
927944
for k, s in self.local_scenarios.items():
928945
nonant_vardata_lists = {}
@@ -935,6 +952,8 @@ def _cache_nonant_var_swap_qp(self):
935952
"all_surrogate_nonants" : s._mpisppy_data.all_surrogate_nonants,
936953
}
937954

955+
self._swap_nonant_vars_back()
956+
938957
def _swap_nonant_vars(self):
939958
''' Change the pointers in
940959
scenario._mpisppy_node_list[i].nonant_vardata_list
@@ -952,28 +971,11 @@ def _swap_nonant_vars(self):
952971
953972
Updates nonant_vardata_list but NOT nonant_list.
954973
'''
955-
if not self._QP_nonants:
956-
self._cache_nonant_var_swap_mip()
957-
for (name, model) in self.local_subproblems.items():
958-
scens = model.scen_list if self.bundling else [name]
959-
for scenario_name in scens:
960-
scenario = self.local_scenarios[scenario_name]
961-
num_nonant_vars = scenario._mpisppy_data.nlens
962-
node_list = scenario._mpisppy_node_list
963-
for node in node_list:
964-
node.nonant_vardata_list = [
965-
self.local_QP_subproblems[name].xr[node.name,i]
966-
if self.bundling else
967-
self.local_QP_subproblems[name].x[node.name,i]
968-
for i in range(num_nonant_vars[node.name])]
969-
self._attach_nonant_indices()
970-
self._cache_nonant_var_swap_qp()
971-
else:
972-
for s, nonant_data in self._QP_nonants.items():
973-
for node in s._mpisppy_node_list:
974-
node.nonant_vardata_list = nonant_data["nonant_vardata_lists"][node.name]
975-
s._mpisppy_data.nonant_indices = nonant_data["nonant_indices"]
976-
s._mpisppy_data.all_surrogate_nonants = nonant_data["all_surrogate_nonants"]
974+
for s, nonant_data in self._QP_nonants.items():
975+
for node in s._mpisppy_node_list:
976+
node.nonant_vardata_list = nonant_data["nonant_vardata_lists"][node.name]
977+
s._mpisppy_data.nonant_indices = nonant_data["nonant_indices"]
978+
s._mpisppy_data.all_surrogate_nonants = nonant_data["all_surrogate_nonants"]
977979

978980
def _swap_nonant_vars_back(self):
979981
''' Swap variables back, in case they're needed somewhere else.

0 commit comments

Comments
 (0)