Skip to content

Commit 354029d

Browse files
committed
gpu - only overwite portion of basis target used
1 parent 1945a8d commit 354029d

File tree

5 files changed

+55
-18
lines changed

5 files changed

+55
-18
lines changed

backends/cuda-ref/ceed-cuda-ref-basis.c

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn
4040

4141
// Clear v for transpose operation
4242
if (is_transpose && !apply_add) {
43+
CeedInt num_comp, q_comp, num_nodes, num_qpts;
4344
CeedSize length;
4445

45-
CeedCallBackend(CeedVectorGetLength(v, &length));
46+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
47+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
48+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
49+
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
50+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
4651
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
4752
}
4853
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
@@ -206,9 +211,14 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons
206211

207212
// Clear v for transpose operation
208213
if (is_transpose && !apply_add) {
214+
CeedInt num_comp, q_comp, num_nodes;
209215
CeedSize length;
210216

211-
CeedCallBackend(CeedVectorGetLength(v, &length));
217+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
218+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
219+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
220+
length =
221+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
212222
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
213223
}
214224

@@ -283,9 +293,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con
283293

284294
// Clear v for transpose operation
285295
if (is_transpose && !apply_add) {
296+
CeedInt num_comp, q_comp;
286297
CeedSize length;
287298

288-
CeedCallBackend(CeedVectorGetLength(v, &length));
299+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
300+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
301+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
289302
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
290303
}
291304

backends/cuda-shared/ceed-cuda-shared-basis.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,14 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad
312312

313313
// Clear v for transpose operation
314314
if (is_transpose && !apply_add) {
315+
CeedInt num_comp, q_comp, num_nodes;
315316
CeedSize length;
316317

317-
CeedCallBackend(CeedVectorGetLength(v, &length));
318+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
319+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
320+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
321+
length =
322+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
318323
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
319324
}
320325

backends/hip-ref/ceed-hip-ref-basis.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ static int CeedBasisApplyCore_Hip(CeedBasis basis, bool apply_add, const CeedInt
3939

4040
// Clear v for transpose operation
4141
if (is_transpose && !apply_add) {
42+
CeedInt num_comp, q_comp, num_nodes, num_qpts;
4243
CeedSize length;
4344

44-
CeedCallBackend(CeedVectorGetLength(v, &length));
45+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
46+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
47+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
48+
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
49+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
4550
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
4651
}
4752
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
@@ -204,9 +209,14 @@ static int CeedBasisApplyAtPointsCore_Hip(CeedBasis basis, bool apply_add, const
204209

205210
// Clear v for transpose operation
206211
if (is_transpose && !apply_add) {
212+
CeedInt num_comp, q_comp, num_nodes;
207213
CeedSize length;
208214

209-
CeedCallBackend(CeedVectorGetLength(v, &length));
215+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
216+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
217+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
218+
length =
219+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
210220
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
211221
}
212222

backends/hip-shared/ceed-hip-shared-basis.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,14 @@ static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add
371371

372372
// Clear v for transpose operation
373373
if (is_transpose && !apply_add) {
374+
CeedInt num_comp, q_comp, num_nodes;
374375
CeedSize length;
375376

376-
CeedCallBackend(CeedVectorGetLength(v, &length));
377+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
378+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
379+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
380+
length =
381+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
377382
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
378383
}
379384

interface/ceed-basis.c

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,10 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co
333333

334334
// Check compatibility coordinates vector
335335
for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i];
336-
CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
336+
CeedCheck((x_length >= (CeedSize)total_num_points * (CeedSize)dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
337337
"Length of reference coordinate vector incompatible with basis dimension and number of points."
338338
" Found reference coordinate vector of length %" CeedSize_FMT ", not of length %" CeedSize_FMT ".",
339-
x_length, total_num_points * dim);
339+
x_length, (CeedSize)total_num_points * (CeedSize)dim);
340340

341341
// Check CEED_EVAL_WEIGHT only on CEED_NOTRANSPOSE
342342
CeedCheck(eval_mode != CEED_EVAL_WEIGHT || t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_UNSUPPORTED,
@@ -346,13 +346,16 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co
346346
bool has_good_dims = true;
347347
switch (eval_mode) {
348348
case CEED_EVAL_INTERP:
349-
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp || v_length >= num_elem * num_nodes * num_comp)) ||
350-
(t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp || u_length >= num_elem * num_nodes * num_comp)));
349+
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp ||
350+
v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) ||
351+
(t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp ||
352+
u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)));
351353
break;
352354
case CEED_EVAL_GRAD:
353-
has_good_dims =
354-
((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp * dim || v_length >= num_elem * num_nodes * num_comp)) ||
355-
(t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp * dim || u_length >= num_elem * num_nodes * num_comp)));
355+
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim ||
356+
v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) ||
357+
(t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim ||
358+
u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)));
356359
break;
357360
case CEED_EVAL_WEIGHT:
358361
has_good_dims = t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points);
@@ -1822,12 +1825,13 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp
18221825
case CEED_EVAL_GRAD:
18231826
case CEED_EVAL_DIV:
18241827
case CEED_EVAL_CURL:
1825-
has_good_dims =
1826-
((t_mode == CEED_TRANSPOSE && u_length >= num_elem * num_comp * num_qpts * q_comp && v_length >= num_elem * num_comp * num_nodes) ||
1827-
(t_mode == CEED_NOTRANSPOSE && v_length >= num_elem * num_qpts * num_comp * q_comp && u_length >= num_elem * num_comp * num_nodes));
1828+
has_good_dims = ((t_mode == CEED_TRANSPOSE && u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_qpts * (CeedSize)q_comp &&
1829+
v_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes) ||
1830+
(t_mode == CEED_NOTRANSPOSE && v_length >= (CeedSize)num_elem * (CeedSize)num_qpts * (CeedSize)num_comp * (CeedSize)q_comp &&
1831+
u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes));
18281832
break;
18291833
case CEED_EVAL_WEIGHT:
1830-
has_good_dims = v_length >= num_elem * num_qpts;
1834+
has_good_dims = v_length >= (CeedSize)num_elem * (CeedSize)num_qpts;
18311835
break;
18321836
}
18331837
CeedCheck(has_good_dims, ceed, CEED_ERROR_DIMENSION, "Input/output vectors too short for basis and evaluation mode");

0 commit comments

Comments
 (0)