Skip to content

Commit 081a9e6

Browse files
authored
Merge pull request #243 from JuliaSymbolics/3.0_minor_fixes_and_improvements
3.0 minor fixes and improvements
2 parents dab995c + 3373cb6 commit 081a9e6

5 files changed

Lines changed: 37 additions & 18 deletions

File tree

src/EGraphs/egraph.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ function Base.union!(
361361
)
362362

363363
g.classes[id_1] = new_eclass
364+
modify!(g, new_eclass)
364365

365366
return true
366367
end
@@ -425,21 +426,23 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
425426
eclass = g.classes[eclass_id_key]
426427

427428
node_data = make(g, node)
428-
if !isnothing(eclass.data)
429-
joined_data = join(eclass.data, node_data)
430-
431-
if joined_data != eclass.data
432-
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data)
433-
# eclass.data = joined_data
429+
if !isnothing(node_data)
430+
if !isnothing(eclass.data)
431+
joined_data = join(eclass.data, node_data)
432+
433+
if joined_data != eclass.data
434+
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data)
435+
# eclass.data = joined_data
436+
modify!(g, eclass)
437+
append!(g.analysis_pending, eclass.parents)
438+
end
439+
else
440+
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data)
441+
# eclass.data = node_data
434442
modify!(g, eclass)
435443
append!(g.analysis_pending, eclass.parents)
436444
end
437-
else
438-
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, node_data)
439-
# eclass.data = node_data
440-
modify!(g, eclass)
441445
end
442-
443446
end
444447
end
445448
n_unions

src/EGraphs/saturation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ function eqsat_search!(
8686
@debug "$rule is banned"
8787
continue
8888
end
89-
ids_left = cached_ids(g, rule.left)
90-
ids_right = is_bidirectional(rule) ? cached_ids(g, rule.right) : UNDEF_ID_VEC
9189

90+
ids_left = cached_ids(g, rule.left)
9291
for i in ids_left
9392
cansearch(scheduler, rule_idx, i) || continue
9493
n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer)
9594
inform!(scheduler, rule_idx, i, n_matches)
9695
end
9796

9897
if is_bidirectional(rule)
98+
ids_right = cached_ids(g, rule.right)
9999
for i in ids_right
100100
cansearch(scheduler, rule_idx, i) || continue
101101
n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer)

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,27 @@ function buffer_readable(g, limit, ematch_buffer)
3636
k = length(ematch_buffer)
3737

3838
while k > limit
39-
delimiter = ematch_buffer[k]
39+
delimiter = ematch_buffer.v[k]
4040
@assert delimiter == 0xffffffffffffffffffffffffffffffff
4141
n = k - 1
4242

4343
next_delimiter_idx = 0
4444
n_elems = 0
4545
for i in n:-1:1
4646
n_elems += 1
47-
if ematch_buffer[i] == 0xffffffffffffffffffffffffffffffff
47+
if ematch_buffer.v[i] == 0xffffffffffffffffffffffffffffffff
4848
n_elems -= 1
4949
next_delimiter_idx = i
5050
break
5151
end
5252
end
5353

54-
match_info = ematch_buffer[next_delimiter_idx + 1]
54+
match_info = ematch_buffer.v[next_delimiter_idx + 1]
5555
id = v_pair_first(match_info)
5656
rule_idx = reinterpret(Int, v_pair_last(match_info))
5757
rule_idx = abs(rule_idx)
5858

59-
bindings = @view ematch_buffer[(next_delimiter_idx + 2):n]
59+
bindings = @view ematch_buffer.v[(next_delimiter_idx + 2):n]
6060

6161
print("$id E-Classes: ", map(x -> reinterpret(Int, v_pair_first(x)), bindings))
6262
print(" Nodes: ", map(x -> reinterpret(Int, v_pair_last(x)), bindings), "\n")

test/egraphs/ematch.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,22 @@ end
289289
@test test_equality(some_theory, :(a * b * 0), 0)
290290
end
291291

292+
@testset "Dynamic rule predicates in EMatcher" begin
293+
g = EGraph(:(2 * 3))
294+
zero_id = addexpr!(g, 0)
295+
296+
some_theory = @theory begin
297+
~a * ~b => 0 where (iszero(a) || iszero(b))
298+
~a * ~b --> ~b * ~a
299+
end
300+
301+
Base.iszero(ec::EClass) = in_same_class(g, zero_id, ec.id)
302+
303+
saturate!(g, some_theory)
304+
305+
@test test_equality(some_theory, :(a * b * 0), 0)
306+
end
307+
292308
@testset "Inequalities" begin
293309
failme = @theory p begin
294310
p != !p

test/integration/stream_fusion.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ fold_theory = @theory x y z begin
4949
x::Number * y::Number => x * y
5050
x::Number + y::Number => x + y
5151
x::Number / y::Number => x / y
52-
x::Number - y::Number => x / y
52+
x::Number - y::Number => x - y
5353
# etc...
5454
end
5555

0 commit comments

Comments
 (0)