@@ -53,31 +53,37 @@ void AllGather::initdone(int num) {
5353}
5454
5555void AllGather::init (long int * result, long int * data, CkCallback cb) {
56- this ->cb = cb;
56+ this ->lib_done_callback = cb;
57+ zero_copy_callback = CkCallback (CkIndex_AllGather::local_buff_done (NULL ), CkArrayIndex1D (thisIndex), thisProxy);
58+ dum_dum = CkCallback (CkCallback::ignore);
5759 this ->store = result;
5860 this ->data = data;
5961 int cnt = 1 ;
6062 CkCallback cbinitdone (CkReductionTarget (AllGather, initdone), thisProxy (0 ));
6163 contribute (sizeof (int ), &cnt, CkReduction::sum_int, cbinitdone);
6264}
6365
66+ void AllGather::local_buff_done (CkDataMsg *m) {
67+ numRecvMsg++;
68+ if (numRecvMsg == n - 1 ) {
69+ lib_done_callback.send (msg);
70+ }
71+ }
72+
6473void AllGather::startGather () {
6574 switch (type) {
6675 case allGatherType::ALL_GATHER_DEFAULT: {
6776 for (int i = 0 ; i < k; i++) {
6877 store[k * thisIndex + i] = data[i];
6978 }
70- numDefaultMsg++ ;
79+ CkNcpyBuffer src (data, k* sizeof ( long int ), dum_dum, CK_BUFFER_UNREG) ;
7180#ifdef TIMESTAMP
7281 thisProxy[(thisIndex + 1 ) % n].recvDefault (
73- thisIndex, data, k , (timeStamp + alpha + beta * k * 8 ));
82+ thisIndex, src , (timeStamp + alpha + beta * k * 8 ));
7483 timeStamp += alpha;
7584#else
76- thisProxy[(thisIndex + 1 ) % n].recvDefault (thisIndex, data, k , 0.0 );
85+ thisProxy[(thisIndex + 1 ) % n].recvDefault (thisIndex, src , 0.0 );
7786#endif
78- if (numDefaultMsg == n) {
79- cb.send (msg);
80- }
8187 } break ;
8288 case allGatherType::ALL_GATHER_HYPERCUBE: {
8389 hyperCubeIndx.push_back (thisIndex);
@@ -90,74 +96,60 @@ void AllGather::startGather() {
9096 for (int i = 0 ; i < k; i++) {
9197 store[k * thisIndex + i] = data[i];
9298 }
93- numAccFloodMsg++;
9499 recvFloodMsg[thisIndex] = true ;
100+ CkNcpyBuffer src (data, k*sizeof (long int ), dum_dum, CK_BUFFER_UNREG);
95101 for (int i = 0 ; i < n; i++) {
96102 if (graph[thisIndex][i] == 1 ) {
97103#ifdef TIMESTAMP
98- thisProxy (i).Flood (thisIndex, data, k ,
104+ thisProxy (i).Flood (thisIndex, src ,
99105 (timeStamp + alpha + beta * k * 8 ));
100106 timeStamp += alpha;
101107#else
102- thisProxy (i).Flood (thisIndex, data, k , 0.0 );
108+ thisProxy (i).Flood (thisIndex, src , 0.0 );
103109#endif
104110 }
105111 }
106- if (numAccFloodMsg == n) {
107- cb.send (msg);
108- }
109112 } break ;
110113 }
111114}
112115
113- void AllGather::recvDefault (int sender, long int data[], int _,
114- double recvTime) {
115- numDefaultMsg++;
116- for (int i = 0 ; i < k; i++) {
117- store[k * sender + i] = data[i];
118- }
116+ void AllGather::recvDefault (int sender, CkNcpyBuffer src, double recvTime) {
117+ CkNcpyBuffer dst (store + sender * k, k * sizeof (long int ), zero_copy_callback, CK_BUFFER_UNREG);
118+ dst.get (src);
119119#ifdef TIMESTAMP
120120 timeStamp = std::max (recvTime, timeStamp);
121121#endif
122122 if (((thisIndex + 1 ) % n) != sender) {
123123#ifdef TIMESTAMP
124124 thisProxy[(thisIndex + 1 ) % n].recvDefault (
125- sender, data, k , (timeStamp + alpha + beta * k * 8 ));
125+ sender, src , (timeStamp + alpha + beta * k * 8 ));
126126 timeStamp += alpha;
127127#else
128- thisProxy[(thisIndex + 1 ) % n].recvDefault (sender, data, k , 0.0 );
128+ thisProxy[(thisIndex + 1 ) % n].recvDefault (sender, src , 0.0 );
129129#endif
130130 }
131- if (numDefaultMsg == n) {
132- cb.send (msg);
133- }
134131}
135132
136- void AllGather::Flood (int sender, long int data[], int _ , double recvTime) {
133+ void AllGather::Flood (int sender, CkNcpyBuffer src , double recvTime) {
137134 if (recvFloodMsg[sender]) {
138135 return ;
139136 }
140- numAccFloodMsg++;
141137 recvFloodMsg[sender] = true ;
142- for (int i = 0 ; i < k; i++) {
143- store[k * sender + i] = data[i];
144- }
138+ CkNcpyBuffer dst (store + sender * k, k * sizeof (long int ), zero_copy_callback, CK_BUFFER_UNREG);
139+ dst.get (src);
145140#ifdef TIMESTAMP
146141 timeStamp = std::max (recvTime, timeStamp);
147142#endif
148143 for (int i = 0 ; i < n; i++) {
149144 if (graph[thisIndex][i] == 1 and i != sender) {
150145#ifdef TIMESTAMP
151- thisProxy (i).Flood (sender, data, k , (timeStamp + alpha + beta * k * 8 ));
146+ thisProxy (i).Flood (sender, src , (timeStamp + alpha + beta * k * 8 ));
152147 timeStamp += alpha;
153148#else
154- thisProxy (i).Flood (sender, data, k , 0.0 );
149+ thisProxy (i).Flood (sender, src , 0.0 );
155150#endif
156151 }
157152 }
158- if (numAccFloodMsg == n) {
159- cb.send (msg);
160- }
161153}
162154
163155#include " allGather.def.h"
0 commit comments