Skip to content

Commit 6da9377

Browse files
committed
Fix scratch level 0 bounds check
1 parent d43f89d commit 6da9377

1 file changed

Lines changed: 61 additions & 51 deletions

File tree

src/hydro/hydro.cpp

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,64 +1200,64 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12001200
const int Nb = u0_cons_pack.GetDim(5);
12011201
const int Nk = ku + 1 - kl;
12021202
const int NbNk = Nb * Nk;
1203-
auto x2flux = KOKKOS_LAMBDA(parthenon::team_mbr_t member) {
1204-
const int b = member.league_rank() / Nk;
1205-
const int k = member.league_rank() % Nk + kl;
1206-
const auto &prim = u0_prim_pack(b);
1207-
parthenon::ScratchPad2D<Real> wl(member.team_scratch(scratch_level),
1208-
num_scratch_vars, nx1);
1209-
parthenon::ScratchPad2D<Real> wr(member.team_scratch(scratch_level),
1210-
num_scratch_vars, nx1);
1211-
parthenon::ScratchPad2D<Real> wlb(member.team_scratch(scratch_level),
1212-
num_scratch_vars, nx1);
1213-
parthenon::ScratchPad2D<Real> flxl(member.team_scratch(scratch_level),
1203+
for (scratch_level = 0; scratch_level < 2; scratch_level++) {
1204+
auto x2flux = KOKKOS_LAMBDA(parthenon::team_mbr_t member) {
1205+
const int b = member.league_rank() / Nk;
1206+
const int k = member.league_rank() % Nk + kl;
1207+
const auto &prim = u0_prim_pack(b);
1208+
parthenon::ScratchPad2D<Real> wl(member.team_scratch(scratch_level),
12141209
num_scratch_vars, nx1);
1215-
parthenon::ScratchPad2D<Real> flxr(member.team_scratch(scratch_level),
1210+
parthenon::ScratchPad2D<Real> wr(member.team_scratch(scratch_level),
12161211
num_scratch_vars, nx1);
1217-
for (int j = jb.s - 1; j <= jb.e + 1; ++j) {
1218-
// reconstruct L/R states at j
1219-
Reconstruct<recon, X2DIR>(member, k, j, il, iu, prim, wlb, wr);
1220-
// Sync all threads in the team so that scratch memory is consistent
1221-
member.team_barrier();
1222-
1223-
if (j > jb.s - 1) {
1224-
riemann.Solve(member, il, iu, IV2, wl, wr, flxr, eos, c_h);
1212+
parthenon::ScratchPad2D<Real> wlb(member.team_scratch(scratch_level),
1213+
num_scratch_vars, nx1);
1214+
parthenon::ScratchPad2D<Real> flxl(member.team_scratch(scratch_level),
1215+
num_scratch_vars, nx1);
1216+
parthenon::ScratchPad2D<Real> flxr(member.team_scratch(scratch_level),
1217+
num_scratch_vars, nx1);
1218+
for (int j = jb.s - 1; j <= jb.e + 1; ++j) {
1219+
// reconstruct L/R states at j
1220+
Reconstruct<recon, X2DIR>(member, k, j, il, iu, prim, wlb, wr);
1221+
// Sync all threads in the team so that scratch memory is consistent
12251222
member.team_barrier();
1226-
if (j > jb.s) {
1227-
const auto &coords = u0_cons_pack.GetCoords(b);
1228-
// Now directly update
1229-
const int Nv = u0_cons_pack.GetDim(4);
1230-
const int Ni = iu - il + 1;
1231-
const int NvNi = Nv * Ni;
1232-
auto tvr = Kokkos::TeamVectorRange(member, u0_cons_pack.GetDim(4), NvNi);
1233-
Kokkos::parallel_for(tvr, [&](const int idx) {
1234-
const int v = idx / Ni;
1235-
const int i = idx % Ni + il;
1236-
const auto du = -(coords.FaceArea<X2DIR>(k, j, i) * flxr(v, i) -
1237-
coords.FaceArea<X2DIR>(k, j - 1, i) * flxl(v, i)) /
1238-
coords.CellVolume(k, j - 1, i);
1239-
1240-
// WARNING: this is specific to the VL2 integrator
1241-
u0_cons_pack(b, v, k, j - 1, i) +=
1242-
// gam0 * u0_cons_pack(b, v, k, j, i) +
1243-
// gam1 * u1_cons_pack(b, v, k, j, i) +
1244-
beta_dt * du;
1245-
});
1223+
1224+
if (j > jb.s - 1) {
1225+
riemann.Solve(member, il, iu, IV2, wl, wr, flxr, eos, c_h);
12461226
member.team_barrier();
1227+
if (j > jb.s) {
1228+
const auto &coords = u0_cons_pack.GetCoords(b);
1229+
// Now directly update
1230+
const int Nv = u0_cons_pack.GetDim(4);
1231+
const int Ni = iu - il + 1;
1232+
const int NvNi = Nv * Ni;
1233+
auto tvr = Kokkos::TeamVectorRange(member, u0_cons_pack.GetDim(4), NvNi);
1234+
Kokkos::parallel_for(tvr, [&](const int idx) {
1235+
const int v = idx / Ni;
1236+
const int i = idx % Ni + il;
1237+
const auto du = -(coords.FaceArea<X2DIR>(k, j, i) * flxr(v, i) -
1238+
coords.FaceArea<X2DIR>(k, j - 1, i) * flxl(v, i)) /
1239+
coords.CellVolume(k, j - 1, i);
1240+
1241+
// WARNING: this is specific to the VL2 integrator
1242+
u0_cons_pack(b, v, k, j - 1, i) +=
1243+
// gam0 * u0_cons_pack(b, v, k, j, i) +
1244+
// gam1 * u1_cons_pack(b, v, k, j, i) +
1245+
beta_dt * du;
1246+
});
1247+
member.team_barrier();
1248+
}
12471249
}
1248-
}
12491250

1250-
// swap the arrays for the next step
1251-
auto *tmp = wl.data();
1252-
wl.assign_data(wlb.data());
1253-
wlb.assign_data(tmp);
1254-
tmp = flxr.data();
1255-
flxr.assign_data(flxl.data());
1256-
flxl.assign_data(tmp);
1257-
}
1258-
};
1251+
// swap the arrays for the next step
1252+
auto *tmp = wl.data();
1253+
wl.assign_data(wlb.data());
1254+
wlb.assign_data(tmp);
1255+
tmp = flxr.data();
1256+
flxr.assign_data(flxl.data());
1257+
flxl.assign_data(tmp);
1258+
}
1259+
};
12591260

1260-
for (scratch_level = 0; scratch_level < 2; scratch_level++) {
12611261
std::array<int, 7> team_sizes = {32, 64, 128, 160, 192, 224, 256};
12621262
for (int t = 0; t < 8; t++) {
12631263
int team_size;
@@ -1276,6 +1276,14 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12761276
if (team_size > team_size_max) {
12771277
continue;
12781278
}
1279+
const auto scratch_size_max =
1280+
parthenon::team_policy(DevExecSpace(), NbNk, 1)
1281+
.set_scratch_size(scratch_level, Kokkos::PerTeam(scratch_size_in_bytes))
1282+
.scratch_size_max(scratch_level);
1283+
if (scratch_size_in_bytes > scratch_size_max && scratch_level == 0) {
1284+
continue;
1285+
}
1286+
12791287
parthenon::team_policy policy(DevExecSpace(), NbNk, team_size);
12801288

12811289
Kokkos::parallel_for("x2 flux" + suffix + " TVR scratch " +
@@ -1289,6 +1297,8 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12891297
}
12901298
//--------------------------------------------------------------------------------------
12911299
// k-direction
1300+
scratch_level =
1301+
pkg->Param<int>("scratch_level"); // 0 is actual scratch (tiny); 1 is HBM
12921302
if (pmb->pmy_mesh->ndim >= 3) {
12931303
// set the loop limits
12941304
il = ib.s - 1, iu = ib.e + 1, jl = jb.s - 1, ju = jb.e + 1;

0 commit comments

Comments
 (0)