Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/irstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ Return a new [`SymbolicUtils.IRStructure`](@ref) containing only the expressions
along with their dependencies.
"""
function subset_ir(ir::IRStructure{T}, expr) where {T}
exprs = Set{BasicSymbolic{T}}()
exprs = OrderedSet{BasicSymbolic{T}}()
buffer = IRStructureSearchBuffer(ir, exprs)
# `Returns(true)` gets all top-level expressions
search_variables!(buffer, expr; is_atomic = Returns(true))
Expand Down
6 changes: 3 additions & 3 deletions src/safe_ctors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ Helper struct for tracking index variable usage in array operations.

# Fields
- `idx_to_axes::IdxToAxesT{T}`: Maps index variables to the axes they index
- `search_buffer::Set{BasicSymbolic{T}}`: Reusable buffer for variable searches
- `search_buffer::OrderedSet{BasicSymbolic{T}}`: Reusable buffer for variable searches
- `buffers::Vector{Vector{IndexedAxis{T}}}`: Pool of reusable buffers

# Details
Expand All @@ -300,12 +300,12 @@ are indexed by which index variables and validates consistency.
"""
struct IndexedAxes{T}
idx_to_axes::IdxToAxesT{T}
search_buffer::Set{BasicSymbolic{T}}
search_buffer::OrderedSet{BasicSymbolic{T}}
buffers::Vector{Vector{IndexedAxis{T}}}
end

function IndexedAxes{T}() where {T}
IndexedAxes{T}(IdxToAxesT{T}(), Set{BasicSymbolic{T}}(), Vector{IndexedAxis{T}}[])
IndexedAxes{T}(IdxToAxesT{T}(), OrderedSet{BasicSymbolic{T}}(), Vector{IndexedAxis{T}}[])
end

function Base.empty!(ix::IndexedAxes)
Expand Down
16 changes: 8 additions & 8 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ end
return DefaultSubstituter{Fold}(d, filter, infer_vartype(d))
end
@inline function DefaultSubstituter{Fold}(d::Pair, filter::F) where {Fold, F}
DefaultSubstituter{Fold}(Dict(d), filter)
DefaultSubstituter{Fold}(OrderedDict(d), filter)
end
@inline function DefaultSubstituter{Fold}(d::AbstractArray{<:Pair}, filter::F) where {Fold, F}
DefaultSubstituter{Fold}(Dict(d), filter)
DefaultSubstituter{Fold}(OrderedDict(d), filter)
end

function (s::Substituter)(ex)
Expand Down Expand Up @@ -261,7 +261,7 @@ julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val(false))
function substitute(expr, dict; fold::Val{Fold}=Val{false}(), filterer=default_substitute_filter) where {Fold}
# This is kind of ugly (inlines some of the constructor logic of `DefaultSubstituter` but is needed to avoid runtime subtyping in
# when calling this function. It makes a very big difference in runtime.
d = dict isa AbstractDict ? dict : Dict(dict)
d = dict isa AbstractDict ? dict : OrderedDict(dict)
isempty(d) && !Fold && return expr
VT = infer_vartype(d)
if VT === Nothing
Expand Down Expand Up @@ -412,7 +412,7 @@ function search_variables!(buffer, expr::SparseMatrixCSC; kw...)
search_variables!(buffer, V; kw...)
end

_default_buffer(::BasicSymbolic{T}) where {T} = Set{BasicSymbolic{T}}()
_default_buffer(::BasicSymbolic{T}) where {T} = OrderedSet{BasicSymbolic{T}}()
_default_buffer(x::Any) = unwrap(x) === x ? Set() : _default_buffer(unwrap(x))

function search_variables(expr; kw...)
Expand All @@ -423,13 +423,13 @@ end

struct ArrayOpReduceCache{T}
new_ranges::RangesT{T}
subrules::Dict{BasicSymbolic{T}, Int}
collapsed_idxs::Set{BasicSymbolic{T}}
subrules::OrderedDict{BasicSymbolic{T}, Int}
collapsed_idxs::OrderedSet{BasicSymbolic{T}}
collapsed_ranges::Vector{StepRange{Int, Int}}
end

function ArrayOpReduceCache{T}() where {T}
ArrayOpReduceCache{T}(RangesT{T}(), Dict{BasicSymbolic{T}, Int}(), Set{BasicSymbolic{T}}(), StepRange{Int, Int}[])
ArrayOpReduceCache{T}(RangesT{T}(), OrderedDict{BasicSymbolic{T}, Int}(), OrderedSet{BasicSymbolic{T}}(), StepRange{Int, Int}[])
end

function Base.empty!(x::ArrayOpReduceCache)
Expand Down Expand Up @@ -651,7 +651,7 @@ scalarization_function(::Type{ArrayOp{T}}) where {T} = _scalarize_arrayop
function _scalarize_arrayop(_, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@match x begin
BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce, shape = sh) => begin
subrules = Dict()
subrules = OrderedDict()
new_expr = reduce_eliminated_idxs(expr, output_idx, ranges, reduce)
empty!(subrules)

Expand Down
8 changes: 4 additions & 4 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Core ADT for symbolic expressions.
end
struct AddMul
const coeff::Any
const dict::Dict{BasicSymbolicImpl.Type{T}, Number}
const dict::OrderedDict{BasicSymbolicImpl.Type{T}, Number}
const variant::AddMulVariant.T
const metadata::MetadataT
const shape::ShapeT
Expand Down Expand Up @@ -211,7 +211,7 @@ Core ADT for symbolic expressions.
const term::Union{BasicSymbolicImpl.Type{T}, Nothing}
# Optional map from symbolic indices in `output_idx` to the range they can
# take. Any index not present in this takes its full range of values.
const ranges::Dict{BasicSymbolicImpl.Type{T}, StepRange{Int, Int}}
const ranges::OrderedDict{BasicSymbolicImpl.Type{T}, StepRange{Int, Int}}
const metadata::MetadataT
const shape::ShapeT
const type::TypeT
Expand Down Expand Up @@ -262,15 +262,15 @@ The type of the dictionary stored in [`BSImpl.AddMul`](@ref). Passing this to th
[`SymbolicUtils.Add`](@ref) or [`SymbolicUtils.Mul`](@ref) constructors will avoid
allocating a new dictionary.
"""
const ACDict{T} = Dict{BasicSymbolic{T}, Number}
const ACDict{T} = OrderedDict{BasicSymbolic{T}, Number}
"""
The type of the `output_idxs` field in [`BSImpl.ArrayOp`](@ref).
"""
const OutIdxT{T} = SmallV{Union{Int, BasicSymbolic{T}}}
"""
The type of the `ranges` field in [`BSImpl.ArrayOp`](@ref).
"""
const RangesT{T} = Dict{BasicSymbolic{T}, StepRange{Int, Int}}
const RangesT{T} = OrderedDict{BasicSymbolic{T}, StepRange{Int, Int}}
"""
The type of the `sequence` field in [`BSImpl.ArrayMaker`](@ref).
"""
Expand Down
Loading