Skip to content

Commit 873a330

Browse files
committed
Change operator precision to a more flexible interface
1 parent 767aa77 commit 873a330

21 files changed

+89
-72
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,17 +1572,17 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
15721572

15731573
// Compile
15741574
{
1575-
bool is_compile_good = false;
1576-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
1577-
bool use_mixed_precision;
1575+
bool is_compile_good = false;
1576+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
1577+
CeedScalarType precision;
15781578

15791579
// Check for mixed precision
1580-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
1580+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
15811581

15821582
data->thread_1d = T_1d;
1583-
if (use_mixed_precision) {
1584-
CeedCallBackend(
1585-
CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 2, "OP_T_1D", T_1d, "CEED_JIT_MIXED_PRECISION", 1));
1583+
if (precision) {
1584+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 2, "OP_T_1D", T_1d, "CEED_JIT_PRECISION",
1585+
(CeedInt)precision));
15861586
} else {
15871587
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "OP_T_1D", T_1d));
15881588
}
@@ -2052,18 +2052,18 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20522052

20532053
// Compile
20542054
{
2055-
bool is_compile_good = false;
2056-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2057-
bool use_mixed_precision;
2055+
bool is_compile_good = false;
2056+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2057+
CeedScalarType precision;
20582058

20592059
// Check for mixed precision
2060-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
2060+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
20612061

20622062
data->thread_1d = T_1d;
2063-
if (use_mixed_precision) {
2063+
if (precision) {
20642064
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
20652065
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 2, "OP_T_1D", T_1d,
2066-
"CEED_JIT_MIXED_PRECISION", 1));
2066+
"CEED_JIT_PRECISION", (CeedInt)precision));
20672067
} else {
20682068
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
20692069
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 1, "OP_T_1D", T_1d));
@@ -2642,17 +2642,17 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26422642

26432643
// Compile
26442644
{
2645-
bool is_compile_good = false;
2646-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2647-
bool use_mixed_precision;
2645+
bool is_compile_good = false;
2646+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2647+
CeedScalarType precision;
26482648

26492649
// Check for mixed precision
2650-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
2650+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
26512651

26522652
data->thread_1d = T_1d;
2653-
if (use_mixed_precision) {
2653+
if (precision) {
26542654
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 2, "OP_T_1D", T_1d,
2655-
"CEED_JIT_MIXED_PRECISION", 1));
2655+
"CEED_JIT_PRECISION", (CeedInt)precision));
26562656
} else {
26572657
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 1, "OP_T_1D", T_1d));
26582658
}

backends/hip-gen/ceed-hip-gen-operator-build.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,8 +2483,8 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(CeedOperat
24832483
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[f], &field_size));
24842484
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[f], &eval_mode));
24852485
if (eval_mode == CEED_EVAL_GRAD) {
2486-
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << "dim_in_" << f << "*"
2487-
<< (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
2486+
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*"
2487+
<< "dim_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24882488
} else {
24892489
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24902490
}

include/ceed-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ struct CeedOperator_private {
381381
bool is_composite;
382382
bool is_at_points;
383383
bool has_restriction;
384-
bool use_mixed_precision;
384+
CeedScalarType precision;
385385
CeedQFunctionAssemblyData qf_assembled;
386386
CeedOperatorAssemblyData op_assembled;
387387
CeedOperator *sub_operators;

include/ceed/ceed-f32.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,18 @@
1818

1919
/// Set base scalar type to FP32. (See CeedScalarType enum in ceed.h for all options.)
2020
#define CEED_SCALAR_TYPE CEED_SCALAR_FP32
21+
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_PRECISION) && (CEED_JIT_PRECISION != CEED_SCALAR_TYPE)
22+
#ifdef CEED_JIT_PRECISION == CEED_SCALAR_FP64
23+
typedef double CeedScalar;
24+
typedef float CeedScalarCPU;
25+
26+
/// Machine epsilon
27+
static const CeedScalar CEED_EPSILON = DBL_EPSILON;
28+
#endif // CEED_JIT_PRECISION
29+
#else
2130
typedef float CeedScalar;
2231
typedef CeedScalar CeedScalarCPU;
2332

2433
/// Machine epsilon
2534
static const CeedScalar CEED_EPSILON = FLT_EPSILON;
35+
#endif

include/ceed/ceed-f64.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
/// Set base scalar type to FP64. (See CeedScalarType enum in ceed.h for all options.)
2020
#define CEED_SCALAR_TYPE CEED_SCALAR_FP64
21-
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_MIXED_PRECISION)
21+
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_PRECISION) && (CEED_JIT_PRECISION != CEED_SCALAR_TYPE)
22+
#if CEED_JIT_PRECISION == CEED_SCALAR_FP32
2223
typedef float CeedScalar;
2324
typedef double CeedScalarCPU;
2425

2526
/// Machine epsilon
2627
static const CeedScalar CEED_EPSILON = FLT_EPSILON;
28+
#endif // CEED_JIT_PRECISION
2729
#else
2830
typedef double CeedScalar;
2931
typedef CeedScalar CeedScalarCPU;

include/ceed/ceed.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ CEED_EXTERN int CeedOperatorCheckReady(CeedOperator op);
427427
CEED_EXTERN int CeedOperatorGetActiveVectorLengths(CeedOperator op, CeedSize *input_size, CeedSize *output_size);
428428
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyReuse(CeedOperator op, bool reuse_assembly_data);
429429
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(CeedOperator op, bool needs_data_update);
430-
CEED_EXTERN int CeedOperatorSetMixedPrecision(CeedOperator op);
431-
CEED_EXTERN int CeedOperatorGetMixedPrecision(CeedOperator op, bool *use_mixed_precision);
430+
CEED_EXTERN int CeedOperatorSetPrecision(CeedOperator op, CeedScalarType precision);
431+
CEED_EXTERN int CeedOperatorGetPrecision(CeedOperator op, CeedScalarType *precision);
432432
CEED_EXTERN int CeedOperatorLinearAssembleQFunction(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request);
433433
CEED_EXTERN int CeedOperatorLinearAssembleQFunctionBuildOrUpdate(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr,
434434
CeedRequest *request);

interface/ceed-operator.c

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -637,41 +637,43 @@ int CeedOperatorIsSetupDone(CeedOperator op, bool *is_setup_done) {
637637
/**
638638
@brief Set a `CeedOperator` to use reduced precision for operator application
639639
640-
@param[in] op `CeedOperator`
640+
@param[in] op `CeedOperator`
641+
@param[in] precision `CeedScalarType` to use for operator application
641642
642643
@return An error code: 0 - success, otherwise - failure
643644
644645
@ref User
645646
**/
646-
int CeedOperatorSetMixedPrecision(CeedOperator op) {
647+
int CeedOperatorSetPrecision(CeedOperator op, CeedScalarType scalar_type) {
647648
bool is_immutable, is_composite, supports_mixed_precision;
648649
Ceed ceed;
649650

650651
CeedCall(CeedOperatorGetCeed(op, &ceed));
651652
CeedCall(CeedOperatorIsImmutable(op, &is_immutable));
652-
CeedCheck(!is_immutable, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetMixedPrecision must be called before operator is finalized");
653+
CeedCheck(!is_immutable, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetPrecision must be called before operator is finalized");
653654
CeedCall(CeedOperatorIsComposite(op, &is_composite));
654-
CeedCheck(!is_composite, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetMixedPrecision should be set on single operators");
655+
CeedCheck(!is_composite, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetPrecision should be set on single operators");
655656
CeedCall(CeedGetSupportsMixedPrecision(ceed, &supports_mixed_precision));
656-
CeedCheck(supports_mixed_precision, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement mixed precision operators");
657+
CeedCheck(scalar_type == CEED_SCALAR_TYPE || supports_mixed_precision, ceed, CEED_ERROR_UNSUPPORTED,
658+
"Backend does not implement mixed precision operators");
657659

658-
op->use_mixed_precision = true;
660+
op->precision = true;
659661
CeedCallBackend(CeedDestroy(&ceed));
660662
return CEED_ERROR_SUCCESS;
661663
}
662664

663665
/**
664666
@brief Get whether a `CeedOperator` is set to use reduced precision for operator application
665667
666-
@param[in] op `CeedOperator`
667-
@param[out] use_mixed_precision Variable to store `CeedQFunction`
668+
@param[in] op `CeedOperator`
669+
@param[out] precision Variable to store operator precision
668670
669671
@return An error code: 0 - success, otherwise - failure
670672
671673
@ref User
672674
**/
673-
int CeedOperatorGetMixedPrecision(CeedOperator op, bool *use_mixed_precision) {
674-
*use_mixed_precision = op->use_mixed_precision;
675+
int CeedOperatorGetPrecision(CeedOperator op, CeedScalarType *precision) {
676+
*precision = op->precision;
675677
return CEED_ERROR_SUCCESS;
676678
}
677679

tests/t502-operator-mixed.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ int main(int argc, char **argv) {
7070
CeedOperatorSetField(op_setup, "weight", CEED_ELEMRESTRICTION_NONE, basis_x, CEED_VECTOR_NONE);
7171
CeedOperatorSetField(op_setup, "dx", elem_restriction_x, basis_x, CEED_VECTOR_ACTIVE);
7272
CeedOperatorSetField(op_setup, "rho", elem_restriction_q_data, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE);
73-
CeedOperatorSetMixedPrecision(op_setup);
73+
CeedOperatorSetPrecision(op_setup, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
7474

7575
CeedOperatorCreate(ceed, qf_mass, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &op_mass);
7676
CeedOperatorSetField(op_mass, "rho", elem_restriction_q_data, CEED_BASIS_NONE, q_data);
7777
CeedOperatorSetField(op_mass, "u", elem_restriction_u, basis_u, CEED_VECTOR_ACTIVE);
7878
CeedOperatorSetField(op_mass, "v", elem_restriction_u, basis_u, CEED_VECTOR_ACTIVE);
79-
CeedOperatorSetMixedPrecision(op_mass);
79+
CeedOperatorSetPrecision(op_mass, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
8080

8181
CeedOperatorApply(op_setup, x, q_data, CEED_REQUEST_IMMEDIATE);
8282

tests/t503-operator-mixed.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ int main(int argc, char **argv) {
7171
CeedOperatorSetField(op_setup, "weight", CEED_ELEMRESTRICTION_NONE, basis_x, CEED_VECTOR_NONE);
7272
CeedOperatorSetField(op_setup, "dx", elem_restriction_x, basis_x, x);
7373
CeedOperatorSetField(op_setup, "rho", elem_restriction_q_data, CEED_BASIS_NONE, q_data);
74-
CeedOperatorSetMixedPrecision(op_setup);
74+
CeedOperatorSetPrecision(op_setup, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
7575

7676
CeedOperatorCreate(ceed, qf_mass, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &op_mass);
7777
CeedOperatorSetField(op_mass, "rho", elem_restriction_q_data, CEED_BASIS_NONE, q_data);
7878
CeedOperatorSetField(op_mass, "u", elem_restriction_u, basis_u, u);
7979
CeedOperatorSetField(op_mass, "v", elem_restriction_u, basis_u, v);
80-
CeedOperatorSetMixedPrecision(op_mass);
80+
CeedOperatorSetPrecision(op_mass, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
8181

8282
// Note - It is atypical to use only passive fields; this test is intended
8383
// as a test for all passive input modes rather than as an example.

tests/t505-operator-mixed.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ int main(int argc, char **argv) {
6969
CeedOperatorSetField(op_setup, "weight", CEED_ELEMRESTRICTION_NONE, basis_x, CEED_VECTOR_NONE);
7070
CeedOperatorSetField(op_setup, "dx", elem_restriction_x, basis_x, CEED_VECTOR_ACTIVE);
7171
CeedOperatorSetField(op_setup, "rho", elem_restriction_q_data, CEED_BASIS_NONE, CEED_VECTOR_ACTIVE);
72-
CeedOperatorSetMixedPrecision(op_setup);
72+
CeedOperatorSetPrecision(op_setup, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
7373

7474
CeedOperatorCreate(ceed, qf_mass, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &op_mass);
7575
CeedOperatorSetField(op_mass, "rho", elem_restriction_q_data, CEED_BASIS_NONE, q_data);
7676
CeedOperatorSetField(op_mass, "u", elem_restriction_u, basis_u, CEED_VECTOR_ACTIVE);
7777
CeedOperatorSetField(op_mass, "v", elem_restriction_u, basis_u, CEED_VECTOR_ACTIVE);
78-
CeedOperatorSetMixedPrecision(op_mass);
78+
CeedOperatorSetPrecision(op_mass, CEED_SCALAR_TYPE == CEED_SCALAR_FP32 ? CEED_SCALAR_FP64 : CEED_SCALAR_FP32);
7979

8080
CeedOperatorApply(op_setup, x, q_data, CEED_REQUEST_IMMEDIATE);
8181

0 commit comments

Comments
 (0)