Skip to content

Initial steps to Mooncake support#391

Draft
kshyatt wants to merge 4 commits into
mainfrom
ksh/moonsvd
Draft

Initial steps to Mooncake support#391
kshyatt wants to merge 4 commits into
mainfrom
ksh/moonsvd

Conversation

@kshyatt
Copy link
Copy Markdown
Member

@kshyatt kshyatt commented Jun 1, 2026

Currently based off the existing ChainRules support

TODO:

  • full support for actually in place svd_trunc!
  • end-to-end support for all the fixed point differentiation logic
  • support for eig(h)_trunc!
  • GPU support and tests

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 1, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/PEPSKitMooncakeExt.jl b/ext/PEPSKitMooncakeExt.jl
index 74b528f..486e8c1 100644
--- a/ext/PEPSKitMooncakeExt.jl
+++ b/ext/PEPSKitMooncakeExt.jl
@@ -63,11 +63,11 @@ end
 function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
     t, dt = arrayify(t_dt)
     alg = primal(alg_dalg)
-    
+
     D, V = eigh_full!(t; alg.fwd_alg.alg)
     (D̃, Ṽ), inds = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
     ϵ = MatrixAlgebraKit.truncation_error(diagview(D), inds)
-    
+
     DVtrunc = (D̃, Ṽ)
     # pack output
     DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
@@ -88,7 +88,7 @@ end
 function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.TruncPullback}
     t, dt = arrayify(t_dt)
     alg = primal(alg_dalg)
-    
+
     D, V, truncerror = eigh_trunc(t, alg)
     gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
     output = (D, V, truncerror)
diff --git a/test/mooncake/eigh_wrapper.jl b/test/mooncake/eigh_wrapper.jl
index a05adaf..71fcabc 100644
--- a/test/mooncake/eigh_wrapper.jl
+++ b/test/mooncake/eigh_wrapper.jl
@@ -3,7 +3,7 @@ using Test
 using Random
 using LinearAlgebra
 using TensorKit
-using Mooncake 
+using Mooncake
 using Accessors
 using PEPSKit
 
@@ -35,7 +35,7 @@ iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :Tru
     full_lossfun = A -> lossfun(A, full_alg, R)
     trunc_lossfun = A -> lossfun(A, trunc_alg, R)
     iter_lossfun = A -> lossfun(A, iter_alg, R)
-    
+
     full_rrule = Mooncake.build_rrule(full_lossfun, r)
     trunc_rrule = Mooncake.build_rrule(trunc_lossfun, r)
     iter_rrule = Mooncake.build_rrule(iter_lossfun, r)
diff --git a/test/mooncake/svd_wrapper.jl b/test/mooncake/svd_wrapper.jl
index 2c010ca..a99c8d4 100644
--- a/test/mooncake/svd_wrapper.jl
+++ b/test/mooncake/svd_wrapper.jl
@@ -119,7 +119,7 @@ symm_R = randn(dtype, space(symm_r))
     @test g_full[2] ≈ g_trunc[2] rtol = rtol
     @test g_full[2] ≈ g_iter[2] rtol = rtol
     @test g_trunc[2] ≈ g_iter[2] rtol = rtol
-    
+
     full_lossfun = A -> lossfun(A, full_alg, symm_R, symm_trspace)
     trunc_lossfun = A -> lossfun(A, trunc_alg, symm_R, symm_trspace)
     iter_lossfun = A -> lossfun(A, iter_alg, symm_R, symm_trspace)
@@ -137,7 +137,7 @@ symm_R = randn(dtype, space(symm_r))
     @test g_trunc_tr[2] ≈ g_iter_tr[2] rtol = rtol
 
     iter_alg_fallback = @set iter_alg.fwd_alg.fallback_threshold = 0.4  # Do dense decomposition in one block, sparse one in the other
-    
+
     fb_lossfun = A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace)
     fb_rrule = Mooncake.build_rrule(fb_lossfun, symm_r)
     l_iter_fb, g_iter_fb = Mooncake.value_and_gradient!!(fb_rrule, fb_lossfun, symm_r)
@@ -159,7 +159,7 @@ end
 
     no_broadening_no_cutoff_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-30
     small_broadening_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-13
-    
+
     only_lossfun = A -> lossfun(A, alg, symm_R, symm_trspace)
     no_broadening_lossfun = A -> lossfun(A, no_broadening_no_cutoff_alg, symm_R, symm_trspace)
     small_broadening_lossfun = A -> lossfun(A, small_broadening_alg, symm_R, symm_trspace)

@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented Jun 1, 2026

This is all falling right now due to some tags needing to hit elsewhere but I am making progress locally :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant