Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions xls/ir/node_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,61 @@ std::vector<Node*> RemoveRedundantNodes(

} // namespace

std::optional<ShiftedBitView> IsOneShiftedBit(Node* node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This strikes me as a specialized form of querying bit_provenance_analysis for the source of a node's bits and then asking if all bits are literal 0 except for 1 bit. Can you look into using that analysis instead of implementing pattern matchers for these two cases?

// Match: shll(zext(b), literal(k))
if (node->op() == Op::kShll) {
Node* shift_base = node->operand(0);
Node* shift_amount = node->operand(1);
if (shift_base->op() == Op::kZeroExt &&
IsSingleBitType(shift_base->operand(0)) &&
shift_amount->Is<Literal>()) {
absl::StatusOr<uint64_t> k_u64 =
shift_amount->As<Literal>()->value().bits().ToUint64();
if (!k_u64.ok()) {
return std::nullopt;
}
return ShiftedBitView{.b = shift_base->operand(0),
.k = static_cast<int64_t>(*k_u64)};
}
}

// Match: concat(0..., b, 0...)
if (node->Is<Concat>()) {
std::optional<int64_t> b_operand_index;
for (int64_t i = 0; i < node->operand_count(); ++i) {
Node* operand = node->operand(i);
if (!IsSingleBitType(operand)) {
continue;
}
if (b_operand_index.has_value()) {
// More than one 1-bit operand.
return std::nullopt;
}
b_operand_index = i;
}
if (!b_operand_index.has_value()) {
return std::nullopt;
}

for (int64_t i = 0; i < node->operand_count(); ++i) {
if (i == *b_operand_index) {
continue;
}
if (!IsLiteralZero(node->operand(i))) {
return std::nullopt;
}
}

int64_t k = 0;
for (int64_t i = *b_operand_index + 1; i < node->operand_count(); ++i) {
k += node->operand(i)->BitCountOrDie();
}
return ShiftedBitView{.b = node->operand(*b_operand_index), .k = k};
}

return std::nullopt;
}

bool IsLiteralWithRunOfSetBits(Node* node, int64_t* leading_zero_count,
int64_t* set_bit_count,
int64_t* trailing_zero_count) {
Expand Down
16 changes: 16 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@

namespace xls {

struct ShiftedBitView {
Node* b;
int64_t k;
};

// Returns (b, k) if `node` is structurally equivalent to a value with a single
// potentially-set bit at position `k` (0 == LSb) controlled by the 1-bit value
// `b`. The recognized forms are:
//
// * shll(zext(b), literal(k))
// * concat(0..., b, 0...)
//
// This is a structural matcher and only recognizes literal zeros / literal shift
// amounts (it does not use any query engine).
std::optional<ShiftedBitView> IsOneShiftedBit(Node* node);

inline bool IsLiteralZero(Node* node) {
return node->Is<Literal>() && node->As<Literal>()->value().IsBits() &&
node->As<Literal>()->value().bits().IsZero();
Expand Down
46 changes: 46 additions & 0 deletions xls/passes/arith_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,52 @@ absl::StatusOr<bool> MatchArithPatterns(int64_t opt_level, Node* n,
}
}

// Pattern:
//
// umod(x, shll(zext(b), k)) -> sel(b, zext(bit_slice(x, 0, k), width(x)), 0)
//
// where b is a 1-bit value.
if (n->op() == Op::kUMod) {
Node* x = n->operand(0);
Node* divisor = n->operand(1);
const int64_t bit_count = x->BitCountOrDie();

auto replace_with_select = [&](Node* b, int64_t k) -> absl::StatusOr<bool> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only ever returns true

XLS_RET_CHECK_EQ(b->BitCountOrDie(), 1);
if (k <= 0 || k >= bit_count) {
XLS_RETURN_IF_ERROR(
n->ReplaceUsesWithNew<Literal>(ZeroOfType(n->GetType())).status());
return true;
}

XLS_ASSIGN_OR_RETURN(
Node * slice, n->function_base()->MakeNode<BitSlice>(
n->loc(), x, /*start=*/0, /*width=*/k));
XLS_ASSIGN_OR_RETURN(
Node * narrowed,
n->function_base()->MakeNode<ExtendOp>(n->loc(), slice, bit_count,
Op::kZeroExt));
XLS_ASSIGN_OR_RETURN(
Node * zero, n->function_base()->MakeNode<Literal>(
n->loc(), Value(UBits(0, bit_count))));
XLS_RETURN_IF_ERROR(
n->ReplaceUsesWithNew<Select>(
b, std::vector<Node*>{zero, narrowed},
/*default_value=*/std::nullopt)
.status());
return true;
};

std::optional<ShiftedBitView> shifted_bit = IsOneShiftedBit(divisor);
if (shifted_bit.has_value()) {
XLS_ASSIGN_OR_RETURN(bool changed,
replace_with_select(shifted_bit->b, shifted_bit->k));
if (changed) {
return true;
}
}
}

// Pattern: UMod/SMod by a literal.
if (n->OpIn({Op::kUMod, Op::kSMod}) &&
query_engine.IsFullyKnown(n->operand(1))) {
Expand Down
22 changes: 22 additions & 0 deletions xls/passes/arith_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,28 @@ TEST_F(ArithSimplificationPassTest, UModByVariable) {
EXPECT_THAT(f->return_value(), m::UMod());
}

TEST_F(ArithSimplificationPassTest, UModShiftedOneBitDivisor) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* bits32 = p->GetBitsType(32);
BValue x = fb.Param("x", bits32);
BValue b = fb.Param("b", p->GetBitsType(1));
const int64_t k = 2;
BValue divisor = fb.Shll(fb.ZeroExtend(b, 32), fb.Literal(UBits(k, 32)));
fb.UMod(x, divisor);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ScopedVerifyEquivalence sve(f, kProverTimeout);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));

EXPECT_THAT(
f->return_value(),
m::Select(m::Param("b"),
/*cases=*/{m::Literal(UBits(0, 32)),
m::ZeroExt(m::BitSlice(m::Param("x"),
/*start=*/0, /*width=*/k))}));
}

TEST_F(ArithSimplificationPassTest, UModOf13) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down
45 changes: 45 additions & 0 deletions xls/passes/optimization_pass_pipeline_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,51 @@ TEST_F(OptimizationPipelineTest, MultiplyBy16StrengthReduction) {
EXPECT_THAT(f->return_value(), m::Concat());
}

TEST_F(OptimizationPipelineTest, UModShiftedOneBitDivisorSimplified) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* bits32 = p->GetBitsType(32);
BValue x = fb.Param("x", bits32);
BValue b = fb.Param("b", p->GetBitsType(1));

// Construct the pattern:
// umod(x, shll(zext(b), k))
const int64_t k = 2;
BValue divisor = fb.Shll(fb.ZeroExtend(b, 32), fb.Literal(UBits(k, 32)));
fb.UMod(x, divisor);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));

// In the pipeline we expect:
//
// umod(x, shll(zext(b), k)) -> sel(b, zext(bit_slice(x, 0, k), width(x)), 0)
//
// which will often then become:
//
// and(sign_ext(b), zext(bit_slice(x, 0, k), width(x))).
auto reduced =
m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/k));
auto and_form = m::And(m::SignExt(m::Param("b")), reduced);
auto and_form_flipped = m::And(reduced, m::SignExt(m::Param("b")));
auto sel_form = m::Select(m::Param("b"),
/*cases=*/{m::Literal(UBits(0, 32)), reduced});
auto narrowed_and = m::And(m::SignExt(m::Param("b")),
m::BitSlice(m::Param("x"), /*start=*/0,
/*width=*/k));
auto narrowed_and_flipped =
m::And(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/k),
m::SignExt(m::Param("b")));
auto concat_form =
m::Concat(m::Literal(UBits(0, 32 - k)), narrowed_and);
auto concat_form_flipped =
m::Concat(m::Literal(UBits(0, 32 - k)), narrowed_and_flipped);
EXPECT_THAT(f->return_value(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any convenient way to restrict the pass pipeline such that we don't need to assert over all these possible forms?

::testing::AnyOf(and_form, and_form_flipped, sel_form, concat_form,
concat_form_flipped))
<< f->DumpIr();
}

TEST_F(OptimizationPipelineTest, LogicAbsorption) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down