7
7
# full copyright and license information.
8
8
###############################################################################
9
9
10
- import numpy as np
11
10
import abc
12
11
import time
13
12
import os
14
13
import math
15
14
16
- from mpisppy import MPI
17
- from mpisppy .cylinders .spcommunicator import RecvArray , SPCommunicator
15
+ from mpisppy .cylinders .spcommunicator import SPCommunicator
18
16
from mpisppy .cylinders .spwindow import Field
19
17
20
18
@@ -31,48 +29,6 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
31
29
32
30
return
33
31
34
- def spoke_from_hub (self ,
35
- buf : RecvArray ,
36
- field : Field ,
37
- ):
38
- buf ._is_new = self ._spoke_from_hub (buf .array (), field , buf .id ())
39
- if buf .is_new ():
40
- buf ._pull_id ()
41
- return buf .is_new ()
42
-
43
- def _spoke_from_hub (self ,
44
- values : np .typing .NDArray ,
45
- field : Field ,
46
- last_write_id : int
47
- ):
48
- """
49
- """
50
-
51
- self .cylinder_comm .Barrier ()
52
- self .window .get (values , 0 , field )
53
-
54
- # On rare occasions a NaN is seen...
55
- new_id = int (values [- 1 ]) if not math .isnan (values [- 1 ]) else 0
56
- local_val = np .array ((new_id ,- new_id ), 'i' )
57
- max_min_ids = np .zeros (2 , 'i' )
58
- self .cylinder_comm .Allreduce ((local_val , MPI .INT ),
59
- (max_min_ids , MPI .INT ),
60
- op = MPI .MAX )
61
-
62
- max_id = max_min_ids [0 ]
63
- min_id = - max_min_ids [1 ]
64
- # NOTE: we only proceed if all the ranks agree
65
- # on the ID
66
- if max_id != min_id :
67
- return False
68
-
69
- assert max_id == min_id == new_id
70
-
71
- if new_id > last_write_id or new_id < 0 :
72
- return True
73
-
74
- return False
75
-
76
32
def _got_kill_signal (self ):
77
33
shutdown_buf = self .receive_buffers [self ._make_key (Field .SHUTDOWN , 0 )]
78
34
if shutdown_buf .is_new ():
@@ -103,8 +59,7 @@ def update_receive_buffers(self):
103
59
for (key , recv_buf ) in self .receive_buffers .items ():
104
60
field , rank = self ._split_key (key )
105
61
# The below code will need to be updated for spoke to spoke communication
106
- assert (rank == 0 )
107
- self .spoke_from_hub (recv_buf , field )
62
+ self .get_receive_buffer (recv_buf , field , rank )
108
63
## End for
109
64
return
110
65
0 commit comments