Skip to content

Commit c2487ad

Browse files
committed
fix trace output
1 parent 566ff7d commit c2487ad

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

src/symbolic_trace.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,28 @@ end
194194
Symbolic function representing the trace of a matrix product.
195195
"""
196196
function tr_val(factors::Vector{SymbolicMatrix})
197-
# Create a symbolic variable representing this trace
197+
# Create a symbolic term representing this trace
198198
# This avoids issues with Term multiplication and simplification
199199
if isempty(factors)
200-
return 1 # Should handle dim separately, but trace of empty is not passed here usually
200+
return 1
201201
end
202202

203-
# Construct a nice string representation
204-
# e.g. "tr(A * B)"
203+
# Construct the inner expression
204+
# e.g. "A * B"
205205
s_parts = String[]
206206
for (i, f) in enumerate(factors)
207207
push!(s_parts, string(f))
208208
end
209-
name = "tr_val(" * join(s_parts, "*") * ")"
210-
return Symbolics.variable(Symbol(name); T = Real)
209+
210+
# We create an inner variable representing the content
211+
# This is needed because `factors` are our custom types, not Symbolics terms.
212+
# To put them inside a Term, we need a Num/Symbolic object.
213+
inner_content_name = join(s_parts, "*")
214+
inner_var = Symbolics.variable(Symbol(inner_content_name); T=Real)
215+
216+
# Return a Term
217+
# IntU.tr is the function head
218+
return Symbolics.term(tr, inner_var)
211219
end
212220
# Symbolics metadata might go here if needed.
213221
# Actually, we rely on Term wrapping it.

0 commit comments

Comments
 (0)