2222import mpisppy .tests .examples .farmer as farmer
2323from mpisppy .spin_the_wheel import WheelSpinner
2424from mpisppy .tests .utils import get_solver
25- from mpisppy .utils .wxbarwriter import WXBarWriter
26- from mpisppy .utils .wxbarreader import WXBarReader
25+ import mpisppy .utils .wxbarreader as wxbarreader
26+ import mpisppy .utils .wxbarwriter as wxbarwriter
2727
28-
29- __version__ = 0.1
28+ __version__ = 0.2
3029
3130solver_available ,solver_name , persistent_available , persistent_solver_name = get_solver ()
3231
3332def _create_cfg ():
3433 cfg = config .Config ()
34+ wxbarreader .add_options_to_config (cfg )
35+ wxbarwriter .add_options_to_config (cfg )
3536 cfg .add_branching_factors ()
3637 cfg .num_scens_required ()
3738 cfg .popular_args ()
@@ -43,7 +44,7 @@ def _create_cfg():
4344
4445#*****************************************************************************
4546
46- class Test_w_writer_farmer (unittest .TestCase ):
47+ class Test_xbar_w_reader_writer_farmer (unittest .TestCase ):
4748 """ Test the gradient code using farmer."""
4849
4950 def _create_ph_farmer (self , ph_extensions = None , max_iter = 100 ):
@@ -59,14 +60,15 @@ def _create_ph_farmer(self, ph_extensions=None, max_iter=100):
5960 self .cfg .max_iterations = max_iter
6061 beans = (self .cfg , scenario_creator , scenario_denouement , all_scenario_names )
6162 hub_dict = vanilla .ph_hub (* beans , scenario_creator_kwargs = scenario_creator_kwargs , ph_extensions = ph_extensions )
62- if ph_extensions == WXBarWriter : #tbd
63- hub_dict ['opt_kwargs' ]['options' ]["W_and_xbar_writer" ] = {"Wcsvdir" : "Wdir" }
64- hub_dict ['opt_kwargs' ]['options' ]['W_fname' ] = self .temp_w_file_name
65- hub_dict ['opt_kwargs' ]['options' ]['Xbar_fname' ] = self .temp_xbar_file_name
66- if ph_extensions == WXBarReader :
67- hub_dict ['opt_kwargs' ]['options' ]["W_and_xbar_reader" ] = {"Wcsvdir" : "Wdir" }
68- hub_dict ['opt_kwargs' ]['options' ]['init_W_fname' ] = self .w_file_name
69- hub_dict ['opt_kwargs' ]['options' ]['init_Xbar_fname' ] = self .xbar_file_name
63+ hub_dict ['opt_kwargs' ]['options' ]['cfg' ] = self .cfg
64+ if ph_extensions == wxbarwriter .WXBarWriter :
65+ self .cfg .W_and_xbar_writer = True
66+ self .cfg .W_fname = self .temp_w_file_name
67+ self .cfg .Xbar_fname = self .temp_xbar_file_name
68+ if ph_extensions == wxbarreader .WXBarReader :
69+ self .cfg .W_and_xbar_reader = True
70+ self .cfg .init_W_fname = self .w_file_name
71+ self .cfg .init_Xbar_fname = self .xbar_file_name
7072 list_of_spoke_dict = list ()
7173 wheel = WheelSpinner (hub_dict , list_of_spoke_dict )
7274 wheel .spin ()
@@ -79,7 +81,7 @@ def setUp(self):
7981 self .ph_object = None
8082
8183 def test_wwriter (self ):
82- self .ph_object = self ._create_ph_farmer (ph_extensions = WXBarWriter , max_iter = 5 )
84+ self .ph_object = self ._create_ph_farmer (ph_extensions = wxbarwriter . WXBarWriter , max_iter = 5 )
8385 with open (self .temp_w_file_name , 'r' ) as f :
8486 read = csv .reader (f )
8587 rows = list (read )
@@ -88,7 +90,7 @@ def test_wwriter(self):
8890 os .remove (self .temp_w_file_name )
8991
9092 def test_xbarwriter (self ):
91- self .ph_object = self ._create_ph_farmer (ph_extensions = WXBarWriter , max_iter = 5 )
93+ self .ph_object = self ._create_ph_farmer (ph_extensions = wxbarwriter . WXBarWriter , max_iter = 5 )
9294 with open (self .temp_xbar_file_name , 'r' ) as f :
9395 read = csv .reader (f )
9496 rows = list (read )
@@ -97,15 +99,15 @@ def test_xbarwriter(self):
9799 os .remove (self .temp_xbar_file_name )
98100
99101 def test_wreader (self ):
100- self .ph_object = self ._create_ph_farmer (ph_extensions = WXBarReader , max_iter = 1 )
102+ self .ph_object = self ._create_ph_farmer (ph_extensions = wxbarreader . WXBarReader , max_iter = 1 )
101103 for sname , scenario in self .ph_object .local_scenarios .items ():
102104 if sname == 'scen0' :
103105 self .assertAlmostEqual (scenario ._mpisppy_model .W [("ROOT" , 1 )]._value , 70.84705093609978 )
104106 if sname == 'scen1' :
105107 self .assertAlmostEqual (scenario ._mpisppy_model .W [("ROOT" , 0 )]._value , - 41.104251445950844 )
106108
107109 def test_xbarreader (self ):
108- self .ph_object = self ._create_ph_farmer (ph_extensions = WXBarReader , max_iter = 1 )
110+ self .ph_object = self ._create_ph_farmer (ph_extensions = wxbarreader . WXBarReader , max_iter = 1 )
109111 for sname , scenario in self .ph_object .local_scenarios .items ():
110112 if sname == 'scen0' :
111113 self .assertAlmostEqual (scenario ._mpisppy_model .xbars [("ROOT" , 1 )]._value , 274.2239371483933 )
0 commit comments