2121int CeedBasisApply_Cuda (CeedBasis basis , const CeedInt num_elem , CeedTransposeMode t_mode , CeedEvalMode eval_mode , CeedVector u , CeedVector v ) {
2222 Ceed ceed ;
2323 CeedInt Q_1d , dim ;
24- const CeedInt transpose = t_mode == CEED_TRANSPOSE ;
24+ const CeedInt is_transpose = t_mode == CEED_TRANSPOSE ;
2525 const int max_block_size = 32 ;
2626 const CeedScalar * d_u ;
2727 CeedScalar * d_v ;
@@ -30,13 +30,13 @@ int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMo
3030 CeedCallBackend (CeedBasisGetCeed (basis , & ceed ));
3131 CeedCallBackend (CeedBasisGetData (basis , & data ));
3232
33- // Read vectors
33+ // Get read/write access to u, v
3434 if (u != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorGetArrayRead (u , CEED_MEM_DEVICE , & d_u ));
3535 else CeedCheck (eval_mode == CEED_EVAL_WEIGHT , ceed , CEED_ERROR_BACKEND , "An input vector is required for this CeedEvalMode" );
3636 CeedCallBackend (CeedVectorGetArrayWrite (v , CEED_MEM_DEVICE , & d_v ));
3737
3838 // Clear v for transpose operation
39- if (transpose ) {
39+ if (is_transpose ) {
4040 CeedSize length ;
4141
4242 CeedCallBackend (CeedVectorGetLength (v , & length ));
@@ -48,13 +48,13 @@ int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMo
4848 // Basis action
4949 switch (eval_mode ) {
5050 case CEED_EVAL_INTERP : {
51- void * interp_args [] = {(void * )& num_elem , (void * )& transpose , & data -> d_interp_1d , & d_u , & d_v };
51+ void * interp_args [] = {(void * )& num_elem , (void * )& is_transpose , & data -> d_interp_1d , & d_u , & d_v };
5252 const CeedInt block_size = CeedIntMin (CeedIntPow (Q_1d , dim ), max_block_size );
5353
5454 CeedCallBackend (CeedRunKernel_Cuda (ceed , data -> Interp , num_elem , block_size , interp_args ));
5555 } break ;
5656 case CEED_EVAL_GRAD : {
57- void * grad_args [] = {(void * )& num_elem , (void * )& transpose , & data -> d_interp_1d , & data -> d_grad_1d , & d_u , & d_v };
57+ void * grad_args [] = {(void * )& num_elem , (void * )& is_transpose , & data -> d_interp_1d , & data -> d_grad_1d , & d_u , & d_v };
5858 const CeedInt block_size = max_block_size ;
5959
6060 CeedCallBackend (CeedRunKernel_Cuda (ceed , data -> Grad , num_elem , block_size , grad_args ));
@@ -66,24 +66,19 @@ int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTransposeMo
6666
6767 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Weight , num_elem , block_size_x , block_size_y , 1 , weight_args ));
6868 } break ;
69+ case CEED_EVAL_NONE : /* handled separately below */
70+ break ;
6971 // LCOV_EXCL_START
70- // Evaluate the divergence to/from the quadrature points
7172 case CEED_EVAL_DIV :
72- return CeedError (ceed , CEED_ERROR_BACKEND , "CEED_EVAL_DIV not supported" );
73- // Evaluate the curl to/from the quadrature points
7473 case CEED_EVAL_CURL :
75- return CeedError (ceed , CEED_ERROR_BACKEND , "CEED_EVAL_CURL not supported" );
76- // Take no action, BasisApply should not have been called
77- case CEED_EVAL_NONE :
78- return CeedError (ceed , CEED_ERROR_BACKEND , "CEED_EVAL_NONE does not make sense in this context" );
74+ return CeedError (ceed , CEED_ERROR_BACKEND , "%s not supported" , CeedEvalModes [eval_mode ]);
7975 // LCOV_EXCL_STOP
8076 }
8177
82- // Restore vectors
83- if (eval_mode != CEED_EVAL_WEIGHT ) {
84- CeedCallBackend (CeedVectorRestoreArrayRead (u , & d_u ));
85- }
78+ // Restore vectors, cover CEED_EVAL_NONE
8679 CeedCallBackend (CeedVectorRestoreArray (v , & d_v ));
80+ if (eval_mode == CEED_EVAL_NONE ) CeedCallBackend (CeedVectorSetArray (v , CEED_MEM_DEVICE , CEED_COPY_VALUES , (CeedScalar * )d_u ));
81+ if (eval_mode != CEED_EVAL_WEIGHT ) CeedCallBackend (CeedVectorRestoreArrayRead (u , & d_u ));
8782 return CEED_ERROR_SUCCESS ;
8883}
8984
@@ -94,7 +89,7 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
9489 CeedVector v ) {
9590 Ceed ceed ;
9691 CeedInt num_nodes , num_qpts ;
97- const CeedInt transpose = t_mode == CEED_TRANSPOSE ;
92+ const CeedInt is_transpose = t_mode == CEED_TRANSPOSE ;
9893 const int elems_per_block = 1 ;
9994 const int grid = CeedDivUpInt (num_elem , elems_per_block );
10095 const CeedScalar * d_u ;
@@ -106,14 +101,13 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
106101 CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis , & num_qpts ));
107102 CeedCallBackend (CeedBasisGetNumNodes (basis , & num_nodes ));
108103
109- // Read vectors
110- if (eval_mode != CEED_EVAL_WEIGHT ) {
111- CeedCallBackend (CeedVectorGetArrayRead (u , CEED_MEM_DEVICE , & d_u ));
112- }
104+ // Get read/write access to u, v
105+ if (u != CEED_VECTOR_NONE ) CeedCallBackend (CeedVectorGetArrayRead (u , CEED_MEM_DEVICE , & d_u ));
106+ else CeedCheck (eval_mode == CEED_EVAL_WEIGHT , ceed , CEED_ERROR_BACKEND , "An input vector is required for this CeedEvalMode" );
113107 CeedCallBackend (CeedVectorGetArrayWrite (v , CEED_MEM_DEVICE , & d_v ));
114108
115109 // Clear v for transpose operation
116- if (transpose ) {
110+ if (is_transpose ) {
117111 CeedSize length ;
118112
119113 CeedCallBackend (CeedVectorGetLength (v , & length ));
@@ -124,39 +118,39 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
124118 switch (eval_mode ) {
125119 case CEED_EVAL_INTERP : {
126120 void * interp_args [] = {(void * )& num_elem , & data -> d_interp , & d_u , & d_v };
127- const int block_size_x = transpose ? num_nodes : num_qpts ;
121+ const int block_size_x = is_transpose ? num_nodes : num_qpts ;
128122
129- if (transpose ) {
123+ if (is_transpose ) {
130124 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> InterpTranspose , grid , block_size_x , 1 , elems_per_block , interp_args ));
131125 } else {
132126 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Interp , grid , block_size_x , 1 , elems_per_block , interp_args ));
133127 }
134128 } break ;
135129 case CEED_EVAL_GRAD : {
136130 void * grad_args [] = {(void * )& num_elem , & data -> d_grad , & d_u , & d_v };
137- const int block_size_x = transpose ? num_nodes : num_qpts ;
131+ const int block_size_x = is_transpose ? num_nodes : num_qpts ;
138132
139- if (transpose ) {
133+ if (is_transpose ) {
140134 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> DerivTranspose , grid , block_size_x , 1 , elems_per_block , grad_args ));
141135 } else {
142136 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Deriv , grid , block_size_x , 1 , elems_per_block , grad_args ));
143137 }
144138 } break ;
145139 case CEED_EVAL_DIV : {
146140 void * div_args [] = {(void * )& num_elem , & data -> d_div , & d_u , & d_v };
147- const int block_size_x = transpose ? num_nodes : num_qpts ;
141+ const int block_size_x = is_transpose ? num_nodes : num_qpts ;
148142
149- if (transpose ) {
143+ if (is_transpose ) {
150144 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> DerivTranspose , grid , block_size_x , 1 , elems_per_block , div_args ));
151145 } else {
152146 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Deriv , grid , block_size_x , 1 , elems_per_block , div_args ));
153147 }
154148 } break ;
155149 case CEED_EVAL_CURL : {
156150 void * curl_args [] = {(void * )& num_elem , & data -> d_curl , & d_u , & d_v };
157- const int block_size_x = transpose ? num_nodes : num_qpts ;
151+ const int block_size_x = is_transpose ? num_nodes : num_qpts ;
158152
159- if (transpose ) {
153+ if (is_transpose ) {
160154 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> DerivTranspose , grid , block_size_x , 1 , elems_per_block , curl_args ));
161155 } else {
162156 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Deriv , grid , block_size_x , 1 , elems_per_block , curl_args ));
@@ -167,18 +161,14 @@ int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt num_elem, CeedTr
167161
168162 CeedCallBackend (CeedRunKernelDim_Cuda (ceed , data -> Weight , grid , num_qpts , 1 , elems_per_block , weight_args ));
169163 } break ;
170- // LCOV_EXCL_START
171- // Take no action, BasisApply should not have been called
172- case CEED_EVAL_NONE :
173- return CeedError (ceed , CEED_ERROR_BACKEND , "CEED_EVAL_NONE does not make sense in this context" );
174- // LCOV_EXCL_STOP
164+ case CEED_EVAL_NONE : /* handled separately below */
165+ break ;
175166 }
176167
177- // Restore vectors
178- if (eval_mode != CEED_EVAL_WEIGHT ) {
179- CeedCallBackend (CeedVectorRestoreArrayRead (u , & d_u ));
180- }
168+ // Restore vectors, cover CEED_EVAL_NONE
181169 CeedCallBackend (CeedVectorRestoreArray (v , & d_v ));
170+ if (eval_mode == CEED_EVAL_NONE ) CeedCallBackend (CeedVectorSetArray (v , CEED_MEM_DEVICE , CEED_COPY_VALUES , (CeedScalar * )d_u ));
171+ if (eval_mode != CEED_EVAL_WEIGHT ) CeedCallBackend (CeedVectorRestoreArrayRead (u , & d_u ));
182172 return CEED_ERROR_SUCCESS ;
183173}
184174
0 commit comments