@@ -89,10 +89,13 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) {
8989 CeedCallSycl (ceed, sycl::free (impl->diag ->d_interp_out , sycl_data->sycl_context ));
9090 CeedCallSycl (ceed, sycl::free (impl->diag ->d_grad_in , sycl_data->sycl_context ));
9191 CeedCallSycl (ceed, sycl::free (impl->diag ->d_grad_out , sycl_data->sycl_context ));
92- CeedCallBackend (CeedElemRestrictionDestroy (&impl->diag ->point_block_diag_rstr ));
9392
9493 CeedCallBackend (CeedVectorDestroy (&impl->diag ->elem_diag ));
9594 CeedCallBackend (CeedVectorDestroy (&impl->diag ->point_block_elem_diag ));
95+ CeedCallBackend (CeedElemRestrictionDestroy (&impl->diag ->diag_rstr ));
96+ CeedCallBackend (CeedElemRestrictionDestroy (&impl->diag ->point_block_diag_rstr ));
97+ CeedCallBackend (CeedBasisDestroy (&impl->diag ->basis_in ));
98+ CeedCallBackend (CeedBasisDestroy (&impl->diag ->basis_out ));
9699 }
97100 CeedCallBackend (CeedFree (&impl->diag ));
98101
@@ -115,7 +118,7 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
115118 Ceed ceed;
116119 CeedSize q_size;
117120 bool is_strided, skip_restriction;
118- CeedInt dim, size;
121+ CeedInt size;
119122 CeedOperatorField *op_fields;
120123 CeedQFunctionField *qf_fields;
121124
@@ -133,7 +136,6 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
133136 CeedEvalMode eval_mode;
134137 CeedVector vec;
135138 CeedElemRestriction elem_rstr;
136- CeedBasis basis;
137139
138140 CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_fields[i], &eval_mode));
139141
@@ -183,20 +185,21 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
183185 CeedCallBackend (CeedVectorCreate (ceed, q_size, &q_vecs[i]));
184186 break ;
185187 case CEED_EVAL_GRAD:
186- CeedCallBackend (CeedOperatorFieldGetBasis (op_fields[i], &basis));
187188 CeedCallBackend (CeedQFunctionFieldGetSize (qf_fields[i], &size));
188- CeedCallBackend (CeedBasisGetDimension (basis, &dim));
189- CeedCallBackend (CeedBasisDestroy (&basis));
190189 q_size = (CeedSize)num_elem * Q * size;
191190 CeedCallBackend (CeedVectorCreate (ceed, q_size, &q_vecs[i]));
192191 break ;
193- case CEED_EVAL_WEIGHT: // Only on input fields
194- CeedCallBackend (CeedOperatorFieldGetBasis (op_fields[i], &basis));
192+ case CEED_EVAL_WEIGHT: {
193+ CeedBasis basis;
194+
195+ // Note: only on input fields
195196 q_size = (CeedSize)num_elem * Q;
196197 CeedCallBackend (CeedVectorCreate (ceed, q_size, &q_vecs[i]));
198+ CeedCallBackend (CeedOperatorFieldGetBasis (op_fields[i], &basis));
197199 CeedCallBackend (CeedBasisApply (basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
198200 CeedCallBackend (CeedBasisDestroy (&basis));
199201 break ;
202+ }
200203 case CEED_EVAL_DIV:
201204 break ; // TODO: Not implemented
202205 case CEED_EVAL_CURL:
@@ -463,8 +466,8 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
463466 // Restrict
464467 CeedCallBackend (CeedOperatorFieldGetVector (op_output_fields[i], &vec));
465468 is_active = vec == CEED_VECTOR_ACTIVE;
466- CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
467469 if (is_active) vec = out_vec;
470+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields[i], &elem_rstr));
468471 CeedCallBackend (CeedElemRestrictionApply (elem_rstr, CEED_TRANSPOSE, impl->e_vecs [i + impl->num_e_in ], vec, request));
469472 if (!is_active) CeedCallBackend (CeedVectorDestroy (&vec));
470473 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
@@ -637,6 +640,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
637640 Ceed_Sycl *sycl_data;
638641 CeedInt num_input_fields, num_output_fields, num_eval_mode_in = 0 , num_comp = 0 , dim = 1 , num_eval_mode_out = 0 ;
639642 CeedEvalMode *eval_mode_in = NULL , *eval_mode_out = NULL ;
643+ CeedElemRestriction rstr_in = NULL , rstr_out = NULL ;
640644 CeedBasis basis_in = NULL , basis_out = NULL ;
641645 CeedQFunctionField *qf_fields;
642646 CeedQFunction qf;
@@ -655,14 +659,19 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
655659
656660 CeedCallBackend (CeedOperatorFieldGetVector (op_fields[i], &vec));
657661 if (vec == CEED_VECTOR_ACTIVE) {
658- CeedEvalMode eval_mode;
659- CeedBasis basis;
662+ CeedEvalMode eval_mode;
663+ CeedElemRestriction elem_rstr;
664+ CeedBasis basis;
660665
666+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_fields[i], &elem_rstr));
667+ if (!rstr_in) CeedCallBackend (CeedElemRestrictionReferenceCopy (elem_rstr, &rstr_in));
668+ CeedCheck (rstr_in == elem_rstr, ceed, CEED_ERROR_BACKEND, " Backend does not implement multi-field non-composite operator diagonal assembly" );
669+ CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
661670 CeedCallBackend (CeedOperatorFieldGetBasis (op_fields[i], &basis));
662- CeedCheck (!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND,
663- " Backend does not implement operator diagonal assembly with multiple active bases" );
664671 if (!basis_in) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_in));
672+ CeedCheck (basis_in == basis, ceed, CEED_ERROR_BACKEND, " Backend does not implement operator diagonal assembly with multiple active bases" );
665673 CeedCallBackend (CeedBasisDestroy (&basis));
674+ CeedCallBackend (CeedBasisGetDimension (basis_in, &dim));
666675 CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_fields[i], &eval_mode));
667676 switch (eval_mode) {
668677 case CEED_EVAL_NONE:
@@ -684,6 +693,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
684693 }
685694 CeedCallBackend (CeedVectorDestroy (&vec));
686695 }
696+ CeedCallBackend (CeedElemRestrictionDestroy (&rstr_in));
687697
688698 // Determine active output basis
689699 CeedCallBackend (CeedOperatorGetFields (op, NULL , NULL , NULL , &op_fields));
@@ -693,26 +703,30 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
693703
694704 CeedCallBackend (CeedOperatorFieldGetVector (op_fields[i], &vec));
695705 if (vec == CEED_VECTOR_ACTIVE) {
696- CeedEvalMode eval_mode;
697- CeedBasis basis;
706+ CeedEvalMode eval_mode;
707+ CeedElemRestriction elem_rstr;
708+ CeedBasis basis;
698709
710+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_fields[i], &elem_rstr));
711+ if (!rstr_out) CeedCallBackend (CeedElemRestrictionReferenceCopy (elem_rstr, &rstr_out));
712+ CeedCheck (rstr_out == elem_rstr, ceed, CEED_ERROR_BACKEND, " Backend does not implement multi-field non-composite operator diagonal assembly" );
713+ CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
699714 CeedCallBackend (CeedOperatorFieldGetBasis (op_fields[i], &basis));
700- CeedCheck (!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
701- " Backend does not implement operator diagonal assembly with multiple active bases" );
702715 if (!basis_out) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_out));
716+ CeedCheck (basis_out == basis, ceed, CEED_ERROR_BACKEND, " Backend does not implement operator diagonal assembly with multiple active bases" );
703717 CeedCallBackend (CeedBasisDestroy (&basis));
704718 CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_fields[i], &eval_mode));
705719 switch (eval_mode) {
706720 case CEED_EVAL_NONE:
707721 case CEED_EVAL_INTERP:
708- CeedCallBackend (CeedRealloc (num_eval_mode_in + 1 , &eval_mode_in ));
709- eval_mode_in[num_eval_mode_in ] = eval_mode;
710- num_eval_mode_in += 1 ;
722+ CeedCallBackend (CeedRealloc (num_eval_mode_out + 1 , &eval_mode_out ));
723+ eval_mode_out[num_eval_mode_out ] = eval_mode;
724+ num_eval_mode_out += 1 ;
711725 break ;
712726 case CEED_EVAL_GRAD:
713- CeedCallBackend (CeedRealloc (num_eval_mode_in + dim, &eval_mode_in ));
714- for (CeedInt d = 0 ; d < dim; d++) eval_mode_in[num_eval_mode_in + d] = eval_mode;
715- num_eval_mode_in += dim;
727+ CeedCallBackend (CeedRealloc (num_eval_mode_out + dim, &eval_mode_out ));
728+ for (CeedInt d = 0 ; d < dim; d++) eval_mode_out[num_eval_mode_out + d] = eval_mode;
729+ num_eval_mode_out += dim;
716730 break ;
717731 case CEED_EVAL_WEIGHT:
718732 case CEED_EVAL_DIV:
@@ -729,8 +743,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
729743 CeedCallBackend (CeedCalloc (1 , &impl->diag ));
730744 CeedOperatorDiag_Sycl *diag = impl->diag ;
731745
732- diag->basis_in = basis_in ;
733- diag->basis_out = basis_out ;
746+ CeedCallBackend ( CeedBasisReferenceCopy (basis_in, & diag->basis_in )) ;
747+ CeedCallBackend ( CeedBasisReferenceCopy (basis_out, & diag->basis_out )) ;
734748 diag->h_eval_mode_in = eval_mode_in;
735749 diag->h_eval_mode_out = eval_mode_out;
736750 diag->num_eval_mode_in = num_eval_mode_in;
@@ -740,6 +754,7 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
740754 CeedInt num_nodes, num_qpts;
741755 CeedCallBackend (CeedBasisGetNumNodes (basis_in, &num_nodes));
742756 CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_in, &num_qpts));
757+ CeedCallBackend (CeedBasisGetNumComponents (basis_in, &num_comp));
743758 diag->num_nodes = num_nodes;
744759 diag->num_qpts = num_qpts;
745760 diag->num_comp = num_comp;
@@ -801,13 +816,12 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
801816 copy_events.push_back (eval_mode_out_copy);
802817
803818 // Restriction
804- {
805- CeedElemRestriction rstr_out;
819+ CeedCallBackend ( CeedElemRestrictionReferenceCopy (rstr_out, &diag-> diag_rstr ));
820+ CeedCallBackend ( CeedElemRestrictionDestroy (& rstr_out)) ;
806821
807- CeedCallBackend (CeedOperatorGetActiveElemRestrictions (op, NULL , &rstr_out));
808- diag->diag_rstr = rstr_out;
809- CeedCallBackend (CeedElemRestrictionDestroy (&rstr_out));
810- }
822+ // Cleanup
823+ CeedCallBackend (CeedBasisDestroy (&basis_in));
824+ CeedCallBackend (CeedBasisDestroy (&basis_out));
811825
812826 // Wait for all copies to complete and handle exceptions
813827 CeedCallSycl (ceed, sycl::event::wait_and_throw (copy_events));
@@ -1020,16 +1034,16 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
10201034 CeedElemRestriction elem_rstr;
10211035 CeedBasis basis;
10221036
1023- CeedCallBackend (CeedOperatorFieldGetBasis (input_fields[i], &basis));
1024- CeedCheck (!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, " Backend does not implement operator assembly with multiple active bases" );
1025- if (!basis_in) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_in));
1026- CeedCallBackend (CeedBasisGetDimension (basis, &dim));
1027- CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis, &num_qpts));
1028- CeedCallBackend (CeedBasisDestroy (&basis));
10291037 CeedCallBackend (CeedOperatorFieldGetElemRestriction (input_fields[i], &elem_rstr));
10301038 if (!rstr_in) CeedCallBackend (CeedElemRestrictionReferenceCopy (elem_rstr, &rstr_in));
10311039 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
10321040 CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in, &elem_size));
1041+ CeedCallBackend (CeedOperatorFieldGetBasis (input_fields[i], &basis));
1042+ if (!basis_in) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_in));
1043+ CeedCheck (basis_in == basis, ceed, CEED_ERROR_BACKEND, " Backend does not implement operator assembly with multiple active bases" );
1044+ CeedCallBackend (CeedBasisDestroy (&basis));
1045+ CeedCallBackend (CeedBasisGetDimension (basis_in, &dim));
1046+ CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_in, &num_qpts));
10331047 CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_fields[i], &eval_mode));
10341048 if (eval_mode != CEED_EVAL_NONE) {
10351049 CeedCallBackend (CeedRealloc (num_B_in_mats_to_load + 1 , &eval_mode_in));
@@ -1058,14 +1072,14 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
10581072 CeedElemRestriction elem_rstr;
10591073 CeedBasis basis;
10601074
1061- CeedCallBackend (CeedOperatorFieldGetBasis (output_fields[i], &basis));
1062- CeedCheck (!basis_out || basis_out == basis, ceed, CEED_ERROR_BACKEND,
1063- " Backend does not implement operator assembly with multiple active bases" );
1064- if (!basis_out) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_out));
1065- CeedCallBackend (CeedBasisDestroy (&basis));
10661075 CeedCallBackend (CeedOperatorFieldGetElemRestriction (output_fields[i], &elem_rstr));
10671076 if (!rstr_out) CeedCallBackend (CeedElemRestrictionReferenceCopy (elem_rstr, &rstr_out));
1077+ CeedCheck (rstr_out == rstr_in, ceed, CEED_ERROR_BACKEND, " Backend does not implement multi-field non-composite operator assembly" );
10681078 CeedCallBackend (CeedElemRestrictionDestroy (&elem_rstr));
1079+ CeedCallBackend (CeedOperatorFieldGetBasis (output_fields[i], &basis));
1080+ if (!basis_out) CeedCallBackend (CeedBasisReferenceCopy (basis, &basis_out));
1081+ CeedCheck (basis_out == basis, ceed, CEED_ERROR_BACKEND, " Backend does not implement operator assembly with multiple active bases" );
1082+ CeedCallBackend (CeedBasisDestroy (&basis));
10691083 CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_fields[i], &eval_mode));
10701084 if (eval_mode != CEED_EVAL_NONE) {
10711085 CeedCallBackend (CeedRealloc (num_B_out_mats_to_load + 1 , &eval_mode_out));
0 commit comments