Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ let
# Small: 1 chain covers ~5% of the IR (<10%)
small_exprs_si = Set{BasicSymbolic{SymReal}}([chains_si[1]])
# Large: first 12 chains cover ~60% of the IR (>50%), leaving 8 chains out
large_exprs_si = Set{BasicSymbolic{SymReal}}(chains_si[1:12])
large_exprs_si = chains_si[1:12]

si["small"] = @benchmarkable SymbolicUtils.subset_ir($(subber_si.ir), $small_exprs_si)
si["large"] = @benchmarkable SymbolicUtils.subset_ir($(subber_si.ir), $large_exprs_si)
Expand Down
22 changes: 9 additions & 13 deletions src/irstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,28 +440,24 @@ function subset_ir(
exprs::Union{AbstractArray{BasicSymbolic{T}}, AbstractSet{BasicSymbolic{T}}}
) where {T}
new_ir = IRStructure{T}()
reachables = get_cached_mask!(ir, length(ir))
expr_reach = get_cached_idxs!(ir)
visited = get_cached_mask!(ir, length(ir))
reachables = get_cached_idxs!(ir)
empty!(reachables)
for expr in exprs
expr_i = get(ir.definition, expr, 0)
iszero(expr_i) && _throw_expr_not_in_ir(expr)
reachables[expr_i] = true
empty!(expr_reach)
get_reachability!(expr_reach, ir, expr_i)
reachables[expr_reach] .= true
get_reachability!(reachables, ir, expr_i; visited)
visited[expr_i] = true
push!(reachables, expr_i)
end

n_new_verts = count(reachables)
n_new_verts = length(reachables)
Graphs.add_vertices!(new_ir.dependency_graph, n_new_verts)
sizehint!(new_ir, n_new_verts)

# Instead of calling `populate_ir!`, we can directly build the new IR.
# Iterate in topological order (children before parents) so that when we
# translate edges to new indices, the dependency is already in `new_ir.definition`.
topo_order = Graphs.topological_sort_by_dfs(ir.dependency_graph)
inew = 0
for iold in Iterators.reverse(topo_order)
reachables[iold] || continue
# `reachables` is in topological order
for iold in reachables
inew += 1
# Add expression to the IR
sym = ir.symbols[iold]
Expand Down
Loading