Skip to content

Commit f28a8b0

Browse files
committed
Update some info on intrinsics. Fix ^ to != for bools and add user assert to not mix bools with binary operators.
1 parent 58a6d7c commit f28a8b0

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

src/FastMathFunctions.cpp

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,33 @@
55
#include "IRMutator.h"
66
#include "IROperator.h"
77
#include "IRPrinter.h"
8+
#include "Util.h"
89

910
namespace Halide {
1011
namespace Internal {
11-
namespace ApproxImpl {
1212

13+
namespace {
1314
constexpr double PI = 3.14159265358979323846;
1415
constexpr double ONE_OVER_PI = 1.0 / PI;
1516
constexpr double TWO_OVER_PI = 2.0 / PI;
1617
constexpr double PI_OVER_TWO = PI / 2;
1718

19+
float ulp_to_ae(float max, int ulp) {
20+
internal_assert(max > 0.0);
21+
uint32_t n = reinterpret_bits<uint32_t>(max);
22+
float fn = reinterpret_bits<float>(n + ulp);
23+
return fn - max;
24+
}
25+
26+
uint32_t ae_to_ulp(float smallest, float ae) {
27+
internal_assert(smallest >= 0.0);
28+
float fn = smallest + ae;
29+
return reinterpret_bits<uint32_t>(fn) - reinterpret_bits<uint32_t>(smallest);
30+
}
31+
} // namespace
32+
33+
namespace ApproxImpl {
34+
1835
std::pair<float, float> split_float(double value) {
1936
float high = float(value); // Convert to single precision
2037
float low = float(value - double(high)); // Compute the residual part
@@ -152,7 +169,7 @@ Expr fast_sin(const Expr &x_full, ApproximationPrecision precision) {
152169
Expr k = cast<int>(k_real);
153170
Expr k_mod4 = k % 4; // Halide mod is always positive!
154171
Expr mirror = (k_mod4 == 1) || (k_mod4 == 3);
155-
Expr flip_sign = (k_mod4 > 1) ^ (x_full < 0);
172+
Expr flip_sign = (k_mod4 > 1) != (x_full < 0);
156173

157174
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
158175
Expr x = x_abs - k_real * make_const(type, PI_OVER_TWO);
@@ -417,7 +434,7 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
417434
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
418435
Expr exp2xm1 = Halide::fast_expm1(2 * arg_exp, prec);
419436
Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
420-
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
437+
tanh = select(flip_exp != flip_sign, -tanh, tanh);
421438
return common_subexpression_elimination(tanh, true);
422439
#else
423440
// expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
@@ -465,6 +482,19 @@ struct IntrinsicsInfo {
465482
} intrinsic;
466483
};
467484

485+
IntrinsicsInfo::NativeFunc MAE_func(bool fast, float mae, float smallest_output = 0.0f) {
486+
return IntrinsicsInfo::NativeFunc{fast, OO::MAE, mae, ae_to_ulp(smallest_output, mae)};
487+
}
488+
IntrinsicsInfo::NativeFunc MULPE_func(bool fast, uint64_t mulpe, float largest_output) {
489+
return IntrinsicsInfo::NativeFunc{fast, OO::MULPE, ulp_to_ae(largest_output, mulpe), mulpe};
490+
}
491+
IntrinsicsInfo::IntrinsicImpl MAE_intrinsic(float mae, float smallest_output = 0.0f) {
492+
return IntrinsicsInfo::IntrinsicImpl{OO::MAE, mae, ae_to_ulp(smallest_output, mae)};
493+
}
494+
IntrinsicsInfo::IntrinsicImpl MULPE_intrinsic(uint64_t mulpe, float largest_output) {
495+
return IntrinsicsInfo::IntrinsicImpl{OO::MULPE, ulp_to_ae(largest_output, mulpe), mulpe};
496+
}
497+
468498
struct IntrinsicsInfoPerDeviceAPI {
469499
OO reasonable_behavior; // A reasonable optimization objective for a given function.
470500
float default_mae; // A reasonable desirable MAE (if specified)
@@ -475,37 +505,45 @@ struct IntrinsicsInfoPerDeviceAPI {
475505
// clang-format off
476506
IntrinsicsInfoPerDeviceAPI ii_sin{
477507
OO::MAE, 1e-5f, 0, {
478-
{DeviceAPI::Vulkan, {true}, {}},
479-
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
480-
{DeviceAPI::Metal, {true}, {OO::MAE, 6e-5f, 400'000}},
508+
{DeviceAPI::Vulkan, MAE_func(true, 5e-4f), {}},
509+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(5e-7f)},
510+
{DeviceAPI::Metal, {true}, MAE_intrinsic(1.2e-4f)}, // 2^-13
481511
{DeviceAPI::WebGPU, {true}, {}},
482-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
512+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(5e-7f)},
483513
}};
484514

485515
IntrinsicsInfoPerDeviceAPI ii_cos{
486516
OO::MAE, 1e-5f, 0, {
487-
{DeviceAPI::Vulkan, {true}, {}},
488-
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
489-
{DeviceAPI::Metal, {true}, {OO::MAE, 7e-7f, 5'000}},
517+
{DeviceAPI::Vulkan, MAE_func(true, 5e-4f), {}},
518+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(5e-7f)},
519+
{DeviceAPI::Metal, {true}, MAE_intrinsic(1.2e-4f)}, // Seems to be 7e-7, but spec says 2^-13...
490520
{DeviceAPI::WebGPU, {true}, {}},
491-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
521+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(5e-7f)},
492522
}};
493523

494-
IntrinsicsInfoPerDeviceAPI ii_atan_atan2{
524+
IntrinsicsInfoPerDeviceAPI ii_atan{
495525
OO::MAE, 1e-5f, 0, {
496526
// no intrinsics available
497527
{DeviceAPI::Vulkan, {false}, {}},
498-
{DeviceAPI::Metal, {true}, {OO::MAE, 5e-6f}},
528+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, float(PI * 0.501))}, // They claim <= 5 ULP!
529+
{DeviceAPI::WebGPU, {true}, {}},
530+
}};
531+
532+
IntrinsicsInfoPerDeviceAPI ii_atan2{
533+
OO::MAE, 1e-5f, 0, {
534+
// no intrinsics available
535+
{DeviceAPI::Vulkan, {false}, {}},
536+
{DeviceAPI::Metal, {true}, MAE_intrinsic(5e-6f, 0.0f)},
499537
{DeviceAPI::WebGPU, {true}, {}},
500538
}};
501539

502540
IntrinsicsInfoPerDeviceAPI ii_tan{
503541
OO::MULPE, 0.0f, 2000, {
504-
{DeviceAPI::Vulkan, {true, OO::MAE, 2e-6f, 1'000'000}, {}}, // Vulkan tan seems to mimic our CUDA implementation
505-
{DeviceAPI::CUDA, {false}, {OO::MAE, 2e-6f, 1'000'000}},
506-
{DeviceAPI::Metal, {true}, {OO::MULPE, 2e-6f, 1'000'000}},
542+
{DeviceAPI::Vulkan, MAE_func(true, 2e-6f), {}}, // Vulkan tan() seems to mimic our CUDA implementation
543+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(2e-6f)},
544+
{DeviceAPI::Metal, {true}, MAE_intrinsic(2e-6f)}, // sin()/cos()
507545
{DeviceAPI::WebGPU, {true}, {}},
508-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 2e-6f, 1'000'000}},
546+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(2e-6f)},
509547
}};
510548

511549
IntrinsicsInfoPerDeviceAPI ii_expm1{
@@ -514,16 +552,16 @@ IntrinsicsInfoPerDeviceAPI ii_expm1{
514552

515553
IntrinsicsInfoPerDeviceAPI ii_exp{
516554
OO::MULPE, 0.0f, 50, {
517-
{DeviceAPI::Vulkan, {true}, {}},
518-
{DeviceAPI::CUDA, {false}, {OO::MULPE, 0.0f, 5}},
519-
{DeviceAPI::Metal, {true}, {OO::MULPE, 0.0f, 5}}, // precise::exp() is fast on metal
555+
{DeviceAPI::Vulkan, MULPE_func(true, 3 + 2 * 2, 2.0f), {}},
556+
{DeviceAPI::CUDA, {false}, MULPE_intrinsic(5, 2.0f)},
557+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, 2.0f)}, // precise::exp() is fast on metal
520558
{DeviceAPI::WebGPU, {true}, {}},
521-
{DeviceAPI::OpenCL, {true}, {OO::MULPE, 0.0f, 5}}, // Both exp() and native_exp() are faster than polys.
559+
{DeviceAPI::OpenCL, {true}, MULPE_intrinsic(5, 2.0f)}, // Both exp() and native_exp() are faster than polys.
522560
}};
523561

524562
IntrinsicsInfoPerDeviceAPI ii_log{
525563
OO::MAE, 1e-5f, 1000, {
526-
{DeviceAPI::Vulkan, {true}, {}},
564+
{DeviceAPI::Vulkan, {true, ApproximationPrecision::MULPE, 5e-7f, 3}, {}}, // Precision piecewise defined: 3 ULP outside the range [0.5,2.0]. Absolute error < 2^−21 inside the range [0.5,2.0].
527565
{DeviceAPI::CUDA, {false}, {OO::MAE, 0.0f, 3'800'000}},
528566
{DeviceAPI::Metal, {false}, {OO::MAE, 0.0f, 3'800'000}}, // slow log() on metal
529567
{DeviceAPI::WebGPU, {true}, {}},
@@ -551,6 +589,7 @@ IntrinsicsInfoPerDeviceAPI ii_asin_acos{
551589
OO::MULPE, 1e-5f, 500, {
552590
{DeviceAPI::Vulkan, {true}, {}},
553591
{DeviceAPI::CUDA, {true}, {}},
592+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, PI)},
554593
{DeviceAPI::OpenCL, {true}, {}},
555594
}};
556595
// clang-format on
@@ -559,8 +598,10 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
559598
const IntrinsicsInfoPerDeviceAPI *iipda = nullptr;
560599
switch (op) {
561600
case Call::fast_atan:
601+
iipda = &ii_atan;
602+
break;
562603
case Call::fast_atan2:
563-
iipda = &ii_atan_atan2;
604+
iipda = &ii_atan2;
564605
break;
565606
case Call::fast_cos:
566607
iipda = &ii_cos;
@@ -858,20 +899,24 @@ class LowerFastMathFunctions : public IRMutator {
858899

859900
// No known fast version available, we will expand our own approximation.
860901
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
861-
} else if (op->is_intrinsic(Call::fast_atan) || op->is_intrinsic(Call::fast_atan2)) {
902+
} else if (op->is_intrinsic(Call::fast_atan)) {
862903
// Handle fast_atan and fast_atan2 together!
863904
ApproximationPrecision prec = extract_approximation_precision(op);
864-
IntrinsicsInfo ii = resolve_precision(prec, ii_atan_atan2, for_device_api);
905+
IntrinsicsInfo ii = resolve_precision(prec, ii_atan, for_device_api);
865906
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
866907
// The native atan is fast: fall back to native and continue lowering.
867908
return to_native_func(op);
868909
}
869-
870-
if (op->is_intrinsic(Call::fast_atan)) {
871-
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
872-
} else {
873-
return ApproxImpl::fast_atan2(mutate(op->args[0]), mutate(op->args[1]), prec);
910+
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
911+
} else if (op->is_intrinsic(Call::fast_atan2)) {
912+
// Handle fast_atan and fast_atan2 together!
913+
ApproximationPrecision prec = extract_approximation_precision(op);
914+
IntrinsicsInfo ii = resolve_precision(prec, ii_atan2, for_device_api);
915+
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
916+
// The native atan2 is fast: fall back to native and continue lowering.
917+
return to_native_func(op);
874918
}
919+
return ApproxImpl::fast_atan2(mutate(op->args[0]), mutate(op->args[1]), prec);
875920
} else if (op->is_intrinsic(Call::fast_tan)) {
876921
ApproximationPrecision prec = extract_approximation_precision(op);
877922
IntrinsicsInfo ii = resolve_precision(prec, ii_tan, for_device_api);
@@ -913,7 +958,7 @@ class LowerFastMathFunctions : public IRMutator {
913958
return append_type_suffix(op);
914959
}
915960
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
916-
// The native atan is fast: fall back to native and continue lowering.
961+
// The native exp is fast: fall back to native and continue lowering.
917962
return to_native_func(op);
918963
}
919964

src/IROperator.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,54 +2380,63 @@ Expr reinterpret(Type t, Expr e) {
23802380
Expr operator&(Expr x, Expr y) {
23812381
match_types_bitwise(x, y, "bitwise and");
23822382
Type t = x.type();
2383+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-AND must operate on integer types.";
23832384
return Call::make(t, Call::bitwise_and, {std::move(x), std::move(y)}, Call::PureIntrinsic);
23842385
}
23852386

23862387
Expr operator&(Expr x, int y) {
23872388
Type t = x.type();
2389+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-AND must operate on integer types.";
23882390
check_representable(t, y);
23892391
return Call::make(t, Call::bitwise_and, {std::move(x), make_const(t, y)}, Call::PureIntrinsic);
23902392
}
23912393

23922394
Expr operator&(int x, Expr y) {
23932395
Type t = y.type();
23942396
check_representable(t, x);
2397+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-AND must operate on integer types.";
23952398
return Call::make(t, Call::bitwise_and, {make_const(t, x), std::move(y)}, Call::PureIntrinsic);
23962399
}
23972400

23982401
Expr operator|(Expr x, Expr y) {
23992402
match_types_bitwise(x, y, "bitwise or");
24002403
Type t = x.type();
2404+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-OR must operate on integer types.";
24012405
return Call::make(t, Call::bitwise_or, {std::move(x), std::move(y)}, Call::PureIntrinsic);
24022406
}
24032407

24042408
Expr operator|(Expr x, int y) {
24052409
Type t = x.type();
24062410
check_representable(t, y);
2411+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-OR must operate on integer types.";
24072412
return Call::make(t, Call::bitwise_or, {std::move(x), make_const(t, y)}, Call::PureIntrinsic);
24082413
}
24092414

24102415
Expr operator|(int x, Expr y) {
24112416
Type t = y.type();
24122417
check_representable(t, x);
2418+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-OR must operate on integer types.";
24132419
return Call::make(t, Call::bitwise_or, {make_const(t, x), std::move(y)}, Call::PureIntrinsic);
24142420
}
24152421

24162422
Expr operator^(Expr x, Expr y) {
24172423
match_types_bitwise(x, y, "bitwise xor");
24182424
Type t = x.type();
2425+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-XOR must operate on integer types.";
24192426
return Call::make(t, Call::bitwise_xor, {std::move(x), std::move(y)}, Call::PureIntrinsic);
24202427
}
24212428

24222429
Expr operator^(Expr x, int y) {
24232430
Type t = x.type();
24242431
check_representable(t, y);
2432+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-XOR must operate on integer types.";
24252433
return Call::make(t, Call::bitwise_xor, {std::move(x), make_const(t, y)}, Call::PureIntrinsic);
24262434
}
24272435

24282436
Expr operator^(int x, Expr y) {
24292437
Type t = y.type();
24302438
check_representable(t, x);
2439+
user_assert(t.is_int_or_uint() && t.bits() > 1) << "Bitwise-XOR must operate on integer types.";
24312440
return Call::make(t, Call::bitwise_xor, {make_const(t, x), std::move(y)}, Call::PureIntrinsic);
24322441
}
24332442

0 commit comments

Comments
 (0)