@@ -28,6 +28,8 @@ static int CeedOperatorDestroy_Cuda_gen(CeedOperator op) {
2828 CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
2929 CeedCallBackend (CeedOperatorGetData (op , & impl ));
3030 if (impl -> module ) CeedCallCuda (ceed , cuModuleUnload (impl -> module ));
31+ if (impl -> module_assemble_full ) CeedCallCuda (ceed , cuModuleUnload (impl -> module_assemble_full ));
32+ if (impl -> module_assemble_diagonal ) CeedCallCuda (ceed , cuModuleUnload (impl -> module_assemble_diagonal ));
3133 if (impl -> points .num_per_elem ) CeedCallCuda (ceed , cudaFree ((void * * )impl -> points .num_per_elem ));
3234 CeedCallBackend (CeedFree (& impl ));
3335 CeedCallBackend (CeedDestroy (& ceed ));
@@ -333,11 +335,173 @@ static int CeedOperatorApplyAddComposite_Cuda_gen(CeedOperator op, CeedVector in
333335 return CEED_ERROR_SUCCESS ;
334336}
335337
338+ //------------------------------------------------------------------------------
339+ // AtPoints diagonal assembly
340+ //------------------------------------------------------------------------------
341+ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen (CeedOperator op , CeedVector assembled , CeedRequest * request ) {
342+ Ceed ceed ;
343+ CeedOperator_Cuda_gen * data ;
344+
345+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
346+ CeedCallBackend (CeedOperatorGetData (op , & data ));
347+
348+ // Build the assembly kernel
349+ if (!data -> assemble_diagonal && !data -> use_assembly_fallback ) {
350+ bool is_build_good = false;
351+ CeedInt num_active_bases_in , num_active_bases_out ;
352+ CeedOperatorAssemblyData assembly_data ;
353+
354+ CeedCallBackend (CeedOperatorGetOperatorAssemblyData (op , & assembly_data ));
355+ CeedCallBackend (
356+ CeedOperatorAssemblyDataGetEvalModes (assembly_data , & num_active_bases_in , NULL , NULL , NULL , & num_active_bases_out , NULL , NULL , NULL , NULL ));
357+ if (num_active_bases_in == num_active_bases_out ) {
358+ CeedCallBackend (CeedOperatorBuildKernel_Cuda_gen (op , & is_build_good ));
359+ if (is_build_good ) CeedCallBackend (CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Cuda_gen (op , & is_build_good ));
360+ }
361+ if (!is_build_good ) data -> use_assembly_fallback = true;
362+ }
363+
364+ // Try assembly
365+ if (!data -> use_assembly_fallback ) {
366+ bool is_run_good = true;
367+ Ceed_Cuda * cuda_data ;
368+ CeedInt num_elem , num_input_fields , num_output_fields ;
369+ CeedEvalMode eval_mode ;
370+ CeedScalar * assembled_array ;
371+ CeedQFunctionField * qf_input_fields , * qf_output_fields ;
372+ CeedQFunction_Cuda_gen * qf_data ;
373+ CeedQFunction qf ;
374+ CeedOperatorField * op_input_fields , * op_output_fields ;
375+
376+ CeedCallBackend (CeedGetData (ceed , & cuda_data ));
377+ CeedCallBackend (CeedOperatorGetQFunction (op , & qf ));
378+ CeedCallBackend (CeedQFunctionGetData (qf , & qf_data ));
379+ CeedCallBackend (CeedOperatorGetNumElements (op , & num_elem ));
380+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & op_input_fields , & num_output_fields , & op_output_fields ));
381+ CeedCallBackend (CeedQFunctionGetFields (qf , NULL , & qf_input_fields , NULL , & qf_output_fields ));
382+
383+ // Input vectors
384+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
385+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
386+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
387+ data -> fields .inputs [i ] = NULL ;
388+ } else {
389+ bool is_active ;
390+ CeedVector vec ;
391+
392+ // Get input vector
393+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
394+ is_active = vec == CEED_VECTOR_ACTIVE ;
395+ if (is_active ) data -> fields .inputs [i ] = NULL ;
396+ else CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> fields .inputs [i ]));
397+ CeedCallBackend (CeedVectorDestroy (& vec ));
398+ }
399+ }
400+
401+ // Point coordinates
402+ {
403+ CeedVector vec ;
404+
405+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
406+ CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> points .coords ));
407+ CeedCallBackend (CeedVectorDestroy (& vec ));
408+
409+ // Points per elem
410+ if (num_elem != data -> points .num_elem ) {
411+ CeedInt * points_per_elem ;
412+ const CeedInt num_bytes = num_elem * sizeof (CeedInt );
413+ CeedElemRestriction rstr_points = NULL ;
414+
415+ data -> points .num_elem = num_elem ;
416+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , & rstr_points , NULL ));
417+ CeedCallBackend (CeedCalloc (num_elem , & points_per_elem ));
418+ for (CeedInt e = 0 ; e < num_elem ; e ++ ) {
419+ CeedInt num_points_elem ;
420+
421+ CeedCallBackend (CeedElemRestrictionGetNumPointsInElement (rstr_points , e , & num_points_elem ));
422+ points_per_elem [e ] = num_points_elem ;
423+ }
424+ if (data -> points .num_per_elem ) CeedCallCuda (ceed , cudaFree ((void * * )data -> points .num_per_elem ));
425+ CeedCallCuda (ceed , cudaMalloc ((void * * )& data -> points .num_per_elem , num_bytes ));
426+ CeedCallCuda (ceed , cudaMemcpy ((void * )data -> points .num_per_elem , points_per_elem , num_bytes , cudaMemcpyHostToDevice ));
427+ CeedCallBackend (CeedElemRestrictionDestroy (& rstr_points ));
428+ CeedCallBackend (CeedFree (& points_per_elem ));
429+ }
430+ }
431+
432+ // Get context data
433+ CeedCallBackend (CeedQFunctionGetInnerContextData (qf , CEED_MEM_DEVICE , & qf_data -> d_c ));
434+
435+ // Assembly array
436+ CeedCallBackend (CeedVectorGetArray (assembled , CEED_MEM_DEVICE , & assembled_array ));
437+
438+ // Assemble diagonal
439+ void * opargs [] = {(void * )& num_elem , & qf_data -> d_c , & data -> indices , & data -> fields , & data -> B , & data -> G , & data -> W , & data -> points , & assembled_array };
440+ int max_threads_per_block , min_grid_size , grid ;
441+
442+ CeedCallCuda (ceed , cuOccupancyMaxPotentialBlockSize (& min_grid_size , & max_threads_per_block , data -> op , dynamicSMemSize , 0 , 0x10000 ));
443+ int block [3 ] = {data -> thread_1d , (data -> dim == 1 ? 1 : data -> thread_1d ), -1 };
444+
445+ CeedCallBackend (BlockGridCalculate (num_elem , min_grid_size / cuda_data -> device_prop .multiProcessorCount , 1 ,
446+ cuda_data -> device_prop .maxThreadsDim [2 ], cuda_data -> device_prop .warpSize , block , & grid ));
447+ CeedInt shared_mem = block [0 ] * block [1 ] * block [2 ] * sizeof (CeedScalar );
448+
449+ CeedCallBackend (
450+ CeedTryRunKernelDimShared_Cuda (ceed , data -> assemble_diagonal , NULL , grid , block [0 ], block [1 ], block [2 ], shared_mem , & is_run_good , opargs ));
451+
452+ // Restore input arrays
453+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
454+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
455+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
456+ } else {
457+ bool is_active ;
458+ CeedVector vec ;
459+
460+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
461+ is_active = vec == CEED_VECTOR_ACTIVE ;
462+ if (!is_active ) CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> fields .inputs [i ]));
463+ CeedCallBackend (CeedVectorDestroy (& vec ));
464+ }
465+ }
466+
467+ // Restore point coordinates
468+ {
469+ CeedVector vec ;
470+
471+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
472+ CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> points .coords ));
473+ CeedCallBackend (CeedVectorDestroy (& vec ));
474+ }
475+
476+ // Restore context data
477+ CeedCallBackend (CeedQFunctionRestoreInnerContextData (qf , & qf_data -> d_c ));
478+
479+ // Restore assembly array
480+ CeedCallBackend (CeedVectorRestoreArray (assembled , & assembled_array ));
481+
482+ // Cleanup
483+ CeedCallBackend (CeedQFunctionDestroy (& qf ));
484+ if (!is_run_good ) data -> use_assembly_fallback = true;
485+ }
486+ CeedCallBackend (CeedDestroy (& ceed ));
487+
488+ // Fallback, if needed
489+ if (data -> use_assembly_fallback ) {
490+ CeedOperator op_fallback ;
491+
492+ CeedDebug256 (CeedOperatorReturnCeed (op ), CEED_DEBUG_COLOR_SUCCESS , "Falling back to /gpu/cuda/ref CeedOperator" );
493+ CeedCallBackend (CeedOperatorGetFallback (op , & op_fallback ));
494+ CeedCallBackend (CeedOperatorLinearAssembleAddDiagonal (op_fallback , assembled , request ));
495+ return CEED_ERROR_SUCCESS ;
496+ }
497+ return CEED_ERROR_SUCCESS ;
498+ }
499+
336500//------------------------------------------------------------------------------
337501// Create operator
338502//------------------------------------------------------------------------------
339503int CeedOperatorCreate_Cuda_gen (CeedOperator op ) {
340- bool is_composite ;
504+ bool is_composite , is_at_points ;
341505 Ceed ceed ;
342506 CeedOperator_Cuda_gen * impl ;
343507
@@ -350,6 +514,11 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
350514 } else {
351515 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda_gen ));
352516 }
517+ CeedCall (CeedOperatorIsAtPoints (op , & is_at_points ));
518+ if (is_at_points ) {
519+ CeedCallBackend (
520+ CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddDiagonal" , CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen ));
521+ }
353522 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda_gen ));
354523 CeedCallBackend (CeedDestroy (& ceed ));
355524 return CEED_ERROR_SUCCESS ;
0 commit comments