7
7
# full copyright and license information.
8
8
###############################################################################
9
9
10
- import numpy as np
11
10
import abc
12
11
import logging
13
12
import mpisppy .log
14
- from mpisppy .opt .aph import APH
15
13
16
- from mpisppy import MPI
17
14
from mpisppy .cylinders .spcommunicator import RecvArray , SendArray , SPCommunicator
18
15
from math import inf
19
16
@@ -83,7 +80,6 @@ def current_iteration(self):
83
80
def main (self ):
84
81
pass
85
82
86
-
87
83
def register_extension_recv_field (self , field : Field , strata_rank : int , buf_len : int = - 1 ) -> RecvArray :
88
84
"""
89
85
Register an extensions interest in the given field from the given spoke. The hub
@@ -125,15 +121,14 @@ def sync_extension_fields(self):
125
121
for key in self .extension_recv :
126
122
ext_buf = self .receive_buffers [key ]
127
123
(field , srank ) = self ._split_key (key )
128
- ext_buf ._is_new = self .hub_from_spoke (ext_buf , srank , field )
124
+ ext_buf ._is_new = self .get_receive_buffer (ext_buf , field , srank )
129
125
## End for
130
126
return
131
127
132
128
def clear_latest_chars (self ):
133
129
self .latest_ib_char = None
134
130
self .latest_ob_char = None
135
131
136
-
137
132
def compute_gaps (self ):
138
133
""" Compute the current absolute and relative gaps,
139
134
using the current self.BestInnerBound and self.BestOuterBound
@@ -157,7 +152,6 @@ def compute_gaps(self):
157
152
rel_gap = float ("inf" )
158
153
return abs_gap , rel_gap
159
154
160
-
161
155
def get_update_string (self ):
162
156
if self .latest_ib_char is None and \
163
157
self .latest_ob_char is None :
@@ -236,7 +230,7 @@ def receive_innerbounds(self):
236
230
"""
237
231
logging .debug ("Hub is trying to receive from InnerBounds" )
238
232
for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_INNER_BOUND ]:
239
- is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_INNER_BOUND )
233
+ is_new = self .get_receive_buffer (recv_buf , Field .OBJECTIVE_INNER_BOUND , idx )
240
234
if is_new :
241
235
bound = recv_buf [0 ]
242
236
logging .debug ("!! new InnerBound to opt {}" .format (bound ))
@@ -250,7 +244,7 @@ def receive_outerbounds(self):
250
244
"""
251
245
logging .debug ("Hub is trying to receive from OuterBounds" )
252
246
for idx , cls , recv_buf in self .receive_field_spcomms [Field .OBJECTIVE_OUTER_BOUND ]:
253
- is_new = self .hub_from_spoke (recv_buf , idx , Field .OBJECTIVE_OUTER_BOUND )
247
+ is_new = self .get_receive_buffer (recv_buf , Field .OBJECTIVE_OUTER_BOUND , idx )
254
248
if is_new :
255
249
bound = recv_buf [0 ]
256
250
logging .debug ("!! new OuterBound to opt {}" .format (bound ))
@@ -264,7 +258,7 @@ def OuterBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
264
258
self .latest_ob_char = char
265
259
self .last_ob_idx = 0
266
260
else :
267
- self .latest_ib_char = cls .converger_spoke_char
261
+ self .latest_ob_char = cls .converger_spoke_char
268
262
self .last_ob_idx = idx
269
263
return new_bound
270
264
else :
@@ -326,7 +320,6 @@ def register_receive_fields(self):
326
320
327
321
return
328
322
329
-
330
323
def register_send_fields (self ):
331
324
super ().register_send_fields ()
332
325
@@ -336,63 +329,6 @@ def register_send_fields(self):
336
329
337
330
return
338
331
339
- def hub_from_spoke (self ,
340
- buf : RecvArray ,
341
- spoke_num : int ,
342
- field : Field ,
343
- ):
344
- """ spoke_num is the rank in the strata_comm, so it is 1-based not 0-based
345
-
346
- Returns:
347
- is_new (bool): Indicates whether the "gotten" values are new,
348
- based on the write_id.
349
- """
350
- buf ._is_new = self ._hub_from_spoke (buf .array (), spoke_num , field , buf .id ())
351
- if buf .is_new ():
352
- buf ._pull_id ()
353
- return buf .is_new ()
354
-
355
- def _hub_from_spoke (self ,
356
- values : np .typing .NDArray ,
357
- spoke_num : int ,
358
- field : Field ,
359
- last_write_id : int ,
360
- ):
361
- """ spoke_num is the rank in the strata_comm, so it is 1-based not 0-based
362
-
363
- Returns:
364
- is_new (bool): Indicates whether the "gotten" values are new,
365
- based on the write_id.
366
- """
367
- # so the window in each rank gets read at approximately the same time,
368
- # and so has the same write_id
369
- if not isinstance (self .opt , APH ):
370
- self .cylinder_comm .Barrier ()
371
- ## End if
372
- self .window .get (values , spoke_num , field )
373
-
374
- if isinstance (self .opt , APH ):
375
- # # reverting part of changes from Ben getting rid of spoke sleep DLW jan 2023
376
- if values [- 1 ] > last_write_id :
377
- return True
378
- else :
379
- new_id = int (values [- 1 ])
380
- local_val = np .array ((new_id ,), 'i' )
381
- sum_ids = np .zeros (1 , 'i' )
382
- self .cylinder_comm .Allreduce ((local_val , MPI .INT ),
383
- (sum_ids , MPI .INT ),
384
- op = MPI .SUM )
385
- if new_id != sum_ids [0 ] / self .cylinder_comm .size :
386
- return False
387
- ## End if
388
- if new_id > last_write_id or new_id < 0 :
389
- return True
390
- ## End if
391
- ## End if
392
-
393
- return False
394
-
395
-
396
332
def send_terminate (self ):
397
333
""" Send an array of zeros with a -1 appended to the
398
334
end to indicate termination. This function puts to the local
@@ -614,6 +550,15 @@ def main(self):
614
550
logger .critical ("aph debug main in hub.py" )
615
551
self .opt .APH_main (spcomm = self , finalize = False )
616
552
553
+ # overwrite the default behavior of this method for APH
554
+ def get_receive_buffer (self ,
555
+ buf : RecvArray ,
556
+ field : Field ,
557
+ origin : int = - 1 ,
558
+ synchronize : bool = False ,
559
+ ):
560
+ return super ().get_receive_buffer (buf , field , origin , synchronize )
561
+
617
562
def finalize (self ):
618
563
""" does PH.post_loops, returns Eobj """
619
564
# NOTE: APH_main does NOT pass in extensions
0 commit comments