16
16
from mpisppy import MPI
17
17
from mpisppy .cylinders .spcommunicator import RecvArray , SendArray , SPCommunicator
18
18
from math import inf
19
- from mpisppy .cylinders .spoke import ConvergerSpokeType
20
19
21
20
from mpisppy import global_toc
22
21
@@ -51,6 +50,8 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
51
50
52
51
self .extension_recv = set ()
53
52
53
+ self .initialize_bound_values ()
54
+
54
55
return
55
56
56
57
@abc .abstractmethod
@@ -233,14 +234,12 @@ def receive_innerbounds(self):
233
234
(but should be harmless to call if there are none)
234
235
"""
235
236
logging .debug ("Hub is trying to receive from InnerBounds" )
236
- for idx in self .innerbound_spoke_indices :
237
- key = self ._make_key (Field .OBJECTIVE_INNER_BOUND , idx )
238
- recv_buf = self .receive_buffers [key ]
237
+ for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_INNER_BOUND ]:
239
238
is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_INNER_BOUND )
240
239
if is_new :
241
240
bound = recv_buf [0 ]
242
241
logging .debug ("!! new InnerBound to opt {}" .format (bound ))
243
- self .BestInnerBound = self .InnerBoundUpdate (bound , idx )
242
+ self .BestInnerBound = self .InnerBoundUpdate (bound , cls , idx )
244
243
logging .debug ("ph back from InnerBounds" )
245
244
246
245
def receive_outerbounds (self ):
@@ -249,37 +248,35 @@ def receive_outerbounds(self):
249
248
(but should be harmless to call if there are none)
250
249
"""
251
250
logging .debug ("Hub is trying to receive from OuterBounds" )
252
- for idx in self .outerbound_spoke_indices :
253
- key = self ._make_key (Field .OBJECTIVE_OUTER_BOUND , idx )
254
- recv_buf = self .receive_buffers [key ]
251
+ for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_OUTER_BOUND ]:
255
252
is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_OUTER_BOUND )
256
253
if is_new :
257
254
bound = recv_buf [0 ]
258
255
logging .debug ("!! new OuterBound to opt {}" .format (bound ))
259
- self .BestOuterBound = self .OuterBoundUpdate (bound , idx )
256
+ self .BestOuterBound = self .OuterBoundUpdate (bound , cls , idx )
260
257
logging .debug ("ph back from OuterBounds" )
261
258
262
- def OuterBoundUpdate (self , new_bound , idx = None , char = '*' ):
259
+ def OuterBoundUpdate (self , new_bound , cls = None , idx = None , char = '*' ):
263
260
current_bound = self .BestOuterBound
264
261
if self ._outer_bound_update (new_bound , current_bound ):
265
- if idx is None :
262
+ if cls is None :
266
263
self .latest_ob_char = char
267
264
self .last_ob_idx = 0
268
265
else :
269
- self .latest_ob_char = self . outerbound_spoke_chars [ idx ]
266
+ self .latest_ib_char = cls . converger_spoke_char
270
267
self .last_ob_idx = idx
271
268
return new_bound
272
269
else :
273
270
return current_bound
274
271
275
- def InnerBoundUpdate (self , new_bound , idx = None , char = '*' ):
272
+ def InnerBoundUpdate (self , new_bound , cls = None , idx = None , char = '*' ):
276
273
current_bound = self .BestInnerBound
277
274
if self ._inner_bound_update (new_bound , current_bound ):
278
- if idx is None :
275
+ if cls is None :
279
276
self .latest_ib_char = char
280
277
self .last_ib_idx = 0
281
278
else :
282
- self .latest_ib_char = self . innerbound_spoke_chars [ idx ]
279
+ self .latest_ib_char = cls . converger_spoke_char
283
280
self .last_ib_idx = idx
284
281
return new_bound
285
282
else :
@@ -297,28 +294,6 @@ def initialize_bound_values(self):
297
294
self ._inner_bound_update = lambda new , old : (new > old )
298
295
self ._outer_bound_update = lambda new , old : (new < old )
299
296
300
- def initialize_outer_bound_buffers (self ):
301
- """ Initialize outer bound receive buffers
302
- """
303
- self .outerbound_receive_buffers = dict ()
304
- for idx in self .outerbound_spoke_indices :
305
- self .outerbound_receive_buffers [idx ] = self .register_recv_field (
306
- Field .OBJECTIVE_OUTER_BOUND , idx , 1 ,
307
- )
308
- ## End for
309
- return
310
-
311
- def initialize_inner_bound_buffers (self ):
312
- """ Initialize inner bound receive buffers
313
- """
314
- self .innerbound_receive_buffers = dict ()
315
- for idx in self .innerbound_spoke_indices :
316
- self .innerbound_receive_buffers [idx ] = self .register_recv_field (
317
- Field .OBJECTIVE_INNER_BOUND , idx , 1
318
- )
319
- ## End for
320
- return
321
-
322
297
def _populate_boundsout_cache (self , buf ):
323
298
""" Populate a given buffer with the current bounds
324
299
"""
@@ -327,62 +302,26 @@ def _populate_boundsout_cache(self, buf):
327
302
328
303
def send_boundsout (self ):
329
304
""" Send bounds to the appropriate spokes
330
- This is called only for spokes which are bounds only.
331
- w and nonant spokes are passed bounds through the w and nonant buffers
332
305
"""
333
306
my_bounds = self .send_buffers [Field .BEST_OBJECTIVE_BOUNDS ]
334
307
self ._populate_boundsout_cache (my_bounds .array ())
335
308
logging .debug ("hub is sending bounds={}" .format (my_bounds ))
336
309
self .hub_to_spoke (my_bounds , Field .BEST_OBJECTIVE_BOUNDS )
337
310
return
338
311
339
- def initialize_spoke_indices (self ):
312
+ def register_receive_fields (self ):
340
313
""" Figure out what types of spokes we have,
341
314
and sort them into the appropriate classes.
342
315
343
316
Note:
344
317
Some spokes may be multiple types (e.g. outerbound and nonant),
345
318
though not all combinations are supported.
346
319
"""
347
- self .outerbound_spoke_indices = set ()
348
- self .innerbound_spoke_indices = set ()
349
- self .nonant_spoke_indices = set ()
350
- self .w_spoke_indices = set ()
351
-
352
- self .outerbound_spoke_chars = dict ()
353
- self .innerbound_spoke_chars = dict ()
354
-
355
- for (i , spoke ) in enumerate (self .communicators ):
356
- if i == self .strata_rank :
357
- continue
358
- spoke_class = spoke ["spcomm_class" ]
359
- if hasattr (spoke_class , "converger_spoke_types" ):
360
- for cst in spoke_class .converger_spoke_types :
361
- if cst == ConvergerSpokeType .OUTER_BOUND :
362
- self .outerbound_spoke_indices .add (i )
363
- self .outerbound_spoke_chars [i ] = spoke_class .converger_spoke_char
364
- elif cst == ConvergerSpokeType .INNER_BOUND :
365
- self .innerbound_spoke_indices .add (i )
366
- self .innerbound_spoke_chars [i ] = spoke_class .converger_spoke_char
367
- elif cst == ConvergerSpokeType .W_GETTER :
368
- self .w_spoke_indices .add (i )
369
- elif cst == ConvergerSpokeType .NONANT_GETTER :
370
- self .nonant_spoke_indices .add (i )
371
- else :
372
- raise RuntimeError (f"Unrecognized converger_spoke_type { cst } " )
373
-
374
- else : ##this isn't necessarily wrong, i.e., cut generators
375
- logger .debug (f"Spoke class { spoke_class } not recognized by hub" )
376
-
377
- # all _BoundSpoke spokes get hub bounds so we determine which spokes
378
- # are "bounds only"
379
- self .bounds_only_indices = \
380
- (self .outerbound_spoke_indices | self .innerbound_spoke_indices ) - \
381
- (self .w_spoke_indices | self .nonant_spoke_indices )
320
+ super ().register_receive_fields ()
382
321
383
322
# Not all opt classes may have extensions
384
323
if getattr (self .opt , "extensions" , None ) is not None :
385
- self .opt .extobject .initialize_spoke_indices ()
324
+ self .opt .extobject .register_receive_fields ()
386
325
387
326
return
388
327
@@ -511,31 +450,14 @@ def setup_hub(self):
511
450
"Cannot call setup_hub before memory windows are constructed"
512
451
)
513
452
514
- self .initialize_spoke_indices ()
515
- self .initialize_bound_values ()
516
-
517
- self .initialize_outer_bound_buffers ()
518
- self .initialize_inner_bound_buffers ()
519
-
520
- ## Do some checking for things we currently don't support
521
- if len (self .outerbound_spoke_indices & self .innerbound_spoke_indices ) > 0 :
522
- raise RuntimeError (
523
- "A Spoke providing both inner and outer "
524
- "bounds is currently unsupported"
525
- )
526
- if len (self .w_spoke_indices & self .nonant_spoke_indices ) > 0 :
527
- raise RuntimeError (
528
- "A Spoke needing both Ws and nonants is currently unsupported"
529
- )
530
-
531
453
## Generate some warnings if nothing is giving bounds
532
- if not self .outerbound_spoke_indices :
454
+ if not self .receive_field_spcomms [ Field . OBJECTIVE_OUTER_BOUND ] :
533
455
logger .warn (
534
456
"No OuterBound Spokes defined, this converger "
535
457
"will not cause the hub to terminate"
536
458
)
537
459
538
- if not self .innerbound_spoke_indices :
460
+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
539
461
logger .warn (
540
462
"No InnerBound Spokes defined, this converger "
541
463
"will not cause the hub to terminate"
@@ -578,7 +500,7 @@ def is_converged(self):
578
500
if self .opt .best_bound_obj_val is not None :
579
501
self .BestOuterBound = self .OuterBoundUpdate (self .opt .best_bound_obj_val )
580
502
581
- if not self .innerbound_spoke_indices :
503
+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
582
504
if self .opt ._PHIter == 1 :
583
505
logger .warning (
584
506
"PHHub cannot compute convergence without "
@@ -591,7 +513,7 @@ def is_converged(self):
591
513
592
514
return False
593
515
594
- if not self .outerbound_spoke_indices :
516
+ if not self .receive_field_spcomms [ Field . OBJECTIVE_OUTER_BOUND ] :
595
517
if self .opt ._PHIter == 1 and not self ._hub_algo_best_bound_provider :
596
518
global_toc (
597
519
"Without outer bound spokes, no progress "
@@ -660,24 +582,8 @@ def setup_hub(self):
660
582
"Cannot call setup_hub before memory windows are constructed"
661
583
)
662
584
663
- self .initialize_spoke_indices ()
664
- self .initialize_bound_values ()
665
-
666
- self .initialize_outer_bound_buffers ()
667
- self .initialize_inner_bound_buffers ()
668
-
669
- ## Do some checking for things we currently
670
- ## do not support
671
- if self .w_spoke_indices :
672
- raise RuntimeError ("LShaped hub does not compute dual weights (Ws)" )
673
- if len (self .outerbound_spoke_indices & self .innerbound_spoke_indices ) > 0 :
674
- raise RuntimeError (
675
- "A Spoke providing both inner and outer "
676
- "bounds is currently unsupported"
677
- )
678
-
679
585
## Generate some warnings if nothing is giving bounds
680
- if not self .innerbound_spoke_indices :
586
+ if not self .receive_field_spcomms [ Field . OBJECTIVE_INNER_BOUND ] :
681
587
logger .warn (
682
588
"No InnerBound Spokes defined, this converger "
683
589
"will not cause the hub to terminate"
0 commit comments