2222#include " ../cuda/ceed-cuda-compile.h"
2323#include " ceed-cuda-gen.h"
2424
25+ struct FieldReuse_Cuda {
26+ CeedInt index;
27+ bool is_input;
28+ CeedEvalMode eval_mode;
29+ };
30+
2531// ------------------------------------------------------------------------------
2632// Determine type of operator
2733// ------------------------------------------------------------------------------
@@ -127,8 +133,8 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
127133// Setup fields
128134// ------------------------------------------------------------------------------
129135static int CeedOperatorBuildKernelFieldData_Cuda_gen (std::ostringstream &code, CeedOperator_Cuda_gen *data, CeedInt i, CeedOperatorField op_field,
130- CeedQFunctionField qf_field, CeedInt Q_1d, bool is_input , bool is_tensor, bool is_at_points ,
131- bool use_3d_slices) {
136+ CeedQFunctionField qf_field, FieldReuse_Cuda field_reuse, CeedInt Q_1d , bool is_input ,
137+ bool is_tensor, bool is_at_points, bool use_3d_slices) {
132138 std::string var_suffix = (is_input ? " _in_" : " _out_" ) + std::to_string (i);
133139 std::string P_name = (is_tensor ? " P_1d" : " P" ) + var_suffix, Q_name = is_tensor ? " Q_1d" : " Q" ;
134140 std::string option_name = (is_input ? " inputs" : " outputs" );
@@ -138,6 +144,9 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
138144 CeedBasis_Cuda_shared *basis_data;
139145 CeedBasis basis;
140146
147+ // Field reuse info
148+ bool use_previous_field = field_reuse.index != -1 ;
149+
141150 code << " // -- " << (is_input ? " Input" : " Output" ) << " field " << i << " \n " ;
142151
143152 // Get field data
@@ -188,8 +197,14 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
188197 if (is_input) data->B .inputs [i] = basis_data->d_interp_1d ;
189198 else data->B .outputs [i] = basis_data->d_interp_1d ;
190199 }
191- code << " __shared__ CeedScalar s_B" << var_suffix << " [" << P_name << " *" << Q_name << " ];\n " ;
192- code << " LoadMatrix<" << P_name << " , " << Q_name << " >(data, B." << option_name << " [" << i << " ], s_B" << var_suffix << " );\n " ;
200+ if (use_previous_field) {
201+ std::string reuse_var = " s_B" + ((field_reuse.is_input ? " _in_" : " _out_" ) + std::to_string (field_reuse.index ));
202+
203+ code << " CeedScalar *s_B" << var_suffix << " = " << reuse_var << " ;\n " ;
204+ } else {
205+ code << " __shared__ CeedScalar s_B" << var_suffix << " [" << P_name << " *" << Q_name << " ];\n " ;
206+ code << " LoadMatrix<" << P_name << " , " << Q_name << " >(data, B." << option_name << " [" << i << " ], s_B" << var_suffix << " );\n " ;
207+ }
193208 break ;
194209 case CEED_EVAL_GRAD:
195210 if (is_at_points) {
@@ -214,27 +229,51 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
214229 else data->B .outputs [i] = basis_data->d_interp_1d ;
215230 }
216231 if (is_tensor) {
217- code << " __shared__ CeedScalar s_B" << var_suffix << " [" << P_name << " *" << Q_name << " ];\n " ;
218- code << " LoadMatrix<" << P_name << " , " << Q_name << " >(data, B." << option_name << " [" << i << " ], s_B" << var_suffix << " );\n " ;
232+ if (use_previous_field) {
233+ std::string reuse_var = " s_B" + ((field_reuse.is_input ? " _in_" : " _out_" ) + std::to_string (field_reuse.index ));
234+
235+ code << " CeedScalar *s_B" << var_suffix << " = " << reuse_var << " ;\n " ;
236+ } else {
237+ code << " __shared__ CeedScalar s_B" << var_suffix << " [" << P_name << " *" << Q_name << " ];\n " ;
238+ code << " LoadMatrix<" << P_name << " , " << Q_name << " >(data, B." << option_name << " [" << i << " ], s_B" << var_suffix << " );\n " ;
239+ }
219240 }
220241 if (is_at_points) break ; // No G mat for AtPoints
221242 if (use_3d_slices) {
222243 if (is_input) data->G .inputs [i] = basis_data->d_collo_grad_1d ;
223244 else data->G .outputs [i] = basis_data->d_collo_grad_1d ;
224- code << " __shared__ CeedScalar s_G" << var_suffix << " [" << Q_name << " *" << Q_name << " ];\n " ;
225- code << " LoadMatrix<" << Q_name << " , " << Q_name << " >(data, G." << option_name << " [" << i << " ], s_G" << var_suffix << " );\n " ;
245+ if (use_previous_field && field_reuse.eval_mode == CEED_EVAL_GRAD) {
246+ std::string reuse_var = " s_G" + ((field_reuse.is_input ? " _in_" : " _out_" ) + std::to_string (field_reuse.index ));
247+
248+ code << " CeedScalar *s_G" << var_suffix << " = " << reuse_var << " ;\n " ;
249+ } else {
250+ code << " __shared__ CeedScalar s_G" << var_suffix << " [" << Q_name << " *" << Q_name << " ];\n " ;
251+ code << " LoadMatrix<" << Q_name << " , " << Q_name << " >(data, G." << option_name << " [" << i << " ], s_G" << var_suffix << " );\n " ;
252+ }
226253 } else {
227254 bool has_collo_grad = basis_data->d_collo_grad_1d ;
228255
229256 if (is_input) data->G .inputs [i] = has_collo_grad ? basis_data->d_collo_grad_1d : basis_data->d_grad_1d ;
230257 else data->G .outputs [i] = has_collo_grad ? basis_data->d_collo_grad_1d : basis_data->d_grad_1d ;
231258 if (has_collo_grad) {
232- code << " __shared__ CeedScalar s_G" << var_suffix << " [" << Q_name << " *" << Q_name << " ];\n " ;
233- code << " LoadMatrix<" << Q_name << " , " << Q_name << " >(data, G." << option_name << " [" << i << " ], s_G" << var_suffix << " );\n " ;
259+ if (use_previous_field && field_reuse.eval_mode == CEED_EVAL_GRAD) {
260+ std::string reuse_var = " s_G" + ((field_reuse.is_input ? " _in_" : " _out_" ) + std::to_string (field_reuse.index ));
261+
262+ code << " CeedScalar *s_G" << var_suffix << " = " << reuse_var << " ;\n " ;
263+ } else {
264+ code << " __shared__ CeedScalar s_G" << var_suffix << " [" << Q_name << " *" << Q_name << " ];\n " ;
265+ code << " LoadMatrix<" << Q_name << " , " << Q_name << " >(data, G." << option_name << " [" << i << " ], s_G" << var_suffix << " );\n " ;
266+ }
234267 } else {
235- code << " __shared__ CeedScalar s_G" << var_suffix << " [" << P_name << " *" << Q_name << (is_tensor ? " " : " *dim" ) << " ];\n " ;
236- code << " LoadMatrix<" << P_name << " , " << Q_name << (is_tensor ? " " : " *dim" ) << " >(data, G." << option_name << " [" << i << " ], s_G"
237- << var_suffix << " );\n " ;
268+ if (use_previous_field && field_reuse.eval_mode == CEED_EVAL_GRAD) {
269+ std::string reuse_var = " s_G" + ((field_reuse.is_input ? " _in_" : " _out_" ) + std::to_string (field_reuse.index ));
270+
271+ code << " CeedScalar *s_G" << var_suffix << " = " << reuse_var << " ;\n " ;
272+ } else {
273+ code << " __shared__ CeedScalar s_G" << var_suffix << " [" << P_name << " *" << Q_name << (is_tensor ? " " : " *dim" ) << " ];\n " ;
274+ code << " LoadMatrix<" << P_name << " , " << Q_name << (is_tensor ? " " : " *dim" ) << " >(data, G." << option_name << " [" << i << " ], s_G"
275+ << var_suffix << " );\n " ;
276+ }
238277 }
239278 }
240279 break ;
@@ -955,7 +994,7 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
955994
956995 CeedCallBackend (CeedBasisIsTensor (basis, &is_tensor));
957996 is_all_tensor = is_all_tensor && is_tensor;
958- is_all_nontensor = is_all_not_tensor && !is_tensor;
997+ is_all_nontensor = is_all_nontensor && !is_tensor;
959998 CeedCallBackend (CeedBasisGetCeed (basis, &basis_ceed));
960999 CeedCallBackend (CeedGetResource (basis_ceed, &resource));
9611000 CeedCallBackend (CeedGetResourceRoot (basis_ceed, resource, " :" , &resource_root));
@@ -1138,16 +1177,116 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
11381177 code << " data.t_id = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;\n " ;
11391178 code << " data.slice = slice + data.t_id_z*T_1D" << ((!is_tensor || dim == 1 ) ? " " : " *T_1D" ) << " ;\n " ;
11401179
1180+ // -- Determine input mat reuse
1181+ FieldReuse_Cuda input_matrix_reuse[CEED_FIELD_MAX];
1182+
1183+ for (CeedInt i = 0 ; i < num_input_fields; i++) {
1184+ input_matrix_reuse[i].index = -1 ;
1185+ }
1186+ for (CeedInt i = 0 ; i < num_input_fields; i++) {
1187+ CeedEvalMode eval_mode_i;
1188+ CeedBasis basis_i;
1189+
1190+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields[i], &eval_mode_i));
1191+ if (eval_mode_i == CEED_EVAL_WEIGHT) continue ;
1192+ CeedCallBackend (CeedOperatorFieldGetBasis (op_input_fields[i], &basis_i));
1193+ for (CeedInt j = 0 ; (input_matrix_reuse[i].index == -1 ) && (j < i); j++) {
1194+ CeedEvalMode eval_mode_j;
1195+ CeedBasis basis_j;
1196+
1197+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields[j], &eval_mode_j));
1198+ if (eval_mode_j == CEED_EVAL_WEIGHT) continue ;
1199+ CeedCallBackend (CeedOperatorFieldGetBasis (op_input_fields[j], &basis_j));
1200+ if (basis_i == basis_j) {
1201+ if (is_tensor) {
1202+ input_matrix_reuse[i].index = j;
1203+ input_matrix_reuse[i].is_input = true ;
1204+ input_matrix_reuse[i].eval_mode = eval_mode_j;
1205+ } else {
1206+ // For non-tensor can only re-use with the same eval mode
1207+ if (eval_mode_i == eval_mode_j) {
1208+ input_matrix_reuse[i].index = j;
1209+ input_matrix_reuse[i].is_input = true ;
1210+ input_matrix_reuse[i].eval_mode = eval_mode_j;
1211+ }
1212+ }
1213+ }
1214+ CeedCallBackend (CeedBasisDestroy (&basis_j));
1215+ }
1216+ CeedCallBackend (CeedBasisDestroy (&basis_i));
1217+ }
1218+
1219+ // -- Determine output mat reuse
1220+ FieldReuse_Cuda output_matrix_reuse[CEED_FIELD_MAX];
1221+
1222+ for (CeedInt i = 0 ; i < num_output_fields; i++) {
1223+ output_matrix_reuse[i].index = -1 ;
1224+ }
1225+ for (CeedInt i = 0 ; i < num_output_fields; i++) {
1226+ CeedEvalMode eval_mode_i;
1227+ CeedBasis basis_i;
1228+
1229+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_output_fields[i], &eval_mode_i));
1230+ CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields[i], &basis_i));
1231+ for (CeedInt j = 0 ; (output_matrix_reuse[i].index == -1 ) && (j < num_input_fields); j++) {
1232+ CeedEvalMode eval_mode_j;
1233+ CeedBasis basis_j;
1234+
1235+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields[j], &eval_mode_j));
1236+ if (eval_mode_j == CEED_EVAL_WEIGHT) continue ;
1237+ CeedCallBackend (CeedOperatorFieldGetBasis (op_input_fields[j], &basis_j));
1238+ if (basis_i == basis_j) {
1239+ if (is_tensor) {
1240+ output_matrix_reuse[i].index = j;
1241+ output_matrix_reuse[i].is_input = true ;
1242+ output_matrix_reuse[i].eval_mode = eval_mode_j;
1243+ } else {
1244+ // For non-tensor can only re-use with the same eval mode
1245+ if (eval_mode_i == eval_mode_j) {
1246+ output_matrix_reuse[i].index = j;
1247+ output_matrix_reuse[i].is_input = true ;
1248+ output_matrix_reuse[i].eval_mode = eval_mode_j;
1249+ }
1250+ }
1251+ }
1252+ CeedCallBackend (CeedBasisDestroy (&basis_j));
1253+ }
1254+ for (CeedInt j = 0 ; (output_matrix_reuse[i].index == -1 ) && (j < i); j++) {
1255+ CeedEvalMode eval_mode_j;
1256+ CeedBasis basis_j;
1257+
1258+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_output_fields[j], &eval_mode_j));
1259+ if (eval_mode_j == CEED_EVAL_WEIGHT) continue ;
1260+ CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields[j], &basis_j));
1261+ if (basis_i == basis_j) {
1262+ if (is_tensor) {
1263+ output_matrix_reuse[i].index = j;
1264+ output_matrix_reuse[i].is_input = false ;
1265+ output_matrix_reuse[i].eval_mode = eval_mode_j;
1266+ } else {
1267+ // For non-tensor can only re-use with the same eval mode
1268+ if (eval_mode_i == eval_mode_j) {
1269+ output_matrix_reuse[i].index = j;
1270+ output_matrix_reuse[i].is_input = false ;
1271+ output_matrix_reuse[i].eval_mode = eval_mode_j;
1272+ }
1273+ }
1274+ }
1275+ CeedCallBackend (CeedBasisDestroy (&basis_j));
1276+ }
1277+ CeedCallBackend (CeedBasisDestroy (&basis_i));
1278+ }
1279+
11411280 // Initialize constants, and matrices B and G
11421281 code << " \n // Input field constants and basis data\n " ;
11431282 for (CeedInt i = 0 ; i < num_input_fields; i++) {
1144- CeedCallBackend (CeedOperatorBuildKernelFieldData_Cuda_gen (code, data, i, op_input_fields[i], qf_input_fields[i], Q_1d, true , is_tensor ,
1145- is_at_points, use_3d_slices));
1283+ CeedCallBackend (CeedOperatorBuildKernelFieldData_Cuda_gen (code, data, i, op_input_fields[i], qf_input_fields[i], input_matrix_reuse[i], Q_1d ,
1284+ true , is_tensor, is_at_points, use_3d_slices));
11461285 }
11471286 code << " \n // Output field constants and basis data\n " ;
11481287 for (CeedInt i = 0 ; i < num_output_fields; i++) {
1149- CeedCallBackend (CeedOperatorBuildKernelFieldData_Cuda_gen (code, data, i, op_output_fields[i], qf_output_fields[i], Q_1d, false , is_tensor ,
1150- is_at_points, use_3d_slices));
1288+ CeedCallBackend (CeedOperatorBuildKernelFieldData_Cuda_gen (code, data, i, op_output_fields[i], qf_output_fields[i], output_matrix_reuse[i], Q_1d ,
1289+ false , is_tensor, is_at_points, use_3d_slices));
11511290 }
11521291
11531292 // Loop over all elements
0 commit comments