@@ -27,6 +27,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
2727
2828 // Apply data
2929 CeedCallBackend (CeedFree (& impl -> skip_rstr_in ));
30+ CeedCallBackend (CeedFree (& impl -> skip_rstr_out ));
31+ CeedCallBackend (CeedFree (& impl -> apply_add_basis_out ));
3032 for (CeedInt i = 0 ; i < impl -> num_inputs + impl -> num_outputs ; i ++ ) {
3133 CeedCallBackend (CeedVectorDestroy (& impl -> e_vecs [i ]));
3234 }
@@ -96,8 +98,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
9698//------------------------------------------------------------------------------
9799// Setup infields or outfields
98100//------------------------------------------------------------------------------
99- static int CeedOperatorSetupFields_Hip (CeedQFunction qf , CeedOperator op , bool is_input , bool is_at_points , bool * skip_rstr , CeedVector * e_vecs ,
100- CeedVector * q_vecs , CeedInt start_e , CeedInt num_fields , CeedInt Q , CeedInt num_elem ) {
101+ static int CeedOperatorSetupFields_Hip (CeedQFunction qf , CeedOperator op , bool is_input , bool is_at_points , bool * skip_rstr , bool * apply_add_basis ,
102+ CeedVector * e_vecs , CeedVector * q_vecs , CeedInt start_e , CeedInt num_fields , CeedInt Q , CeedInt num_elem ) {
101103 Ceed ceed ;
102104 CeedQFunctionField * qf_fields ;
103105 CeedOperatorField * op_fields ;
@@ -183,7 +185,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
183185 break ;
184186 }
185187 }
186- // Drop duplicate input restrictions
188+ // Drop duplicate restrictions
187189 if (is_input ) {
188190 for (CeedInt i = 0 ; i < num_fields ; i ++ ) {
189191 CeedVector vec_i ;
@@ -198,11 +200,31 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
198200 CeedCallBackend (CeedOperatorFieldGetVector (op_fields [j ], & vec_j ));
199201 CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_fields [j ], & rstr_j ));
200202 if (vec_i == vec_j && rstr_i == rstr_j ) {
201- CeedCallBackend (CeedVectorReferenceCopy (e_vecs [i ], & e_vecs [j ]));
203+ CeedCallBackend (CeedVectorReferenceCopy (e_vecs [i + start_e ], & e_vecs [j + start_e ]));
202204 skip_rstr [j ] = true;
203205 }
204206 }
205207 }
208+ } else {
209+ for (CeedInt i = num_fields - 1 ; i >= 0 ; i -- ) {
210+ CeedVector vec_i ;
211+ CeedElemRestriction rstr_i ;
212+
213+ CeedCallBackend (CeedOperatorFieldGetVector (op_fields [i ], & vec_i ));
214+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_fields [i ], & rstr_i ));
215+ for (CeedInt j = i - 1 ; j >= 0 ; j -- ) {
216+ CeedVector vec_j ;
217+ CeedElemRestriction rstr_j ;
218+
219+ CeedCallBackend (CeedOperatorFieldGetVector (op_fields [j ], & vec_j ));
220+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_fields [j ], & rstr_j ));
221+ if (vec_i == vec_j && rstr_i == rstr_j ) {
222+ CeedCallBackend (CeedVectorReferenceCopy (e_vecs [i + start_e ], & e_vecs [j + start_e ]));
223+ skip_rstr [j ] = true;
224+ apply_add_basis [i ] = true;
225+ }
226+ }
227+ }
206228 }
207229 return CEED_ERROR_SUCCESS ;
208230}
@@ -233,6 +255,8 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
233255 // Allocate
234256 CeedCallBackend (CeedCalloc (num_input_fields + num_output_fields , & impl -> e_vecs ));
235257 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> skip_rstr_in ));
258+ CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> skip_rstr_out ));
259+ CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> apply_add_basis_out ));
236260 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> input_states ));
237261 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> q_vecs_in ));
238262 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> q_vecs_out ));
@@ -242,10 +266,10 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
242266 // Set up infield and outfield e_vecs and q_vecs
243267 // Infields
244268 CeedCallBackend (
245- CeedOperatorSetupFields_Hip (qf , op , true, false, impl -> skip_rstr_in , impl -> e_vecs , impl -> q_vecs_in , 0 , num_input_fields , Q , num_elem ));
269+ CeedOperatorSetupFields_Hip (qf , op , true, false, impl -> skip_rstr_in , NULL , impl -> e_vecs , impl -> q_vecs_in , 0 , num_input_fields , Q , num_elem ));
246270 // Outfields
247- CeedCallBackend (
248- CeedOperatorSetupFields_Hip ( qf , op , false, false, NULL , impl -> e_vecs , impl -> q_vecs_out , num_input_fields , num_output_fields , Q , num_elem ));
271+ CeedCallBackend (CeedOperatorSetupFields_Hip ( qf , op , false, false, impl -> skip_rstr_out , impl -> apply_add_basis_out , impl -> e_vecs , impl -> q_vecs_out ,
272+ num_input_fields , num_output_fields , Q , num_elem ));
249273
250274 CeedCallBackend (CeedOperatorSetSetupDone (op ));
251275 return CEED_ERROR_SUCCESS ;
@@ -430,7 +454,11 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
430454 case CEED_EVAL_DIV :
431455 case CEED_EVAL_CURL :
432456 CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields [i ], & basis ));
433- CeedCallBackend (CeedBasisApply (basis , num_elem , CEED_TRANSPOSE , eval_mode , impl -> q_vecs_out [i ], impl -> e_vecs [i + impl -> num_inputs ]));
457+ if (impl -> apply_add_basis_out [i ]) {
458+ CeedCallBackend (CeedBasisApplyAdd (basis , num_elem , CEED_TRANSPOSE , eval_mode , impl -> q_vecs_out [i ], impl -> e_vecs [i + impl -> num_inputs ]));
459+ } else {
460+ CeedCallBackend (CeedBasisApply (basis , num_elem , CEED_TRANSPOSE , eval_mode , impl -> q_vecs_out [i ], impl -> e_vecs [i + impl -> num_inputs ]));
461+ }
434462 break ;
435463 // LCOV_EXCL_START
436464 case CEED_EVAL_WEIGHT : {
@@ -451,6 +479,7 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
451479 if (eval_mode == CEED_EVAL_NONE ) {
452480 CeedCallBackend (CeedVectorRestoreArray (impl -> e_vecs [i + impl -> num_inputs ], & e_data [i + num_input_fields ]));
453481 }
482+ if (impl -> skip_rstr_out [i ]) continue ;
454483 // Get output vector
455484 CeedCallBackend (CeedOperatorFieldGetVector (op_output_fields [i ], & vec ));
456485 // Restrict
@@ -498,6 +527,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
498527 // Allocate
499528 CeedCallBackend (CeedCalloc (num_input_fields + num_output_fields , & impl -> e_vecs ));
500529 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> skip_rstr_in ));
530+ CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> skip_rstr_out ));
531+ CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> apply_add_basis_out ));
501532 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> input_states ));
502533 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> q_vecs_in ));
503534 CeedCallBackend (CeedCalloc (CEED_FIELD_MAX , & impl -> q_vecs_out ));
@@ -506,11 +537,11 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
506537
507538 // Set up infield and outfield e_vecs and q_vecs
508539 // Infields
509- CeedCallBackend (CeedOperatorSetupFields_Hip (qf , op , true, true, impl -> skip_rstr_in , impl -> e_vecs , impl -> q_vecs_in , 0 , num_input_fields ,
540+ CeedCallBackend (CeedOperatorSetupFields_Hip (qf , op , true, true, impl -> skip_rstr_in , NULL , impl -> e_vecs , impl -> q_vecs_in , 0 , num_input_fields ,
510541 max_num_points , num_elem ));
511542 // Outfields
512- CeedCallBackend (CeedOperatorSetupFields_Hip (qf , op , false, true, NULL , impl -> e_vecs , impl -> q_vecs_out , num_input_fields , num_output_fields ,
513- max_num_points , num_elem ));
543+ CeedCallBackend (CeedOperatorSetupFields_Hip (qf , op , false, true, impl -> skip_rstr_out , impl -> apply_add_basis_out , impl -> e_vecs , impl -> q_vecs_out ,
544+ num_input_fields , num_output_fields , max_num_points , num_elem ));
514545
515546 CeedCallBackend (CeedOperatorSetSetupDone (op ));
516547 return CEED_ERROR_SUCCESS ;
@@ -634,8 +665,13 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
634665 case CEED_EVAL_DIV :
635666 case CEED_EVAL_CURL :
636667 CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields [i ], & basis ));
637- CeedCallBackend (CeedBasisApplyAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem , impl -> q_vecs_out [i ],
638- impl -> e_vecs [i + impl -> num_inputs ]));
668+ if (impl -> apply_add_basis_out [i ]) {
669+ CeedCallBackend (CeedBasisApplyAddAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem ,
670+ impl -> q_vecs_out [i ], impl -> e_vecs [i + impl -> num_inputs ]));
671+ } else {
672+ CeedCallBackend (CeedBasisApplyAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem , impl -> q_vecs_out [i ],
673+ impl -> e_vecs [i + impl -> num_inputs ]));
674+ }
639675 break ;
640676 // LCOV_EXCL_START
641677 case CEED_EVAL_WEIGHT : {
@@ -656,6 +692,7 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
656692 if (eval_mode == CEED_EVAL_NONE ) {
657693 CeedCallBackend (CeedVectorRestoreArray (impl -> e_vecs [i + impl -> num_inputs ], & e_data [i + num_input_fields ]));
658694 }
695+ if (impl -> skip_rstr_out [i ]) continue ;
659696 // Get output vector
660697 CeedCallBackend (CeedOperatorFieldGetVector (op_output_fields [i ], & vec ));
661698 // Restrict
0 commit comments