Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,73 @@ inline bool IsLiteralZero(Node* node) {
node->As<Literal>()->value().bits().IsZero();
}

// A uniform view of a bits-typed node that represents a zero-extension of a
// narrower bits-typed `base` value.
//
// This matches either:
// - `zero_ext(base, new_bit_count=W)` where base has width N < W
// - `concat(0..., base)` where all leading operands are literal zeros and base
// is the final operand.
struct ZeroExtendedBitsView {
// The original value being extended.
Node* base;
// Bit width of `base`.
int64_t base_width;
// Bit width of the zero-extended result (i.e. the node being matched).
int64_t result_width;
// Number of leading zero bits added (result_width - base_width).
int64_t leading_zero_width;
};

// Returns a view if `node` is a zero extension of a narrower bits value, as
// defined by `ZeroExtendedBitsView`.
inline std::optional<ZeroExtendedBitsView> MatchZeroExtendedBits(Node* node) {
if (node == nullptr || !node->GetType()->IsBits()) {
return std::nullopt;
}
const int64_t result_width = node->BitCountOrDie();

if (node->op() == Op::kZeroExt) {
ExtendOp* ext = node->As<ExtendOp>();
Node* base = ext->operand(0);
const int64_t base_width = base->BitCountOrDie();
if (base_width < result_width) {
return ZeroExtendedBitsView{
.base = base,
.base_width = base_width,
.result_width = result_width,
.leading_zero_width = result_width - base_width};
}
return std::nullopt;
}

if (node->op() == Op::kConcat) {
Concat* concat = node->As<Concat>();
if (concat->operand_count() < 2) {
return std::nullopt;
}
int64_t prefix_width = 0;
for (int64_t i = 0; i < concat->operand_count() - 1; ++i) {
Node* prefix = concat->operand(i);
if (!IsLiteralZero(prefix)) {
return std::nullopt;
}
prefix_width += prefix->BitCountOrDie();
}
Node* base = concat->operand(concat->operand_count() - 1);
const int64_t base_width = base->BitCountOrDie();
if (prefix_width <= 0) {
return std::nullopt;
}
return ZeroExtendedBitsView{.base = base,
.base_width = base_width,
.result_width = result_width,
.leading_zero_width = prefix_width};
}

return std::nullopt;
}

// Returns true if the given node is a literal with the value one when
// interpreted as an unsigned number
inline bool IsLiteralUnsignedOne(Node* node) {
Expand Down Expand Up @@ -123,6 +190,29 @@ inline bool AnyTwoOperandsWhere(Node* node,
return false;
}

// Returns true if `pred_a` and `pred_b` match `a` and `b` in either order.
//
// If a match is found, `on_match(matched_a, matched_b)` is invoked with
// `matched_a` being the node that satisfied `pred_a` and `matched_b` being the
// node that satisfied `pred_b`.
//
// This is useful for matching commutative patterns while populating additional
// context via captures in `on_match`.
inline bool MatchNodesInAnyOrder(
Node* a, Node* b, const std::function<bool(Node*)>& pred_a,
const std::function<bool(Node*)>& pred_b,
const std::function<void(Node*, Node*)>& on_match) {
if (pred_a(a) && pred_b(b)) {
on_match(a, b);
return true;
}
if (pred_a(b) && pred_b(a)) {
on_match(b, a);
return true;
}
return false;
}

inline bool HasSingleUse(Node* node) {
if (node->function_base()->HasImplicitUse(node)) {
return node->users().empty();
Expand Down
34 changes: 34 additions & 0 deletions xls/ir/node_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,40 @@ TEST_F(NodeUtilTest, MatchBinarySelectLikeNonMatch) {
EXPECT_FALSE(arms.has_value());
}

TEST_F(NodeUtilTest, MatchNodesInAnyOrder) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue x = fb.Param("x", p->GetBitsType(8));
BValue y = fb.Param("y", p->GetBitsType(8));
BValue sum = fb.Add(x, y);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sum));

const auto pred_x = [](Node* n) { return n->GetName() == "x"; };
const auto pred_y = [](Node* n) { return n->GetName() == "y"; };

// Already in order.
Node* matched_a = nullptr;
Node* matched_b = nullptr;
EXPECT_TRUE(MatchNodesInAnyOrder(f->param(0), f->param(1), pred_x, pred_y,
[&](Node* a, Node* b) {
matched_a = a;
matched_b = b;
}));
EXPECT_EQ(matched_a, f->param(0));
EXPECT_EQ(matched_b, f->param(1));

// Swapped order.
matched_a = nullptr;
matched_b = nullptr;
EXPECT_TRUE(MatchNodesInAnyOrder(f->param(0), f->param(1), pred_y, pred_x,
[&](Node* a, Node* b) {
matched_a = a;
matched_b = b;
}));
EXPECT_EQ(matched_a, f->param(1));
EXPECT_EQ(matched_b, f->param(0));
}

TEST_F(NodeUtilTest, GatherTreeBits) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down
101 changes: 101 additions & 0 deletions xls/passes/bit_slice_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,101 @@ absl::StatusOr<std::optional<Node*>> GetUnscaledIndex(
return unscaled_index;
}

// Simplifies bit-slices that extract the carry bit from an addition of a
// zero-extended value and a literal:
//
// x: bits[N]
// zext: bits[W] = zero_ext(x) or concat(0..., x) (W > N)
// sum: bits[W] = add(zext, K) (K is a literal)
// ret: bits[1] = bit_slice(sum, start=N, width=1)
//
// This bit slice extracts the carry-out bit of x + K (in N bits). Rewrite it
// into a simpler comparison against a literal.
static absl::StatusOr<bool> SimplifyCarryExtraction(BitSlice* bit_slice) {
if (bit_slice->width() != 1) {
return false;
}
Node* add = bit_slice->operand(0);
if (add->op() != Op::kAdd) {
return false;
}
const int64_t add_width = add->BitCountOrDie();
Node* add_lhs = add->operand(0);
Node* add_rhs = add->operand(1);

// Match (zero_ext(v), literal) in either operand order.
std::optional<ZeroExtendedBitsView> maybe_v;
Literal* literal = nullptr;
if (!MatchNodesInAnyOrder(
add_lhs, add_rhs,
[](Node* n) { return MatchZeroExtendedBits(n).has_value(); },
[](Node* n) { return n->Is<Literal>(); },
[&](Node* zeroext_node, Node* literal_node) {
maybe_v = MatchZeroExtendedBits(zeroext_node);
literal = literal_node->As<Literal>();
})) {
return false;
}
DCHECK(maybe_v.has_value());
DCHECK(literal != nullptr);
DCHECK_EQ(maybe_v->result_width, add_width);

Node* v = maybe_v->base; // bits[N]
const int64_t n_width = maybe_v->base_width;
DCHECK_LT(n_width, add_width);
if (n_width <= 0) {
return false;
}
if (bit_slice->start() != n_width) {
return false;
}
const Bits k = literal->value().bits();
DCHECK_EQ(k.bit_count(), add_width);

// Let `A = zero_ext(v)` (so `A[N] = 0`) and `B = k` (so `B[N] = b_n`).
// Then `sum[N] = b_n XOR carry_in`, where `carry_in` is the carry-out from
// adding the low N bits: `v + k_low`.
const bool b_n = k.Get(n_width);
Bits k_low_wide = k.Slice(0, n_width);
if (k_low_wide.IsZero()) {
// No carry-in is possible; sum[N] == bN.
XLS_RETURN_IF_ERROR(
bit_slice->ReplaceUsesWithNew<Literal>(Value(UBits(b_n ? 1 : 0, 1)))
.status());
return true;
}

// `carry_in(v + k_low)` <=> `v >= 2^N - k_low`
Bits k_low_ext = bits_ops::ZeroExtend(k_low_wide, n_width + 1);
Bits two_pow_n = Bits::PowerOfTwo(n_width, /*bit_count=*/n_width + 1);
Bits threshold_ext = bits_ops::Sub(two_pow_n, k_low_ext);
Bits threshold = threshold_ext.Slice(0, n_width);
XLS_ASSIGN_OR_RETURN(Node * threshold_literal,
bit_slice->function_base()->MakeNode<Literal>(
bit_slice->loc(), Value(threshold)));

if (!b_n) {
XLS_ASSIGN_OR_RETURN(Node * cmp,
bit_slice->function_base()->MakeNode<CompareOp>(
bit_slice->loc(), v, threshold_literal, Op::kUGe));
VLOG(3) << absl::StreamFormat(
"Replacing bitslice(add(zext(x), k), start=N) => uge(x, T): %s",
bit_slice->GetName());
XLS_RETURN_IF_ERROR(bit_slice->ReplaceUsesWith(cmp));
return true;
}

// `bN==1`: `sum[N] == !carry_in == v < threshold`
XLS_ASSIGN_OR_RETURN(Node * cmp,
bit_slice->function_base()->MakeNode<CompareOp>(
bit_slice->loc(), v, threshold_literal, Op::kULt));
VLOG(3) << absl::StreamFormat(
"Replacing bitslice(add(zext(x), k), start=N) => ult(x, T): %s",
bit_slice->GetName());
XLS_RETURN_IF_ERROR(bit_slice->ReplaceUsesWith(cmp));
return true;
}

// Attempts to replace the given bit slice with a simpler or more canonical
// form. Returns true if the bit slice was replaced. Any newly created
// bit-slices are added to the worklist.
Expand All @@ -159,6 +254,12 @@ absl::StatusOr<bool> SimplifyBitSlice(BitSlice* bit_slice, int64_t opt_level,
Node* operand = bit_slice->operand(0);
BitsType* operand_type = operand->GetType()->AsBitsOrDie();

XLS_ASSIGN_OR_RETURN(bool carry_rewritten,
SimplifyCarryExtraction(bit_slice));
if (carry_rewritten) {
return true;
}

// Creates a new bit slice and adds it to the worklist.
auto make_bit_slice = [&](const SourceInfo& loc, Node* operand, int64_t start,
int64_t width) -> absl::StatusOr<BitSlice*> {
Expand Down
Loading