diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 120234e20..eb8b5175a 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -279,7 +279,7 @@ let # Small: 1 chain covers ~5% of the IR (<10%) small_exprs_si = Set{BasicSymbolic{SymReal}}([chains_si[1]]) # Large: first 12 chains cover ~60% of the IR (>50%), leaving 8 chains out - large_exprs_si = Set{BasicSymbolic{SymReal}}(chains_si[1:12]) + large_exprs_si = chains_si[1:12] si["small"] = @benchmarkable SymbolicUtils.subset_ir($(subber_si.ir), $small_exprs_si) si["large"] = @benchmarkable SymbolicUtils.subset_ir($(subber_si.ir), $large_exprs_si) diff --git a/src/irstructure.jl b/src/irstructure.jl index bc0ed0515..7555d2166 100644 --- a/src/irstructure.jl +++ b/src/irstructure.jl @@ -440,28 +440,24 @@ function subset_ir( exprs::Union{AbstractArray{BasicSymbolic{T}}, AbstractSet{BasicSymbolic{T}}} ) where {T} new_ir = IRStructure{T}() - reachables = get_cached_mask!(ir, length(ir)) - expr_reach = get_cached_idxs!(ir) + visited = get_cached_mask!(ir, length(ir)) + reachables = get_cached_idxs!(ir) + empty!(reachables) for expr in exprs expr_i = get(ir.definition, expr, 0) iszero(expr_i) && _throw_expr_not_in_ir(expr) - reachables[expr_i] = true - empty!(expr_reach) - get_reachability!(expr_reach, ir, expr_i) - reachables[expr_reach] .= true + get_reachability!(reachables, ir, expr_i; visited) + visited[expr_i] = true + push!(reachables, expr_i) end - n_new_verts = count(reachables) + n_new_verts = length(reachables) Graphs.add_vertices!(new_ir.dependency_graph, n_new_verts) sizehint!(new_ir, n_new_verts) - # Instead of calling `populate_ir!`, we can directly build the new IR. - # Iterate in topological order (children before parents) so that when we - # translate edges to new indices, the dependency is already in `new_ir.definition`. - topo_order = Graphs.topological_sort_by_dfs(ir.dependency_graph) inew = 0 - for iold in Iterators.reverse(topo_order) - reachables[iold] || continue + # `reachables` is in topological order + for iold in reachables inew += 1 # Add expression to the IR sym = ir.symbols[iold]