Skip to content

Commit f006bd8

Browse files
authored
[Scheduler] omp parallelize LB (#1723)
1 parent ec46a45 commit f006bd8

2 files changed

Lines changed: 48 additions & 42 deletions

File tree

src/shamrock/include/shamrock/scheduler/loadbalance/LoadBalanceStrategy.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ namespace shamrock::scheduler::details {
4141
u64 index;
4242
i32 new_owner;
4343

44+
LoadBalancedTile() = default;
45+
4446
LoadBalancedTile(TileWithLoad<Torder, Tweight> in, u64 inindex)
4547
: ordering_val(in.ordering_val), load_value(in.load_value), index(inindex) {}
4648
};
@@ -76,9 +78,10 @@ namespace shamrock::scheduler::details {
7678
using LBTile = TileWithLoad<Torder, Tweight>;
7779
using LBTileResult = details::LoadBalancedTile<Torder, Tweight>;
7880

79-
std::vector<LBTileResult> res;
81+
std::vector<LBTileResult> res(lb_vector.size());
82+
#pragma omp parallel for
8083
for (u64 i = 0; i < lb_vector.size(); i++) {
81-
res.push_back(LBTileResult{lb_vector[i], i});
84+
res[i] = LBTileResult{lb_vector[i], i};
8285
}
8386

8487
// apply the ordering
@@ -94,15 +97,18 @@ namespace shamrock::scheduler::details {
9497

9598
double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;
9699

97-
for (LBTileResult &tile : res) {
100+
#pragma omp parallel for
101+
for (u64 i = 0; i < res.size(); i++) {
102+
LBTileResult &tile = res[i];
98103
tile.new_owner
99104
= (target_datacnt == 0)
100105
? 0
101106
: sycl::clamp(
102107
i32(tile.accumulated_load_value / target_datacnt), 0, wsize - 1);
103108
}
104109

105-
if (shamcomm::world_rank() == 0) {
110+
if (shamcomm::world_rank() == 0
111+
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
106112
for (LBTileResult t : res) {
107113
shamlog_debug_ln(
108114
"HilbertLoadBalance",
@@ -141,9 +147,10 @@ namespace shamrock::scheduler::details {
141147
using LBTile = TileWithLoad<Torder, Tweight>;
142148
using LBTileResult = details::LoadBalancedTile<Torder, Tweight>;
143149

144-
std::vector<LBTileResult> res;
150+
std::vector<LBTileResult> res(lb_vector.size());
151+
#pragma omp parallel for
145152
for (u64 i = 0; i < lb_vector.size(); i++) {
146-
res.push_back(LBTileResult{lb_vector[i], i});
153+
res[i] = LBTileResult{lb_vector[i], i};
147154
}
148155

149156
// apply the ordering
@@ -160,15 +167,18 @@ namespace shamrock::scheduler::details {
160167

161168
double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;
162169

163-
for (LBTileResult &tile : res) {
170+
#pragma omp parallel for
171+
for (u64 i = 0; i < res.size(); i++) {
172+
LBTileResult &tile = res[i];
164173
tile.new_owner
165174
= (target_datacnt == 0)
166175
? 0
167176
: sycl::clamp(
168177
i32(tile.accumulated_load_value / target_datacnt), 0, wsize - 1);
169178
}
170179

171-
if (shamcomm::world_rank() == 0) {
180+
if (shamcomm::world_rank() == 0
181+
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
172182
for (LBTileResult t : res) {
173183
shamlog_debug_ln(
174184
"HilbertLoadBalance",

src/shamrock/src/scheduler/HilbertLoadBalance.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,16 @@
2525

2626
inline void apply_node_patch_packing(
2727
std::vector<shamrock::patch::Patch> &global_patch_list, std::vector<i32> &new_owner_table) {
28-
using namespace shamrock::patch;
29-
sycl::buffer<i32> new_owner(new_owner_table.data(), new_owner_table.size());
30-
sycl::buffer<Patch> patch_buf(global_patch_list.data(), global_patch_list.size());
3128

32-
sycl::range<1> range{global_patch_list.size()};
33-
34-
// pack nodes
35-
shamsys::instance::get_alt_queue()
36-
.submit([&](sycl::handler &cgh) {
37-
auto ptch = patch_buf.get_access<sycl::access::mode::read>(cgh);
38-
// auto pdt = dt_buf.get_access<sycl::access::mode::read>(cgh);
39-
auto chosen_node = new_owner.get_access<sycl::access::mode::write>(cgh);
40-
41-
cgh.parallel_for(range, [=](sycl::item<1> item) {
42-
u64 i = (u64) item.get_id(0);
43-
44-
if (ptch[i].pack_node_index != u64_max) {
45-
chosen_node[i] = chosen_node[ptch[i].pack_node_index];
46-
}
47-
});
48-
})
49-
.wait();
29+
// Note that there seems to be a data race here
30+
// However this should never happends as packing index will only point toward a patch without
31+
// packing. As such the data we are accessing should never be modified during this loop.
32+
#pragma omp parallel for
33+
for (size_t i = 0; i < global_patch_list.size(); i++) {
34+
if (global_patch_list[i].pack_node_index != u64_max) {
35+
new_owner_table[i] = new_owner_table[global_patch_list[i].pack_node_index];
36+
}
37+
}
5038
}
5139

5240
namespace shamrock::scheduler {
@@ -102,17 +90,17 @@ namespace shamrock::scheduler {
10290

10391
// TODO add bool for optional print verbosity
10492
// std::cout << i << " : " << old_owner << " -> " << new_owner << std::endl;
93+
if (new_owner != old_owner) {
10594

106-
using ChangeOp = LoadBalancingChangeList::ChangeOp;
95+
using ChangeOp = LoadBalancingChangeList::ChangeOp;
10796

108-
ChangeOp op;
109-
op.patch_idx = i;
110-
op.patch_id = global_patch_list[i].id_patch;
111-
op.rank_owner_new = new_owner;
112-
op.rank_owner_old = old_owner;
113-
op.tag_comm = tags_it_node[old_owner];
97+
ChangeOp op;
98+
op.patch_idx = i;
99+
op.patch_id = global_patch_list[i].id_patch;
100+
op.rank_owner_new = new_owner;
101+
op.rank_owner_old = old_owner;
102+
op.tag_comm = tags_it_node[old_owner];
114103

115-
if (new_owner != old_owner) {
116104
change_list.change_ops.push_back(op);
117105
tags_it_node[old_owner]++;
118106
}
@@ -126,23 +114,31 @@ namespace shamrock::scheduler {
126114
f64 avg = 0;
127115
f64 var = 0;
128116

129-
for (i32 nid = 0; nid < shamcomm::world_size(); nid++) {
117+
i32 world_size = shamcomm::world_size();
118+
119+
#pragma omp parallel for reduction(min : min) reduction(max : max) reduction(+ : avg)
120+
for (i32 nid = 0; nid < world_size; nid++) {
130121
f64 val = load_per_node[nid];
131122
min = sycl::fmin(min, val);
132123
max = sycl::fmax(max, val);
133124
avg += val;
125+
}
134126

135-
if (shamcomm::world_rank() == 0) {
127+
if (shamcomm::world_rank() == 0
128+
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
129+
for (i32 nid = 0; nid < world_size; nid++) {
136130
shamlog_debug_ln(
137131
"HilbertLoadBalance", "node :", nid, "load :", load_per_node[nid]);
138132
}
139133
}
140-
avg /= shamcomm::world_size();
141-
for (i32 nid = 0; nid < shamcomm::world_size(); nid++) {
134+
avg /= world_size;
135+
136+
#pragma omp parallel for reduction(+ : var)
137+
for (i32 nid = 0; nid < world_size; nid++) {
142138
f64 val = load_per_node[nid];
143139
var += (val - avg) * (val - avg);
144140
}
145-
var /= shamcomm::world_size();
141+
var /= world_size;
146142

147143
if (shamcomm::world_rank() == 0) {
148144
std::string str = "Loadbalance stats : \n";

0 commit comments

Comments
 (0)