Skip to content

Commit 1945a8d

Browse files
committed
gpu - use cached work vectors across operators
1 parent c08fdf4 commit 1945a8d

File tree

8 files changed

+454
-342
lines changed

8 files changed

+454
-342
lines changed

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 222 additions & 160 deletions
Large diffs are not rendered by default.

backends/cuda-ref/ceed-cuda-ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ typedef struct {
131131
} CeedOperatorAssemble_Cuda;
132132

133133
typedef struct {
134-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs;
134+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
135135
uint64_t *input_states; // State tracking for passive inputs
136136
CeedVector *e_vecs_in, *e_vecs_out;
137137
CeedVector *q_vecs_in, *q_vecs_out;

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 225 additions & 165 deletions
Large diffs are not rendered by default.

backends/hip-ref/ceed-hip-ref.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ typedef struct {
135135
} CeedOperatorAssemble_Hip;
136136

137137
typedef struct {
138-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs;
138+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
139139
uint64_t *input_states; // State tracking for passive inputs
140140
CeedVector *e_vecs_in, *e_vecs_out;
141141
CeedVector *q_vecs_in, *q_vecs_out;

interface/ceed-basis.c

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,6 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co
331331
if (x_ref != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(x_ref, &x_length));
332332
if (u != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(u, &u_length));
333333

334-
// Check compatibility of topological and geometrical dimensions
335-
CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0) || (t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0) ||
336-
(eval_mode == CEED_EVAL_WEIGHT),
337-
ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions and number of points");
338-
339334
// Check compatibility coordinates vector
340335
for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i];
341336
CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
@@ -1819,11 +1814,6 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp
18191814
CeedCall(CeedVectorGetLength(v, &v_length));
18201815
if (u) CeedCall(CeedVectorGetLength(u, &u_length));
18211816

1822-
// Check compatibility of topological and geometrical dimensions
1823-
CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0 && u_length % num_qpts == 0) ||
1824-
(t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0 && v_length % num_qpts == 0),
1825-
ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions");
1826-
18271817
// Check vector lengths to prevent out of bounds issues
18281818
bool has_good_dims = true;
18291819
switch (eval_mode) {

interface/ceed-vector.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
862862
CeedCall(CeedVectorGetLength(w, &length_w));
863863
CeedCall(CeedVectorGetLength(x, &length_x));
864864
CeedCall(CeedVectorGetLength(y, &length_y));
865-
CeedCheck(length_w == length_x && length_w == length_y, ceed, CEED_ERROR_UNSUPPORTED,
865+
CeedCheck(length_x >= length_x && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
866866
"Cannot multiply vectors of different lengths."
867867
" x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
868868
length_x, length_y);

tests/junit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def check_required_failure(self, test: str, spec: TestSpec, resource: str, stder
172172
elif test_id in ['t215']:
173173
fail_str = 'Cannot destroy CeedElemRestriction, a process has read access to the offset data'
174174
elif test_id in ['t303']:
175-
fail_str = 'Length of input/output vectors incompatible with basis dimensions'
175+
fail_str = 'Input/output vectors too short for basis and evaluation mode'
176176
elif test_id in ['t408']:
177177
fail_str = 'CeedQFunctionContextGetData(): Cannot grant CeedQFunctionContext data access, a process has read access'
178178
elif test_id in ['t409'] and contains_any(resource, ['memcheck']):

tests/t303-basis.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/// @file
2-
/// Test checking BasisApply input/output vectors compatibility with basis dimensions
3-
/// \test Test checking BasisApply input/output vectors compatibility with basis dimensions
2+
/// Test checking BasisApply input/output vectors compatibility with basis
3+
/// \test Test checking BasisApply input/output vectors compatibility with basis
44

55
//TESTARGS(only="cpu") {ceed_resource}
66
#include <ceed.h>
@@ -15,7 +15,7 @@ int main(int argc, char **argv) {
1515
CeedInit(argv[1], &ceed);
1616

1717
CeedVectorCreate(ceed, len, &u);
18-
CeedVectorCreate(ceed, len + 1, &v);
18+
CeedVectorCreate(ceed, len - 1, &v);
1919

2020
CeedBasisCreateTensorH1Lagrange(ceed, dim, num_comp, p, q, CEED_GAUSS, &basis);
2121

0 commit comments

Comments
 (0)