@@ -65,14 +65,19 @@ function _robust_real(x)
6565
6666 # Try early plain value recovery
6767 v = Symbolics. value (x_un)
68+ if v isa AbstractFloat
69+ return rationalize (v, tol= 1e-13 )
70+ end
6871 if v isa Real
6972 return v
7073 end
7174 if v isa Complex
72- if iszero (imag (v))
73- return _robust_real (real (v))
75+ rv = _robust_real (real (v))
76+ iv = _robust_real (imag (v))
77+ if iszero (iv)
78+ return rv
7479 end
75- return Complex (_robust_real ( real (v)), _robust_real ( imag (v)) )
80+ return Complex (rv, iv )
7681 end
7782
7883 # Handle symbolic complex calls
@@ -92,18 +97,43 @@ function _robust_real(x)
9297 # Try symbolic simplification to see if it's real
9398 try
9499 nx = _safe_Num (x_un)
95- ix = Symbolics. simplify (imag (nx))
96- if _iszero (ix)
100+
101+ # FAST PATH: If already a value, handle directly
102+ v = Symbolics. value (nx)
103+ if v isa AbstractFloat
104+ return rationalize (v, tol= 1e-13 )
105+ end
106+ if v isa Real
107+ return v
108+ end
109+
110+ # SLOW PATH: Simplify and check again
111+ nx = Symbolics. simplify (nx)
112+ v = Symbolics. value (nx)
113+ if v isa AbstractFloat
114+ return rationalize (v, tol= 1e-13 )
115+ end
116+ if v isa Real
117+ return v
118+ end
119+
120+ # Check if it has non-zero imaginary part
121+ # Symbolics.simplify(imag(nx)) should be fast now since nx is simplified
122+ if _iszero (Symbolics. simplify (imag (nx)))
97123 rx = Symbolics. simplify (real (nx))
98124 vx = Symbolics. value (rx)
125+ if vx isa AbstractFloat
126+ return rationalize (vx, tol= 1e-13 )
127+ end
99128 if vx isa Real
100129 return vx
101130 end
102131 return rx
103132 end
133+ return nx
104134 catch
105135 end
106-
136+
107137 return x_un
108138end
109139
@@ -202,7 +232,15 @@ function _integrate_core(
202232
203233 # Rewrite rules to ensure abs(z)^2 becomes (z * conj(z)) or real^2 + imag^2
204234 r_abs2 = @rule abs2 (~ x) => (~ x) * conj (~ x)
205- r_abs_sq = @rule abs (~ x)^ 2 => (~ x) * conj (~ x)
235+ r_abs_pow = @rule abs (~ x)^ ~ n => begin
236+ n_un = Symbolics. unwrap (~ n)
237+ if n_un isa Number && isinteger (n_un) && iseven (Int (n_un))
238+ k = Int (n_un) ÷ 2
239+ ((~ x) * conj (~ x))^ k
240+ else
241+ nothing
242+ end
243+ end
206244 r_abs = @rule abs (~ x) => hypot (real (~ x), imag (~ x))
207245
208246 r_real = @rule real (~ x) => (1 // 2 ) * (~ x + conj (~ x))
@@ -220,11 +258,42 @@ function _integrate_core(
220258 r_complex = @rule complex (~ x, ~ y) => ~ x + im* ~ y
221259 r_complex_base = @rule Base. complex (~ x, ~ y) => ~ x + im* ~ y
222260
261+ # Fix nested powers: ((x^a)^b) -> x^(a*b)
262+ function power_simplifier (x, a, b)
263+ new_expon = a * b
264+ if isinteger (new_expon)
265+ return x^ Int (new_expon)
266+ end
267+ return x^ new_expon
268+ end
269+ r_pow_nested = @rule ((~ x)^ ~ a)^ ~ b => power_simplifier (~ x, ~ a, ~ b)
270+
271+ # Even powers of hypot: hypot(x,y)^4 -> (x^2 + y^2)^2
272+ r_hypot_pow = @rule hypot (~ x, ~ y)^ ~ n => begin
273+ n_un = Symbolics. unwrap (~ n)
274+ if n_un isa Number && isinteger (n_un) && iseven (Int (n_un))
275+ k = Int (n_un) ÷ 2
276+ ((~ x)^ 2 + (~ y)^ 2 )^ k
277+ else
278+ nothing
279+ end
280+ end
281+
282+ # Convert float integer powers to Int: x^2.0 -> x^2
283+ r_float_to_int_pow = @rule (~ x)^ ~ a => begin
284+ a_un = Symbolics. unwrap (~ a)
285+ if a_un isa AbstractFloat && isinteger (a_un)
286+ (~ x)^ Int (a_un)
287+ else
288+ nothing
289+ end
290+ end
291+
223292 expr_unwrapped = Symbolics. unwrap (expr)
224293
225294 # Apply rewrites
226295 chain = SymbolicUtils. Chain ([
227- r_abs_sq ,
296+ r_abs_pow ,
228297 r_abs2,
229298 r_abs,
230299 r_real,
@@ -235,6 +304,9 @@ function _integrate_core(
235304 r_hypot_default,
236305 r_complex,
237306 r_complex_base,
307+ r_pow_nested,
308+ r_hypot_pow,
309+ r_float_to_int_pow, # Add float fix
238310 ])
239311 expr_rewritten =
240312 SymbolicUtils. Postwalk (
@@ -247,41 +319,14 @@ function _integrate_core(
247319 end
248320 )(expr_unwrapped)
249321
250- # Manual power fixing function
251- function fix_powers (t)
252- if Symbolics. iscall (t)
253- op = Symbolics. operation (t)
254- if op == (^ )
255- args = Symbolics. arguments (t)
256- base = args[1 ]
257- expon = args[2 ]
258-
259- if Symbolics. iscall (base) && Symbolics. operation (base) == (^ )
260- base_args = Symbolics. arguments (base)
261- inner_base = base_args[1 ]
262- inner_expon = base_args[2 ]
263-
264- new_expon = inner_expon * expon
265- if isinteger (new_expon)
266- return inner_base^ Int (new_expon)
267- else
268- return inner_base^ new_expon
269- end
270- end
271-
272- if expon isa Rational && isinteger (expon)
273- return base^ Int (expon)
274- end
275- end
276- end
277- return t
278- end
279-
280- expr_rewritten = SymbolicUtils. Postwalk (fix_powers)(expr_rewritten)
322+ # Manual power fixing function removed
323+
324+ # expr_rewritten = SymbolicUtils.Postwalk(fix_powers)(expr_rewritten)
281325 if expr_rewritten isa Complex
282326 return _integrate_core (expr_rewritten, dim, subs_dict, matcher, measure_type)
283327 end
284328 expr_num = _safe_Num (expr_rewritten)
329+
285330
286331 # Expand again
287332 try
@@ -326,24 +371,20 @@ function _integrate_core(
326371 robust_substitute (Symbolics. unwrap (expr_num), subs_dict)
327372 end
328373
374+
329375 # Expand
330376 expanded_expr = try
331377 Symbolics. expand (_safe_Num (Symbolics. unwrap (expr_subbed)))
332378 catch
333379 _safe_Num (expr_subbed)
334380 end
381+
382+
383+
335384
336385 # Apply rewrites again to catch complex(...) introduced by substitution
337- expanded_expr = SymbolicUtils. Postwalk (SymbolicUtils. PassThrough (chain))(
338- Symbolics. unwrap (expanded_expr),
339- )
340- # Safe wrap and expand again to distribute
341- try
342- expanded_expr = Symbolics. expand (_safe_Num (expanded_expr))
343- catch e
344- println (" DEBUG: Expansion failed: " , e)
345- expanded_expr = _safe_Num (expanded_expr)
346- end
386+ # NOTE: This second Postwalk was removed to improve performance on large expanded expressions.
387+ # The simplification/expansion above should be sufficient.
347388
348389 # Helper to traverse product
349390 function process_term_wrapped (term)
@@ -375,7 +416,8 @@ function _integrate_core(
375416
376417 # Handle result
377418 final_res = integrate_num_expr (expanded_expr)
378- return _robust_real (final_res)
419+ res = _robust_real (final_res)
420+ return res
379421end
380422
381423function integrate (expr:: LazySum , measure)
@@ -416,7 +458,8 @@ function fallback_integrate(expr, measure)
416458 info = measure_info (measure)
417459 if info != = nothing
418460 subs_dict, matcher, dim, measure_type = info
419- return _robust_real (_integrate_core (expr, dim, subs_dict, matcher, measure_type))
461+ # _integrate_core already calls _robust_real on the final result
462+ return _integrate_core (expr, dim, subs_dict, matcher, measure_type)
420463 end
421464
422465 # Optional: fallback to manual implementation if measure_info is not provided
@@ -567,15 +610,15 @@ function process_term(term, matcher::AbstractIndexMatcher, dim, measure_type = :
567610 rule = get (INTEGRATION_RULES, first (measure_type), nothing )
568611 end
569612
570- if rule != = nothing
571- val = rule (u_indices, u_bar_indices, dim, measure_type)
572- if _symbolic_isequal (val, 0 )
573- return 0
613+ if rule != = nothing
614+ val = rule (u_indices, u_bar_indices, dim, measure_type)
615+ if _symbolic_isequal (val, 0 )
616+ return 0
617+ end
618+ return coeff * val
619+ else
620+ error (" Unknown measure type: $measure_type " )
574621 end
575- return coeff * val
576- else
577- error (" Unknown measure type: $measure_type " )
578- end
579622end
580623
581624# Register standard rules
@@ -791,10 +834,18 @@ function integrate_indices_orthogonal(indices::Vector{Tuple{Int,Int}}, dim)
791834 sigma_counts[c_sigma] = get (sigma_counts, c_sigma, 0 ) + 1
792835 end
793836
837+ val_mat, lookup = get_weingarten_orthogonal_data (n ÷ 2 , dim)
838+
794839 total = 0 // 1
795840 for (c_pi, count_pi) in pi_counts
841+ idx_pi = get (lookup, c_pi, nothing )
842+ idx_pi === nothing && continue
843+
796844 for (c_sigma, count_sigma) in sigma_counts
797- val = weingarten_orthogonal_val (c_pi, c_sigma, dim)
845+ idx_sigma = get (lookup, c_sigma, nothing )
846+ idx_sigma === nothing && continue
847+
848+ val = val_mat[idx_pi, idx_sigma]
798849 total += (count_pi * count_sigma) * val
799850 end
800851 end
0 commit comments