Skip to content

Commit 700701c

Browse files
committed
sycl - fix regresions
1 parent e3553ef commit 700701c

File tree

2 files changed

+60
-46
lines changed

2 files changed

+60
-46
lines changed

backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
274274
// Get elem_size, eval_mode, num_comp
275275
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
276276
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
277-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
278277
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
279278
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
279+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
280280

281281
// Set field constants
282282
if (eval_mode != CEED_EVAL_WEIGHT) {
@@ -334,9 +334,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
334334
// Get elem_size, eval_mode, num_comp
335335
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
336336
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
337-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
338337
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
339338
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
339+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
340340

341341
// Set field constants
342342
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
@@ -401,8 +401,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
401401
// Get elem_size, eval_mode, num_comp
402402
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
403403
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
404-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
405404
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
405+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
406406

407407
// Restriction
408408
if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
@@ -677,8 +677,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
677677
// Get elem_size, eval_mode, num_comp
678678
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
679679
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
680-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
681680
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
681+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
682682
// Basis action
683683
code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
684684
switch (eval_mode) {

backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,13 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) {
8989
CeedCallSycl(ceed, sycl::free(impl->diag->d_interp_out, sycl_data->sycl_context));
9090
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_in, sycl_data->sycl_context));
9191
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_out, sycl_data->sycl_context));
92-
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
9392

9493
CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
9594
CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
95+
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
96+
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
97+
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_in));
98+
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_out));
9699
}
97100
CeedCallBackend(CeedFree(&impl->diag));
98101

@@ -115,7 +118,7 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
115118
Ceed ceed;
116119
CeedSize q_size;
117120
bool is_strided, skip_restriction;
118-
CeedInt dim, size;
121+
CeedInt size;
119122
CeedOperatorField *op_fields;
120123
CeedQFunctionField *qf_fields;
121124

@@ -133,7 +136,6 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
133136
CeedEvalMode eval_mode;
134137
CeedVector vec;
135138
CeedElemRestriction elem_rstr;
136-
CeedBasis basis;
137139

138140
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
139141

@@ -183,20 +185,21 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
183185
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
184186
break;
185187
case CEED_EVAL_GRAD:
186-
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
187188
CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
188-
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
189-
CeedCallBackend(CeedBasisDestroy(&basis));
190189
q_size = (CeedSize)num_elem * Q * size;
191190
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
192191
break;
193-
case CEED_EVAL_WEIGHT: // Only on input fields
194-
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
192+
case CEED_EVAL_WEIGHT: {
193+
CeedBasis basis;
194+
195+
// Note: only on input fields
195196
q_size = (CeedSize)num_elem * Q;
196197
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
198+
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
197199
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
198200
CeedCallBackend(CeedBasisDestroy(&basis));
199201
break;
202+
}
200203
case CEED_EVAL_DIV:
201204
break; // TODO: Not implemented
202205
case CEED_EVAL_CURL:
@@ -463,8 +466,8 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
463466
// Restrict
464467
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
465468
is_active = vec == CEED_VECTOR_ACTIVE;
466-
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
467469
if (is_active) vec = out_vec;
470+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
468471
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_e_in], vec, request));
469472
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
470473
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
@@ -637,6 +640,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
637640
Ceed_Sycl *sycl_data;
638641
CeedInt num_input_fields, num_output_fields, num_eval_mode_in = 0, num_comp = 0, dim = 1, num_eval_mode_out = 0;
639642
CeedEvalMode *eval_mode_in = NULL, *eval_mode_out = NULL;
643+
CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
640644
CeedBasis basis_in = NULL, basis_out = NULL;
641645
CeedQFunctionField *qf_fields;
642646
CeedQFunction qf;
@@ -655,14 +659,19 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
655659

656660
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
657661
if (vec == CEED_VECTOR_ACTIVE) {
658-
CeedEvalMode eval_mode;
659-
CeedBasis basis;
662+
CeedEvalMode eval_mode;
663+
CeedElemRestriction elem_rstr;
664+
CeedBasis basis;
660665

666+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
667+
if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in));
668+
CeedCheck(rstr_in == elem_rstr, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly");
669+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
661670
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
662-
CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND,
663-
"Backend does not implement operator diagonal assembly with multiple active bases");
664671
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
672+
CeedCheck(basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator diagonal assembly with multiple active bases");
665673
CeedCallBackend(CeedBasisDestroy(&basis));
674+
CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
666675
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
667676
switch (eval_mode) {
668677
case CEED_EVAL_NONE:
@@ -684,6 +693,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
684693
}
685694
CeedCallBackend(CeedVectorDestroy(&vec));
686695
}
696+
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_in));
687697

688698
// Determine active output basis
689699
CeedCallBackend(CeedOperatorGetFields(op, NULL, NULL, NULL, &op_fields));
@@ -693,26 +703,30 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
693703

694704
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec));
695705
if (vec == CEED_VECTOR_ACTIVE) {
696-
CeedEvalMode eval_mode;
697-
CeedBasis basis;
706+
CeedEvalMode eval_mode;
707+
CeedElemRestriction elem_rstr;
708+
CeedBasis basis;
698709

710+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));
711+
if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out));
712+
CeedCheck(rstr_out == elem_rstr, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator diagonal assembly");
713+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
699714
CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis));
700-
CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
701-
"Backend does not implement operator diagonal assembly with multiple active bases");
702715
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
716+
CeedCheck(basis_out == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator diagonal assembly with multiple active bases");
703717
CeedCallBackend(CeedBasisDestroy(&basis));
704718
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
705719
switch (eval_mode) {
706720
case CEED_EVAL_NONE:
707721
case CEED_EVAL_INTERP:
708-
CeedCallBackend(CeedRealloc(num_eval_mode_in + 1, &eval_mode_in));
709-
eval_mode_in[num_eval_mode_in] = eval_mode;
710-
num_eval_mode_in += 1;
722+
CeedCallBackend(CeedRealloc(num_eval_mode_out + 1, &eval_mode_out));
723+
eval_mode_out[num_eval_mode_out] = eval_mode;
724+
num_eval_mode_out += 1;
711725
break;
712726
case CEED_EVAL_GRAD:
713-
CeedCallBackend(CeedRealloc(num_eval_mode_in + dim, &eval_mode_in));
714-
for (CeedInt d = 0; d < dim; d++) eval_mode_in[num_eval_mode_in + d] = eval_mode;
715-
num_eval_mode_in += dim;
727+
CeedCallBackend(CeedRealloc(num_eval_mode_out + dim, &eval_mode_out));
728+
for (CeedInt d = 0; d < dim; d++) eval_mode_out[num_eval_mode_out + d] = eval_mode;
729+
num_eval_mode_out += dim;
716730
break;
717731
case CEED_EVAL_WEIGHT:
718732
case CEED_EVAL_DIV:
@@ -729,8 +743,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
729743
CeedCallBackend(CeedCalloc(1, &impl->diag));
730744
CeedOperatorDiag_Sycl *diag = impl->diag;
731745

732-
diag->basis_in = basis_in;
733-
diag->basis_out = basis_out;
746+
CeedCallBackend(CeedBasisReferenceCopy(basis_in, &diag->basis_in));
747+
CeedCallBackend(CeedBasisReferenceCopy(basis_out, &diag->basis_out));
734748
diag->h_eval_mode_in = eval_mode_in;
735749
diag->h_eval_mode_out = eval_mode_out;
736750
diag->num_eval_mode_in = num_eval_mode_in;
@@ -740,6 +754,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
740754
CeedInt num_nodes, num_qpts;
741755
CeedCallBackend(CeedBasisGetNumNodes(basis_in, &num_nodes));
742756
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
757+
CeedCallBackend(CeedBasisGetNumComponents(basis_in, &num_comp));
743758
diag->num_nodes = num_nodes;
744759
diag->num_qpts = num_qpts;
745760
diag->num_comp = num_comp;
@@ -801,13 +816,12 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
801816
copy_events.push_back(eval_mode_out_copy);
802817

803818
// Restriction
804-
{
805-
CeedElemRestriction rstr_out;
819+
CeedCallBackend(CeedElemRestrictionReferenceCopy(rstr_out, &diag->diag_rstr));
820+
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
806821

807-
CeedCallBackend(CeedOperatorGetActiveElemRestrictions(op, NULL, &rstr_out));
808-
diag->diag_rstr = rstr_out;
809-
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
810-
}
822+
// Cleanup
823+
CeedCallBackend(CeedBasisDestroy(&basis_in));
824+
CeedCallBackend(CeedBasisDestroy(&basis_out));
811825

812826
// Wait for all copies to complete and handle exceptions
813827
CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
@@ -1020,16 +1034,16 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
10201034
CeedElemRestriction elem_rstr;
10211035
CeedBasis basis;
10221036

1023-
CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
1024-
CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
1025-
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
1026-
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
1027-
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
1028-
CeedCallBackend(CeedBasisDestroy(&basis));
10291037
CeedCallBackend(CeedOperatorFieldGetElemRestriction(input_fields[i], &elem_rstr));
10301038
if (!rstr_in) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_in));
10311039
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
10321040
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size));
1041+
CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
1042+
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
1043+
CeedCheck(basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
1044+
CeedCallBackend(CeedBasisDestroy(&basis));
1045+
CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
1046+
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
10331047
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
10341048
if (eval_mode != CEED_EVAL_NONE) {
10351049
CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in));
@@ -1058,14 +1072,14 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
10581072
CeedElemRestriction elem_rstr;
10591073
CeedBasis basis;
10601074

1061-
CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis));
1062-
CeedCheck(!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
1063-
"Backend does not implement operator assembly with multiple active bases");
1064-
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
1065-
CeedCallBackend(CeedBasisDestroy(&basis));
10661075
CeedCallBackend(CeedOperatorFieldGetElemRestriction(output_fields[i], &elem_rstr));
10671076
if (!rstr_out) CeedCallBackend(CeedElemRestrictionReferenceCopy(elem_rstr, &rstr_out));
1077+
CeedCheck(rstr_out == rstr_in, ceed, CEED_ERROR_BACKEND, "Backend does not implement multi-field non-composite operator assembly");
10681078
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1079+
CeedCallBackend(CeedOperatorFieldGetBasis(output_fields[i], &basis));
1080+
if (!basis_out) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_out));
1081+
CeedCheck(basis_out == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
1082+
CeedCallBackend(CeedBasisDestroy(&basis));
10691083
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
10701084
if (eval_mode != CEED_EVAL_NONE) {
10711085
CeedCallBackend(CeedRealloc(num_B_out_mats_to_load + 1, &eval_mode_out));

0 commit comments

Comments
 (0)