@@ -1200,64 +1200,64 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12001200 const int Nb = u0_cons_pack.GetDim (5 );
12011201 const int Nk = ku + 1 - kl;
12021202 const int NbNk = Nb * Nk;
1203- auto x2flux = KOKKOS_LAMBDA (parthenon::team_mbr_t member) {
1204- const int b = member.league_rank () / Nk;
1205- const int k = member.league_rank () % Nk + kl;
1206- const auto &prim = u0_prim_pack (b);
1207- parthenon::ScratchPad2D<Real> wl (member.team_scratch (scratch_level),
1208- num_scratch_vars, nx1);
1209- parthenon::ScratchPad2D<Real> wr (member.team_scratch (scratch_level),
1210- num_scratch_vars, nx1);
1211- parthenon::ScratchPad2D<Real> wlb (member.team_scratch (scratch_level),
1212- num_scratch_vars, nx1);
1213- parthenon::ScratchPad2D<Real> flxl (member.team_scratch (scratch_level),
1203+ for (scratch_level = 0 ; scratch_level < 2 ; scratch_level++) {
1204+ auto x2flux = KOKKOS_LAMBDA (parthenon::team_mbr_t member) {
1205+ const int b = member.league_rank () / Nk;
1206+ const int k = member.league_rank () % Nk + kl;
1207+ const auto &prim = u0_prim_pack (b);
1208+ parthenon::ScratchPad2D<Real> wl (member.team_scratch (scratch_level),
12141209 num_scratch_vars, nx1);
1215- parthenon::ScratchPad2D<Real> flxr (member.team_scratch (scratch_level),
1210+ parthenon::ScratchPad2D<Real> wr (member.team_scratch (scratch_level),
12161211 num_scratch_vars, nx1);
1217- for (int j = jb.s - 1 ; j <= jb.e + 1 ; ++j) {
1218- // reconstruct L/R states at j
1219- Reconstruct<recon, X2DIR >(member, k, j, il, iu, prim, wlb, wr);
1220- // Sync all threads in the team so that scratch memory is consistent
1221- member.team_barrier ();
1222-
1223- if (j > jb.s - 1 ) {
1224- riemann.Solve (member, il, iu, IV2 , wl, wr, flxr, eos, c_h);
1212+ parthenon::ScratchPad2D<Real> wlb (member.team_scratch (scratch_level),
1213+ num_scratch_vars, nx1);
1214+ parthenon::ScratchPad2D<Real> flxl (member.team_scratch (scratch_level),
1215+ num_scratch_vars, nx1);
1216+ parthenon::ScratchPad2D<Real> flxr (member.team_scratch (scratch_level),
1217+ num_scratch_vars, nx1);
1218+ for (int j = jb.s - 1 ; j <= jb.e + 1 ; ++j) {
1219+ // reconstruct L/R states at j
1220+ Reconstruct<recon, X2DIR >(member, k, j, il, iu, prim, wlb, wr);
1221+ // Sync all threads in the team so that scratch memory is consistent
12251222 member.team_barrier ();
1226- if (j > jb.s ) {
1227- const auto &coords = u0_cons_pack.GetCoords (b);
1228- // Now directly update
1229- const int Nv = u0_cons_pack.GetDim (4 );
1230- const int Ni = iu - il + 1 ;
1231- const int NvNi = Nv * Ni;
1232- auto tvr = Kokkos::TeamVectorRange (member, u0_cons_pack.GetDim (4 ), NvNi);
1233- Kokkos::parallel_for (tvr, [&](const int idx) {
1234- const int v = idx / Ni;
1235- const int i = idx % Ni + il;
1236- const auto du = -(coords.FaceArea <X2DIR >(k, j, i) * flxr (v, i) -
1237- coords.FaceArea <X2DIR >(k, j - 1 , i) * flxl (v, i)) /
1238- coords.CellVolume (k, j - 1 , i);
1239-
1240- // WARNING: this is specific to the VL2 integrator
1241- u0_cons_pack (b, v, k, j - 1 , i) +=
1242- // gam0 * u0_cons_pack(b, v, k, j, i) +
1243- // gam1 * u1_cons_pack(b, v, k, j, i) +
1244- beta_dt * du;
1245- });
1223+
1224+ if (j > jb.s - 1 ) {
1225+ riemann.Solve (member, il, iu, IV2 , wl, wr, flxr, eos, c_h);
12461226 member.team_barrier ();
1227+ if (j > jb.s ) {
1228+ const auto &coords = u0_cons_pack.GetCoords (b);
1229+ // Now directly update
1230+ const int Nv = u0_cons_pack.GetDim (4 );
1231+ const int Ni = iu - il + 1 ;
1232+ const int NvNi = Nv * Ni;
1233+ auto tvr = Kokkos::TeamVectorRange (member, u0_cons_pack.GetDim (4 ), NvNi);
1234+ Kokkos::parallel_for (tvr, [&](const int idx) {
1235+ const int v = idx / Ni;
1236+ const int i = idx % Ni + il;
1237+ const auto du = -(coords.FaceArea <X2DIR >(k, j, i) * flxr (v, i) -
1238+ coords.FaceArea <X2DIR >(k, j - 1 , i) * flxl (v, i)) /
1239+ coords.CellVolume (k, j - 1 , i);
1240+
1241+ // WARNING: this is specific to the VL2 integrator
1242+ u0_cons_pack (b, v, k, j - 1 , i) +=
1243+ // gam0 * u0_cons_pack(b, v, k, j, i) +
1244+ // gam1 * u1_cons_pack(b, v, k, j, i) +
1245+ beta_dt * du;
1246+ });
1247+ member.team_barrier ();
1248+ }
12471249 }
1248- }
12491250
1250- // swap the arrays for the next step
1251- auto *tmp = wl.data ();
1252- wl.assign_data (wlb.data ());
1253- wlb.assign_data (tmp);
1254- tmp = flxr.data ();
1255- flxr.assign_data (flxl.data ());
1256- flxl.assign_data (tmp);
1257- }
1258- };
1251+ // swap the arrays for the next step
1252+ auto *tmp = wl.data ();
1253+ wl.assign_data (wlb.data ());
1254+ wlb.assign_data (tmp);
1255+ tmp = flxr.data ();
1256+ flxr.assign_data (flxl.data ());
1257+ flxl.assign_data (tmp);
1258+ }
1259+ };
12591260
1260- for (scratch_level = 0 ; scratch_level < 2 ; scratch_level++) {
12611261 std::array<int , 7 > team_sizes = {32 , 64 , 128 , 160 , 192 , 224 , 256 };
12621262 for (int t = 0 ; t < 8 ; t++) {
12631263 int team_size;
@@ -1276,6 +1276,14 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12761276 if (team_size > team_size_max) {
12771277 continue ;
12781278 }
1279+ const auto scratch_size_max =
1280+ parthenon::team_policy (DevExecSpace (), NbNk, 1 )
1281+ .set_scratch_size (scratch_level, Kokkos::PerTeam (scratch_size_in_bytes))
1282+ .scratch_size_max (scratch_level);
1283+ if (scratch_size_in_bytes > scratch_size_max && scratch_level == 0 ) {
1284+ continue ;
1285+ }
1286+
12791287 parthenon::team_policy policy (DevExecSpace (), NbNk, team_size);
12801288
12811289 Kokkos::parallel_for (" x2 flux" + suffix + " TVR scratch " +
@@ -1289,6 +1297,8 @@ TaskStatus CalculateFluxes(MeshData<Real> *u0_data, MeshData<Real> *u1_data,
12891297 }
12901298 // --------------------------------------------------------------------------------------
12911299 // k-direction
1300+ scratch_level =
1301+ pkg->Param <int >(" scratch_level" ); // 0 is actual scratch (tiny); 1 is HBM
12921302 if (pmb->pmy_mesh ->ndim >= 3 ) {
12931303 // set the loop limits
12941304 il = ib.s - 1 , iu = ib.e + 1 , jl = jb.s - 1 , ju = jb.e + 1 ;
0 commit comments