Skip to content

Commit 9c9b9dd

Browse files
committed
Adapt to v0.15
1 parent fdc9112 commit 9c9b9dd

6 files changed

Lines changed: 52 additions & 81 deletions

File tree

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "TensorKitManifolds"
22
uuid = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
33
authors = ["Jutho Haegeman <jutho.haegeman@ugent.be>", "Markus Hauru <markus@mhauru.org>"]
4-
version = "0.7.2"
4+
version = "0.7.3"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
89
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
910

1011
[compat]
11-
TensorKit = "0.13,0.14"
12+
MatrixAlgebraKit = "0.5.0"
13+
TensorKit = "0.15"
1214
julia = "1.10"
1315

1416
[extras]

src/TensorKitManifolds.jl

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ export Grassmann, Stiefel, Unitary
77
export inner, retract, transport, transport!
88

99
using TensorKit
10+
using MatrixAlgebraKit: MatrixAlgebraKit, AbstractAlgorithm, Algorithm, PolarViaSVD,
11+
LAPACK_DivideAndConquer, diagview
12+
import MatrixAlgebraKit as MAK
1013

1114
# Every submodule -- Grassmann, Stiefel, and Unitary -- implements their own methods for
1215
# these. The signatures should be
@@ -28,23 +31,8 @@ checkbase(x, y, z, args...) = checkbase(checkbase(x, y), z, args...)
2831
# the machine epsilon for the elements of an object X, name inspired from eltype
2932
scalareps(X) = eps(real(scalartype(X)))
3033

31-
# default SVD algorithm used in the algorithms
32-
default_svd_alg(::AbstractTensorMap) = TensorKit.SVD()
33-
34-
function isisometry(W::AbstractTensorMap; tol=10 * scalareps(W))
35-
WdW = W' * W
36-
s = zero(float(real(scalartype(W))))
37-
for (c, b) in blocks(WdW)
38-
_subtractone!(b)
39-
s += dim(c) * length(b)
40-
end
41-
return norm(WdW) <= tol * sqrt(s)
42-
end
43-
44-
function isunitary(W::AbstractTensorMap; tol=10 * scalareps(W))
45-
return isisometry(W; tol=tol) && isisometry(W'; tol=tol)
46-
end
47-
34+
# TODO: these functions should be replaced by MAK functions
35+
projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
4836
function projecthermitian!(W::AbstractTensorMap)
4937
codomain(W) == domain(W) ||
5038
throw(DomainError("Tensor with distinct domain and codomain cannot be hermitian."))
@@ -53,6 +41,8 @@ function projecthermitian!(W::AbstractTensorMap)
5341
end
5442
return W
5543
end
44+
45+
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))
5646
function projectantihermitian!(W::AbstractTensorMap)
5747
codomain(W) == domain(W) ||
5848
throw(DomainError("Tensor with distinct domain and codomain cannot be anithermitian."))
@@ -62,27 +52,18 @@ function projectantihermitian!(W::AbstractTensorMap)
6252
return W
6353
end
6454

65-
struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm
66-
end
67-
function projectisometric!(W::AbstractTensorMap; alg=default_svd_alg(W))
68-
if alg isa TensorKit.Polar || alg isa TensorKit.SDD
69-
foreach(blocks(W)) do (c, b)
70-
return _polarsdd!(b)
71-
end
72-
elseif alg isa TensorKit.SVD
73-
foreach(blocks(W)) do (c, b)
74-
return _polarsvd!(b)
75-
end
76-
elseif alg isa PolarNewton
77-
foreach(blocks(W)) do (c, b)
78-
return _polarnewton!(b)
79-
end
80-
else
81-
throw(ArgumentError("unkown algorithm for projectisometric!: alg = $alg"))
55+
projectisometric(W::AbstractTensorMap; kwargs...) = projectisometric!(copy(W); kwargs...)
56+
function projectisometric!(W::AbstractTensorMap;
57+
alg::AbstractAlgorithm=MAK.select_algorithm(left_polar!, W))
58+
TensorKit.foreachblock(W) do c, (b,)
59+
return _left_polar!(b, alg)
8260
end
8361
return W
8462
end
8563

64+
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap, kwargs...)
65+
return projectcomplement!(copy(X), W; kwargs...)
66+
end
8667
function projectcomplement!(X::AbstractTensorMap, W::AbstractTensorMap;
8768
tol=10 * scalareps(X))
8869
P = W' * X
@@ -97,18 +78,6 @@ function projectcomplement!(X::AbstractTensorMap, W::AbstractTensorMap;
9778
return X
9879
end
9980

100-
projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
101-
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))
102-
103-
function projectisometric(W::AbstractTensorMap;
104-
alg=default_svd_alg(W))
105-
return projectisometric!(copy(W); alg=alg)
106-
end
107-
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap,
108-
tol=10 * scalareps(X))
109-
return projectcomplement!(copy(X), W; tol=tol)
110-
end
111-
11281
include("auxiliary.jl")
11382
include("grassmann.jl")
11483
include("stiefel.jl")

src/auxiliary.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,37 +70,42 @@ function _subtractone!(a::AbstractMatrix)
7070
view(a, diagind(a)) .= view(a, diagind(a)) .- 1
7171
return a
7272
end
73-
function _polarsdd!(A::StridedMatrix)
74-
U, S, V = svd!(A; alg=LinearAlgebra.DivideAndConquer())
75-
return mul!(A, U, V')
76-
end
77-
function _polarsvd!(A::StridedMatrix)
78-
U, S, V = svd!(A; alg=LinearAlgebra.QRIteration())
79-
return mul!(A, U, V')
73+
74+
# TODO: _left_polar! is more or less the same as MAK.left_polar! but doesn't compute the P
75+
# which is not needed here. Can we unify this?
76+
function _left_polar!(A::StridedMatrix, alg::PolarViaSVD=PolarViaSVD(LAPACK_DivideAndConquer()))
77+
U, _, Vᴴ = svd_compact!(A, alg.svdalg)
78+
return mul!(A, U, Vᴴ)
8079
end
80+
81+
# TODO: can we move this to a dedicated MAK algorithm?
82+
MatrixAlgebraKit.@algdef PolarNewton
83+
84+
_left_polar!(A::StridedMatrix, alg::PolarNewton) = _polarnewton!(A; alg.kwargs...)
8185
function _polarnewton!(A::StridedMatrix; tol=10 * scalareps(A), maxiter=5)
8286
m, n = size(A)
8387
@assert m >= n
8488
A2 = copy(A)
85-
Q, R = qr!(A2)
86-
Ri = ldiv!(UpperTriangular(R)', TensorKit.MatrixAlgebra.one!(similar(R)))
89+
Q, R = LinearAlgebra.qr!(A2)
90+
Ri = ldiv!(UpperTriangular(R)', MatrixAlgebraKit.one!(similar(R)))
8791
R, Ri = _avgdiff!(R, Ri)
8892
i = 1
8993
R2 = view(A, 1:n, 1:n)
9094
fill!(view(A, (n + 1):m, 1:n), zero(eltype(A)))
9195
copyto!(R2, R)
9296
while maximum(abs, Ri) > tol
9397
if i == maxiter # if not converged by now, fall back to sdd
94-
_polarsdd!(Ri)
98+
_left_polar!(Ri)
9599
break
96100
end
97-
Ri = ldiv!(lu!(R2)', TensorKit.MatrixAlgebra.one!(Ri))
101+
Ri = ldiv!(lu!(R2)', MatrixAlgebraKit.one!(Ri))
98102
R, Ri = _avgdiff!(R, Ri)
99103
copyto!(R2, R)
100104
i += 1
101105
end
102106
return lmul!(Q, A)
103107
end
108+
104109
# in place computation of the average and difference of two arrays
105110
function _avgdiff!(A::AbstractArray, B::AbstractArray)
106111
axes(A) == axes(B) || throw(DimensionMismatch())
@@ -124,7 +129,7 @@ end
124129
function _stiefelexp(W::StridedMatrix, A::StridedMatrix, Z::StridedMatrix, α)
125130
n, p = size(W)
126131
r = min(2 * p, n)
127-
QQ, _ = qr!([W Z])
132+
QQ, _ = LinearAlgebra.qr!([W Z])
128133
Q = similar(W, n, r - p)
129134
@inbounds for j in Base.OneTo(r - p)
130135
for i in Base.OneTo(n)
@@ -139,7 +144,7 @@ function _stiefelexp(W::StridedMatrix, A::StridedMatrix, Z::StridedMatrix, α)
139144
A2[1:p, (p + 1):end] .= (-α) .* (R')
140145
A2[(p + 1):end, (p + 1):end] .= 0
141146
U = [W Q] * exp(A2)
142-
U = _polarnewton!(U)
147+
U = _left_polar!(U, PolarNewton())
143148
W′ = U[:, 1:p]
144149
Q′ = U[:, (p + 1):end]
145150
R′ = R
@@ -152,7 +157,7 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
152157
r = min(2 * p, n)
153158
P = Wold' * Wnew
154159
dW = Wnew - Wold * P
155-
QQ, _ = qr!([Wold dW])
160+
QQ, _ = LinearAlgebra.qr!([Wold dW])
156161
Q = similar(Wold, n, r - p)
157162
@inbounds for j in Base.OneTo(r - p)
158163
for i in Base.OneTo(n)
@@ -161,23 +166,17 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
161166
end
162167
Q = lmul!(QQ, Q)
163168
R = Q' * dW
164-
Wext = [Wold Q]
165-
F = qr!([P; R])
166-
U = lmul!(F.Q, TensorKit.MatrixAlgebra.one!(similar(P, r, r)))
169+
F = LinearAlgebra.qr!([P; R])
170+
U = lmul!(F.Q, MatrixAlgebraKit.one!(similar(P, r, r)))
167171
U[1:p, 1:p] .= P
168172
U[(p + 1):r, 1:p] .= R
169173
X = view(U, 1:p, (p + 1):r)
170174
Y = view(U, (p + 1):r, (p + 1):r)
171175
if p < n
172-
YSVD = svd!(Y)
173-
mul!(X, X * (YSVD.V), (YSVD.U)')
174-
UsqrtS = YSVD.U
175-
@inbounds for j in 1:size(UsqrtS, 2)
176-
s = sqrt(YSVD.S[j])
177-
@simd for i in 1:size(UsqrtS, 1)
178-
UsqrtS[i, j] *= s
179-
end
180-
end
176+
USVᴴ = svd_compact!(Y)
177+
mul!(X, X * USVᴴ[3]', USVᴴ[1]')
178+
diagview(USVᴴ[2]) .= sqrt.(diagview(USVᴴ[2]))
179+
UsqrtS = rmul!(USVᴴ[1], USVᴴ[2])
181180
mul!(Y, UsqrtS, UsqrtS')
182181
end
183182
logU = _projectantihermitian!(log(U))

src/grassmann.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function Base.getproperty(Δ::GrassmannTangent, sym::Symbol)
6060
elseif sym (:U, :S, :V)
6161
v = Base.getfield(Δ, sym)
6262
v !== nothing && return v
63-
U, S, V, = tsvd.Z; alg=default_svd_alg.Z))
63+
U, S, V, = svd_compact.Z)
6464
Base.setfield!(Δ, :U, U)
6565
Base.setfield!(Δ, :S, S)
6666
Base.setfield!(Δ, :V, V)
@@ -198,7 +198,7 @@ function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothin
198198
space(Wold) == space(Wnew) || throw(SpaceMismatch())
199199
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
200200
Wneworth = Wnew - Wold * WodWn
201-
Vd, cS, VY = tsvd!(WodWn; alg=default_svd_alg(WodWn))
201+
Vd, cS, VY = svd_compact!(WodWn)
202202
Scmplx = acos(cS)
203203
# acos always returns a complex TensorMap. We cast back to real if possible.
204204
S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx

src/unitary.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using TensorKit
77
import TensorKit: similarstoragetype, SectorDict
88
using ..TensorKitManifolds: projectantihermitian!, projectisometric!, PolarNewton
99
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!
10+
import MatrixAlgebraKit as MAK
1011

1112
struct UnitaryTangent{T<:AbstractTensorMap,TA<:AbstractTensorMap}
1213
W::T
@@ -82,10 +83,10 @@ end
8283
project(X, W; metric=:euclidean) = project!(copy(X), W; metric=:euclidean)
8384

8485
# geodesic retraction, coincides with Stiefel retraction (which is not geodesic for p < n)
85-
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=nothing)
86+
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=MAK.select_algorithm(left_polar!, W))
8687
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
8788
E = exp* Δ.A)
88-
W′ = projectisometric!(W * E; alg=SDD())
89+
W′ = projectisometric!(W * E; alg)
8990
A′ = Δ.A
9091
return W′, UnitaryTangent(W′, A′)
9192
end
@@ -104,7 +105,7 @@ function transport!(Θ::UnitaryTangent, W::AbstractTensorMap, Δ::UnitaryTangent
104105
end
105106
function transport::UnitaryTangent, W::AbstractTensorMap, Δ::UnitaryTangent, α::Real, W′;
106107
alg=:stiefel)
107-
return transport!(copy(Θ), W, Δ, α, W′; alg=alg)
108+
return transport!(copy(Θ), W, Δ, α, W′; alg)
108109
end
109110

110111
# transport_parallel correspondings to the torsion-free Levi-Civita connection

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const α = 0.75
88

99
@testset "Grassmann with space $V" for V in spaces
1010
for T in (Float64,)
11-
W, = leftorth(randn(T, V * V * V, V * V); alg=Polar())
11+
W, = left_polar(randn(T, V * V * V, V * V))
1212
X = randn(T, space(W))
1313
Y = randn(T, space(W))
1414
Δ = @inferred Grassmann.project(X, W)
@@ -124,7 +124,7 @@ end
124124

125125
@testset "Unitary with space $V" for V in spaces
126126
for T in (Float64, ComplexF64)
127-
W, = leftorth(randn(T, V * V * V, V * V); alg=Polar())
127+
W, = left_polar(randn(T, V * V * V, V * V))
128128
X = randn(T, space(W))
129129
Y = randn(T, space(W))
130130
Δ = @inferred Unitary.project(X, W)

0 commit comments

Comments
 (0)