Skip to content

Commit bdd4742

Browse files
authored
Merge pull request #1673 from CEED/jeremy/use-work-vecs
GPU Operators use work vectors
2 parents 01ecf5b + 96093a6 commit bdd4742

File tree

15 files changed

+1014
-812
lines changed

15 files changed

+1014
-812
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,11 +839,11 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op) {
839839
CeedElemRestriction rstr_i;
840840

841841
if (is_ordered[i]) continue;
842-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
843842
field_rstr_in_buffer[i] = i;
844843
is_ordered[i] = true;
845844
input_field_order[curr_index] = i;
846845
curr_index++;
846+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
847847
if (vec_i == CEED_VECTOR_NONE) continue; // CEED_EVAL_WEIGHT
848848
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
849849
for (CeedInt j = i + 1; j < num_input_fields; j++) {

backends/cuda-ref/ceed-cuda-ref-basis.c

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn
4040

4141
// Clear v for transpose operation
4242
if (is_transpose && !apply_add) {
43+
CeedInt num_comp, q_comp, num_nodes, num_qpts;
4344
CeedSize length;
4445

45-
CeedCallBackend(CeedVectorGetLength(v, &length));
46+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
47+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
48+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
49+
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
50+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
4651
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
4752
}
4853
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
@@ -206,9 +211,14 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons
206211

207212
// Clear v for transpose operation
208213
if (is_transpose && !apply_add) {
214+
CeedInt num_comp, q_comp, num_nodes;
209215
CeedSize length;
210216

211-
CeedCallBackend(CeedVectorGetLength(v, &length));
217+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
218+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
219+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
220+
length =
221+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
212222
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
213223
}
214224

@@ -283,9 +293,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con
283293

284294
// Clear v for transpose operation
285295
if (is_transpose && !apply_add) {
296+
CeedInt num_comp, q_comp;
286297
CeedSize length;
287298

288-
CeedCallBackend(CeedVectorGetLength(v, &length));
299+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
300+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
301+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
289302
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
290303
}
291304

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 469 additions & 383 deletions
Large diffs are not rendered by default.

backends/cuda-ref/ceed-cuda-ref.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,14 @@ typedef struct {
131131
} CeedOperatorAssemble_Cuda;
132132

133133
typedef struct {
134-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs;
134+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
135135
uint64_t *input_states; // State tracking for passive inputs
136-
CeedVector *e_vecs; // E-vectors, inputs followed by outputs
137-
CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator
138-
CeedVector *q_vecs_out; // Output Q-vectors needed to apply operator
136+
CeedVector *e_vecs_in, *e_vecs_out;
137+
CeedVector *q_vecs_in, *q_vecs_out;
139138
CeedInt num_inputs, num_outputs;
140139
CeedInt num_active_in, num_active_out;
140+
CeedInt *input_field_order, *output_field_order;
141+
CeedSize max_active_e_vec_len;
141142
CeedInt max_num_points;
142143
CeedInt *num_points;
143144
CeedVector *qf_active_in, point_coords_elem;

backends/cuda-shared/ceed-cuda-shared-basis.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,14 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad
312312

313313
// Clear v for transpose operation
314314
if (is_transpose && !apply_add) {
315+
CeedInt num_comp, q_comp, num_nodes;
315316
CeedSize length;
316317

317-
CeedCallBackend(CeedVectorGetLength(v, &length));
318+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
319+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
320+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
321+
length =
322+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
318323
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
319324
}
320325

backends/hip-gen/ceed-hip-gen-operator-build.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,11 +848,11 @@ extern "C" int CeedOperatorBuildKernel_Hip_gen(CeedOperator op) {
848848
CeedElemRestriction rstr_i;
849849

850850
if (is_ordered[i]) continue;
851-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
852851
field_rstr_in_buffer[i] = i;
853852
is_ordered[i] = true;
854853
input_field_order[curr_index] = i;
855854
curr_index++;
855+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
856856
if (vec_i == CEED_VECTOR_NONE) continue; // CEED_EVAL_WEIGHT
857857
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
858858
for (CeedInt j = i + 1; j < num_input_fields; j++) {

backends/hip-ref/ceed-hip-ref-basis.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ static int CeedBasisApplyCore_Hip(CeedBasis basis, bool apply_add, const CeedInt
3939

4040
// Clear v for transpose operation
4141
if (is_transpose && !apply_add) {
42+
CeedInt num_comp, q_comp, num_nodes, num_qpts;
4243
CeedSize length;
4344

44-
CeedCallBackend(CeedVectorGetLength(v, &length));
45+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
46+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
47+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
48+
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
49+
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
4550
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
4651
}
4752
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
@@ -204,9 +209,14 @@ static int CeedBasisApplyAtPointsCore_Hip(CeedBasis basis, bool apply_add, const
204209

205210
// Clear v for transpose operation
206211
if (is_transpose && !apply_add) {
212+
CeedInt num_comp, q_comp, num_nodes;
207213
CeedSize length;
208214

209-
CeedCallBackend(CeedVectorGetLength(v, &length));
215+
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
216+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
217+
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
218+
length =
219+
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
210220
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
211221
}
212222

0 commit comments

Comments
 (0)