Skip to content

Commit 23c3d89

Browse files
committed
almost working refacctored code
1 parent c85cebe commit 23c3d89

2 files changed

Lines changed: 136 additions & 81 deletions

File tree

src/integration_core.jl

Lines changed: 111 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,19 @@ function _robust_real(x)
6565

6666
# Try early plain value recovery
6767
v = Symbolics.value(x_un)
68+
if v isa AbstractFloat
69+
return rationalize(v, tol=1e-13)
70+
end
6871
if v isa Real
6972
return v
7073
end
7174
if v isa Complex
72-
if iszero(imag(v))
73-
return _robust_real(real(v))
75+
rv = _robust_real(real(v))
76+
iv = _robust_real(imag(v))
77+
if iszero(iv)
78+
return rv
7479
end
75-
return Complex(_robust_real(real(v)), _robust_real(imag(v)))
80+
return Complex(rv, iv)
7681
end
7782

7883
# Handle symbolic complex calls
@@ -92,18 +97,43 @@ function _robust_real(x)
9297
# Try symbolic simplification to see if it's real
9398
try
9499
nx = _safe_Num(x_un)
95-
ix = Symbolics.simplify(imag(nx))
96-
if _iszero(ix)
100+
101+
# FAST PATH: If already a value, handle directly
102+
v = Symbolics.value(nx)
103+
if v isa AbstractFloat
104+
return rationalize(v, tol=1e-13)
105+
end
106+
if v isa Real
107+
return v
108+
end
109+
110+
# SLOW PATH: Simplify and check again
111+
nx = Symbolics.simplify(nx)
112+
v = Symbolics.value(nx)
113+
if v isa AbstractFloat
114+
return rationalize(v, tol=1e-13)
115+
end
116+
if v isa Real
117+
return v
118+
end
119+
120+
# Check if it has non-zero imaginary part
121+
# Symbolics.simplify(imag(nx)) should be fast now since nx is simplified
122+
if _iszero(Symbolics.simplify(imag(nx)))
97123
rx = Symbolics.simplify(real(nx))
98124
vx = Symbolics.value(rx)
125+
if vx isa AbstractFloat
126+
return rationalize(vx, tol=1e-13)
127+
end
99128
if vx isa Real
100129
return vx
101130
end
102131
return rx
103132
end
133+
return nx
104134
catch
105135
end
106-
136+
107137
return x_un
108138
end
109139

@@ -202,7 +232,15 @@ function _integrate_core(
202232

203233
# Rewrite rules to ensure abs(z)^2 becomes (z * conj(z)) or real^2 + imag^2
204234
r_abs2 = @rule abs2(~x) => (~x) * conj(~x)
205-
r_abs_sq = @rule abs(~x)^2 => (~x) * conj(~x)
235+
r_abs_pow = @rule abs(~x)^~n => begin
236+
n_un = Symbolics.unwrap(~n)
237+
if n_un isa Number && isinteger(n_un) && iseven(Int(n_un))
238+
k = Int(n_un) ÷ 2
239+
((~x) * conj(~x))^k
240+
else
241+
nothing
242+
end
243+
end
206244
r_abs = @rule abs(~x) => hypot(real(~x), imag(~x))
207245

208246
r_real = @rule real(~x) => (1//2) * (~x + conj(~x))
@@ -220,11 +258,42 @@ function _integrate_core(
220258
r_complex = @rule complex(~x, ~y) => ~x + im*~y
221259
r_complex_base = @rule Base.complex(~x, ~y) => ~x + im*~y
222260

261+
# Fix nested powers: ((x^a)^b) -> x^(a*b)
262+
function power_simplifier(x, a, b)
263+
new_expon = a * b
264+
if isinteger(new_expon)
265+
return x^Int(new_expon)
266+
end
267+
return x^new_expon
268+
end
269+
r_pow_nested = @rule ((~x)^~a)^~b => power_simplifier(~x, ~a, ~b)
270+
271+
# Even powers of hypot: hypot(x,y)^4 -> (x^2 + y^2)^2
272+
r_hypot_pow = @rule hypot(~x, ~y)^~n => begin
273+
n_un = Symbolics.unwrap(~n)
274+
if n_un isa Number && isinteger(n_un) && iseven(Int(n_un))
275+
k = Int(n_un) ÷ 2
276+
((~x)^2 + (~y)^2)^k
277+
else
278+
nothing
279+
end
280+
end
281+
282+
# Convert float integer powers to Int: x^2.0 -> x^2
283+
r_float_to_int_pow = @rule (~x)^~a => begin
284+
a_un = Symbolics.unwrap(~a)
285+
if a_un isa AbstractFloat && isinteger(a_un)
286+
(~x)^Int(a_un)
287+
else
288+
nothing
289+
end
290+
end
291+
223292
expr_unwrapped = Symbolics.unwrap(expr)
224293

225294
# Apply rewrites
226295
chain = SymbolicUtils.Chain([
227-
r_abs_sq,
296+
r_abs_pow,
228297
r_abs2,
229298
r_abs,
230299
r_real,
@@ -235,6 +304,9 @@ function _integrate_core(
235304
r_hypot_default,
236305
r_complex,
237306
r_complex_base,
307+
r_pow_nested,
308+
r_hypot_pow,
309+
r_float_to_int_pow, # Add float fix
238310
])
239311
expr_rewritten =
240312
SymbolicUtils.Postwalk(
@@ -247,41 +319,14 @@ function _integrate_core(
247319
end
248320
)(expr_unwrapped)
249321

250-
# Manual power fixing function
251-
function fix_powers(t)
252-
if Symbolics.iscall(t)
253-
op = Symbolics.operation(t)
254-
if op == (^)
255-
args = Symbolics.arguments(t)
256-
base = args[1]
257-
expon = args[2]
258-
259-
if Symbolics.iscall(base) && Symbolics.operation(base) == (^)
260-
base_args = Symbolics.arguments(base)
261-
inner_base = base_args[1]
262-
inner_expon = base_args[2]
263-
264-
new_expon = inner_expon * expon
265-
if isinteger(new_expon)
266-
return inner_base^Int(new_expon)
267-
else
268-
return inner_base^new_expon
269-
end
270-
end
271-
272-
if expon isa Rational && isinteger(expon)
273-
return base^Int(expon)
274-
end
275-
end
276-
end
277-
return t
278-
end
279-
280-
expr_rewritten = SymbolicUtils.Postwalk(fix_powers)(expr_rewritten)
322+
# Manual power fixing function removed
323+
324+
# expr_rewritten = SymbolicUtils.Postwalk(fix_powers)(expr_rewritten)
281325
if expr_rewritten isa Complex
282326
return _integrate_core(expr_rewritten, dim, subs_dict, matcher, measure_type)
283327
end
284328
expr_num = _safe_Num(expr_rewritten)
329+
285330

286331
# Expand again
287332
try
@@ -326,24 +371,20 @@ function _integrate_core(
326371
robust_substitute(Symbolics.unwrap(expr_num), subs_dict)
327372
end
328373

374+
329375
# Expand
330376
expanded_expr = try
331377
Symbolics.expand(_safe_Num(Symbolics.unwrap(expr_subbed)))
332378
catch
333379
_safe_Num(expr_subbed)
334380
end
381+
382+
383+
335384

336385
# Apply rewrites again to catch complex(...) introduced by substitution
337-
expanded_expr = SymbolicUtils.Postwalk(SymbolicUtils.PassThrough(chain))(
338-
Symbolics.unwrap(expanded_expr),
339-
)
340-
# Safe wrap and expand again to distribute
341-
try
342-
expanded_expr = Symbolics.expand(_safe_Num(expanded_expr))
343-
catch e
344-
println("DEBUG: Expansion failed: ", e)
345-
expanded_expr = _safe_Num(expanded_expr)
346-
end
386+
# NOTE: This second Postwalk was removed to improve performance on large expanded expressions.
387+
# The simplification/expansion above should be sufficient.
347388

348389
# Helper to traverse product
349390
function process_term_wrapped(term)
@@ -375,7 +416,8 @@ function _integrate_core(
375416

376417
# Handle result
377418
final_res = integrate_num_expr(expanded_expr)
378-
return _robust_real(final_res)
419+
res = _robust_real(final_res)
420+
return res
379421
end
380422

381423
function integrate(expr::LazySum, measure)
@@ -416,7 +458,8 @@ function fallback_integrate(expr, measure)
416458
info = measure_info(measure)
417459
if info !== nothing
418460
subs_dict, matcher, dim, measure_type = info
419-
return _robust_real(_integrate_core(expr, dim, subs_dict, matcher, measure_type))
461+
# _integrate_core already calls _robust_real on the final result
462+
return _integrate_core(expr, dim, subs_dict, matcher, measure_type)
420463
end
421464

422465
# Optional: fallback to manual implementation if measure_info is not provided
@@ -567,15 +610,15 @@ function process_term(term, matcher::AbstractIndexMatcher, dim, measure_type = :
567610
rule = get(INTEGRATION_RULES, first(measure_type), nothing)
568611
end
569612

570-
if rule !== nothing
571-
val = rule(u_indices, u_bar_indices, dim, measure_type)
572-
if _symbolic_isequal(val, 0)
573-
return 0
613+
if rule !== nothing
614+
val = rule(u_indices, u_bar_indices, dim, measure_type)
615+
if _symbolic_isequal(val, 0)
616+
return 0
617+
end
618+
return coeff * val
619+
else
620+
error("Unknown measure type: $measure_type")
574621
end
575-
return coeff * val
576-
else
577-
error("Unknown measure type: $measure_type")
578-
end
579622
end
580623

581624
# Register standard rules
@@ -791,10 +834,18 @@ function integrate_indices_orthogonal(indices::Vector{Tuple{Int,Int}}, dim)
791834
sigma_counts[c_sigma] = get(sigma_counts, c_sigma, 0) + 1
792835
end
793836

837+
val_mat, lookup = get_weingarten_orthogonal_data(n ÷ 2, dim)
838+
794839
total = 0 // 1
795840
for (c_pi, count_pi) in pi_counts
841+
idx_pi = get(lookup, c_pi, nothing)
842+
idx_pi === nothing && continue
843+
796844
for (c_sigma, count_sigma) in sigma_counts
797-
val = weingarten_orthogonal_val(c_pi, c_sigma, dim)
845+
idx_sigma = get(lookup, c_sigma, nothing)
846+
idx_sigma === nothing && continue
847+
848+
val = val_mat[idx_pi, idx_sigma]
798849
total += (count_pi * count_sigma) * val
799850
end
800851
end

src/weingarten.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,29 @@ Reference:
9696
"""
9797
@memoize function irrep_dimension(part::Vector{Int}, d)
9898
conj_part = conjugate_partition(part)
99-
cols = length(part) > 0 ? part[1] : 0
100-
101-
# We need to iterate over all boxes (i, j) in the Young diagram
102-
prod_val = 1 // 1
103-
104-
for i = 1:length(part)
105-
for j = 1:part[i]
106-
# Hook length h_{i,j} = lambda[i] - i + lambda'[j] - j + 1
107-
hook_length = part[i] - i + conj_part[j] - j + 1
108-
109-
# Content c_{i,j} = j - i
110-
# Term = d + c_{i,j} = d + j - i
111-
term = d + j - i
112-
113-
# Update product
114-
prod_val *= (d isa Integer ? term // hook_length : term / hook_length)
99+
100+
if d isa Integer
101+
prod_val = 1 // 1
102+
for i = 1:length(part)
103+
for j = 1:part[i]
104+
hook_length = part[i] - i + conj_part[j] - j + 1
105+
term = d + j - i
106+
prod_val *= term // hook_length
107+
end
115108
end
109+
return prod_val
110+
else
111+
num = 1
112+
den = 1
113+
for i = 1:length(part)
114+
for j = 1:part[i]
115+
hook_length = part[i] - i + conj_part[j] - j + 1
116+
num *= (d + j - i)
117+
den *= hook_length
118+
end
119+
end
120+
return num / den
116121
end
117-
118-
return prod_val
119122
end
120123

121124
function get_binary_partition(part::Vector{Int})
@@ -193,6 +196,7 @@ Reference:
193196
@memoize function weingarten(partition_type::Vector{Int}, d)
194197
# Wg(sigma, d) where sigma has cycle type `partition_type`.
195198
n = sum(partition_type)
199+
n_fact = factorial(big(n))
196200

197201
# Iterate over all partitions of n
198202
parts = partitions(n)
@@ -215,13 +219,13 @@ Reference:
215219
dim_lam = irrep_dimension(lam, d)
216220

217221
term = (
218-
d isa Integer ? ((f_lam)^2 * chi_lam_mu) // dim_lam :
219-
((f_lam)^2 * chi_lam_mu) / dim_lam
222+
d isa Integer ? (big(f_lam)^2 * chi_lam_mu) // dim_lam :
223+
(big(f_lam)^2 * chi_lam_mu) / dim_lam
220224
)
221225
sum_val += term
222226
end
223227

224-
return (d isa Integer ? sum_val // (factorial(n)^2) : sum_val / (factorial(n)^2))
228+
return (d isa Integer ? sum_val // (n_fact^2) : sum_val / (n_fact^2))
225229
end
226230

227231

0 commit comments

Comments
 (0)