Skip to content

Commit 2834ffe

Browse files
alexreinkingclaude
andauthored
Add solve fuzzer and fix the soundness bugs it surfaced (#9105)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent af9c016 commit 2834ffe

13 files changed

Lines changed: 1146 additions & 372 deletions

src/Bounds.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,6 +1827,18 @@ Interval bounds_of_expr_in_scope(const Expr &expr, const Scope<Interval> &scope,
18271827
return bounds_of_expr_in_scope_with_indent(expr, scope, fb, const_bound, 0);
18281828
}
18291829

1830+
Expr and_condition_over_domain(const Expr &e, const Scope<Interval> &varying) {
1831+
internal_assert(e.type().is_bool()) << "Expr provided to and_condition_over_domain is not boolean: " << e << "\n";
1832+
Interval bounds = bounds_of_expr_in_scope(e, varying);
1833+
internal_assert(bounds.has_lower_bound()) << "Failed to produce bound on boolean value in and_condition_over_domain" << e << "\n";
1834+
// Minimum of a boolean value is sufficient condition, implies expression.
1835+
return simplify(bounds.min);
1836+
}
1837+
1838+
Expr or_condition_over_domain(const Expr &c, const Scope<Interval> &varying) {
1839+
return simplify(!and_condition_over_domain(simplify(!c), varying));
1840+
}
1841+
18301842
void merge_boxes(Box &a, const Box &b) {
18311843
if (b.empty()) {
18321844
return;

src/Bounds.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ Expr find_constant_bound(const Expr &e, Direction d,
4848
* +/-inf. */
4949
Interval find_constant_bounds(const Expr &e, const Scope<Interval> &scope);
5050

51+
/** Take a conditional that includes variables that vary over some
52+
* domain, and convert it to a more conservative (less frequently
53+
* true) condition that doesn't depend on those variables. Formally,
54+
* the output expr implies the input expr.
55+
*
56+
* The condition may be a vector condition, in which case we also
57+
* 'and' over the vector lanes, and return a scalar result. */
58+
Expr and_condition_over_domain(const Expr &c, const Scope<Interval> &varying);
59+
60+
/** Take a conditional that includes variables that vary over some
61+
* domain, and convert it to a weaker (less frequently false) condition
62+
* that doesn't depend on those variables. Formally, the input expr
63+
* implies the output expr. Note that this function might be unable to
64+
* provide a better response than simply const_true(). */
65+
Expr or_condition_over_domain(const Expr &c, const Scope<Interval> &varying);
66+
5167
/** Represents the bounds of a region of arbitrary dimension. Zero
5268
* dimensions corresponds to a scalar region. */
5369
struct Box {

src/Simplify_Cast.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
2020
int64_t old_min = value_info.bounds.min;
2121
bool old_min_defined = value_info.bounds.min_defined;
2222
value_info.cast_to(op->type);
23+
if (value.type().is_float() && op->type.is_int_or_uint()) {
24+
// ExprInfo::cast_to handles integer casts, where narrowing wraps
25+
// and preserves alignment modulo the destination width. Float to
26+
// integer casts saturate instead, so the old alignment can be
27+
// wrong after the cast.
28+
value_info.alignment = ModulusRemainder();
29+
}
2330
if (op->type.is_uint() && op->type.bits() == 64 && old_min_defined && old_min > 0) {
2431
// It's impossible for a cast *to* a uint64 in Halide to lower the
2532
// min. Casts to uint64_t don't overflow for any source type.
@@ -110,13 +117,14 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
110117
} else if (cast &&
111118
op->type.is_int_or_uint() &&
112119
cast->type.is_int_or_uint() &&
120+
cast->value.type().is_int_or_uint() &&
113121
op->type.bits() <= cast->type.bits() &&
114122
op->type.bits() <= op->value.type().bits()) {
115123
// If this is a cast between integer types, where the
116124
// outer cast is narrower than the inner cast and the
117125
// inner cast's argument, the inner cast can be
118-
// eliminated. The inner cast is either a sign extend
119-
// or a zero extend, and the outer cast truncates the extended bits
126+
// eliminated. The inner cast is either a sign-extend
127+
// or a zero-extend, and the outer cast truncates the extended bits.
120128
if (op->type == cast->value.type()) {
121129
return mutate(cast->value, info);
122130
} else {

src/Simplify_LT.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) {
4949

5050
rewrite(broadcast(x, c0) < broadcast(y, c0), broadcast(x < y, c0)) ||
5151

52-
// We can learn more from equality than less with mod.
53-
rewrite(x % y < 1, x % y == 0) ||
54-
rewrite(0 < x % y, x % y != 0) ||
55-
rewrite(x % c0 < c1, x % c0 != fold(c0 - 1), c1 + 1 == c0 && c0 > 0) ||
56-
rewrite(c0 < x % c1, x % c1 == fold(c1 - 1), c0 + 2 == c1 && c1 > 0)) ||
52+
// We can learn more from equality than less with (Euclidean) mod.
53+
(!ty.is_float() && EVAL_IN_LAMBDA //
54+
(rewrite(x % y < 1, x % y == 0) ||
55+
rewrite(0 < x % y, x % y != 0) ||
56+
rewrite(x % c0 < c1, x % c0 != fold(c0 - 1), c1 + 1 == c0 && c0 > 0) ||
57+
rewrite(c0 < x % c1, x % c1 == fold(c1 - 1), c0 + 2 == c1 && c1 > 0)))) ||
5758

5859
(no_overflow(ty) && EVAL_IN_LAMBDA //
5960
(rewrite(ramp(x, y, c0) < ramp(z, y, c0), broadcast(x < z, c0)) ||

0 commit comments

Comments
 (0)