Skip to content

Commit 58bf523

Browse files
committed
Make use of the new strict_float intrinsics for the fast math functions.
1 parent 5ee7c6a commit 58bf523

8 files changed

Lines changed: 83 additions & 20 deletions

File tree

src/CodeGen_LLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ void CodeGen_LLVM::init_codegen(const std::string &name) {
408408
module->addModuleFlag(llvm::Module::Warning, "halide_mabi", MDString::get(*context, mabi()));
409409
module->addModuleFlag(llvm::Module::Warning, "halide_use_pic", use_pic() ? 1 : 0);
410410
module->addModuleFlag(llvm::Module::Warning, "halide_use_large_code_model", llvm_large_code_model ? 1 : 0);
411-
module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", any_strict_float);
411+
module->addModuleFlag(llvm::Module::Warning, "halide_per_instruction_fast_math_flags", any_strict_float ? 1 : 0);
412412
if (effective_vscale != 0) {
413413
module->addModuleFlag(llvm::Module::Warning, "halide_effective_vscale", effective_vscale);
414414
}
@@ -498,6 +498,7 @@ CodeGen_LLVM::ScopedFastMath::~ScopedFastMath() {
498498

499499
std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
500500
any_strict_float = input.any_strict_float();
501+
debug(2) << "Module: any_strict_float = " << any_strict_float << "\n";
501502

502503
init_codegen(input.name());
503504

src/FastMathFunctions.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,15 @@ Expr eval_poly_horner(const std::vector<double> &coefs, const Expr &x) {
9797
}
9898

9999
inline std::pair<Expr, Expr> two_sum(const Expr &a, const Expr &b) {
100-
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
101-
Expr x = strict_float(a + b);
102-
Expr z = strict_float(x - a);
103-
Expr y = strict_float(strict_float(a - strict_float(x - z)) + strict_float(b - z));
100+
Expr x = strict_add(a, b);
101+
Expr z = strict_sub(x, a);
102+
Expr y = strict_add(strict_sub(a, strict_sub(x, z)), strict_sub(b, z));
104103
return {x, y};
105104
}
106105

107106
inline std::pair<Expr, Expr> two_prod(const Expr &a, const Expr &b) {
108-
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
109-
Expr x = strict_float(a * b);
107+
Expr x = strict_mul(a, b);
108+
// TODO(mcourteaux): replace with proper strict_float fma intrinsic op.
110109
Expr y = (a * b - x); // No strict float, so let's hope it gets compiled as FMA.
111110
return {x, y};
112111
}
@@ -176,8 +175,7 @@ Expr fast_sin(const Expr &x_full, ApproximationPrecision precision) {
176175
Expr pi_over_two_minus_x = make_const(type, PI_OVER_TWO) - x;
177176
if (type == Float(32) && precision.optimized_for == ApproximationPrecision::MULPE) {
178177
auto [hi, lo] = split_float(PI_OVER_TWO);
179-
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
180-
pi_over_two_minus_x = strict_float(make_const(type, hi) - x) + make_const(type, lo);
178+
pi_over_two_minus_x = strict_sub(make_const(type, hi), x) + make_const(type, lo);
181179
}
182180
x = select(mirror, pi_over_two_minus_x, x);
183181

@@ -210,7 +208,7 @@ Expr fast_cos(const Expr &x_full, ApproximationPrecision precision) {
210208
if (type == Float(32) && precision.optimized_for == ApproximationPrecision::MULPE) {
211209
auto [hi, lo] = split_float(PI_OVER_TWO);
212210
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
213-
pi_over_two_minus_x = strict_float(strict_float(make_const(type, hi) - x) + make_const(type, lo));
211+
pi_over_two_minus_x = strict_add(strict_sub(make_const(type, hi), x), make_const(type, lo));
214212
}
215213
x = select(mirror, pi_over_two_minus_x, x);
216214

@@ -238,8 +236,7 @@ Expr fast_tan(const Expr &x_full, ApproximationPrecision precision) {
238236
Expr x = x_full - k_real * make_const(type, PI);
239237
if (type == Float(32) && precision.optimized_for == ApproximationPrecision::MULPE) {
240238
auto [pi_hi, pi_lo] = split_float(PI);
241-
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
242-
x = strict_float(strict_float(x_full - k_real * make_const(type, pi_hi)) - (k_real * make_const(type, pi_lo)));
239+
x = strict_sub((x_full - k_real * make_const(type, pi_hi)), (k_real * make_const(type, pi_lo)));
243240
}
244241

245242
// When polynomial: x is assumed to be reduced to [-pi/2, pi/2]!
@@ -250,11 +247,11 @@ Expr fast_tan(const Expr &x_full, ApproximationPrecision precision) {
250247
Expr use_cotan = abs_x > make_const(type, PI / 4.0);
251248
Expr pi_over_two_minus_abs_x;
252249
if (type == Float(64)) {
250+
// TODO(mcourteaux): We could do split floats here too.
253251
pi_over_two_minus_abs_x = make_const(type, PI_OVER_TWO) - abs_x;
254252
} else if (type == Float(32)) { // We want to do this trick always, because we invert later.
255253
auto [hi, lo] = split_float(PI_OVER_TWO);
256-
// TODO(mcourteaux): replace with proper strict_float intrinsic ops.
257-
pi_over_two_minus_abs_x = strict_float(make_const(type, hi) - abs_x) + make_const(type, lo);
254+
pi_over_two_minus_abs_x = strict_sub(make_const(type, hi), abs_x) + make_const(type, lo);
258255
}
259256
Expr arg = select(use_cotan, pi_over_two_minus_abs_x, abs_x);
260257

src/IROperator.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,6 +2670,29 @@ Expr strict_float(const Expr &e) {
26702670
return strictify_float(e);
26712671
}
26722672

2673+
inline Expr strict_float_op(const Expr &a, const Expr &b, Call::IntrinsicOp op) {
2674+
user_assert(a.type() == b.type()) << "strict_float ops should be done on equal types.";
2675+
user_assert(a.type().is_float()) << "strict_float ops should be done on floating point types.";
2676+
return Call::make(a.type(), op, {a, b}, Call::CallType::PureIntrinsic);
2677+
}
2678+
2679+
#define impl_strict_op(x) \
2680+
Expr strict_##x(const Expr &a, const Expr &b) { \
2681+
return strict_float_op(a, b, Call::strict_##x); \
2682+
}
2683+
2684+
impl_strict_op(add);
2685+
impl_strict_op(sub);
2686+
impl_strict_op(div);
2687+
impl_strict_op(mul);
2688+
impl_strict_op(max);
2689+
impl_strict_op(min);
2690+
impl_strict_op(eq);
2691+
impl_strict_op(le);
2692+
impl_strict_op(lt);
2693+
2694+
#undef impl_strict_op
2695+
26732696
Expr undef(Type t) {
26742697
return Call::make(t, Call::undef,
26752698
std::vector<Expr>(),

src/IROperator.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,22 @@ Expr saturating_cast(Type t, Expr e);
15781578
* generated code. */
15791579
Expr strict_float(const Expr &e);
15801580

1581+
/**
1582+
* Helper functions to the strict-float variants of the
1583+
* basic floating point operators.
1584+
*/
1585+
/// @{
1586+
Expr strict_add(const Expr &a, const Expr &b);
1587+
Expr strict_sub(const Expr &a, const Expr &b);
1588+
Expr strict_mul(const Expr &a, const Expr &b);
1589+
Expr strict_div(const Expr &a, const Expr &b);
1590+
Expr strict_max(const Expr &a, const Expr &b);
1591+
Expr strict_min(const Expr &a, const Expr &b);
1592+
Expr strict_eq(const Expr &a, const Expr &b);
1593+
Expr strict_le(const Expr &a, const Expr &b);
1594+
Expr strict_lt(const Expr &a, const Expr &b);
1595+
/// @}
1596+
15811597
/** Create an Expr that that promises another Expr is clamped but do
15821598
* not generate code to check the assertion or modify the value. No
15831599
* attempt is made to prove the bound at compile time. (If it is

src/Lower.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ void lower_impl(const vector<Function> &output_funcs,
148148

149149
lower_target_query_ops(env, t);
150150

151-
bool any_strict_float = strictify_float(env, t);
152-
result_module.set_any_strict_float(any_strict_float);
151+
bool has_any_strict_float = strictify_float(env, t);
152+
result_module.set_any_strict_float(has_any_strict_float);
153153

154154
// Output functions should all be computed and stored at root.
155155
for (const Function &f : outputs) {
@@ -333,6 +333,13 @@ void lower_impl(const vector<Function> &output_funcs,
333333
debug(1) << "Selecting fast math function implementations...\n";
334334
s = lower_fast_math_functions(s, t);
335335
log("Lowering after selecting fast math functions:", s);
336+
if (!has_any_strict_float) {
337+
has_any_strict_float = any_strict_float(s);
338+
if (has_any_strict_float) {
339+
debug(2) << "Detected strict_float ops after selecting fast math functions.\n";
340+
result_module.set_any_strict_float(has_any_strict_float);
341+
}
342+
}
336343

337344
debug(1) << "Simplifying...\n";
338345
s = simplify(s);

src/StrictifyFloat.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,17 @@ bool strictify_float(std::map<std::string, Function> &env, const Target &t) {
164164
return checker.any_strict || t.has_feature(Target::StrictFloat);
165165
}
166166

167+
bool any_strict_float(const Stmt &s) {
168+
AnyStrictIntrinsics c;
169+
s.accept(&c);
170+
return c.any_strict;
171+
}
172+
173+
bool any_strict_float(const Expr &e) {
174+
AnyStrictIntrinsics c;
175+
e.accept(&c);
176+
return c.any_strict;
177+
}
178+
167179
} // namespace Internal
168180
} // namespace Halide

src/StrictifyFloat.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace Halide {
1212

1313
struct Target;
1414
struct Expr;
15+
struct Stmt;
1516

1617
namespace Internal {
1718

@@ -33,6 +34,12 @@ Expr unstrictify_float(const Call *op);
3334
* strictness). */
3435
bool strictify_float(std::map<std::string, Function> &env, const Target &t);
3536

37+
/** Checks the passed Stmt for the precense of any strict_float ops. */
38+
bool any_strict_float(const Stmt &s);
39+
40+
/** Checks the passed Expr for the precense of any strict_float ops. */
41+
bool any_strict_float(const Expr &s);
42+
3643
} // namespace Internal
3744
} // namespace Halide
3845

test/correctness/fast_function_approximations.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ constexpr RangedAccuracyTest::Validation rlx_abs_val = {1.02, 1e-7};
111111
constexpr RangedAccuracyTest::Validation vrlx_abs_val = {1.1, 1e-6};
112112
constexpr RangedAccuracyTest::Validation rsnbl_abs_val = {2.0, 1e-5};
113113
constexpr RangedAccuracyTest::Validation rlx_abs_val_pct(double pct) {
114-
return {1.0 + 100 * pct, 1e-7};
114+
return {1.0 + 0.01 * pct, 1e-7};
115115
}
116116
constexpr RangedAccuracyTest::Validation max_abs_val(double max_val) {
117117
return {0.0f, max_val};
@@ -171,7 +171,7 @@ struct FunctionToTest {
171171
[](Expr x, Expr y, Halide::ApproximationPrecision prec) { return Halide::fast_atan2(x, y, prec); },
172172
Halide::Internal::ApproximationTables::best_atan_approximation,
173173
{
174-
{ "precise" , {{ -10.0f, 10.0f}, {-10.0f, 10.0f}}, rlx_abs_val_pct(4), rlx_abs_val, rlx_ulp_val, rlx_ulp_val, 70, 30 },
174+
{ "precise" , {{ -10.0f, 10.0f}, {-10.0f, 10.0f}}, rlx_abs_val_pct(6), rlx_abs_val, rlx_ulp_val, rlx_ulp_val, 70, 30 },
175175
}
176176
},
177177
{
@@ -385,7 +385,7 @@ int main(int argc, char **argv) {
385385
Buffer<float, 1> out_ref{steps * steps};
386386
Buffer<float, 1> out_approx{steps * steps};
387387

388-
bool target_has_proper_strict_float_support = !target.has_gpu_feature();
388+
bool target_has_proper_strict_float_support = !target.has_gpu_feature() || target.has_feature(Target::CUDA);
389389

390390
double best_mae_for_backend = 0.0;
391391
if (target.has_feature(Halide::Target::Vulkan)) {
@@ -528,7 +528,7 @@ int main(int argc, char **argv) {
528528
.vectorize(ii, 4);
529529
// TODO(mcourteaux): When vector legalization lowering pass is in, increase vectorize for testing.
530530
} else {
531-
approx_func.vectorize(i, 8);
531+
approx_func.vectorize(i, target.natural_vector_size<float>());
532532
}
533533
approx_func.realize(out_approx);
534534
if (emit_asm) {

0 commit comments

Comments
 (0)