Skip to content

Changes for non-CPU array support#375

Open
kshyatt wants to merge 1 commit intomainfrom
ksh/cu
Open

Changes for non-CPU array support#375
kshyatt wants to merge 1 commit intomainfrom
ksh/cu

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Feb 6, 2026

Mostly allowing passing an array type instead of just an element type to allow GPU arrays to back MPS tensors

@kshyatt kshyatt requested a review from lkdvos February 6, 2026 10:32
envs::FiniteEnvironments = environments(ψ, O)
)
ens = zeros(scalartype(ψ), length(ψ))
ens = zeros(storagetype(ψ), length(ψ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit strange to me, would we not expect this to have to be on the CPU anyways, since it originates from a call to dot which produces a scalar?

If you prefer, I am also happy with writing this entire expression as a single call to sum to avoid the intermediate allocation altogether

TT = blocktensormaptype(spacetype(bra), numout(V), numin(V), TA)
else
TT = TensorMap{T}
TT = TensorKit.TensorMapWithStorage{T, TA}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TT = TensorKit.TensorMapWithStorage{T, TA}
TT = TensorKit.tensormaptype(spacetype(bra), numout(V), numin(V), TA)

# ------------------
function allocate_GL(bra::AbstractMPS, mpo::AbstractMPO, ket::AbstractMPS, i::Int)
T = Base.promote_type(scalartype(bra), scalartype(mpo), scalartype(ket))
TA = similarstoragetype(storagetype(mpo), T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is already a strict improvement, but to make this work in full generality I don't think only using the storagetype of the mpo is really sufficient. I think TensorKit has a promote_storagetype function that can be used to also take into account the storagetype of the states

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can modify this to take into account the backing types of bra and ket too

)
# left to middle
U = ones(scalartype(H), left_virtualspace(H, 1))
U = ones(storagetype(H), left_virtualspace(H, 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should actually call removeunit instead of writing this as a contraction. This is both more efficient and avoids having to deal with storagetype altogether.
Let me know if you need help with this though

MultilineMPO(mpos::AbstractVector{<:AbstractMPO}) = Multiline(mpos)
MultilineMPO(t::MPOTensor) = MultilineMPO(PeriodicMatrix(fill(t, 1, 1)))

TensorKit.storagetype(M::MultilineMPO) = storagetype(M.data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? I would expect you would need something like storagetype(eltype(M)), I don't think we defined storagetype(::AbstractVector) somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem to work, using storagetype all over makes things a lot more flexible as we can call it more or less agnostically. We should have some function that allows us to figure out "what is your actual array type"

return TensorMap{eltype}(undef, Vₗ ⊗ P ← Vᵣ)
::UndefInitializer, ::Type{TorA}, P::Union{S, CompositeSpace{S}}, Vₗ::S, Vᵣ::S = Vₗ
) where {S <: ElementarySpace, TorA}
return TensorKit.TensorMapWithStorage{TorA}(undef, Vₗ ⊗ P ← Vᵣ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to also use TensorKit.tensormaptype here instead.

That being said, this constructor is the worst type piracy I have ever committed, and I would be happy to not have to feel the shame of its existence anymore...

Comment on lines +81 to 85
function MPSTensor(A::AA) where {T <: Number, AA <: AbstractArray{T}}
@assert ndims(A) > 2 "MPSTensor should have at least 3 dims, but has $ndims(A)"
sz = size(A)
t = TensorMap(undef, T, foldl(⊗, ComplexSpace.(sz[1:(end - 1)])) ← ℂ^sz[end])
t = TensorKit.TensorMapWithStorage{T, AA}(undef, foldl(⊗, ComplexSpace.(sz[1:(end - 1)])) ← ℂ^sz[end])
t[] .= A
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function MPSTensor(A::AA) where {T <: Number, AA <: AbstractArray{T}}
@assert ndims(A) > 2 "MPSTensor should have at least 3 dims, but has $ndims(A)"
sz = size(A)
t = TensorMap(undef, T, foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end])
t = TensorKit.TensorMapWithStorage{T, AA}(undef, foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end])
t[] .= A
function MPSTensor(A::AbstractArray{<:Number})
@assert ndims(A) > 2 "MPSTensor should have at least 3 dims, but has $ndims(A)"
sz = size(A)
V = foldl(, ComplexSpace.(sz[1:(end - 1)])) ^sz[end]
t = TensorMap(A, V)

I think this might be slightly cleaner, simply making use of TensorKit functionality?

Comment on lines +41 to +42
TensorKit.storagetype(PA::PeriodicArray{T, N}) where {T, N} = storagetype(T)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little hesitant to adopt storagetype in such a general manner here in MPSKit, since I don't think this is even exported by TensorKit. If this function should exist, I think we would have to decide to use storagetype(A::AbstractArray) = storagetype(eltype(A)), but I'm not sure if that really is in the scope of that (somewhat internal) function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants