Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion xls/passes/range_query_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,83 @@ absl::Status RangeQueryVisitor::HandleDecode(Decode* decode) {
absl::Status RangeQueryVisitor::HandleDynamicBitSlice(
DynamicBitSlice* dynamic_bit_slice) {
INITIALIZE_OR_SKIP(dynamic_bit_slice);
return absl::OkStatus(); // TODO(taktoa): implement
ASSIGN_INTERVAL_SET_REF_OR_RETURN(to_slice, dynamic_bit_slice->operand(0));
ASSIGN_INTERVAL_SET_REF_OR_RETURN(start, dynamic_bit_slice->start());
// Many IntervalSet query APIs (e.g. Size/Values/Intervals/ConvexHull) require
// the interval set to be normalized. The RangeQueryEngine may hold
// non-normalized sets, so normalize local copies here.
IntervalSet to_slice_norm = to_slice;
if (!to_slice_norm.IsNormalized()) {
to_slice_norm.Normalize();
}
IntervalSet start_norm = start;
if (!start_norm.IsNormalized()) {
start_norm.Normalize();
}

const int64_t input_width = dynamic_bit_slice->operand(0)->BitCountOrDie();
const int64_t result_width = dynamic_bit_slice->width();

// Strategy:
// - If the start is precise, compute the slice directly.
// - If the start set is small, enumerate possible start indices and union
// the results (exact within the current `to_slice` interval set).
// - Otherwise, try a quick "always overshift" check, and fall back to
// maximal if we can't prove anything cheaply.
//
// Note that DynamicBitSlice(x, start, width) is equivalent to:
// BitSlice(Shrl(x, start), 0, width)
// with overshift yielding 0 for the shift-right.

// Fast path: if the start index is precise, we can compute the result
// directly.
if (start_norm.IsPrecise()) {
const int64_t shift_amount =
bits_ops::UnsignedBitsToSaturatedInt64(*start_norm.GetPreciseValue());
if (shift_amount >= input_width) {
return SetIntervalSet(dynamic_bit_slice,
IntervalSet::Precise(UBits(0, result_width)));
}
return SetIntervalSet(
dynamic_bit_slice,
interval_ops::BitSlice(to_slice_norm, shift_amount, result_width));
}

// Exact path: enumerate possible start indices when the set is small enough.
if (std::optional<int64_t> sz = start_norm.Size();
sz.has_value() && *sz <= kMaxIterationSize) {
IntervalSet result(result_width);
for (const Bits& s : start_norm.Values()) {
const int64_t shift_amount = bits_ops::UnsignedBitsToSaturatedInt64(s);
if (shift_amount >= input_width) {
result.AddInterval(Interval::Precise(UBits(0, result_width)));
} else {
IntervalSet sliced =
interval_ops::BitSlice(to_slice_norm, shift_amount, result_width);
result = IntervalSet::Combine(result, sliced);
}
result.Normalize();
if (result.Intervals().size() > kMaxResIntervalSetSize) {
result = interval_ops::MinimizeIntervals(std::move(result),
kDefaultIntervalSize);
}
}
result.Normalize();
return SetIntervalSet(dynamic_bit_slice, std::move(result));
}

// If the start is always large enough to overshift, the result is always 0.
if (std::optional<Bits> start_lb = start_norm.LowerBound();
start_lb.has_value()) {
const int64_t min_shift = bits_ops::UnsignedBitsToSaturatedInt64(*start_lb);
if (min_shift >= input_width) {
return SetIntervalSet(dynamic_bit_slice,
IntervalSet::Precise(UBits(0, result_width)));
}
}

// Conservative fallback: if the start set is large, avoid enumeration.
return SetIntervalSet(dynamic_bit_slice, IntervalSet::Maximal(result_width));
}

absl::Status RangeQueryVisitor::HandleDynamicCountedFor(
Expand Down
155 changes: 155 additions & 0 deletions xls/passes/range_query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,51 @@ FUZZ_TEST(RangeQueryEngineFuzzTest, ConcatIsCorrect)
NonemptyNormalizedIntervalSet(5),
NonemptyNormalizedIntervalSet(3));

// This is a property test: for all concrete x/start values allowed by the
// given interval sets, the range analysis result must cover the concrete
// DynamicBitSlice(x, start, 4) value.
//
// We use Covers() (not equality) because RangeQueryEngine is allowed to be
// conservative; this test is intended to detect unsoundness.
void DynamicBitSliceIsCorrect(const IntervalSet& x_intervals,
const IntervalSet& start_intervals) {
constexpr std::string_view kTestName =
"RangeQueryEngineFuzzTest.DynamicBitSliceIsCorrect";

auto p = std::make_unique<VerifiedPackage>(kTestName);
FunctionBuilder fb(kTestName, p.get());
BValue x = fb.Param("x", p->GetBitsType(x_intervals.BitCount()));
BValue start = fb.Param("start", p->GetBitsType(start_intervals.BitCount()));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

x_intervals.ForEachElement([&](const Bits& x_bits) -> bool {
start_intervals.ForEachElement([&](const Bits& start_bits) -> bool {
const int64_t shift_amount =
bits_ops::UnsignedBitsToSaturatedInt64(start_bits);
Bits out(4);
if (shift_amount < x_bits.bit_count()) {
out = bits_ops::ShiftRightLogical(x_bits, shift_amount)
.Slice(0, /*width=*/4);
}
EXPECT_TRUE(engine.GetIntervalSetTree(expr.node()).Get({}).Covers(out));
return false;
});
return false;
});
}

FUZZ_TEST(RangeQueryEngineFuzzTest, DynamicBitSliceIsCorrect)
.WithDomains(NonemptyNormalizedIntervalSet(6),
NonemptyNormalizedIntervalSet(3));

TEST_F(RangeQueryEngineTest, Decode) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down Expand Up @@ -831,6 +876,116 @@ TEST_F(RangeQueryEngineTest, DecodePreciseOverflow) {
EXPECT_EQ("0b00_0000_0000", engine.ToString(expr.node()));
}

TEST_F(RangeQueryEngineTest, DynamicBitSlicePreciseStart) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Literal(UBits(2, 3));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet expected =
interval_ops::BitSlice(x_intervals, /*start=*/2, /*width=*/4);
IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, expected);
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceSmallStartSet) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(4));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
IntervalSet start_intervals = IntervalSet::Of(
{Interval::Precise(UBits(0, 4)), Interval::Precise(UBits(2, 4)),
Interval::Precise(UBits(7, 4)), Interval::Precise(UBits(8, 4))});
start_intervals.Normalize();

IntervalSet expected(4);
expected.Normalize();
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 0, 4));
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 2, 4));
expected =
IntervalSet::Combine(expected, interval_ops::BitSlice(x_intervals, 7, 4));
expected.AddInterval(Interval::Precise(UBits(0, 4)));
expected.Normalize();

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, expected);
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceAlwaysOvershiftIsZero) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(4));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
IntervalSet start_intervals =
IntervalSet::Of({Interval(UBits(8, 4), UBits(15, 4))});
start_intervals.Normalize();

RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
engine.SetIntervalSetTree(start.node(),
BitsLTT(start.node(), start_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_EQ(got, IntervalSet::Precise(UBits(0, 4)));
}

TEST_F(RangeQueryEngineTest, DynamicBitSliceLargeStartFallsBackToMaximal) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());

BValue x = fb.Param("x", p->GetBitsType(8));
BValue start = fb.Param("start", p->GetBitsType(16));
BValue expr = fb.DynamicBitSlice(x, start, /*width=*/4);

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

IntervalSet x_intervals =
IntervalSet::Of({Interval(UBits(10, 8), UBits(200, 8))});
RangeQueryEngine engine;
engine.SetIntervalSetTree(x.node(),
BitsLTT(x.node(), x_intervals.Intervals()));
XLS_ASSERT_OK(engine.Populate(f));

IntervalSet got = engine.GetIntervals(expr.node()).Get({});
EXPECT_TRUE(got.IsMaximal());
}

TEST_F(RangeQueryEngineTest, DecodeUnconstrained) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down