-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy pathsubstitute.jl
More file actions
834 lines (724 loc) · 30 KB
/
Copy pathsubstitute.jl
File metadata and controls
834 lines (724 loc) · 30 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
"""
Substituter{Fold}
An abstract supertype for functors that perform substitution operations on symbolic
expressions. This can also be used as a constructor for the functor used by
[`substitute`](@ref). `Fold` corresponds to the `fold` keyword of `substitute`. Passing
`fold = Val(true)` corresponds to `Substituter{true}` (and similarly for `Val(false)`). To
define substitution rules for custom types that wrap/contain [`BasicSymbolic`](@ref),
define methods for this abstract type. For example,
```julia
struct Equation
lhs::BasicSymbolic{SymReal}
rhs::BasicSymbolic{SymReal}
end
function (subst::Substituter)(eq::Equation)
return Equation(subst(eq.lhs), subst(eq.rhs))
end
```
Custom substitution algorithms should define functors that subtype `Substituter`. For
example, a functor may be defined for substituting until the expression reaches a fixpoint.
These functors should then implement:
- `(s::Substituter{Fold})(ex::BasicSymbolic{T}) where {T}` to perform the appropriate
substitution on the given symbolic expression.
- `get_substitution_dict(::Substituter)` returning an `AbstractDict` of the substitution rules.
Instead of repeatedly calling `substitute` with the same rules, it is usually more
efficient to build a `Substituter` and reuse it.
`Substituter` is also allowed to cache intermediate results as necessary. When
constructing `Substituter` with an `AbstractDict`, it will alias the provided
mapping. Mutating the map such that the identity of the substitution rules
changes invalidates the substituter. It can be reused by clearing the
cache using [`SymbolicUtils.clear_cache!`](@ref).
The caching is only available when the [`SymbolicUtils.vartype`](@ref) of the
expressions is inferable from the substitution rules, or explicitly specified.
As long as either the keys or values of the substitution rules are all
`BasicSymbolic{T}` (for some `T`) the automatic inference will work. To allow
the inference to work for your custom wrapper type, implement
[`SymbolicUtils.infer_vartype`](@ref). For example:
```julia
struct Num <: Real
inner::BasicSymbolic{SymReal}
end
SymbolicUtils.infer_vartype(::Type{Num}) = SymReal
```
Alternatively, the vartype can be provided as the third positional argument
to the `Substituter` constructor.
# Extended help
The following is internal details of `Substituter` and should not be relied on
as public API.
## Fields
$TYPEDFIELDS
"""
abstract type Substituter{Fold} end
struct DefaultSubstituter{Fold, D <: AbstractDict, F, C} <: Substituter{Fold}
"""
The `AbstractDict` of substitution rules.
"""
dict::D
"""
The filter function to eliminate trees that do not need substitution.
"""
filter::F
"""
Cache of intermediate results.
"""
cache::C
end
"""
$TYPEDSIGNATURES
Clear the cached values associated with `subst`. See the documentation of
[`SymbolicUtils.Substituter`](@ref) for more details.
"""
function clear_cache!(subst::DefaultSubstituter)
empty!(subst.cache)
end
"""
$TYPEDSIGNATURES
Get an `AbstractDict` of the substitution rules for the given
[`SymbolicUtils.Substituter`](@ref).
"""
get_substitution_dict(s::DefaultSubstituter) = s.dict
infer_vartype(x) = infer_vartype(typeof(x))
infer_vartype(::Type{T}) where {T} = Nothing
function infer_vartype(::Type{D}) where {K, V, D <: AbstractDict{K, V}}
T = infer_vartype(K)
if T !== Nothing
return T
end
T = infer_vartype(V)
if T !== Nothing
return T
end
return Nothing
end
infer_vartype(::Type{BasicSymbolic{T}}) where {T} = T
# This exists more for backward compat than anything
Substituter{Fold}(args...) where {Fold} = DefaultSubstituter{Fold}(args...)
"""
$METHODLIST
"""
@inline DefaultSubstituter{Fold}(d) where {Fold} = DefaultSubstituter{Fold}(d, default_substitute_filter)
@inline function DefaultSubstituter{Fold}(d::AbstractDict, filter::F, ::Type{Nothing}) where {Fold, F}
return DefaultSubstituter{Fold, typeof(d), F, Nothing}(d, filter, nothing)
end
@inline function DefaultSubstituter{Fold}(d::AbstractDict, filter::F, ::Type{T}) where {Fold, F, T}
# Since `substitute` retains metadata, this needs to be an `IdDict`. Otherwise, keys
# with different metadata get cached to the same result.
return DefaultSubstituter{Fold, typeof(d), F, IdDict{BasicSymbolic{T}, BasicSymbolic{T}}}(
d, filter, IdDict{BasicSymbolic{T}, BasicSymbolic{T}}()
)
end
@inline function DefaultSubstituter{Fold}(d::AbstractDict, filter) where {Fold}
return DefaultSubstituter{Fold}(d, filter, infer_vartype(d))
end
@inline function DefaultSubstituter{Fold}(d::Pair, filter::F) where {Fold, F}
DefaultSubstituter{Fold}(OrderedDict(d), filter)
end
@inline function DefaultSubstituter{Fold}(d::AbstractArray{<:Pair}, filter::F) where {Fold, F}
DefaultSubstituter{Fold}(OrderedDict(d), filter)
end
function (s::Substituter)(ex)
return get(get_substitution_dict(s), ex, ex)
end
function (s::Substituter)(ex::AbstractArray)
[s(x) for x in ex]
end
function (s::Substituter)(ex::SparseMatrixCSC)
I, J, V = findnz(ex)
V = s(V)
m, n = size(ex)
return sparse(I, J, V, m, n)
end
function (s::DefaultSubstituter{Fold})(ex::BasicSymbolic{T}) where {T, Fold}
result = unwrap(get(s.dict, ex, nothing))
result === nothing || return Const{T}(result)
iscall(ex) || return ex
s.filter(ex) || return ex
if s.cache !== nothing
result = get(s.cache, ex, nothing)
result === nothing || return result
end
op = operation(ex)
# We need to `unwrap_const` because `op` could be a symbolic function with
# a substitution in `s.dict`, in which case this method will be called recursively
# and return a symbolic result.
_op = unwrap_const(s(op))
args = arguments(ex)
dirty = false
can_fold = !(_op isa BasicSymbolic{T})
newargs = parent(args)
for i in eachindex(args)
arg = args[i]
# This is `unsafe` because comparing `Const` equality is cheap, and we don't want to
# pay the cost of hashconsing if it is identical.
newarg = Const{T}(s(arg); unsafe = true)
if arg === newarg || @manually_scope COMPARE_FULL => true isequal(arg, newarg)::Bool true
can_fold &= isconst(arg)
continue
end
newarg = hashcons(newarg)
if !dirty
newargs = copy(parent(args))::ArgsT{T}
end
dirty = true
can_fold &= isconst(newarg)
newargs[i] = newarg
end
dirty |= op !== _op
result = if dirty || can_fold
if Fold
combine_fold(T, _op, newargs, metadata(ex), can_fold)::BasicSymbolic{T}
else
maketerm(BasicSymbolic{T}, _op, newargs, metadata(ex))::BasicSymbolic{T}
end
else
ex
end
if s.cache !== nothing
s.cache[ex] = result
end
return result
end
function combine_fold(::Type{T}, op, args::Union{ROArgsT{T}, ArgsT{T}}, meta::MetadataT, can_fold::Bool) where {T}
@nospecialize op args meta
if can_fold
if length(args) == 1
Const{T}(op(unwrap_const(args[1])))
elseif length(args) == 2
Const{T}(op(unwrap_const(args[1]), unwrap_const(args[2])))
elseif length(args) == 3
Const{T}(op(unwrap_const(args[1]), unwrap_const(args[2]), unwrap_const(args[3])))
else
Const{T}(op(unwrap_const.(args)...))
end
else
maketerm(BasicSymbolic{T}, op, args, meta)::BasicSymbolic{T}
end
end
"""
$TYPEDSIGNATURES
The default filter function used by [`substitute`](@ref) to determine whether to substitute
within an expression. Returns `false` for expressions that are `Term`s with an `Operator` as
the operation (preventing substitution within operator calls), and `true` otherwise.
# Arguments
- `ex::BasicSymbolic{T}`: The expression to check.
# Returns
- `Bool`: `false` if the expression should not be substituted into, `true` otherwise.
"""
@inline function default_substitute_filter(ex::BasicSymbolic{T}) where {T}
@match ex begin
BSImpl.Term(; f) && if f isa Operator end => false
BSImpl.Term(; f) && if f isa BasicSymbolic{T} end => is_function_symbolic(f)
_ => true
end
end
"""
substitute(expr, dict; fold=Val(false))
substitute any subexpression that matches a key in `dict` with
the corresponding value. If `fold=Val(false)`,
expressions which can be evaluated won't be evaluated.
```julia-repl
julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val(true))
2.414213562373095
julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val(false))
1 + sqrt(2)
```
"""
function substitute(expr, dict; fold::Val{Fold}=Val{false}(), filterer=default_substitute_filter) where {Fold}
# This is kind of ugly (inlines some of the constructor logic of `DefaultSubstituter` but is needed to avoid runtime subtyping in
# when calling this function. It makes a very big difference in runtime.
d = dict isa AbstractDict ? dict : OrderedDict(dict)
isempty(d) && !Fold && return expr
VT = infer_vartype(d)
if VT === Nothing
sub = DefaultSubstituter{Fold, typeof(d), typeof(filterer), Nothing}(d, filterer, nothing)
else
cache = IdDict{BasicSymbolic{VT}, BasicSymbolic{VT}}()
sub = DefaultSubstituter{Fold, typeof(d), typeof(filterer), typeof(cache)}(d, filterer, cache)
end
return sub(expr)
end
const EMPTY_DICT = Dict{Int, Int}()
function evaluate(expr; filterer = default_substitute_filter)
return substitute(expr, EMPTY_DICT; fold = Val{true}(), filterer)
end
"""
$TYPEDSIGNATURES
Recursively search an expression tree to determine if any subexpression satisfies a given
predicate. This function traverses the expression tree and returns `true` if the predicate
returns `true` for any node in the tree.
# Arguments
- `predicate::F`: A function that takes an expression and returns a `Bool`.
- `expr::BasicSymbolic`: The expression to search.
# Keyword Arguments
- `recurse::G=iscall`: A function determining whether to recurse into a subexpression.
- `default::Bool=false`: The default value to return if the expression is not a call or
recursion is prevented.
# Returns
- `Bool`: `true` if any subexpression satisfies the predicate, `false` otherwise.
"""
function query(predicate::F, expr::BasicSymbolic; recurse::G = iscall, default::Bool = false) where {F, G}
predicate(expr) && return true
iscall(expr) || return default
recurse(expr) || return default
return @match expr begin
BSImpl.Term(; f, args) => any(args) do arg
query(predicate, arg; recurse, default)
end
BSImpl.AddMul(; dict) => any(keys(dict)) do arg
query(predicate, arg; recurse, default)
end
BSImpl.Div(; num, den) => query(predicate, num; recurse, default) || query(predicate, den; recurse, default)
BSImpl.ArrayOp(; expr = inner_expr, term) => begin
query(predicate, @something(term, inner_expr); recurse, default)
end
BSImpl.ArrayMaker(; values) => any(values) do arg
query(predicate, arg; recurse, default)
end
end
end
query(predicate::F, expr; kw...) where {F} = predicate(expr)
search_variables!(buffer, expr; kw...) = nothing
"""
$(TYPEDSIGNATURES)
The default `is_atomic` predicate for [`search_variables!`](@ref). `ex` is considered
atomic if one of the following conditions is true:
- It is a `Sym` and not an internal index variable for an arrayop
- It is a `Term`, the operation is a `BasicSymbolic` and the operation represents a
dependent variable according to [`is_function_symbolic`](@ref).
- It is a `Term`, the operation is `getindex` and the variable being indexed is atomic.
"""
function default_is_atomic(ex::BasicSymbolic{T}) where {T}
@match ex begin
BSImpl.Sym(; name) => name !== IDXS_SYM
BSImpl.Term(; f) && if f isa Operator end => true
BSImpl.Term(; f) && if f isa BasicSymbolic{T} end => !is_function_symbolic(f)
BSImpl.Term(; f, args) && if f === getindex end => default_is_atomic(args[1])
_ => false
end
end
"""
$(TYPEDSIGNATURES)
Find all variables used in `expr` and add them to `buffer`. A variable is identified by the
predicate `is_atomic`. The predicate `recurse` determines whether to search further inside
`expr` if it is not a variable. Note that `recurse` must at least return `false` if
`iscall` returns `false`.
Wrappers for [`BasicSymbolic`](@ref) should implement this function by unwrapping.
See also: [`default_is_atomic`](@ref).
"""
function search_variables!(buffer, expr::BasicSymbolic{T}; is_atomic::F = default_is_atomic, recurse::G = iscall, _seen::Union{Base.IdSet{BasicSymbolic{T}}, Nothing} = nothing) where {T, F, G}
if is_atomic(expr)
push!(buffer, expr)
return nothing
end
recurse(expr) || return nothing
seen = _seen === nothing ? Base.IdSet{BasicSymbolic{T}}() : _seen
_search_variables_impl!(buffer, expr, is_atomic, recurse, seen)
return nothing
end
function _search_variables_impl!(buffer, expr::BasicSymbolic{T}, is_atomic::F, recurse::G, seen::Base.IdSet{BasicSymbolic{T}}) where {T, F, G}
expr in seen && return nothing
push!(seen, expr)
if is_atomic(expr)
push!(buffer, expr)
return nothing
end
recurse(expr) || return nothing
@match expr begin
BSImpl.Term(; f, args) => begin
if f isa BasicSymbolic{T}
_search_variables_impl!(buffer, f, is_atomic, recurse, seen)
end
for arg in args
_search_variables_impl!(buffer, arg, is_atomic, recurse, seen)
end
end
BSImpl.AddMul(; dict) => begin
for k in keys(dict)
_search_variables_impl!(buffer, k, is_atomic, recurse, seen)
end
end
BSImpl.Div(; num, den) => begin
_search_variables_impl!(buffer, num, is_atomic, recurse, seen)
_search_variables_impl!(buffer, den, is_atomic, recurse, seen)
end
BSImpl.ArrayOp(; expr = inner_expr, term) => begin
_search_variables_impl!(buffer, @something(term, inner_expr), is_atomic, recurse, seen)
end
BSImpl.ArrayMaker(; values) => @union_split_smallvec values for val in values
_search_variables_impl!(buffer, val, is_atomic, recurse, seen)
end
end
return nothing
end
function search_variables!(buffer, expr::Union{AbstractArray, AbstractSet}; kw...)
for el in expr
search_variables!(buffer, unwrap(el); kw...)
end
end
function search_variables!(buffer, expr::SparseMatrixCSC; kw...)
_, _, V = findnz(expr)
search_variables!(buffer, V; kw...)
end
_default_buffer(::BasicSymbolic{T}) where {T} = OrderedSet{BasicSymbolic{T}}()
_default_buffer(x::Any) = unwrap(x) === x ? Set() : _default_buffer(unwrap(x))
function search_variables(expr; kw...)
buffer = _default_buffer(expr)
search_variables!(buffer, expr; kw...)
return buffer
end
struct ArrayOpReduceCache{T}
new_ranges::RangesT{T}
subrules::OrderedDict{BasicSymbolic{T}, Int}
collapsed_idxs::OrderedSet{BasicSymbolic{T}}
collapsed_ranges::Vector{StepRange{Int, Int}}
end
function ArrayOpReduceCache{T}() where {T}
ArrayOpReduceCache{T}(RangesT{T}(), OrderedDict{BasicSymbolic{T}, Int}(), OrderedSet{BasicSymbolic{T}}(), StepRange{Int, Int}[])
end
function Base.empty!(x::ArrayOpReduceCache)
empty!(x.new_ranges)
empty!(x.subrules)
empty!(x.collapsed_idxs)
empty!(x.collapsed_ranges)
return x
end
const ARRAYOP_REDUCE_SYMREAL = TaskLocalValue{ArrayOpReduceCache{SymReal}}(ArrayOpReduceCache{SymReal})
const ARRAYOP_REDUCE_SAFEREAL = TaskLocalValue{ArrayOpReduceCache{SafeReal}}(ArrayOpReduceCache{SafeReal})
arrayop_reduce_cache(::Type{SymReal}) = empty!(ARRAYOP_REDUCE_SYMREAL[])
arrayop_reduce_cache(::Type{SafeReal}) = empty!(ARRAYOP_REDUCE_SAFEREAL[])
function _reduce_eliminated_idxs(expr::BasicSymbolic{T}, output_idx::OutIdxT{T}, ranges::RangesT{T}, @nospecialize(reduce)) where {T}
cache = arrayop_reduce_cache(T)
new_ranges = cache.new_ranges
subrules = cache.subrules
new_expr = Code.unidealize_indices(expr, ranges, new_ranges)
merge!(new_ranges, ranges)
collapsed = cache.collapsed_idxs
union!(collapsed, keys(new_ranges))
setdiff!(collapsed, output_idx)
collapsed_ranges = cache.collapsed_ranges
for i in collapsed
push!(collapsed_ranges, new_ranges[i])
end
return mapreduce(reduce, Iterators.product(collapsed_ranges...)) do iidxs
for (idx, ii) in zip(iidxs, collapsed)
subrules[ii] = idx
end
return substitute(new_expr, subrules)::BasicSymbolic{T}
end::BasicSymbolic{T}
end
struct IdHashWrapper{T}
val::T
h::UInt
end
IdHashWrapper(x) = IdHashWrapper(x, objectid(x))
Base.hash(x::IdHashWrapper, h::UInt) = hash(x.h, h)
Base.isequal(a::IdHashWrapper, b::IdHashWrapper) = a.val === b.val
@cache function reduce_eliminated_idxs_1(expr::BasicSymbolic{SymReal}, output_idx::OutIdxT{SymReal}, ranges_wrapped::IdHashWrapper{RangesT{SymReal}}, reduce)::BasicSymbolic{SymReal}
@nospecialize reduce
_reduce_eliminated_idxs(expr, output_idx, ranges_wrapped.val, reduce)::BasicSymbolic{SymReal}
end
@cache function reduce_eliminated_idxs_2(expr::BasicSymbolic{SafeReal}, output_idx::OutIdxT{SafeReal}, ranges_wrapped::IdHashWrapper{RangesT{SafeReal}}, reduce)::BasicSymbolic{SafeReal}
@nospecialize reduce
_reduce_eliminated_idxs(expr, output_idx, ranges_wrapped.val, reduce)::BasicSymbolic{SafeReal}
end
"""
$TYPEDSIGNATURES
Reduce (collapse) all indices in an `ArrayOp` expression that are not present in the output
index list. This function performs a reduction operation (like summation) over the eliminated
indices by substituting concrete values for each eliminated index and combining the results
using the provided reduction function.
# Arguments
- `expr::BasicSymbolic{T}`: The expression containing indices to be reduced.
- `output_idx::OutIdxT{T}`: The indices that should remain in the output (not reduced).
- `ranges::RangesT{T}`: A dictionary mapping indices to their ranges.
- `reduce`: The reduction function to apply (e.g., `+` for summation).
# Returns
- `BasicSymbolic{T}`: The expression with eliminated indices reduced according to the
reduction function.
"""
function reduce_eliminated_idxs(expr::BasicSymbolic{T}, output_idx::OutIdxT{T}, ranges::RangesT{T}, @nospecialize(reduce)) where {T}
wrapped = IdHashWrapper(ranges)
if T === SymReal
return reduce_eliminated_idxs_1(expr, output_idx, wrapped, reduce)::BasicSymbolic{T}
elseif T === SafeReal
return reduce_eliminated_idxs_2(expr, output_idx, wrapped, reduce)::BasicSymbolic{T}
end
_unreachable()
end
"""
$TYPEDSIGNATURES
Given a function `f`, return a function that will scalarize an expression with `f` as the
head. The returned function is passed `f`, the expression with `f` as the head, and
`Val(true)` or `Val(false)` indicating whether to recursively scalarize or not.
This function provides a dispatch mechanism for customizing scalarization behavior based on
the operation type. Different operations may require different scalarization strategies
(e.g., array operations, determinants, indexing operations).
# Arguments
- `f`: The function/operation to get a scalarization function for.
# Returns
- A function that takes `(f, x::BasicSymbolic{T}, ::Val{toplevel})` and returns the
scalarized form of `x`.
"""
scalarization_function(@nospecialize(_)) = _default_scalarize
scalarization_function(::typeof(+)) = _scalarize_add
function _scalarize_add(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
reduce(+, map(unwrap_const ∘ Base.Fix2(scalarize, Val{toplevel}()), args))
end
scalarization_function(::typeof(*)) = _scalarize_mul
function _scalarize_mul(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
scal_args = map(Base.Fix2(scalarize, Val{toplevel}()), args)
is_array_shape(shape(x)) || return mul_worker(T, scal_args)
is_array_shape(shape(args[1])) && return mapreduce(unwrap_const, *, scal_args)
coeff = scal_args[1]
res = mapreduce(unwrap_const, *, @view(scal_args[2:end]))
return map(Base.Fix2(*, coeff), res)
end
scalarization_function(::typeof(broadcast)) = _scalarize_broadcast
function _scalarize_broadcast(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
scal_args = Vector{Any}(undef, length(args))
map!(unwrap_const ∘ Base.Fix2(scalarize, Val{toplevel}()), scal_args, args)
for i in eachindex(scal_args)
val = scal_args[i]
if val isa BasicSymbolic{T}
@assert !is_array_shape(shape(val))
scal_args[i] = Ref(val)
end
end
return broadcast(scal_args...)
end
scalarization_function(::Union{typeof(adjoint), typeof(transpose)}) = _scalarize_adjoint_transpose
function _scalarize_adjoint_transpose(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
val = scalarize(args[1], Val{toplevel}())::Array{BasicSymbolic{T}}
if f === adjoint
return adjoint(val)
elseif f === transpose
return transpose(val)
end
_unreachable()
end
scalarization_function(::typeof(/)) = _scalarize_rdiv
function _scalarize_rdiv(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
num = scalarize(args[1], Val{toplevel}())
den = scalarize(args[2], Val{toplevel}())
shnum = shape(num)
shden = shape(den)
if !is_array_shape(shnum) && !is_array_shape(shden)
return num / den
elseif !is_array_shape(shden)
return map(Base.Fix2(/, den), num)
else
res = num / den
return BasicSymbolic{T}[res[i] for i in eachindex(res)]
end
end
scalarization_function(::typeof(\)) = _scalarize_ldiv
function _scalarize_ldiv(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
fst = scalarize(args[1], Val{toplevel}())
lst = scalarize(args[2], Val{toplevel}())
shfst = shape(fst)
shlst = shape(lst)
if !is_array_shape(shfst) && !is_array_shape(shlst)
return fst \ lst
elseif !is_array_shape(shfst)
return map(Base.Fix2(/, fst), lst)
else
return BasicSymbolic{T}[x[i] for i in eachindex(x)]
end
end
scalarization_function(::Union{typeof(-), typeof(^)}) = _default_scalarize_array
function _default_scalarize_array(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
args = arguments(x)
if toplevel
f(map(unwrap_const, args)...)
else
f(map(unwrap_const ∘ Base.Fix2(scalarize, Val{toplevel}()), args)...)
end
end
function _default_scalarize(f, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@nospecialize f
sh = shape(x)
is_array_shape(sh) && return [x[idx] for idx in eachindex(x)]
args = arguments(x)
if toplevel
f(map(unwrap_const, args)...)
else
f(map(unwrap_const ∘ scalarize, args)...)
end
end
scalarization_function(::Type{ArrayOp{T}}) where {T} = _scalarize_arrayop
function _scalarize_arrayop(_, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@match x begin
BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce, shape = sh) => begin
subrules = OrderedDict()
new_expr = reduce_eliminated_idxs(expr, output_idx, ranges, reduce)
empty!(subrules)
scope_filter = function (y)
@match y begin
BSImpl.ArrayOp(; output_idx=_) => false
_ => true
end
end
res = map(Iterators.product(sh...)) do idxs
for (i, ii) in enumerate(output_idx)
ii isa Int && continue
subrules[ii] = idxs[i]
end
if toplevel
substitute(new_expr, subrules; filterer = scope_filter, fold = Val{true}())
else
scalarize(substitute(new_expr, subrules; filterer = scope_filter, fold = Val{true}()))
end
end
return isempty(sh) ? res[] : res
end
end
end
scalarization_function(::Type{ArrayMaker{T}}) where {T} = _scalarize_arraymaker
function _scalarize_arraymaker(_, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
sh = shape(x)::ShapeVecT
buffer = Array{BasicSymbolic{T}}(undef, size(x))
regbuf = ShapeVecT()
sizehint!(regbuf, length(sh))
@match x begin
BSImpl.ArrayMaker(; regions, values) => begin
@union_split_smallvec regions @union_split_smallvec values @union_split_smallvec sh begin
for (reg, val) in zip(regions, values)
empty!(regbuf)
@union_split_smallvec reg for (ax, canonical_ax) in zip(reg, sh)
offset = first(canonical_ax)
push!(regbuf, (first(ax) - offset + 1):(last(ax) - offset + 1))
end
setindex!(buffer, scalarize(val, Val{toplevel}())::AbstractArray{BasicSymbolic{T}}, regbuf...)
end
end
end
end
return buffer
end
scalarization_function(::Mapper) = _scalarize_arrayop
scalarization_function(::Mapreducer) = _scalarize_arrayop
function _scalarize_norm(_, x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
@match x begin
BSImpl.Term(; args) => begin
if length(args) == 1
return sqrt(scalarize(sum(abs2, args[1]), Val{toplevel}())::BasicSymbolic{T})
end
@match args[2] begin
BSImpl.Const(;val) => begin
if val === Inf
return scalarize(mapreduce(abs, max, args[1]), Val{toplevel}())::BasicSymbolic{T}
elseif val === -Inf
return scalarize(mapreduce(abs, min, args[1]), Val{toplevel}())::BasicSymbolic{T}
elseif safe_isinteger(val::Number) && iseven(Int(val::Number))
exp = Int(val::Number)
return (sum(scalarize(args[1], Val{toplevel}()) .^ args[2])) ^ (1 / args[2])
else
return (sum(abs.(scalarize(args[1], Val{toplevel}())) .^ args[2])) ^ (1 / args[2])
end
end
_ => return (sum(scalarize(args[1], Val{toplevel}()) .^ args[2])) ^ (1 / args[2])
end
end
end
end
scalarization_function(::typeof(LinearAlgebra.norm)) = _scalarize_norm
"""
$TYPEDSIGNATURES
Convert a symbolic expression with array operations into a fully scalarized form. This function
expands array operations into element-wise operations, converting symbolic array expressions
into arrays of scalar symbolic expressions.
For `ArrayOp` expressions, this function reduces eliminated indices and substitutes concrete
values for output indices to generate scalar expressions for each array element.
# Arguments
- `x::BasicSymbolic{T}`: The symbolic expression to scalarize.
- `::Val{toplevel}=Val{false}()`: Whether to evaluate constant expressions at the top level.
When `true`, constant subexpressions are evaluated; when `false`, they are recursively
scalarized.
# Returns
- The scalarized expression. For array-shaped expressions, returns an array of scalar
expressions. For scalar expressions, returns the expression unchanged or with recursively
scalarized subexpressions.
"""
function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T, toplevel}
sh = shape(x)
sh isa Unknown && return x
@match x begin
BSImpl.Const(; val) => begin
is_array_shape(sh) || return x
if val isa SparseMatrixCSC
I, J, V = findnz(val)
V = Const{T}.(V)
return sparse(I, J, V, size(val)::NTuple{2, Int}...)
else
return Const{T}.(val)
end
end
BSImpl.Sym(;) => is_array_shape(sh) ? [x[idx] for idx in eachindex(x)] : x
_ => begin
f = operation(x)
if f isa BasicSymbolic{T} || f isa Operator
return length(sh) == 0 ? x : [x[idx] for idx in eachindex(x)]
end
return scalarization_function(f)(f, x, Val{toplevel}())
end
end
end
function scalarize(arr::AbstractArray, ::Val{toplevel} = Val{false}()) where {toplevel}
map(Base.Fix2(scalarize, Val{toplevel}()), arr)
end
scalarize(x, _...) = x
scalarization_function(::typeof(inv)) = _inv_scal
function _inv_scal(::typeof(inv), x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
sh = shape(x)
(sh isa ShapeVecT && !isempty(sh)) ? [x[idx] for idx in eachindex(x)] : x
end
scalarization_function(::typeof(LinearAlgebra.det)) = _det_scal
function _det_scal(::typeof(LinearAlgebra.det), x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
arg = arguments(x)[1]
sh = shape(arg)
sh isa Unknown && return collect(x)
sh = sh::ShapeVecT
isempty(sh) && return x
sarg = toplevel ? collect(arg) : scalarize(arg)
_det_scal(LinearAlgebra.det, T, sarg)
end
function _det_scal(::typeof(LinearAlgebra.det), ::Type{T}, x::AbstractMatrix) where {T}
length(x) == 1 && return x[]
add_buffer = BasicSymbolic{T}[]
for i in 1:size(x, 1)
ex = _det_scal(LinearAlgebra.det, T, view(x, setdiff(axes(x, 1), i), 2:size(x, 2)))
push!(add_buffer, (isodd(i) ? 1 : -1) * x[i, 1] * ex)
end
return add_worker(T, add_buffer)
end
scalarization_function(::typeof(getindex)) = _getindex_scal
function _getindex_scal(::typeof(getindex), x::BasicSymbolic{T}, ::Val{toplevel}) where {T, toplevel}
sh = shape(x)
if length(sh) > 0
return [x[idx] for idx in eachindex(x)]
end
args = MData.variant_getfield(x, BSImpl.Term, :args)
idx = try
StableIndex{Int}(x)
catch
nothing
end
if idx !== nothing
return getindex(scalarize(args[1]), idx)
end
idxs = Iterators.map((-), Iterators.map(unwrap_const, Iterators.drop(args, 1)), Iterators.map(Base.Fix2((-), 1) ∘ first, shape(args[1])))
return getindex(scalarize(args[1]), idxs...)
end