Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions src/ast/euf/euf_seq_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ namespace euf {
}
}

// Saturating unsigned addition: returns UINT_MAX instead of wrapping.
static unsigned saturating_add(unsigned a, unsigned b) {
return (b > UINT_MAX - a) ? UINT_MAX : a + b;
// Check if a + b can be computed in unsigned arithmetic without overflow.
static bool can_add_without_overflow(unsigned a, unsigned b) {
return b <= UINT_MAX - a;
}

unsigned enode_concat_hash::operator()(enode* n) const {
Expand Down Expand Up @@ -183,13 +183,20 @@ namespace euf {
propagate_simplify(n);
}

{
enode* root = a->get_root();
for (enode* cn : m_concats) {
enode *na, *nb;
if (is_concat(cn, na, nb) &&
(na->get_root() == root || nb->get_root() == root))
propagate_simplify(cn);
}
}

// Re-apply identity and absorption rules over all tracked concat nodes.
// This handles the case where the merge caused a child to become equivalent
// to an identity (ε) or absorbing element (∅) that was not known at
// registration time (e.g. b ~ "" discovered after concat(x, b) was registered).
// Also re-simplifies RE concat nodes when a child's root has become full_seq,
// to handle nullable absorption through nested concats:
// concat(.*, concat(v, w)) = concat(.*, w) when v is nullable but w is not.
for (enode* n : m_concats) {
enode *na, *nb;
if (is_str_concat(n, na, nb)) {
Expand All @@ -207,8 +214,6 @@ namespace euf {
push_merge(n, na); // absorb: concat(∅, b) = ∅
else if (is_re_empty(nb))
push_merge(n, nb); // absorb: concat(a, ∅) = ∅
else if (is_full_seq(na->get_root()) || is_full_seq(nb->get_root()))
propagate_simplify(n);
}
}
}
Expand Down Expand Up @@ -260,21 +265,21 @@ namespace euf {

// Rule 1 extended (right): concat(v*, concat(v*, c)) = concat(v*, c)
enode* b1, *b2;
if (is_concat(b, b1, b2) && same_star_body(a, b1))
if (is_concat(b->get_root(), b1, b2) && same_star_body(a, b1))
push_merge(n, b);

// Rule 1 extended (left): concat(concat(c, v*), v*) = concat(c, v*)
enode* a1, *a2;
if (is_concat(a, a1, a2) && same_star_body(a2, b))
if (is_concat(a->get_root(), a1, a2) && same_star_body(a2, b))
push_merge(n, a);

// Rule 2: Nullable absorption by .*
// concat(.*, v) = .* when v is nullable
if (is_full_seq(a) && is_nullable(b))
if (is_full_seq(a->get_root()) && is_nullable(b->get_root()))
push_merge(n, a);

// concat(v, .*) = .* when v is nullable
if (is_nullable(a) && is_full_seq(b))
if (is_nullable(a->get_root()) && is_full_seq(b->get_root()))
push_merge(n, b);

// concat(.*, concat(v, w)) = concat(.*, w) when v nullable
Expand All @@ -286,12 +291,13 @@ namespace euf {
// Rule 3: Loop merging
// concat(v{l1,h1}, v{l2,h2}) = v{l1+l2,h1+h2}
unsigned lo1, hi1, lo2, hi2;
if (same_loop_body(a, b, lo1, hi1, lo2, hi2)) {
if (same_loop_body(a, b, lo1, hi1, lo2, hi2) &&
can_add_without_overflow(lo1, lo2) &&
can_add_without_overflow(hi1, hi2)) {
ast_manager& m = g.get_manager();
enode* body_n = a->get_arg(0);
// saturating add: prevent silent unsigned wrap-around on large bounds
unsigned lo_merged = saturating_add(lo1, lo2);
unsigned hi_merged = saturating_add(hi1, hi2);
unsigned lo_merged = lo1 + lo2;
unsigned hi_merged = hi1 + hi2;
expr_ref merged(m_seq.re.mk_loop(body_n->get_expr(), lo_merged, hi_merged), m);
enode* merged_n = mk(merged, 1, &body_n);
push_merge(n, merged_n);
Expand All @@ -304,10 +310,11 @@ namespace euf {
}

bool seq_plugin::same_star_body(enode* a, enode* b) {
if (!is_star(a) || !is_star(b))
enode* a_root = a->get_root(), *b_root = b->get_root();
if (!is_star(a_root) || !is_star(b_root))
return false;
// re.star(x) and re.star(y) have congruent bodies if x ~ y
return a->get_arg(0)->get_root() == b->get_arg(0)->get_root();
return a_root->get_arg(0)->get_root() == b_root->get_arg(0)->get_root();
}

bool seq_plugin::same_loop_body(enode* a, enode* b,
Expand Down
126 changes: 126 additions & 0 deletions src/test/euf_seq_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Copyright (c) 2026 Microsoft Corporation
#include "ast/euf/euf_egraph.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <climits>
#include <iostream>

static unsigned s_var = 0;
Expand Down Expand Up @@ -290,6 +291,127 @@ static void test_seq_plugin_loop_merge() {
std::cout << g << "\n";
}

// test seq_plugin: star merge should fire when a child is merged into a star.
static void test_seq_plugin_star_merge_after_child_merge() {
std::cout << "test_seq_plugin_star_merge_after_child_merge\n";
ast_manager m;
reg_decl_plugins(m);
euf::egraph eg(m);
euf::sgraph sg(m, eg);
euf::egraph& g = sg.get_egraph();
seq_util seq(m);
sort_ref str_sort(seq.str.mk_string_sort(), m);
sort_ref re_sort(seq.re.mk_re(str_sort), m);

expr_ref x(m.mk_const("x", str_sort), m);
expr_ref a(m.mk_const("a", re_sort), m);
expr_ref to_re_x(seq.re.mk_to_re(x), m);
expr_ref star_x(seq.re.mk_star(to_re_x), m);
expr_ref concat_expr(seq.re.mk_concat(a, star_x), m);

auto* nc = get_node(g, seq, concat_expr);
auto* na = get_node(g, seq, a);
auto* ns = get_node(g, seq, star_x);
g.propagate();

g.merge(na, ns, nullptr);
g.propagate();

SASSERT(nc->get_root() == ns->get_root());
std::cout << g << "\n";
}

// test seq_plugin: extended star merge should use the concat root.
static void test_seq_plugin_star_merge_extended_root() {
std::cout << "test_seq_plugin_star_merge_extended_root\n";
ast_manager m;
reg_decl_plugins(m);
euf::egraph eg(m);
euf::sgraph sg(m, eg);
euf::egraph& g = sg.get_egraph();
seq_util seq(m);
sort_ref str_sort(seq.str.mk_string_sort(), m);
sort_ref re_sort(seq.re.mk_re(str_sort), m);

expr_ref x(m.mk_const("x", str_sort), m);
expr_ref c(m.mk_const("c", str_sort), m);
expr_ref b(m.mk_const("b", re_sort), m);
expr_ref to_re_x(seq.re.mk_to_re(x), m);
expr_ref to_re_c(seq.re.mk_to_re(c), m);
expr_ref star_x(seq.re.mk_star(to_re_x), m);
expr_ref rhs(seq.re.mk_concat(star_x, to_re_c), m);
expr_ref top(seq.re.mk_concat(star_x, b), m);

auto* ntop = get_node(g, seq, top);
auto* nb = get_node(g, seq, b);
auto* nrhs = get_node(g, seq, rhs);
g.propagate();

g.merge(nb, nrhs, nullptr);
g.propagate();

SASSERT(ntop->get_root() == nb->get_root());
std::cout << g << "\n";
}

// test seq_plugin: nullable absorption should use merged roots.
static void test_seq_plugin_nullable_absorb_root() {
std::cout << "test_seq_plugin_nullable_absorb_root\n";
ast_manager m;
reg_decl_plugins(m);
euf::egraph eg(m);
euf::sgraph sg(m, eg);
euf::egraph& g = sg.get_egraph();
seq_util seq(m);
sort_ref str_sort(seq.str.mk_string_sort(), m);
sort_ref re_sort(seq.re.mk_re(str_sort), m);

expr_ref a(m.mk_const("a", re_sort), m);
expr_ref b(m.mk_const("b", re_sort), m);
expr_ref full_seq(seq.re.mk_full_seq(str_sort), m);
expr_ref eps(seq.re.mk_epsilon(str_sort), m);
expr_ref top(seq.re.mk_concat(a, b), m);

auto* ntop = get_node(g, seq, top);
auto* na = get_node(g, seq, a);
auto* nb = get_node(g, seq, b);
auto* nfull = get_node(g, seq, full_seq);
auto* neps = get_node(g, seq, eps);
g.propagate();

g.merge(na, nfull, nullptr);
g.merge(nb, neps, nullptr);
g.propagate();

SASSERT(ntop->get_root() == na->get_root());
std::cout << g << "\n";
}

// test seq_plugin: loop merge should not fire when bounds overflow.
static void test_seq_plugin_loop_merge_overflow_guard() {
std::cout << "test_seq_plugin_loop_merge_overflow_guard\n";
ast_manager m;
reg_decl_plugins(m);
euf::egraph eg(m);
euf::sgraph sg(m, eg);
euf::egraph& g = sg.get_egraph();
seq_util seq(m);
sort_ref str_sort(seq.str.mk_string_sort(), m);

expr_ref x(m.mk_const("x", str_sort), m);
expr_ref r(seq.re.mk_to_re(x), m);
expr_ref l1(seq.re.mk_loop_proper(r, UINT_MAX, UINT_MAX), m);
expr_ref l2(seq.re.mk_loop_proper(r, 1, 1), m);
expr_ref concat_loops(seq.re.mk_concat(l1, l2), m);

auto* nc = get_node(g, seq, concat_loops);
auto* nl1 = get_node(g, seq, l1);
g.propagate();

SASSERT(nc->get_root() != nl1->get_root());
std::cout << g << "\n";
}

void tst_euf_seq_plugin() {
s_var = 0; test_sgraph_basic();
s_var = 0; test_sgraph_backtrack();
Expand All @@ -300,4 +422,8 @@ void tst_euf_seq_plugin() {
s_var = 0; test_sgraph_egraph_sync();
s_var = 0; test_seq_plugin_identity_after_merge();
s_var = 0; test_seq_plugin_loop_merge();
s_var = 0; test_seq_plugin_star_merge_after_child_merge();
s_var = 0; test_seq_plugin_star_merge_extended_root();
s_var = 0; test_seq_plugin_nullable_absorb_root();
s_var = 0; test_seq_plugin_loop_merge_overflow_guard();
}
Loading