Initial steps to Mooncake support#391
Draft
kshyatt wants to merge 4 commits into
Draft
Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/PEPSKitMooncakeExt.jl b/ext/PEPSKitMooncakeExt.jl
index 74b528f..486e8c1 100644
--- a/ext/PEPSKitMooncakeExt.jl
+++ b/ext/PEPSKitMooncakeExt.jl
@@ -63,11 +63,11 @@ end
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.FullPullback}
t, dt = arrayify(t_dt)
alg = primal(alg_dalg)
-
+
D, V = eigh_full!(t; alg.fwd_alg.alg)
(D̃, Ṽ), inds = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.fwd_alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(D), inds)
-
+
DVtrunc = (D̃, Ṽ)
# pack output
DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ))
@@ -88,7 +88,7 @@ end
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_trunc)}, t_dt::CoDual{<:TensorKit.AbstractTensorMap}, alg_dalg::CoDual{EighAdjoint{F, R}}) where {F, R <: PEPSKit.TruncPullback}
t, dt = arrayify(t_dt)
alg = primal(alg_dalg)
-
+
D, V, truncerror = eigh_trunc(t, alg)
gtol = PEPSKit._get_pullback_gauge_tol(alg.rrule_alg.verbosity)
output = (D, V, truncerror)
diff --git a/test/mooncake/eigh_wrapper.jl b/test/mooncake/eigh_wrapper.jl
index a05adaf..71fcabc 100644
--- a/test/mooncake/eigh_wrapper.jl
+++ b/test/mooncake/eigh_wrapper.jl
@@ -3,7 +3,7 @@ using Test
using Random
using LinearAlgebra
using TensorKit
-using Mooncake
+using Mooncake
using Accessors
using PEPSKit
@@ -35,7 +35,7 @@ iter_alg = EighAdjoint(; fwd_alg = (; alg = :Lanczos), rrule_alg = (; alg = :Tru
full_lossfun = A -> lossfun(A, full_alg, R)
trunc_lossfun = A -> lossfun(A, trunc_alg, R)
iter_lossfun = A -> lossfun(A, iter_alg, R)
-
+
full_rrule = Mooncake.build_rrule(full_lossfun, r)
trunc_rrule = Mooncake.build_rrule(trunc_lossfun, r)
iter_rrule = Mooncake.build_rrule(iter_lossfun, r)
diff --git a/test/mooncake/svd_wrapper.jl b/test/mooncake/svd_wrapper.jl
index 2c010ca..a99c8d4 100644
--- a/test/mooncake/svd_wrapper.jl
+++ b/test/mooncake/svd_wrapper.jl
@@ -119,7 +119,7 @@ symm_R = randn(dtype, space(symm_r))
@test g_full[2] ≈ g_trunc[2] rtol = rtol
@test g_full[2] ≈ g_iter[2] rtol = rtol
@test g_trunc[2] ≈ g_iter[2] rtol = rtol
-
+
full_lossfun = A -> lossfun(A, full_alg, symm_R, symm_trspace)
trunc_lossfun = A -> lossfun(A, trunc_alg, symm_R, symm_trspace)
iter_lossfun = A -> lossfun(A, iter_alg, symm_R, symm_trspace)
@@ -137,7 +137,7 @@ symm_R = randn(dtype, space(symm_r))
@test g_trunc_tr[2] ≈ g_iter_tr[2] rtol = rtol
iter_alg_fallback = @set iter_alg.fwd_alg.fallback_threshold = 0.4 # Do dense decomposition in one block, sparse one in the other
-
+
fb_lossfun = A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace)
fb_rrule = Mooncake.build_rrule(fb_lossfun, symm_r)
l_iter_fb, g_iter_fb = Mooncake.value_and_gradient!!(fb_rrule, fb_lossfun, symm_r)
@@ -159,7 +159,7 @@ end
no_broadening_no_cutoff_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-30
small_broadening_alg = @set alg.rrule_alg.degeneracy_atol = 1.0e-13
-
+
only_lossfun = A -> lossfun(A, alg, symm_R, symm_trspace)
no_broadening_lossfun = A -> lossfun(A, no_broadening_no_cutoff_alg, symm_R, symm_trspace)
small_broadening_lossfun = A -> lossfun(A, small_broadening_alg, symm_R, symm_trspace) |
Member
Author
|
This is all falling right now due to some tags needing to hit elsewhere but I am making progress locally :) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently based off the existing ChainRules support
TODO:
svd_trunc!eig(h)_trunc!