Skip to content
4 changes: 2 additions & 2 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
function merge_analysis_data!(a::EClass{D}, b::EClass{D})::Tuple{Bool,Bool,Union{D,Nothing}} where {D}
if !isnothing(a.data) && !isnothing(b.data)
new_a_data = join(a.data, b.data)
(a.data == new_a_data, b.data == new_a_data, new_a_data)
(a.data != new_a_data, b.data != new_a_data, new_a_data)
elseif isnothing(a.data) && !isnothing(b.data)
# a merged, b not merged
(true, false, b.data)
Expand Down Expand Up @@ -508,7 +508,7 @@ function rebuild!(g::EGraph)
n_unions = process_unions!(g)
trimmed_nodes = rebuild_classes!(g)
# @assert check_memo(g)
# @assert check_analysis(g)
@assert check_analysis(g)
Comment thread
0x0f0f0f marked this conversation as resolved.
Outdated
g.clean = true

@debug "REBUILT" n_unions trimmed_nodes
Expand Down
35 changes: 22 additions & 13 deletions test/tutorials/lambda_theory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,47 +102,55 @@ end

const LambdaAnalysis = Set{Symbol}

getdata(eclass) = isnothing(eclass.data) ? LambdaAnalysis() : eclass.data
getdata(eclass) = eclass.data

function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType}
v_isexpr(n) || LambdaAnalysis()
v_isexpr(n) || return LambdaAnalysis()
if v_iscall(n)
h = v_head(n)
op = get_constant(g, h)
args = v_children(n)
eclass = g[args[1]]
free = getdata(eclass)
free = copy(getdata(eclass))

if op == Variable
push!(free, get_constant(g, v_head(eclass.nodes[1])))

elseif op == Let
v, a, b = args[1:3]
v, a, b = args[1:3] # v=a in b
vclass = g[v]
vsy = get_constant(g, v_head(vclass.nodes[1]))
adata = getdata(g[a])
bdata = getdata(g[b])
union!(free, adata)
delete!(free, v)
union!(free, bdata)

delete!(free, vsy)
union!(free, adata)
elseif op == λ
v, b = args[1:2]
vclass = g[v]
vsy = get_constant(g, v_head(vclass.nodes[1]))
bdata = getdata(g[b])
union!(free, bdata)
delete!(free, v)

delete!(free, vsy)
elseif op == Apply
l, v = args[1:2]
ldata = getdata(g[l])
vdata = getdata(g[v])
union!(free, ldata)
union!(free, vdata)

end
return free
end
end

EGraphs.join(from::LambdaAnalysis, to::LambdaAnalysis) = union(from, to)
function EGraphs.join(from::LambdaAnalysis, to::LambdaAnalysis)
if issubset(from, to) # includes case from==to
from
elseif issubset(to, from)
to
else
error("inconsistent free variable sets from: $from to: $to")
end
Comment thread
0x0f0f0f marked this conversation as resolved.
end

function fresh_var_generator()
idx = 0
Expand All @@ -159,6 +167,7 @@ freshvar = fresh_var_generator()
# The final ruleset then looks like below and correctly renames variables when needed:

λT = @theory v e c v1 v2 a b body begin
# let(v,e,body) means let v = e in body
Let(v, e, c::Any) --> c
Let(v1, e, Variable(v1)) --> e
Let(v1, e, Variable(v2)) => v1 == v2 ? e : Variable(v2)
Expand Down Expand Up @@ -203,7 +212,7 @@ params = SaturationParams(
)
saturate!(g, λT, params)
two_ = extract!(g, astsize)
@test two_ == λ(:a₁, λ(:a₇, Apply(Variable(:a₁), Apply(Variable(:a₁), Variable(:a₇)))))
@test two_ == λ(:x, λ(:y, Apply(Variable(:x), Apply(Variable(:x), Variable(:y)))))
two_

# which is the same as `two` up to $\alpha$-conversion:
Expand Down