Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TensorNetwork serialization #213

Merged
merged 3 commits into from
Oct 1, 2024
Merged

Fix TensorNetwork serialization #213

merged 3 commits into from
Oct 1, 2024

Conversation

jofrevalles
Copy link
Member

This PR fixes the serialization problem of TensorNetworks. Since this function was not tested until now, I don't know how long it has not been working. This PR also adds a testset to cover this function.

In the following code snippets, I show the serialization problem with an MPS, but the same problem holds for any TensorNetwork.

Current problem

The deserialize function does not work:

julia> using Tenet; using Serialization

julia> mps = rand(Chain, Open, State; n=10, χ=10)
MPS (inputs=0, outputs=10)

julia> buffer = IOBuffer()
IOBuffer(data=UInt8[...], readable=true, writable=true, seekable=true, append=false, size=0, maxsize=Inf, ptr=1, mark=-1)

julia> serialize(buffer, mps)

julia> seekstart(buffer)
IOBuffer(data=UInt8[...], readable=true, writable=true, seekable=true, append=false, size=7794, maxsize=Inf, ptr=1, mark=-1)

julia> content = read(buffer)
7794-element Vector{UInt8}:
...

julia> read_buffer  = IOBuffer(content)
IOBuffer(data=UInt8[...], readable=true, writable=false, seekable=true, append=false, size=7794, maxsize=Inf, ptr=1, mark=-1)

julia> mps2 = deserialize(read_buffer)
ERROR: MethodError: no method matching deserialize(::Serializer{IOBuffer}, ::Vector{Tensor})

Closest candidates are:
  deserialize(::AbstractSerializer, ::Type{T}) where T<:Base.GenericCondition
   @ Serialization ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/Serialization/src/Serialization.jl:1592
  deserialize(::AbstractSerializer, ::Type{T}) where T<:Distributed.WorkerPool
   @ Distributed ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/workerpool.jl:69
  deserialize(::AbstractSerializer, ::Type{T}) where T<:Base.AbstractLock
   @ Serialization ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/Serialization/src/Serialization.jl:1580
  ...

Stacktrace:
...

With this PR

This PR solves the problem:

julia> using Tenet; using Serialization

julia> mps = rand(Chain, Open, State; n=10, χ=10)
MPS (inputs=0, outputs=10)

julia> buffer = IOBuffer()
IOBuffer(data=UInt8[...], readable=true, writable=true, seekable=true, append=false, size=0, maxsize=Inf, ptr=1, mark=-1)

julia> serialize(buffer, mps)

julia> seekstart(buffer)
IOBuffer(data=UInt8[...], readable=true, writable=true, seekable=true, append=false, size=7836, maxsize=Inf, ptr=1, mark=-1)

julia> content = read(buffer)
7836-element Vector{UInt8}:
...

julia> read_buffer  = IOBuffer(content)
IOBuffer(data=UInt8[...], readable=true, writable=false, seekable=true, append=false, size=7836, maxsize=Inf, ptr=1, mark=-1)

julia> mps2 = deserialize(read_buffer)
MPS (inputs=0, outputs=10)

julia> mps == mps2
true

@jofrevalles jofrevalles requested a review from mofeing October 1, 2024 11:03
@@ -728,6 +728,7 @@ end

function Serialization.serialize(s::AbstractSerializer, obj::TensorNetwork)
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
serialize(s, TensorNetwork)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I'm trying to understand this line. Is this needed in order to tell Julia that the object to deserialize is a TensorNetwork right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, this is what I have seen in similar posts (for example).

Copy link
Member

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@mofeing mofeing added the backport Fix must be back ported label Oct 1, 2024
@mofeing mofeing merged commit 5825ffc into master Oct 1, 2024
6 checks passed
@mofeing mofeing deleted the fix/serialization branch October 1, 2024 17:11
jofrevalles added a commit that referenced this pull request Oct 2, 2024
* Fix TensorNetwork serialization

* Add Serialization cover tests

* Format test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport Fix must be back ported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants