Skip to content

Commit 0b2e6c9

Browse files
authored
Merge pull request #239 from gkronber/fix_enode_memo_2
Fix hashing and memoization of enodes (VecExpr)
2 parents 1dc53da + 66ea780 commit 0b2e6c9

3 files changed

Lines changed: 29 additions & 54 deletions

File tree

src/EGraphs/egraph.jl

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -224,22 +224,6 @@ Returns the canonical e-class id for a given e-class.
224224

225225
@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]
226226

227-
# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
228-
# if !v_isexpr(n)
229-
# v_hash!(n)
230-
# return n
231-
# end
232-
# l = v_arity(n)
233-
# new_n = v_new(l)
234-
# v_set_flag!(new_n, v_flags(n))
235-
# v_set_head!(new_n, v_head(n))
236-
# for i in v_children_range(n)
237-
# @inbounds new_n[i] = find(g, n[i])
238-
# end
239-
# v_hash!(new_n)
240-
# new_n
241-
# end
242-
243227
function canonicalize!(g::EGraph, n::VecExpr)
244228
v_isexpr(n) || @goto ret
245229
for i in (VECEXPR_META_LENGTH + 1):length(n)
@@ -253,19 +237,16 @@ end
253237

254238
function lookup(g::EGraph, n::VecExpr)::Id
255239
canonicalize!(g, n)
256-
h = IdKey(v_hash(n))
257240

258-
haskey(g.memo, n) ? find(g, g.memo[n]) : 0
241+
id = get(g.memo, n, zero(Id))
242+
iszero(id) ? id : find(g, id)
259243
end
260244

261245

262246
function add_class_by_op(g::EGraph, n, eclass_id)
263247
key = IdKey(v_signature(n))
264-
if haskey(g.classes_by_op, key)
265-
push!(g.classes_by_op[key], eclass_id)
266-
else
267-
g.classes_by_op[key] = [eclass_id]
268-
end
248+
vec = get!(g.classes_by_op, key, Vector{Id}())
249+
push!(vec, eclass_id)
269250
end
270251

271252
"""
@@ -274,7 +255,8 @@ Inserts an e-node in an [`EGraph`](@ref)
274255
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
275256
canonicalize!(g, n)
276257

277-
haskey(g.memo, n) && return g.memo[n]
258+
id = get(g.memo, n, zero(Id))
259+
iszero(id) || return id
278260

279261
if should_copy
280262
n = copy(n)
@@ -291,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
291273
g.memo[n] = id
292274

293275
add_class_by_op(g, n, id)
294-
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
276+
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
295277
g.classes[IdKey(id)] = eclass
296278
modify!(g, eclass)
297279
push!(g.pending, n => id)
@@ -320,28 +302,22 @@ function addexpr!(g::EGraph, se)::Id
320302
se isa EClass && return se.id
321303
e = preprocess(se)
322304

323-
n = if isexpr(e)
324-
args = iscall(e) ? arguments(e) : children(e)
325-
ar = length(args)
326-
n = v_new(ar)
327-
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
328-
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
329-
330-
h = iscall(e) ? operation(e) : head(e)
331-
v_set_head!(n, add_constant!(g, h))
332-
333-
# get the signature from op and arity
334-
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
335-
336-
for i in v_children_range(n)
337-
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
338-
end
339-
n
340-
else # constant enode
341-
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
305+
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)
306+
307+
args = iscall(e) ? arguments(e) : children(e)
308+
ar = length(args)
309+
n = v_new(ar)
310+
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
311+
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
312+
h = iscall(e) ? operation(e) : head(e)
313+
v_set_head!(n, add_constant!(g, h))
314+
# get the signature from op and arity
315+
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
316+
for i in v_children_range(n)
317+
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
342318
end
343-
id = add!(g, n, false)
344-
return id
319+
320+
add!(g, n, false)
345321
end
346322

347323
"""
@@ -431,10 +407,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
431407
while !isempty(g.pending) || !isempty(g.analysis_pending)
432408
while !isempty(g.pending)
433409
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
410+
node = copy(node)
434411
canonicalize!(g, node)
435-
if haskey(g.memo, node)
436-
old_class_id = g.memo[node]
437-
g.memo[node] = eclass_id
412+
old_class_id = get!(g.memo, node, eclass_id)
413+
if old_class_id != eclass_id
438414
did_something = union!(g, old_class_id, eclass_id)
439415
# TODO unique! can node dedup be moved here? compare performance
440416
# did_something && unique!(g[eclass_id].nodes)
@@ -474,9 +450,8 @@ function check_memo(g::EGraph)::Bool
474450
for (id, class) in g.classes
475451
@assert id.val == class.id
476452
for node in class.nodes
477-
if haskey(test_memo, node)
478-
old_id = test_memo[node]
479-
test_memo[node] = id.val
453+
old_id = get!(test_memo, node, id.val)
454+
if old_id != id.val
480455
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
481456
end
482457
end

src/EGraphs/uniquequeue.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ function Base.pop!(uq::UniqueQueue{T}) where {T}
3030
v
3131
end
3232

33-
Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
33+
Base.isempty(uq::UniqueQueue) = isempty(uq.vec)

src/vecexpr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080

8181
"""The hash of the e-node."""
8282
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
83-
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
83+
Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here
8484
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])
8585

8686
"""Set e-node hash to zero."""

0 commit comments

Comments
 (0)