@@ -56,7 +56,7 @@ mutable struct CodegenState{T}
5656 In such a case, if we assume that the index of `x + y` in `ir` is `4`, then
5757 `cache[4] == Symbol("%3")`.
5858 """
59- const cache:: Dictionary{Int32, Symbol}
59+ const cache:: Dictionary{Int32, Union{ Symbol, Int, Float64, Float32} }
6060 """
6161 Rewrite rules, similar to `NameState`.
6262 """
@@ -78,7 +78,7 @@ written to `block` and `ir` is the underlying `IRStructure`. Rewrite rules can o
7878be supplied as the last argument.
7979"""
8080function CodegenState (expr:: Expr , block:: Expr , ir:: IRStructure{T} , rewrites = Dict ()) where {T}
81- CodegenState {T} (expr, block, ir, Dictionary {Int32, Symbol} (), rewrites, 0 )
81+ CodegenState {T} (expr, block, ir, Dictionary {Int32, Union{ Symbol, Int, Float64, Float32} } (), rewrites, 0 )
8282end
8383
8484"""
@@ -199,7 +199,7 @@ function enter_scope(cs::CodegenState{T}) where {T}
199199 new_scope = Expr (:block )
200200 bm = bookmark (cs)
201201 scoped_cs = CodegenState {T} (new_scope, new_scope, cs. ir, cs. cache, cs. rewrites, cs. misc_idx)
202-
202+
203203 return scoped_cs, bm
204204end
205205
312312function fast_toexpr (sym:: CodegenPrimitive , ir:: IRStructure{T} , rewrites:: Dict{Any, Any} ) where {T}
313313 expr = block = Expr (:block )
314314 state = CodegenState (expr, block, ir, rewrites)
315- lhs = state (sym):: Symbol
315+ lhs = state (sym)
316+ if ! (lhs isa Symbol)
317+ return lhs
318+ end
316319 for line in expr. args
317320 if Meta. isexpr (line, :(= )) && line. args[1 ] === lhs
318321 return line. args[2 ]
@@ -535,7 +538,7 @@ function codegen_function!(::Type{ArrayMaker{T}}, cs::CodegenState{T}, expr::Bas
535538 end
536539 return declare! (cs, get_misc_identifier (cs), result)
537540 end
538-
541+
539542 if _allocator != = zeros && ! __allocator_is_returns_expr (T, _allocator) &&
540543 isequal (regions[1 ], sh) && __is_fill_zero (cs. ir[first (values_exprs_idxs)])
541544 output_buffer = codegen_allocator_call! (
853856
854857function codegen_ir! (cs:: CodegenState{T} , idx:: Integer ) where {T}
855858 cached = get (cs. cache, idx, nothing )
856- if cached isa Symbol
859+ if cached != = nothing
857860 return cached
858861 end
859862 ir = cs. ir
@@ -875,6 +878,9 @@ function codegen_ir!(cs::CodegenState{T}, idx::Integer) where {T}
875878 @match sym begin
876879 BSImpl. Const (; val) => if val isa CodegenPrimitive
877880 return cs (val)
881+ elseif val isa Union{Int, Float64, Float32}
882+ insert! (cs. cache, idx, val)
883+ return val
878884 else
879885 return codegen! (cs, idx, val)
880886 end
@@ -933,10 +939,13 @@ function (cs::CodegenState)(@nospecialize(thing))
933939 if uthing != = thing
934940 return cs (uthing)
935941 end
942+ if thing isa Union{Int, Float64, Float32}
943+ return thing
944+ end
936945 return declare! (cs, get_misc_identifier (cs), thing)
937946end
938947
939- function (cs:: CodegenState )(expr:: BasicSymbolic{T} ):: Symbol where {T}
948+ function (cs:: CodegenState )(expr:: BasicSymbolic{T} ) where {T}
940949 idx = populate_ir! (cs. ir, expr)
941950 codegen_ir! (cs, idx)
942951end
@@ -1070,9 +1079,9 @@ function (cs::CodegenState{T})(fn::Func) where {T}
10701079end
10711080
10721081function (cs:: CodegenState )(ex:: SetArray )
1073- arr = cs (ex. arr):: Symbol
1082+ arr = cs (ex. arr):: Union{ Symbol, Int, Float64, Float32}
10741083 lhss = []
1075- rhss = Symbol[]
1084+ rhss = Union{ Symbol, Int, Float64, Float32} []
10761085 for (i, elem) in enumerate (ex. elems)
10771086 if elem isa AtIndex
10781087 push! (lhss, cs (elem. i))
0 commit comments