Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions xls/ir/interval_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
namespace xls::interval_ops {

namespace {

// How many exact calculations we are willing to perform (or exact values we are
// willing to enumerate) before falling back to a conservative approximation.
static constexpr int64_t kMaxExactCalculations = 16;

TernaryVector ExtractTernaryInterval(const Interval& interval) {
Bits lcp = bits_ops::LongestCommonPrefixMSB(
{interval.LowerBound(), interval.UpperBound()});
Expand Down Expand Up @@ -360,8 +365,6 @@ std::optional<IntervalSet> MaybePerformExactCalculation(
})) {
return std::nullopt;
}
// How many exact calculations we are willing to perform.
static constexpr int64_t kMaxExactCalculations = 16;
int64_t required_calculations = 1;
for (const IntervalSet& is : input_operands) {
// required_calculations *= is.Size();
Expand Down Expand Up @@ -581,7 +584,7 @@ IntervalSet PerformVariadicOp(Calculate calc,
}

result_intervals.Normalize();
return MinimizeIntervals(result_intervals, /*size=*/16);
return MinimizeIntervals(result_intervals, /*size=*/kMaxExactCalculations);
}

template <typename Calculate>
Expand Down Expand Up @@ -1049,6 +1052,63 @@ IntervalSet BitSlice(const IntervalSet& a, int64_t start, int64_t width) {
return Truncate(Shrl(a, IntervalSet::Precise(UBits(start, 64))), width);
}

IntervalSet DynamicBitSlice(const IntervalSet& to_slice,
const IntervalSet& start, int64_t width) {
if (to_slice.IsEmpty() || start.IsEmpty()) {
return IntervalSet(width);
}
CHECK(to_slice.IsNormalized());
CHECK(start.IsNormalized());

const int64_t input_width = to_slice.BitCount();

// Fast path: if the start index is precise, we can compute the result
// directly.
if (start.IsPrecise()) {
const int64_t shift_amount =
bits_ops::UnsignedBitsToSaturatedInt64(*start.GetPreciseValue());
if (shift_amount >= input_width) {
return IntervalSet::Precise(UBits(0, width));
}
return BitSlice(to_slice, shift_amount, width);
}

// Exact path: enumerate possible start indices when the set is small enough.
if (std::optional<int64_t> sz = start.Size();
sz.has_value() && *sz <= kMaxExactCalculations) {
IntervalSet result(width);
for (const Bits& s : start.Values()) {
const int64_t shift_amount = bits_ops::UnsignedBitsToSaturatedInt64(s);
IntervalSet sliced = (shift_amount >= input_width)
? IntervalSet::Precise(UBits(0, width))
: BitSlice(to_slice, shift_amount, width);
result = IntervalSet::Combine(result, sliced);
result.Normalize();
if (result.IsMaximal()) {
return result;
}
}
result.Normalize();
// Avoid unbounded growth if the union gets too large.
if (result.NumberOfIntervals() > kMaxExactCalculations) {
result = MinimizeIntervals(std::move(result),
/*size=*/kMaxExactCalculations);
}
return result;
}

// If the start is always large enough to overshift, the result is always 0.
if (std::optional<Bits> start_lb = start.LowerBound(); start_lb.has_value()) {
const int64_t min_shift = bits_ops::UnsignedBitsToSaturatedInt64(*start_lb);
if (min_shift >= input_width) {
return IntervalSet::Precise(UBits(0, width));
}
}

// Conservative fallback: if the start set is large, avoid enumeration.
return IntervalSet::Maximal(width);
}

IntervalSet Concat(absl::Span<IntervalSet const> sets) {
std::vector<ArgumentBehavior> behaviors(sets.size(),
kMonotoneNonSizePreserving);
Expand Down
2 changes: 2 additions & 0 deletions xls/ir/interval_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ IntervalSet SignExtend(const IntervalSet& a, int64_t width);
IntervalSet ZeroExtend(const IntervalSet& a, int64_t width);
IntervalSet Truncate(const IntervalSet& a, int64_t width);
IntervalSet BitSlice(const IntervalSet& a, int64_t start, int64_t width);
IntervalSet DynamicBitSlice(const IntervalSet& to_slice,
const IntervalSet& start, int64_t width);

// Cmp
IntervalSet Eq(const IntervalSet& a, const IntervalSet& b);
Expand Down
29 changes: 29 additions & 0 deletions xls/ir/interval_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,35 @@ FUZZ_TEST(IntervalOpsTest, BitSliceZ3Fuzz)
.WithDomains(IntervalDomain(8), fuzztest::InRange<int8_t>(1, 15),
fuzztest::InRange<int8_t>(1, 15));

TEST(IntervalOpsTest, DynamicBitSlice) {
IntervalSet x = FromRanges({{10, 200}}, 8);
IntervalSet start = FromValues({0, 2, 7, 8}, 4);

IntervalSet expected(4);
expected = IntervalSet::Combine(expected, interval_ops::BitSlice(x, 0, 4));
expected = IntervalSet::Combine(expected, interval_ops::BitSlice(x, 2, 4));
expected = IntervalSet::Combine(expected, interval_ops::BitSlice(x, 7, 4));
expected = IntervalSet::Combine(expected, IntervalSet::Precise(UBits(0, 4)));
expected.Normalize();

EXPECT_EQ(interval_ops::DynamicBitSlice(x, start, /*width=*/4), expected);
}

void DynamicBitSliceZ3Fuzz(absl::Span<std::pair<int64_t, int64_t> const> lhs,
absl::Span<std::pair<int64_t, int64_t> const> rhs) {
BinaryOpFuzz(
"dynamic_bit_slice",
[](FunctionBuilder& fb, BValue l, BValue r) {
return fb.DynamicBitSlice(l, r, /*width=*/4);
},
[](const IntervalSet& l, const IntervalSet& r) {
return interval_ops::DynamicBitSlice(l, r, /*width=*/4);
},
lhs, rhs, /*bits=*/8);
}
FUZZ_TEST(IntervalOpsTest, DynamicBitSliceZ3Fuzz)
.WithDomains(IntervalDomain(8), IntervalDomain(8));

void EqZ3Fuzz(absl::Span<std::pair<int64_t, int64_t> const> lhs,
absl::Span<std::pair<int64_t, int64_t> const> rhs) {
BinaryOpFuzz(
Expand Down
19 changes: 17 additions & 2 deletions xls/passes/range_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class RangeQueryVisitor : public DfsVisitor {

// The maximum number of points covered by an interval set that can be
// iterated over in an analysis.
static constexpr int64_t kMaxIterationSize = 1024;

// Wrapper around GetIntervalSetTree for consistency with the
// SetIntervalSetTree wrapper.
Expand Down Expand Up @@ -517,7 +516,23 @@ absl::Status RangeQueryVisitor::HandleDecode(Decode* decode) {
absl::Status RangeQueryVisitor::HandleDynamicBitSlice(
DynamicBitSlice* dynamic_bit_slice) {
INITIALIZE_OR_SKIP(dynamic_bit_slice);
return absl::OkStatus(); // TODO(taktoa): implement
ASSIGN_INTERVAL_SET_REF_OR_RETURN(to_slice, dynamic_bit_slice->operand(0));
ASSIGN_INTERVAL_SET_REF_OR_RETURN(start, dynamic_bit_slice->start());
// Many IntervalSet query APIs (e.g. Size/Values/Intervals/ConvexHull) require
// the interval set to be normalized. The RangeQueryEngine may hold
// non-normalized sets, so normalize local copies here.
IntervalSet to_slice_norm = to_slice;
if (!to_slice_norm.IsNormalized()) {
to_slice_norm.Normalize();
}
IntervalSet start_norm = start;
if (!start_norm.IsNormalized()) {
start_norm.Normalize();
}

return SetIntervalSet(dynamic_bit_slice, interval_ops::DynamicBitSlice(
to_slice_norm, start_norm,
dynamic_bit_slice->width()));
}

absl::Status RangeQueryVisitor::HandleDynamicCountedFor(
Expand Down
155 changes: 155 additions & 0 deletions xls/passes/range_query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,51 @@ FUZZ_TEST(RangeQueryEngineFuzzTest, ConcatIsCorrect)
NonemptyNormalizedIntervalSet(5),
NonemptyNormalizedIntervalSet(3));

// This is a property test: for all concrete x/start values allowed by the
// given interval sets, the range analysis result must cover the concrete
// DynamicBitSlice(x, start, 4) value.
//
// We use Covers() (not equality) because RangeQueryEngine is allowed to be
// conservative; this test is intended to detect unsoundness.
void DynamicBitSliceIsCorrect(const IntervalSet& x_intervals,
const IntervalSet& start_intervals) {
constexpr std::string_view kTestName =
"RangeQueryEngineFuzzTest.DynamicBitSliceIsCorrect";

auto p = std::make_unique<VerifiedPackage>(kTestName);
FunctionBuilder fb(kTestName, p.get());
BValue x = fb.Param("x", p->GetBitsType(x_intervals.BitCount()));
BValue start = fb.Param("start", p->GetBitsType(start_intervals.BitCount()));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

x_intervals.ForEachElement([&](const Bits& x_bits) -> bool {
start_intervals.ForEachElement([&](const Bits& start_bits) -> bool {
const int64_t shift_amount =
bits_ops::UnsignedBitsToSaturatedInt64(start_bits);
Bits out(4);
if (shift_amount < x_bits.bit_count()) {
out = bits_ops::ShiftRightLogical(x_bits, shift_amount)
.Slice(0, /*width=*/4);
}
EXPECT_TRUE(engine.GetIntervalSetTree(expr.node()).Get({}).Covers(out));
return false;
});
return false;
});
}

FUZZ_TEST(RangeQueryEngineFuzzTest, DynamicBitSliceIsCorrect)
.WithDomains(NonemptyNormalizedIntervalSet(6),
NonemptyNormalizedIntervalSet(3));

TEST_F(RangeQueryEngineTest, Decode) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down Expand Up @@ -831,6 +876,116 @@ TEST_F(RangeQueryEngineTest, DecodePreciseOverflow) {
EXPECT_EQ("0b00_0000_0000", engine.ToString(expr.node()));
}

TEST_F(RangeQueryEngineTest, DynamicBitSlicePreciseStart) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Literal(UBits(2, 3));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet expected =
interval_ops::BitSlice(x_intervals, /*start=*/2, /*width=*/4);
IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, expected);
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceSmallStartSet) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(4));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
IntervalSet start_intervals = IntervalSet::Of(
{Interval::Precise(UBits(0, 4)), Interval::Precise(UBits(2, 4)),
Interval::Precise(UBits(7, 4)), Interval::Precise(UBits(8, 4))});
start_intervals.Normalize();

IntervalSet expected(4);
expected.Normalize();
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 0, 4));
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 2, 4));
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 7, 4));
expected.AddInterval(Interval::Precise(UBits(0, 4)));
expected.Normalize();

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, expected);
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceAlwaysOvershiftIsZero) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(4));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
IntervalSet start_intervals =
IntervalSet::Of({Interval(UBits(8, 4), UBits(15, 4))});
start_intervals.Normalize();

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, IntervalSet::Precise(UBits(0, 4)));
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceLargeStartFallsBackToMaximal) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(16));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_TRUE(got.IsMaximal());
}

TEST_F(RangeQueryEngineTest, DecodeUnconstrained) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down