Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicUtils"
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
authors = ["Shashi Gowda"]
version = "4.35.0"
version = "4.36.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
14 changes: 14 additions & 0 deletions docs/src/manual/codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,18 @@ called within the expression are pure. SymbolicUtils can and will change the num
```@docs
SymbolicUtils.Code.cse
SymbolicUtils.Code.cse_inside_expr
SymbolicUtils.Code.cse_bind_expr
```

### Conditionals

`ifelse` lowers to an `if`/`else` whose branches are still subject to CSE. Two variants pin
the evaluation strategy: `ifelse_eager` always evaluates both branches, while
`ifelse_branching` guarantees the untaken branch is never evaluated — its branch interiors
are excluded from CSE (the conditional itself is still bound, see
[`SymbolicUtils.Code.cse_bind_expr`](@ref)).

```@docs
SymbolicUtils.ifelse_eager
SymbolicUtils.ifelse_branching
```
26 changes: 25 additions & 1 deletion src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,22 @@ end
# conditional; the codegen for `ifelse_branching` emits them inside their `if`/`else` arms.
cse_inside_expr(sym, ::typeof(ifelse_branching)) = false

"""
$(TYPEDSIGNATURES)

Only consulted when [`cse_inside_expr`](@ref) returns `false` for `sym`, which has operation
`f`. Return `true` if CSE should still bind `sym` itself to a temporary variable — so that
multiple references to it share a single computation — while leaving its arguments un-CSEd.
The default `false` leaves the expression fully inline at each reference site.
"""
function cse_bind_expr(sym, f)
return false
end

# Bind the conditional itself so multiple references share one `if`/`else` instead of
# duplicating it per use site. The branches remain un-CSEd (and therefore lazy).
cse_bind_expr(sym, ::typeof(ifelse_branching)) = true

"""
$(TYPEDEF)

Expand Down Expand Up @@ -1737,7 +1753,15 @@ function _cse_compute(expr::BasicSymbolic{T}, state::CSEState) where {T}
if op isa BasicSymbolic{T}
SymbolicUtils.is_function_symbolic(op) || return expr
end
cse_inside_expr(expr, op)::Bool || return expr
if !(cse_inside_expr(expr, op)::Bool)
cse_bind_expr(expr, op)::Bool || return expr
# Bind the node as-is (arguments un-CSEd) so references share one
# computation. `cse!` caches the binding by node id, so every reference
# in this scope resolves to the same temporary.
sym = newsym!(state, T, symtype(expr), shape(expr))
push!(state.sorted_exprs, sym ← expr)
return sym
end
args = copy(parent(args))
for i in eachindex(args)
args[i] = cse!(args[i], state)::BasicSymbolic{T}
Expand Down
9 changes: 4 additions & 5 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,10 @@ produces `NaN`/`Inf`). Contrast with [`ifelse_eager`](@ref), which evaluates bot
`ifelse`, `ifelse_eager` and `ifelse_branching` differ only in how code generation lowers them;
`ifelse` is the default and is intended to eventually pick a strategy via a cost heuristic.

!!! note
Opting the conditional out of CSE also means the `ifelse_branching` node itself is not
bound to a shared temporary, so when its result is referenced at multiple sites the
generated `if`/`else` is duplicated at each use (the branches stay lazy and the computed
values are unchanged).
While the branch interiors are excluded from common subexpression elimination (hoisting them
would defeat the laziness), the conditional itself is still bound by CSE (see
[`cse_bind_expr`](@ref)), so multiple references to one `ifelse_branching` expression share a
single `if`/`else`.
"""
ifelse_branching(cond, x, y) = cond ? x : y

Expand Down
38 changes: 38 additions & 0 deletions test/conditionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,41 @@ end
f = compile(ifelse_eager(x > 0, x^2, boomterm(x)), x)
@test_throws ErrorException f(2.0)
end

# Counts how many times it is evaluated; used to detect duplicated computation.
const FIRE_COUNT = Ref(0)
fire(v) = (FIRE_COUNT[] += 1; v)
fireterm(v) = term(fire, v; type = Real, shape = SymbolicUtils.shape(v))

@testset "multiply-referenced ifelse_branching is computed once under CSE" begin
@syms x::Real
z = ifelse_branching(x > 0, fireterm(x)^2, boomterm(x))
# CSE binds the conditional itself to a temporary...
csex = cse(z + 2z)
@test csex isa Let
@test count(p -> p isa Assignment && iscall(p.rhs) &&
operation(p.rhs) === ifelse_branching, csex.pairs) == 1
# ...so referencing it twice evaluates the taken branch once and the untaken
# branch (`boom`) not at all.
f = compile(z + 2z, x)
FIRE_COUNT[] = 0
@test f(2.0) == 12.0
@test FIRE_COUNT[] == 1

# One-sided nesting stays linear: each level is emitted once, matching `ifelse`.
# (A conditional referenced inside *both* branches of an enclosing one is still
# emitted per branch — hoisting it out would evaluate it eagerly.)
nested_b = ifelse_branching(x > 0, x, -x)
nested_i = ifelse(x > 0, x, -x)
for i in 1:5
nested_b = ifelse_branching(x > i, nested_b + 1, -x)
nested_i = ifelse(x > i, nested_i + 1, -x)
end
code_b = string(toexpr(cse(Func([x], [], nested_b))))
code_i = string(toexpr(cse(Func([x], [], nested_i))))
nif(s) = length(collect(eachmatch(r"\bif\b", s)))
@test nif(code_b) == nif(code_i) == 6
fb = compile(nested_b, x)
fi = compile(nested_i, x)
@test fb(2.5) == fi(2.5) == -2.5
end
Loading