Skip to content

Commit c66094c

Browse files
committed
Change operator precision to a more flexible interface
1 parent 767aa77 commit c66094c

24 files changed

+113
-78
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,17 +1572,17 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_b
15721572

15731573
// Compile
15741574
{
1575-
bool is_compile_good = false;
1576-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
1577-
bool use_mixed_precision;
1575+
bool is_compile_good = false;
1576+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
1577+
CeedScalarType precision;
15781578

15791579
// Check for mixed precision
1580-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
1580+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
15811581

15821582
data->thread_1d = T_1d;
1583-
if (use_mixed_precision) {
1584-
CeedCallBackend(
1585-
CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 2, "OP_T_1D", T_1d, "CEED_JIT_MIXED_PRECISION", 1));
1583+
if (precision != CEED_SCALAR_TYPE) {
1584+
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 2, "OP_T_1D", T_1d, "CEED_JIT_PRECISION",
1585+
CeedScalarTypes[precision]));
15861586
} else {
15871587
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module, 1, "OP_T_1D", T_1d));
15881588
}
@@ -2052,18 +2052,18 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
20522052

20532053
// Compile
20542054
{
2055-
bool is_compile_good = false;
2056-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2057-
bool use_mixed_precision;
2055+
bool is_compile_good = false;
2056+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2057+
CeedScalarType precision;
20582058

20592059
// Check for mixed precision
2060-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
2060+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
20612061

20622062
data->thread_1d = T_1d;
2063-
if (use_mixed_precision) {
2063+
if (precision != CEED_SCALAR_TYPE) {
20642064
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
20652065
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 2, "OP_T_1D", T_1d,
2066-
"CEED_JIT_MIXED_PRECISION", 1));
2066+
"CEED_JIT_PRECISION", CeedScalarTypes[precision]));
20672067
} else {
20682068
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good,
20692069
is_full ? &data->module_assemble_full : &data->module_assemble_diagonal, 1, "OP_T_1D", T_1d));
@@ -2642,17 +2642,17 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Cuda_gen(CeedOpera
26422642

26432643
// Compile
26442644
{
2645-
bool is_compile_good = false;
2646-
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2647-
bool use_mixed_precision;
2645+
bool is_compile_good = false;
2646+
const CeedInt T_1d = CeedIntMax(is_all_tensor ? Q_1d : Q, data->max_P_1d);
2647+
CeedScalarType precision;
26482648

26492649
// Check for mixed precision
2650-
CeedCallBackend(CeedOperatorGetMixedPrecision(op, &use_mixed_precision));
2650+
CeedCallBackend(CeedOperatorGetPrecision(op, &precision));
26512651

26522652
data->thread_1d = T_1d;
2653-
if (use_mixed_precision) {
2653+
if (precision != CEED_SCALAR_TYPE) {
26542654
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 2, "OP_T_1D", T_1d,
2655-
"CEED_JIT_MIXED_PRECISION", 1));
2655+
"CEED_JIT_PRECISION", CeedScalarTypes[precision]));
26562656
} else {
26572657
CeedCallBackend(CeedTryCompile_Cuda(ceed, code.str().c_str(), &is_compile_good, &data->module_assemble_qfunction, 1, "OP_T_1D", T_1d));
26582658
}

backends/cuda/ceed-cuda-compile.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,20 @@ static int CeedCompileCore_Cuda(Ceed ceed, const char *source, const bool throw_
5252
// Get kernel specific options, such as kernel constants
5353
if (num_defines > 0) {
5454
char *name;
55-
int val;
5655

5756
for (int i = 0; i < num_defines; i++) {
5857
name = va_arg(args, char *);
59-
val = va_arg(args, int);
60-
code << "#define " << name << " " << val << "\n";
58+
if (!strcmp(name, "CEED_JIT_PRECISION")) {
59+
char *val;
60+
61+
val = va_arg(args, char *);
62+
code << "#define " << name << " " << val << "\n";
63+
} else {
64+
int val;
65+
66+
val = va_arg(args, int);
67+
code << "#define " << name << " " << val << "\n";
68+
}
6169
}
6270
}
6371

backends/hip-gen/ceed-hip-gen-operator-build.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,8 +2483,8 @@ extern "C" int CeedOperatorBuildKernelLinearAssembleQFunction_Hip_gen(CeedOperat
24832483
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[f], &field_size));
24842484
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[f], &eval_mode));
24852485
if (eval_mode == CEED_EVAL_GRAD) {
2486-
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << "dim_in_" << f << "*"
2487-
<< (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
2486+
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*"
2487+
<< "dim_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24882488
} else {
24892489
code << tab << "CeedScalar r_q_in_" << f << "[num_comp_in_" << f << "*" << (is_all_tensor && (max_dim >= 3) ? "Q_1d" : "1") << "] = {0.};\n";
24902490
}

include/ceed-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ struct CeedOperator_private {
381381
bool is_composite;
382382
bool is_at_points;
383383
bool has_restriction;
384-
bool use_mixed_precision;
384+
CeedScalarType precision;
385385
CeedQFunctionAssemblyData qf_assembled;
386386
CeedOperatorAssemblyData op_assembled;
387387
CeedOperator *sub_operators;

include/ceed/ceed-f32.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,18 @@
1818

1919
/// Set base scalar type to FP32. (See CeedScalarType enum in ceed.h for all options.)
2020
#define CEED_SCALAR_TYPE CEED_SCALAR_FP32
21+
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_PRECISION) && (CEED_JIT_PRECISION != CEED_SCALAR_TYPE)
22+
#ifdef CEED_JIT_PRECISION == CEED_SCALAR_FP64
23+
typedef double CeedScalar;
24+
typedef float CeedScalarCPU;
25+
26+
/// Machine epsilon
27+
static const CeedScalar CEED_EPSILON = DBL_EPSILON;
28+
#endif // CEED_JIT_PRECISION
29+
#else
2130
typedef float CeedScalar;
2231
typedef CeedScalar CeedScalarCPU;
2332

2433
/// Machine epsilon
2534
static const CeedScalar CEED_EPSILON = FLT_EPSILON;
35+
#endif

include/ceed/ceed-f64.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
/// Set base scalar type to FP64. (See CeedScalarType enum in ceed.h for all options.)
2020
#define CEED_SCALAR_TYPE CEED_SCALAR_FP64
21-
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_MIXED_PRECISION)
21+
#if defined(CEED_RUNNING_JIT_PASS) && defined(CEED_JIT_PRECISION) && (CEED_JIT_PRECISION != CEED_SCALAR_TYPE)
22+
#if CEED_JIT_PRECISION == CEED_SCALAR_FP32
2223
typedef float CeedScalar;
2324
typedef double CeedScalarCPU;
2425

2526
/// Machine epsilon
2627
static const CeedScalar CEED_EPSILON = FLT_EPSILON;
28+
#endif // CEED_JIT_PRECISION
2729
#else
2830
typedef double CeedScalar;
2931
typedef CeedScalar CeedScalarCPU;

include/ceed/ceed.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ CEED_EXTERN const char *const CeedEvalModes[];
178178
CEED_EXTERN const char *const CeedQuadModes[];
179179
CEED_EXTERN const char *const CeedElemTopologies[];
180180
CEED_EXTERN const char *const CeedContextFieldTypes[];
181+
CEED_EXTERN const char *const CeedScalarTypes[];
181182

182183
CEED_EXTERN int CeedGetPreferredMemType(Ceed ceed, CeedMemType *type);
183184

@@ -427,8 +428,8 @@ CEED_EXTERN int CeedOperatorCheckReady(CeedOperator op);
427428
CEED_EXTERN int CeedOperatorGetActiveVectorLengths(CeedOperator op, CeedSize *input_size, CeedSize *output_size);
428429
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyReuse(CeedOperator op, bool reuse_assembly_data);
429430
CEED_EXTERN int CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(CeedOperator op, bool needs_data_update);
430-
CEED_EXTERN int CeedOperatorSetMixedPrecision(CeedOperator op);
431-
CEED_EXTERN int CeedOperatorGetMixedPrecision(CeedOperator op, bool *use_mixed_precision);
431+
CEED_EXTERN int CeedOperatorSetPrecision(CeedOperator op, CeedScalarType precision);
432+
CEED_EXTERN int CeedOperatorGetPrecision(CeedOperator op, CeedScalarType *precision);
432433
CEED_EXTERN int CeedOperatorLinearAssembleQFunction(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request);
433434
CEED_EXTERN int CeedOperatorLinearAssembleQFunctionBuildOrUpdate(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr,
434435
CeedRequest *request);

include/ceed/types.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ typedef signed char CeedInt8;
118118
/// @ingroup Ceed
119119
typedef enum {
120120
/// Single precision
121-
CEED_SCALAR_FP32,
121+
CEED_SCALAR_FP32 = 0,
122122
/// Double precision
123-
CEED_SCALAR_FP64
123+
CEED_SCALAR_FP64 = 1
124124
} CeedScalarType;
125+
#define CEED_SCALAR_FP32 0
126+
#define CEED_SCALAR_FP64 1
125127
/// Base scalar type for the library to use: change which header is included to change the precision.
126128
#include "ceed-f64.h" // IWYU pragma: export
127129

interface/ceed-operator.c

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -635,43 +635,45 @@ int CeedOperatorIsSetupDone(CeedOperator op, bool *is_setup_done) {
635635
}
636636

637637
/**
638-
@brief Set a `CeedOperator` to use reduced precision for operator application
638+
@brief Set the floating point precision for `CeedOperator` application
639639
640-
@param[in] op `CeedOperator`
640+
@param[in] op `CeedOperator`
641+
@param[in] precision `CeedScalarType` to use for operator application
641642
642643
@return An error code: 0 - success, otherwise - failure
643644
644645
@ref User
645646
**/
646-
int CeedOperatorSetMixedPrecision(CeedOperator op) {
647+
int CeedOperatorSetPrecision(CeedOperator op, CeedScalarType scalar_type) {
647648
bool is_immutable, is_composite, supports_mixed_precision;
648649
Ceed ceed;
649650

650651
CeedCall(CeedOperatorGetCeed(op, &ceed));
651652
CeedCall(CeedOperatorIsImmutable(op, &is_immutable));
652-
CeedCheck(!is_immutable, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetMixedPrecision must be called before operator is finalized");
653+
CeedCheck(!is_immutable, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetPrecision must be called before operator is finalized");
653654
CeedCall(CeedOperatorIsComposite(op, &is_composite));
654-
CeedCheck(!is_composite, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetMixedPrecision should be set on single operators");
655+
CeedCheck(!is_composite, ceed, CEED_ERROR_INCOMPATIBLE, "CeedOperatorSetPrecision should be set on single operators");
655656
CeedCall(CeedGetSupportsMixedPrecision(ceed, &supports_mixed_precision));
656-
CeedCheck(supports_mixed_precision, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement mixed precision operators");
657+
CeedCheck(scalar_type == CEED_SCALAR_TYPE || supports_mixed_precision, ceed, CEED_ERROR_UNSUPPORTED,
658+
"Backend does not implement mixed precision operators");
657659

658-
op->use_mixed_precision = true;
660+
op->precision = scalar_type;
659661
CeedCallBackend(CeedDestroy(&ceed));
660662
return CEED_ERROR_SUCCESS;
661663
}
662664

663665
/**
664666
@brief Get whether a `CeedOperator` is set to use reduced precision for operator application
665667
666-
@param[in] op `CeedOperator`
667-
@param[out] use_mixed_precision Variable to store `CeedQFunction`
668+
@param[in] op `CeedOperator`
669+
@param[out] precision Variable to store operator precision
668670
669671
@return An error code: 0 - success, otherwise - failure
670672
671673
@ref User
672674
**/
673-
int CeedOperatorGetMixedPrecision(CeedOperator op, bool *use_mixed_precision) {
674-
*use_mixed_precision = op->use_mixed_precision;
675+
int CeedOperatorGetPrecision(CeedOperator op, CeedScalarType *precision) {
676+
*precision = op->precision;
675677
return CEED_ERROR_SUCCESS;
676678
}
677679

@@ -809,6 +811,7 @@ int CeedOperatorCreate(Ceed ceed, CeedQFunction qf, CeedQFunction dqf, CeedQFunc
809811
(*op)->ref_count = 1;
810812
(*op)->input_size = -1;
811813
(*op)->output_size = -1;
814+
(*op)->precision = CEED_SCALAR_TYPE;
812815
CeedCall(CeedQFunctionReferenceCopy(qf, &(*op)->qf));
813816
if (dqf && dqf != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqf, &(*op)->dqf));
814817
if (dqfT && dqfT != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqfT, &(*op)->dqfT));
@@ -853,6 +856,7 @@ int CeedOperatorCreateAtPoints(Ceed ceed, CeedQFunction qf, CeedQFunction dqf, C
853856
(*op)->is_at_points = true;
854857
(*op)->input_size = -1;
855858
(*op)->output_size = -1;
859+
(*op)->precision = CEED_SCALAR_TYPE;
856860
CeedCall(CeedQFunctionReferenceCopy(qf, &(*op)->qf));
857861
if (dqf && dqf != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqf, &(*op)->dqf));
858862
if (dqfT && dqfT != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqfT, &(*op)->dqfT));

interface/ceed-types.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,8 @@ const char *const CeedFESpaces[] = {
6464
[CEED_FE_SPACE_HDIV] = "H(div) space",
6565
[CEED_FE_SPACE_HCURL] = "H(curl) space",
6666
};
67+
68+
const char *const CeedScalarTypes[] = {
69+
[CEED_SCALAR_FP32] = "CEED_SCALAR_FP32",
70+
[CEED_SCALAR_FP64] = "CEED_SCALAR_FP32",
71+
};

0 commit comments

Comments
 (0)