Skip to content

Commit aab4935

Browse files
committed
Merge remote-tracking branch 'origin/main' into abadams/faster_inlining
2 parents 6be5c9e + 2fad88f commit aab4935

13 files changed

Lines changed: 413 additions & 44 deletions

File tree

.github/workflows/pip.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
fetch-tags: true
4848

4949
- uses: ilammy/msvc-dev-cmd@v1
50-
- uses: lukka/get-cmake@v4.3.2
50+
- uses: lukka/get-cmake@v4.3.3
5151
with:
5252
cmakeVersion: "~3.28.0"
5353

python_bindings/src/halide/halide_/PyHalide.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@
3838
PYBIND11_MODULE(HALIDE_PYBIND_MODULE_NAME, m) {
3939
using namespace Halide::PythonBindings;
4040

41+
#if PY_VERSION_HEX >= 0x030E0000
42+
// CPython 3.14 caches C stack limits in the thread state, which does not
43+
// interact well with Halide's user-space compiler stack when warnings call
44+
// back into Python. Stay on Python's stack by default, but preserve the
45+
// existing Halide behavior when the user explicitly opts in.
46+
if (Halide::Internal::get_env_variable("HL_COMPILER_STACK_SIZE").empty()) {
47+
Halide::set_compiler_stack_size(0);
48+
}
49+
#endif
50+
4151
// Order of definitions matters somewhat:
4252
// things used for default arguments must be registered
4353
// prior to that usage.

src/IROperator.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,12 +609,23 @@ Expr lossless_negate(const Expr &x) {
609609
} else if (const FloatImm *f = x.as<FloatImm>()) {
610610
return FloatImm::make(f->type, -f->value);
611611
} else if (const Cast *c = x.as<Cast>()) {
612-
Expr value = lossless_negate(c->value);
613-
if (value.defined()) {
614-
// This logic is only sound if we know the cast can't overflow.
615-
value = lossless_cast(c->type, value);
612+
// Unsigned inner types wrap modularly (-uint8(65) = 191), and signed
613+
// integer inner types wrap at INT_TYPE_MIN (-int8(-128) = -128), so both
614+
// make cast(outer, -inner) != -cast(outer, inner). Floats are exact.
615+
// For signed integers, only proceed when bounds exclude INT_TYPE_MIN.
616+
bool inner_negation_safe = c->value.type().is_float();
617+
if (!inner_negation_safe && c->value.type().is_int()) {
618+
ConstantInterval ci = constant_integer_bounds(c->value);
619+
inner_negation_safe = ci.min_defined && !c->value.type().is_min(ci.min);
620+
}
621+
if (inner_negation_safe) {
622+
Expr value = lossless_negate(c->value);
616623
if (value.defined()) {
617-
return value;
624+
// This logic is only sound if we know the cast can't overflow.
625+
value = lossless_cast(c->type, value);
626+
if (value.defined()) {
627+
return value;
628+
}
618629
}
619630
}
620631
} else if (const Ramp *r = x.as<Ramp>()) {

src/SlidingWindow.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "SlidingWindow.h"
22

33
#include "Bounds.h"
4+
#include "CSE.h"
45
#include "CompilerLogger.h"
56
#include "Debug.h"
67
#include "ExprUsesVar.h"
@@ -86,7 +87,7 @@ class ExpandExpr : public IRMutator {
8687
// Perform all the substitutions in a scope
8788
Expr expand_expr(const Expr &e, const Scope<Expr> &scope) {
8889
ExpandExpr ee(scope);
89-
Expr result = ee(e);
90+
Expr result = common_subexpression_elimination(ee(e));
9091
debug(4) << "Expanded " << e << " into " << result << "\n";
9192
return result;
9293
}
@@ -223,6 +224,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
223224
Expr loop_min;
224225
set<int> &slid_dimensions;
225226
Scope<Expr> scope;
227+
Scope<Interval> &bounds_scope;
226228

227229
// For loops strictly between the loop being slid over and the current
228230
// node (not including the loop being slid over itself).
@@ -282,8 +284,8 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
282284
internal_assert(min_val && max_val);
283285
Expr min_req = *min_val;
284286
Expr max_req = *max_val;
285-
min_req = expand_expr(min_req, scope);
286-
max_req = expand_expr(max_req, scope);
287+
min_req = simplify(expand_expr(min_req, scope), bounds_scope);
288+
max_req = simplify(expand_expr(max_req, scope), bounds_scope);
287289

288290
debug(3) << func_args[i] << ":" << min_req << ", " << max_req << "\n";
289291
if (expr_depends_on_var(min_req, loop_var) ||
@@ -594,7 +596,10 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
594596
}
595597

596598
Stmt visit(const LetStmt *op) override {
597-
ScopedBinding<Expr> bind(scope, op->name, simplify(expand_expr(op->value, scope)));
599+
ScopedBinding<Interval> bind_bounds(bounds_scope, op->name,
600+
bounds_of_expr_in_scope(op->value, bounds_scope));
601+
ScopedBinding<Expr> bind(scope, op->name, simplify(expand_expr(op->value, scope), bounds_scope));
602+
598603
Stmt new_body = mutate(op->body);
599604

600605
Expr value = op->value;
@@ -613,8 +618,10 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
613618
}
614619

615620
public:
616-
SlidingWindowOnFunctionAndLoop(Function f, string v, Expr v_min, set<int> &slid_dimensions)
617-
: func(std::move(f)), loop_var(std::move(v)), loop_min(std::move(v_min)), slid_dimensions(slid_dimensions) {
621+
SlidingWindowOnFunctionAndLoop(Function f, string v, Expr v_min, set<int> &slid_dimensions,
622+
Scope<Interval> &bounds_scope)
623+
: func(std::move(f)), loop_var(std::move(v)), loop_min(std::move(v_min)),
624+
slid_dimensions(slid_dimensions), bounds_scope(bounds_scope) {
618625
}
619626

620627
Expr new_loop_min;
@@ -755,9 +762,16 @@ class SlidingWindow : public IRMutator {
755762
// Keep track of realizations we want to slide, from innermost to
756763
// outermost.
757764
list<Function> sliding;
765+
Scope<Interval> bounds_scope;
758766

759767
using IRMutator::visit;
760768

769+
Stmt visit(const LetStmt *op) override {
770+
ScopedBinding<Interval> bind(bounds_scope, op->name,
771+
bounds_of_expr_in_scope(op->value, bounds_scope));
772+
return IRMutator::visit(op);
773+
}
774+
761775
Stmt visit(const Realize *op) override {
762776
// Find the args for this function
763777
map<string, Function>::const_iterator iter = env.find(op->name);
@@ -827,7 +841,14 @@ class SlidingWindow : public IRMutator {
827841

828842
set<int> &slid_dims = slid_dimensions[func.name()];
829843
size_t old_slid_dims_size = slid_dims.size();
830-
SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dims);
844+
845+
Interval min_bounds = bounds_of_expr_in_scope(loop_min, bounds_scope);
846+
Interval max_bounds = bounds_of_expr_in_scope(loop_max, bounds_scope);
847+
ScopedBinding<Interval> bind_bounds(bounds_scope, op->name,
848+
Interval(min_bounds.min, max_bounds.max));
849+
850+
SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dims, bounds_scope);
851+
831852
body = slider(body);
832853

833854
if (func.schedule().memory_type() == MemoryType::Register &&

src/Solve.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ class SolveExpression : public IRMutator {
312312
} else if (mul_a && mul_b && equal(mul_a->b, mul_b->b)) {
313313
// f(x)*a - g(x)*a -> (f(x) - g(x))*a;
314314
expr = mutate((mul_a->a - mul_b->a) * mul_a->b);
315+
} else if (mul_a && equal(mul_a->a, b)) {
316+
// f(x)*a - f(x) -> f(x) * (a - 1)
317+
expr = mutate(b * (mul_a->b - 1));
318+
} else if (mul_b && equal(mul_b->a, a)) {
319+
// f(x) - f(x)*a -> f(x) * (1 - a)
320+
expr = mutate(a * (make_one(a.type()) - mul_b->b));
315321
} else if (div_a && !a_failed && no_overflow_int(op->type) && can_prove(div_a->b != 0)) {
316322
// f(x)/a - g(x) -> (f(x) - g(x) * a) / a
317323
// Same overflow and div-by-zero concerns as the Add case above.
@@ -1053,7 +1059,35 @@ class SolveForInterval : public IRVisitor {
10531059
if (!already_solved) {
10541060
SolverResult solved = solve_expression(le, var, scope);
10551061
if (!solved.fully_solved) {
1056-
fail();
1062+
// solve_expression failed; try direct max/min decomposition on the LHS.
1063+
if (const Max *max_fallback = le->a.as<Max>()) {
1064+
// max(a, b) <= c <==> a <= c && b <= c
1065+
(max_fallback->a <= le->b && max_fallback->b <= le->b).accept(this);
1066+
} else if (const Min *min_fallback = le->a.as<Min>()) {
1067+
// min(a, b) <= c <==> a <= c || b <= c
1068+
(min_fallback->a <= le->b || min_fallback->b <= le->b).accept(this);
1069+
} else if (const Mul *mul_fallback = le->a.as<Mul>()) {
1070+
// max/min(a, b) * pos_c <= rhs <==> a*pos_c <= rhs [&&/||] b*pos_c <= rhs
1071+
const Max *mxf = mul_fallback->a.as<Max>();
1072+
const Min *mnf = mul_fallback->a.as<Min>();
1073+
Expr factor = mul_fallback->b;
1074+
if (!mxf && !mnf) {
1075+
mxf = mul_fallback->b.as<Max>();
1076+
mnf = mul_fallback->b.as<Min>();
1077+
factor = mul_fallback->a;
1078+
}
1079+
if (mxf && is_positive_const(factor)) {
1080+
// max(a, b) * pos_c <= rhs <==> a*pos_c <= rhs && b*pos_c <= rhs
1081+
(mxf->a * factor <= le->b && mxf->b * factor <= le->b).accept(this);
1082+
} else if (mnf && is_positive_const(factor)) {
1083+
// min(a, b) * pos_c <= rhs <==> a*pos_c <= rhs || b*pos_c <= rhs
1084+
(mnf->a * factor <= le->b || mnf->b * factor <= le->b).accept(this);
1085+
} else {
1086+
fail();
1087+
}
1088+
} else {
1089+
fail();
1090+
}
10571091
} else {
10581092
already_solved = true;
10591093
solved.result.accept(this);
@@ -1110,7 +1144,35 @@ class SolveForInterval : public IRVisitor {
11101144
if (!already_solved) {
11111145
SolverResult solved = solve_expression(ge, var, scope);
11121146
if (!solved.fully_solved) {
1113-
fail();
1147+
// solve_expression failed; try direct max/min decomposition on the LHS.
1148+
if (const Max *max_fallback = ge->a.as<Max>()) {
1149+
// max(a, b) >= c <==> a >= c || b >= c
1150+
(max_fallback->a >= ge->b || max_fallback->b >= ge->b).accept(this);
1151+
} else if (const Min *min_fallback = ge->a.as<Min>()) {
1152+
// min(a, b) >= c <==> a >= c && b >= c
1153+
(min_fallback->a >= ge->b && min_fallback->b >= ge->b).accept(this);
1154+
} else if (const Mul *mul_fallback = ge->a.as<Mul>()) {
1155+
// max/min(a, b) * pos_c >= rhs <==> a*pos_c >= rhs [||/&&] b*pos_c >= rhs
1156+
const Max *mxf = mul_fallback->a.as<Max>();
1157+
const Min *mnf = mul_fallback->a.as<Min>();
1158+
Expr factor = mul_fallback->b;
1159+
if (!mxf && !mnf) {
1160+
mxf = mul_fallback->b.as<Max>();
1161+
mnf = mul_fallback->b.as<Min>();
1162+
factor = mul_fallback->a;
1163+
}
1164+
if (mxf && is_positive_const(factor)) {
1165+
// max(a, b) * pos_c >= rhs <==> a*pos_c >= rhs || b*pos_c >= rhs
1166+
(mxf->a * factor >= ge->b || mxf->b * factor >= ge->b).accept(this);
1167+
} else if (mnf && is_positive_const(factor)) {
1168+
// min(a, b) * pos_c >= rhs <==> a*pos_c >= rhs && b*pos_c >= rhs
1169+
(mnf->a * factor >= ge->b && mnf->b * factor >= ge->b).accept(this);
1170+
} else {
1171+
fail();
1172+
}
1173+
} else {
1174+
fail();
1175+
}
11141176
} else {
11151177
already_solved = true;
11161178
solved.result.accept(this);

test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ tests(
300300
sliding_over_guard_with_if.cpp
301301
sliding_reduction.cpp
302302
sliding_window.cpp
303+
sliding_window_cascade.cpp
303304
solve.cpp
304305
sort_exprs.cpp
305306
specialize.cpp

0 commit comments

Comments
 (0)