@@ -293,17 +293,53 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,
293293 return CEED_ERROR_SUCCESS ;
294294}
295295
296+ static int CeedOperatorApplyAddComposite_Cuda_gen (CeedOperator op , CeedVector input_vec , CeedVector output_vec , CeedRequest * request ) {
297+ bool is_run_good [CEED_COMPOSITE_MAX ] = {false};
298+ CeedInt num_suboperators ;
299+ const CeedScalar * input_arr = NULL ;
300+ CeedScalar * output_arr = NULL ;
301+ CeedOperator * sub_operators ;
302+
303+ CeedCall (CeedCompositeOperatorGetNumSub (op , & num_suboperators ));
304+ CeedCall (CeedCompositeOperatorGetSubList (op , & sub_operators ));
305+ if (input_vec != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorGetArrayRead (input_vec , CEED_MEM_DEVICE , & input_arr ));
306+ if (output_vec != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorGetArray (output_vec , CEED_MEM_DEVICE , & output_arr ));
307+ for (CeedInt i = 0 ; i < num_suboperators ; i ++ ) {
308+ CeedCallBackend (CeedOperatorApplyAddCore_Cuda_gen (sub_operators [i ], NULL , input_arr , output_arr , & is_run_good [i ], request ));
309+ }
310+ if (input_vec != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorRestoreArrayRead (input_vec , & input_arr ));
311+ if (output_vec != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorRestoreArray (output_vec , & output_arr ));
312+
313+ // Fallback on unsuccessful run
314+ for (CeedInt i = 0 ; i < num_suboperators ; i ++ ) {
315+ if (!is_run_good [i ]) {
316+ CeedOperator op_fallback ;
317+
318+ CeedDebug256 (CeedOperatorReturnCeed (op ), CEED_DEBUG_COLOR_SUCCESS , "Falling back to /gpu/cuda/ref CeedOperator" );
319+ CeedCallBackend (CeedOperatorGetFallback (sub_operators [i ], & op_fallback ));
320+ CeedCallBackend (CeedOperatorApplyAdd (op_fallback , input_vec , output_vec , request ));
321+ }
322+ }
323+ return CEED_ERROR_SUCCESS ;
324+ }
325+
296326//------------------------------------------------------------------------------
297327// Create operator
298328//------------------------------------------------------------------------------
299329int CeedOperatorCreate_Cuda_gen (CeedOperator op ) {
330+ bool is_composite ;
300331 Ceed ceed ;
301332 CeedOperator_Cuda_gen * impl ;
302333
303334 CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
304335 CeedCallBackend (CeedCalloc (1 , & impl ));
305336 CeedCallBackend (CeedOperatorSetData (op , impl ));
306- CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda_gen ));
337+ CeedCall (CeedOperatorIsComposite (op , & is_composite ));
338+ if (is_composite ) {
339+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAddComposite" , CeedOperatorApplyAddComposite_Cuda_gen ));
340+ } else {
341+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda_gen ));
342+ }
307343 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda_gen ));
308344 CeedCallBackend (CeedDestroy (& ceed ));
309345 return CEED_ERROR_SUCCESS ;
0 commit comments