30
30
31
31
class Hub (SPCommunicator ):
32
32
33
+ send_fields = (* SPCommunicator .send_fields , Field .SHUTDOWN , Field .BEST_OBJECTIVE_BOUNDS ,)
34
+ receive_fields = (* SPCommunicator .receive_fields , )
35
+ optional_receive_fields = (* SPCommunicator .optional_receive_fields , Field .OBJECTIVE_INNER_BOUND , Field .OBJECTIVE_OUTER_BOUND , )
36
+
33
37
_hub_algo_best_bound_provider = False
34
38
35
39
def __init__ (self , spbase_object , fullcomm , strata_comm , cylinder_comm , communicators , options = None ):
@@ -85,7 +89,7 @@ def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len:
85
89
to the extension sync_with_spokes function.
86
90
"""
87
91
key = self ._make_key (field , strata_rank )
88
- if key not in self ._locals :
92
+ if key not in self .receive_buffers :
89
93
# if it is not already registered, we need to update the local buffer
90
94
self .extension_recv .add (key )
91
95
## End if
@@ -103,7 +107,7 @@ def register_extension_send_field(self, field: Field, buf_len: int) -> SendArray
103
107
return self .register_send_field (field , buf_len )
104
108
105
109
def is_send_field_registered (self , field : Field ) -> bool :
106
- return field in self ._sends
110
+ return field in self .send_buffers
107
111
108
112
def extension_send_field (self , field : Field , buf : SendArray ):
109
113
"""
@@ -117,7 +121,7 @@ def sync_extension_fields(self):
117
121
Update all registered extension fields. Safe to call even when there are no extension fields.
118
122
"""
119
123
for key in self .extension_recv :
120
- ext_buf = self ._locals [key ]
124
+ ext_buf = self .receive_buffers [key ]
121
125
(field , srank ) = self ._split_key (key )
122
126
ext_buf ._is_new = self .hub_from_spoke (ext_buf , srank , field )
123
127
## End for
@@ -233,7 +237,7 @@ def receive_innerbounds(self):
233
237
logging .debug ("Hub is trying to receive from InnerBounds" )
234
238
for idx in self .innerbound_spoke_indices :
235
239
key = self ._make_key (Field .OBJECTIVE_INNER_BOUND , idx )
236
- recv_buf = self ._locals [key ]
240
+ recv_buf = self .receive_buffers [key ]
237
241
is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_INNER_BOUND )
238
242
if is_new :
239
243
bound = recv_buf [0 ]
@@ -249,7 +253,7 @@ def receive_outerbounds(self):
249
253
logging .debug ("Hub is trying to receive from OuterBounds" )
250
254
for idx in self .outerbound_spoke_indices :
251
255
key = self ._make_key (Field .OBJECTIVE_OUTER_BOUND , idx )
252
- recv_buf = self ._locals [key ]
256
+ recv_buf = self .receive_buffers [key ]
253
257
is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_OUTER_BOUND )
254
258
if is_new :
255
259
bound = recv_buf [0 ]
@@ -320,18 +324,18 @@ def initialize_inner_bound_buffers(self):
320
324
def _populate_boundsout_cache (self , buf ):
321
325
""" Populate a given buffer with the current bounds
322
326
"""
323
- buf [- 3 ] = self .BestOuterBound
324
- buf [- 2 ] = self .BestInnerBound
327
+ buf [0 ] = self .BestOuterBound
328
+ buf [1 ] = self .BestInnerBound
325
329
326
330
def send_boundsout (self ):
327
331
""" Send bounds to the appropriate spokes
328
332
This is called only for spokes which are bounds only.
329
333
w and nonant spokes are passed bounds through the w and nonant buffers
330
334
"""
331
- my_bounds = self .boundsout_send_buffer
335
+ my_bounds = self .send_buffers [ Field . BEST_OBJECTIVE_BOUNDS ]
332
336
self ._populate_boundsout_cache (my_bounds .array ())
333
337
logging .debug ("hub is sending bounds={}" .format (my_bounds ))
334
- self .hub_to_spoke (my_bounds , Field .OBJECTIVE_BOUNDS )
338
+ self .hub_to_spoke (my_bounds , Field .BEST_OBJECTIVE_BOUNDS )
335
339
return
336
340
337
341
def initialize_spoke_indices (self ):
@@ -392,45 +396,7 @@ def initialize_spoke_indices(self):
392
396
393
397
394
398
def register_send_fields (self ):
395
-
396
- self .shutdown = self .register_send_field (Field .SHUTDOWN , 1 )
397
-
398
- required_fields = set ()
399
- for i , spoke in enumerate (self .communicators ):
400
- if i == self .strata_rank :
401
- continue
402
- spoke_class = spoke ["spcomm_class" ]
403
- if hasattr (spoke_class , "converger_spoke_types" ):
404
- for cst in spoke_class .converger_spoke_types :
405
- if cst == ConvergerSpokeType .W_GETTER :
406
- required_fields .add (Field .DUALS )
407
- elif cst == ConvergerSpokeType .NONANT_GETTER :
408
- required_fields .add (Field .NONANT )
409
- elif cst == ConvergerSpokeType .INNER_BOUND or cst == ConvergerSpokeType .OUTER_BOUND :
410
- required_fields .add (Field .OBJECTIVE_BOUNDS )
411
- else :
412
- pass # Intentional no-op
413
- ## End if
414
- ## End for
415
- else :
416
- # Intentional no-op. Non-converger spokes need to register any needed
417
- # fields separately. See the functions `register_extension_recv_field`
418
- # and `register_extension_send_field`.
419
- pass
420
- ## End if
421
- ## End for
422
-
423
- n_nonants = 0
424
- for s in self .opt .local_scenarios .values ():
425
- n_nonants += len (s ._mpisppy_data .nonant_indices )
426
- ## End for
427
-
428
- if Field .DUALS in required_fields :
429
- self .w_send_buffer = self .register_send_field (Field .DUALS , n_nonants )
430
- if Field .NONANT in required_fields :
431
- self .nonant_send_buffer = self .register_send_field (Field .NONANT , n_nonants )
432
- if Field .OBJECTIVE_BOUNDS in required_fields :
433
- self .boundsout_send_buffer = self .register_send_field (Field .OBJECTIVE_BOUNDS , 2 )
399
+ super ().register_send_fields ()
434
400
435
401
# Not all opt classes may have extensions
436
402
if getattr (self .opt , "extensions" , None ) is not None :
@@ -439,7 +405,6 @@ def register_send_fields(self):
439
405
return
440
406
441
407
442
-
443
408
def hub_to_spoke (self , buf : SendArray , field : Field ):
444
409
""" Put the specified values into the specified locally-owned buffer
445
410
for the spoke to pick up.
@@ -534,13 +499,17 @@ def send_terminate(self):
534
499
buffer, so every spoke will see it simultaneously.
535
500
processes (don't need to call them one at a time).
536
501
"""
537
- shutdown = self .shutdown
538
- shutdown [0 ] = 1.0
539
- self .hub_to_spoke (shutdown , Field .SHUTDOWN )
502
+ self .send_buffers [Field .SHUTDOWN ][0 ] = 1.0
503
+ self .hub_to_spoke (self .send_buffers [Field .SHUTDOWN ], Field .SHUTDOWN )
540
504
return
541
505
542
506
543
507
class PHHub (Hub ):
508
+
509
+ send_fields = (* Hub .send_fields , Field .NONANT , Field .DUALS )
510
+ receive_fields = (* Hub .receive_fields ,)
511
+ optional_receive_fields = (* Hub .optional_receive_fields ,)
512
+
544
513
def setup_hub (self ):
545
514
""" Must be called after make_windows(), so that
546
515
the hub knows the sizes of all the spokes windows
@@ -673,8 +642,7 @@ def send_nonants(self):
673
642
"""
674
643
self .opt ._save_nonants ()
675
644
ci = 0 ## index to self.nonant_send_buffer
676
- # my_nonants = self._sends[Field.NONANT]
677
- nonant_send_buffer = self .nonant_send_buffer
645
+ nonant_send_buffer = self .send_buffers [Field .NONANT ]
678
646
for k , s in self .opt .local_scenarios .items ():
679
647
for xvar in s ._mpisppy_data .nonant_indices .values ():
680
648
nonant_send_buffer [ci ] = xvar ._value
@@ -690,7 +658,7 @@ def send_ws(self):
690
658
""" Send dual weights to the appropriate spokes
691
659
"""
692
660
# NOTE: my_ws.array() and self.w_send_buffer should be the same array.
693
- my_ws = self ._sends [Field .DUALS ]
661
+ my_ws = self .send_buffers [Field .DUALS ]
694
662
self .opt ._populate_W_cache (my_ws .array (), padding = 1 )
695
663
logging .debug ("hub is sending Ws={}" .format (my_ws .array ()))
696
664
@@ -701,6 +669,10 @@ def send_ws(self):
701
669
702
670
class LShapedHub (Hub ):
703
671
672
+ send_fields = (* Hub .send_fields , Field .NONANT ,)
673
+ receive_fields = (* Hub .receive_fields ,)
674
+ optional_receive_fields = (* Hub .optional_receive_fields ,)
675
+
704
676
def setup_hub (self ):
705
677
""" Must be called after make_windows(), so that
706
678
the hub knows the sizes of all the spokes windows
@@ -781,7 +753,7 @@ def send_nonants(self):
781
753
TODO: Will likely fail with bundling
782
754
"""
783
755
ci = 0 ## index to self.nonant_send_buffer
784
- nonant_send_buffer = self .nonant_send_buffer
756
+ nonant_send_buffer = self .send_buffers [ Field . NONANT ]
785
757
for k , s in self .opt .local_scenarios .items ():
786
758
nonant_to_root_var_map = s ._mpisppy_model .subproblem_to_root_vars_map
787
759
for xvar in s ._mpisppy_data .nonant_indices .values ():
@@ -797,6 +769,8 @@ def send_nonants(self):
797
769
798
770
class SubgradientHub (PHHub ):
799
771
772
+ # send / receive fields are same as PHHub
773
+
800
774
_hub_algo_best_bound_provider = True
801
775
802
776
def main (self ):
@@ -806,6 +780,8 @@ def main(self):
806
780
807
781
class APHHub (PHHub ):
808
782
783
+ # send / receive fields are same as PHHub
784
+
809
785
def main (self ):
810
786
""" SPComm gets attached by self.__init___; holding APH harmless """
811
787
logger .critical ("aph debug main in hub.py" )
0 commit comments