Skip to content

Commit 474fba4

Browse files
committed
wip
1 parent 94ff9af commit 474fba4

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
7474

7575
CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
7676
if (basis != CEED_BASIS_NONE) {
77+
bool is_field_tensor;
7778
CeedInt field_dim = 0, field_P = 0, field_P_1d = 0, field_Q = 0, field_Q_1d = 0;
7879

7980
// Check if 3D
@@ -84,14 +85,15 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
8485
// Collect P, P_1d, Q, and Q_1d
8586
CeedCallBackend(CeedBasisGetNumNodes(basis, &field_P));
8687
*max_P = CeedIntMax(*max_P, field_P);
87-
if (*is_all_tensor) {
88+
CeedCallBackend(CeedBasisIsTensor(basis, &is_field_tensor));
89+
if (is_field_tensor) {
8890
CeedCallBackend(CeedBasisGetNumNodes1D(basis, &field_P_1d));
8991
*max_P_1d = CeedIntMax(*max_P_1d, field_P_1d);
9092
}
9193
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &field_Q));
9294
CeedCheck(*Q == 0 || field_Q == *Q, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
9395
*Q = field_Q;
94-
if (*is_all_tensor) {
96+
if (is_field_tensor) {
9597
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &field_Q_1d));
9698
CeedCheck(*Q_1d == 0 || field_Q_1d == *Q_1d, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
9799
*Q_1d = field_Q_1d;
@@ -104,6 +106,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
104106

105107
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
106108
if (basis != CEED_BASIS_NONE) {
109+
bool is_field_tensor;
107110
CeedInt field_dim = 0, field_P = 0, field_P_1d = 0, field_Q = 0, field_Q_1d = 0;
108111

109112
// Check if 3D
@@ -114,14 +117,15 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie
114117
// Collect P, P_1d, Q, and Q_1d
115118
CeedCallBackend(CeedBasisGetNumNodes(basis, &field_P));
116119
*max_P = CeedIntMax(*max_P, field_P);
117-
if (*is_all_tensor) {
120+
CeedCallBackend(CeedBasisIsTensor(basis, &is_field_tensor));
121+
if (is_field_tensor) {
118122
CeedCallBackend(CeedBasisGetNumNodes1D(basis, &field_P_1d));
119123
*max_P_1d = CeedIntMax(*max_P_1d, field_P_1d);
120124
}
121125
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &field_Q));
122126
CeedCheck(*Q == 0 || field_Q == *Q, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
123127
*Q = field_Q;
124-
if (*is_all_tensor) {
128+
if (is_field_tensor) {
125129
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &field_Q_1d));
126130
CeedCheck(*Q_1d == 0 || field_Q_1d == *Q_1d, ceed, CEED_ERROR_BACKEND, "Quadrature spaces must be compatible");
127131
*Q_1d = field_Q_1d;
@@ -1471,8 +1475,9 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
14711475
// Compile
14721476
{
14731477
bool is_compile_good = false;
1478+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
14741479

1475-
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "T_1D", CeedIntMax(Q_1d, data->max_P_1d)));
1480+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "T_1D", T_1d));
14761481
if (is_compile_good) {
14771482
*is_good_build = true;
14781483
CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, operator_name.c_str(), &data->op));

0 commit comments

Comments
 (0)