Skip to content

Commit 8d2462b

Browse files
committed
fold conj, real, imag on the symbolic imaginary unit
Add three `@match` arms in `Base.real`, `Base.conj`, `Base.imag` for `BasicSymbolic` that match `Sym(:im; type = Number)` structurally and fold to `0`, `-im`, and `1` respectively. `Symbolics.IM` is defined as a `Sym{VartypeT}(:im; type = Number)` and used as a stand-in for `1im` to keep expressions inside `BasicSymbolic{<:Real}` algebra (where multiplying by a Julia `Complex` literal would otherwise materialise an opaque `complex(re, im)` `Term`). Until now those three operations on `IM` produced opaque `conj(im)` / `real(im)` / `imag(im)` wrappers that `simplify` could not reduce, so downstream code that algebraically conjugates `IM`-bearing expressions (e.g. SQA's `qadjoint` / `inner_adjoint`) had to special-case `IM` themselves. Matching is structural (name + symtype) rather than by identity, so no new dependency on Symbolics is introduced. A same-named sym with a non-`Number` symtype is left alone (test case).
1 parent 7ffce39 commit 8d2462b

2 files changed

Lines changed: 14 additions & 0 deletions

File tree

src/methods.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ function Base.real(s::BasicSymbolic{T}) where {T}
479479
@match s begin
480480
BSImpl.Const(; val) => Const{T}(real(val))
481481
BSImpl.Term(; f, args) && if f === complex && length(args) == 2 end => args[1]
482+
# Match `Symbolics.IM = Sym(:im; type = Number)` structurally.
483+
BSImpl.Sym(; name) && if name === :im && symtype(s) === Number end => zero_of_vartype(T)
482484
_ => Term{T}(real, ArgsT{T}((s,)); type = Real)
483485
end
484486
end
@@ -490,6 +492,7 @@ function Base.conj(s::BasicSymbolic{T}) where {T}
490492
BSImpl.Term(; f, args, type, shape) && if f === complex && length(args) == 2 end => begin
491493
BSImpl.Term{T}(f, ArgsT{T}(args[1], -args[2]); type, shape)
492494
end
495+
BSImpl.Sym(; name) && if name === :im && symtype(s) === Number end => -s
493496
_ => Term{T}(conj, ArgsT{T}((s,)); type = symtype(s), shape = shape(s))
494497
end
495498
end
@@ -498,6 +501,7 @@ function Base.imag(s::BasicSymbolic{T}) where {T}
498501
@match s begin
499502
BSImpl.Const(; val) => Const{T}(imag(val))
500503
BSImpl.Term(; f, args) && if f === complex && length(args) == 2 end => args[2]
504+
BSImpl.Sym(; name) && if name === :im && symtype(s) === Number end => one_of_vartype(T)
501505
_ => Term{T}(imag, ArgsT{T}((s,)); type = Real)
502506
end
503507
end

test/rulesets.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ end
5959
@test unwrap_const(simplify(Term{SymReal}(zero, [a]))) == 0
6060
@test unwrap_const(simplify(Term{SymReal}(zero, [b + 1]))) == 0
6161
@test unwrap_const(simplify(Term{SymReal}(zero, [x + 2]))) == 0
62+
63+
# Fold for the symbolic imaginary unit (matches `Symbolics.IM`).
64+
let IM = SymbolicUtils.Sym{SymReal}(:im; type = Number)
65+
@eqtest conj(IM) == -IM
66+
@test unwrap_const(real(IM)) == 0
67+
@test unwrap_const(imag(IM)) == 1
68+
# Same name, non-Number symtype: untouched.
69+
not_im = SymbolicUtils.Sym{SymReal}(:im; type = Real)
70+
@eqtest conj(not_im) == not_im
71+
end
6272
end
6373

6474
@testset "LiteralReal" begin

0 commit comments

Comments
 (0)