Skip to content

Commit 93e8d56

Browse files
authored
Merge branch 'master' into feature/compression-v2
2 parents a3454e9 + 20ec49d commit 93e8d56

35 files changed

+635
-380
lines changed

.github/workflows/CI.yml

+47-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
include-all-prereleases: true
4949
- uses: julia-actions/cache@v1
5050
with:
51-
cache-name: Unit Tests CI - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
51+
cache-name: CI / ${{ matrix.test_group }} / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
5252
- name: Add Julia registries
5353
run: |
5454
using Pkg
@@ -63,6 +63,52 @@ jobs:
6363
- uses: codecov/codecov-action@v3
6464
with:
6565
files: lcov.info
66+
python:
67+
name: ${{ matrix.test_group }} / Julia ${{ matrix.version }}
68+
runs-on: ubuntu-latest
69+
strategy:
70+
fail-fast: false
71+
matrix:
72+
version:
73+
- '1'
74+
test_group:
75+
- python
76+
os:
77+
- ubuntu-latest
78+
arch:
79+
- x64
80+
steps:
81+
- uses: actions/checkout@v4
82+
- uses: julia-actions/setup-julia@v1
83+
with:
84+
version: ${{ matrix.version }}
85+
arch: ${{ matrix.arch }}
86+
include-all-prereleases: true
87+
- uses: julia-actions/cache@v1
88+
with:
89+
cache-name: CI / ${{ matrix.test_group }} / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
90+
- name: Set and export registry flavor preference
91+
run: echo "JULIA_PKG_SERVER_REGISTRY_PREFERENCE=${JULIA_PKG_SERVER_REGISTRY_PREFERENCE:-eager}" >> ${GITHUB_ENV}
92+
shell: bash
93+
- name: Add Julia registries
94+
run: |
95+
using Pkg
96+
pkg"registry add https://github.com/bsc-quantic/Registry.git"
97+
pkg"registry add General"
98+
shell: julia --color=yes {0}
99+
- name: Run tests
100+
run: |
101+
julia --color=yes --code-coverage=user --depwarn=yes --project=test/python/ -e '
102+
using Pkg
103+
Pkg.instantiate()
104+
Pkg.resolve()
105+
include("test/python/runtests.jl")'
106+
shell: bash
107+
- uses: julia-actions/julia-buildpkg@v1
108+
- uses: julia-actions/julia-processcoverage@v1
109+
- uses: codecov/codecov-action@v3
110+
with:
111+
files: lcov.info
66112
docs:
67113
name: Documentation
68114
runs-on: ubuntu-latest

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,6 @@ dist
420420
.julia
421421
*.excalidraw
422422
archive/
423-
test/.CondaPkg/
423+
**/*/.CondaPkg/
424424
.CondaPkg/
425425
CondaPkg.toml

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Tenet"
22
uuid = "85d41934-b9cd-44e1-8730-56d86f15f3ec"
33
authors = ["Sergio Sánchez Ramírez <[email protected]>"]
4-
version = "0.8.0"
4+
version = "0.8.3"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -71,15 +71,15 @@ GraphMakie = "0.4,0.5"
7171
Graphs = "1.7"
7272
ITensorMPS = "0.2, 0.3"
7373
ITensorNetworks = "0.11"
74-
ITensors = "0.6, 0.7"
74+
ITensors = "0.6, 0.7, 0.8"
7575
KrylovKit = "0.7, 0.8, 0.9"
7676
LinearAlgebra = "1.10"
7777
Makie = "0.18,0.19,0.20, 0.21, 0.22"
7878
OMEinsum = "0.7, 0.8"
7979
PythonCall = "0.9"
8080
Quac = "0.3"
8181
Random = "1.10"
82-
Reactant = "0.2.18"
82+
Reactant = "0.2.22"
8383
ScopedValues = "1"
8484
Serialization = "1.10"
8585
SparseArrays = "1.10"

ext/TenetITensorMPSExt.jl

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module TenetITensorMPSExt
22

33
using Tenet
4-
using Tenet: Tenet, MPS, tensors, form, inds, lanes, id, Site, Lane
5-
using ITensors
6-
using ITensorMPS
7-
using ITensors: ITensor, Index, dim
4+
using Tenet: Tenet, MPS, tensors, form, inds, lanes, Site, Lane
5+
using ITensors: ITensors, ITensor, Index, dim, siteinds
6+
using ITensorMPS: ITensorMPS, linkinds
87

98
# Convert an AbstractMPS to an ITensor MPS
109
function Base.convert(::Type{ITensorMPS.MPS}, mps::Tenet.AbstractMPS)
@@ -76,15 +75,15 @@ function Base.convert(::Type{MPS}, itensors_mps::ITensorMPS.MPS)
7675
links = linkinds(itensors_mps)
7776

7877
tensors_vec = []
79-
first_ten = array(itensors_mps[1], sites[1], links[1])
78+
first_ten = ITensors.array(itensors_mps[1], sites[1], links[1])
8079
push!(tensors_vec, first_ten)
8180

8281
# Extract the bulk tensors
8382
for j in 2:(length(itensors_mps) - 1)
84-
ten = array(itensors_mps[j], sites[j], links[j - 1], links[j]) # Indices are ordered as (site index, left link, right link)
83+
ten = ITensors.array(itensors_mps[j], sites[j], links[j - 1], links[j]) # Indices are ordered as (site index, left link, right link)
8584
push!(tensors_vec, ten)
8685
end
87-
last_ten = array(itensors_mps[end], sites[end], links[end])
86+
last_ten = ITensors.array(itensors_mps[end], sites[end], links[end])
8887
push!(tensors_vec, last_ten)
8988

9089
mps = Tenet.MPS(tensors_vec)

ext/TenetReactantExt.jl

+36-28
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,41 @@ const stablehlo = MLIR.Dialects.stablehlo
99

1010
const Enzyme = Reactant.Enzyme
1111

12-
@static if isdefined(Reactant, :traced_type_inner)
13-
# we specify `mode` and `track_numbers` types due to ambiguity
14-
Base.@nospecializeinfer function Reactant.traced_type_inner(
15-
@nospecialize(TT::Type{<:Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
16-
)
17-
A_traced = Reactant.traced_type_inner(Tenet.parenttype(TT), seen, mode, track_numbers)
18-
T = eltype(A_traced)
19-
N = ndims(TT)
20-
return Tensor{T,N,A_traced}
21-
end
12+
# we specify `mode` and `track_numbers` types due to ambiguity
13+
Base.@nospecializeinfer function Reactant.traced_type_inner(
14+
@nospecialize(_::Type{Tensor}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
15+
)
16+
return Tensor
17+
end
2218

23-
# we specify `mode` and `track_numbers` types due to ambiguity
24-
Base.@nospecializeinfer function Reactant.traced_type_inner(
25-
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
26-
seen,
27-
mode::Reactant.TraceMode,
28-
@nospecialize(track_numbers::Type)
29-
)
30-
return T
31-
end
19+
Base.@nospecializeinfer function Reactant.traced_type_inner(
20+
@nospecialize(_::Type{Tensor{T}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
21+
) where {T}
22+
return Tensor{TracedRNumber{T}}
23+
end
24+
25+
Base.@nospecializeinfer function Reactant.traced_type_inner(
26+
@nospecialize(_::Type{Tensor{T,N}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
27+
) where {T,N}
28+
return Tensor{TracedRNumber{T,N}}
29+
end
30+
31+
Base.@nospecializeinfer function Reactant.traced_type_inner(
32+
@nospecialize(_::Type{Tensor{T,N,A}}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
33+
) where {T,N,A}
34+
A_traced = Reactant.traced_type_inner(A, seen, mode, track_numbers)
35+
T_traced = eltype(A_traced)
36+
return Tensor{T_traced,N,A_traced}
37+
end
38+
39+
# we specify `mode` and `track_numbers` types due to ambiguity
40+
Base.@nospecializeinfer function Reactant.traced_type_inner(
41+
@nospecialize(T::Type{<:Tenet.AbstractTensorNetwork}),
42+
seen,
43+
mode::Reactant.TraceMode,
44+
@nospecialize(track_numbers::Type)
45+
)
46+
return T
3247
end
3348

3449
function Reactant.make_tracer(seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...) where {RT<:Tensor}
@@ -184,8 +199,8 @@ Base.@nospecializeinfer @noinline function Tenet.contract(
184199

185200
# TODO replace for `Ops.convert`/`adapt` when it's available (there can be problems with nested array structures)
186201
T = Base.promote_eltype(a, b)
187-
da = eltype(a) != T ? TracedRArray{T,ndims(a)}(parent(a)) : parent(a)
188-
db = eltype(b) != T ? TracedRArray{T,ndims(b)}(parent(b)) : parent(b)
202+
da = eltype(a) != T ? TracedRArray{Reactant.unwrapped_eltype(T),ndims(a)}(parent(a)) : parent(a)
203+
db = eltype(b) != T ? TracedRArray{Reactant.unwrapped_eltype(T),ndims(b)}(parent(b)) : parent(b)
189204

190205
data = Reactant.Ops.dot_general(da, db; contracting_dimensions, batching_dimensions)
191206

@@ -221,11 +236,4 @@ function Base.conj(@nospecialize(x::Tensor{TracedRNumber{T},N,<:TracedRArray}))
221236
Tensor(conj(parent(x)), inds(x))
222237
end
223238

224-
# fix infinite recursion on Reactant rewrite of invoke/call step
225-
@reactant_overlay @noinline function Base.replace!(
226-
tn::Tenet.AbstractQuantum, old_new::Base.AbstractVecOrTuple{Pair{Symbol,Symbol}}
227-
)
228-
Base.inferencebarrier(Base.replace!)(tn, old_new)
229-
end
230-
231239
end

src/Ansatz.jl

+65-16
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,12 @@ function inds(kwargs::NamedTuple{(:bond,)}, tn::AbstractAnsatz)
165165
return only(inds(tensor1) inds(tensor2))
166166
end
167167

168-
# TODO fix this properly when we do the mapping
169-
tensors(kwargs::NamedTuple{(:at,),Tuple{L}}, tn::AbstractAnsatz) where {L<:Lane} = tensors(tn; at=Site(kwargs.at))
168+
# TODO fix this properly when we do the mapping
169+
function tensors(kwargs::NamedTuple{(:at,),Tuple{L}}, tn::AbstractAnsatz) where {L<:Lane}
170+
hassite(tn, Site(kwargs.at)) && return tensors(tn; at=Site(kwargs.at))
171+
hassite(tn, Site(kwargs.at; dual=true)) && return tensors(tn; at=Site(kwargs.at; dual=true))
172+
throw(ArgumentError("Lane $kwargs.at not found"))
173+
end
170174

171175
"""
172176
tensors(tn::AbstractAnsatz; bond)
@@ -375,7 +379,8 @@ Compute the expectation value of an observable on a [`AbstractAnsatz`](@ref) Ten
375379
function expect::AbstractAnsatz, observable; bra=adjoint(ψ))
376380
@assert socket(ψ) == State() "ψ must be a state"
377381
@assert socket(bra) == State(; dual=true) "bra must be a dual state"
378-
contract(merge(ψ, observable, bra))
382+
383+
return expect(form(ψ), ψ, observable; bra)
379384
end
380385

381386
function expect::AbstractAnsatz, observables::AbstractVecOrTuple; bra=adjoint(ψ))
@@ -384,6 +389,50 @@ function expect(ψ::AbstractAnsatz, observables::AbstractVecOrTuple; bra=adjoint
384389
end
385390
end
386391

392+
function expect(::NonCanonical, ψ::AbstractAnsatz, observable; bra=adjoint(ψ))
393+
return contract(merge(ψ, observable, bra))
394+
end
395+
396+
# TODO: Try to find a better way to do this
397+
function expect(::MixedCanonical, ψ::AbstractAnsatz, observable; bra=adjoint(ψ))
398+
return contract(merge(ψ, observable, bra))
399+
end
400+
401+
function expect(::Canonical, ψ::Tenet.AbstractAnsatz, observable; bra=adjoint(ψ))
402+
obs_sites = unique(id.(sites(observable)))
403+
404+
ket_Λ = []
405+
bra_Λ = []
406+
ket_tensors = []
407+
bra_tensors = []
408+
for i in obs_sites
409+
replace!(observable, inds(observable; at=Site(i)) => Symbol(:input, i))
410+
replace!(observable, inds(observable; at=Site(i; dual=true)) => Symbol(:output, i))
411+
replace!(ψ, inds(ψ; at=Site(i)) => Symbol(:input, i))
412+
replace!(bra, inds(bra; at=Site(i; dual=true)) => Symbol(:output, i))
413+
414+
replace!(bra, inds(bra; bond=(Lane(i), Lane(i + 1))) => inds(ψ; bond=(Lane(i), Lane(i + 1))))
415+
replace!(bra, inds(bra; bond=(Lane(i - 1), Lane(i))) => inds(ψ; bond=(Lane(i - 1), Lane(i))))
416+
417+
push!(ket_Λ, tensors(ψ; bond=(Lane(i - 1), Lane(i))))
418+
push!(bra_Λ, tensors(bra; bond=(Lane(i - 1), Lane(i))))
419+
420+
push!(ket_tensors, tensors(ψ; at=Site(i)))
421+
push!(bra_tensors, tensors(bra; at=Site(i; dual=true)))
422+
end
423+
424+
push!(ket_Λ, tensors(ψ; bond=(Lane(obs_sites[end]), Lane(obs_sites[end] + 1))))
425+
push!(bra_Λ, tensors(bra; bond=(Lane(obs_sites[end]), Lane(obs_sites[end] + 1))))
426+
427+
t = contract(
428+
contract(ket_Λ..., ket_tensors...; dims=[]),
429+
contract(bra_Λ..., bra_tensors...; dims=[]),
430+
tensors(Quantum(observable))[1],
431+
)
432+
433+
return t
434+
end
435+
387436
"""
388437
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, normalize = false)
389438
@@ -485,10 +534,10 @@ function simple_update_2site!(::NonCanonical, ψ::AbstractAnsatz, gate; kwargs..
485534
gate = copy(gate)
486535

487536
# contract involved sites
488-
bond = (sitel, siter) = extrema(lanes(gate))
537+
bond = (lanel, laner) = extrema(lanes(gate))
489538
vind = inds(ψ; bond)
490-
linds = filter(!=(vind), inds(tensors(ψ; at=sitel)))
491-
rinds = filter(!=(vind), inds(tensors(ψ; at=siter)))
539+
linds = filter(!=(vind), inds(tensors(ψ; at=lanel)))
540+
rinds = filter(!=(vind), inds(tensors(ψ; at=laner)))
492541
contract!(ψ; bond)
493542

494543
# TODO replace for `merge!` when #243 is fixed
@@ -519,20 +568,20 @@ end
519568
# TODO remove `normalize` argument?
520569
function simple_update_2site!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, normalize=false, canonize=true)
521570
# Contract the exterior Λ tensors
522-
sitel, siter = extrema(lanes(gate))
523-
(0 < id(sitel) < nlanes(ψ) || 0 < id(siter) < nlanes(ψ)) ||
571+
lanel, laner = extrema(lanes(gate))
572+
(0 < id(lanel) < nlanes(ψ) || 0 < id(laner) < nlanes(ψ)) ||
524573
throw(ArgumentError("The sites in the bond must be between 1 and $(nlanes(ψ))"))
525574

526-
Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Lane(id(sitel) - 1), sitel))
527-
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Lane(id(siter) + 1)))
575+
Λᵢ₋₁ = id(lanel) == 1 ? nothing : tensors(ψ; bond=(Lane(id(lanel) - 1), lanel))
576+
Λᵢ₊₁ = id(lanel) == nsites(ψ) - 1 ? nothing : tensors(ψ; bond=(laner, Lane(id(laner) + 1)))
528577

529-
!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Lane(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
530-
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Lane(id(siter) + 1)), direction=:left, delete_Λ=false)
578+
!isnothing(Λᵢ₋₁) && absorb!(ψ; bond=(Lane(id(lanel) - 1), lanel), dir=:right, delete_Λ=false)
579+
!isnothing(Λᵢ₊₁) && absorb!(ψ; bond=(laner, Lane(id(laner) + 1)), dir=:left, delete_Λ=false)
531580

532581
simple_update_2site!(NonCanonical(), ψ, gate; threshold, maxdim, normalize=false, canonize=false)
533582

534583
# contract the updated tensors with the inverse of Λᵢ and Λᵢ₊₂, to get the new Γ tensors
535-
U, Vt = tensors(ψ; at=sitel), tensors(ψ; at=siter)
584+
U, Vt = tensors(ψ; at=lanel), tensors(ψ; at=laner)
536585
Γᵢ₋₁ = if isnothing(Λᵢ₋₁)
537586
U
538587
else
@@ -545,13 +594,13 @@ function simple_update_2site!(::Canonical, ψ::AbstractAnsatz, gate; threshold,
545594
end
546595

547596
# Update the tensors in the tensor network
548-
replace!(ψ, tensors(ψ; at=sitel) => Γᵢ₋₁)
549-
replace!(ψ, tensors(ψ; at=siter) => Γᵢ)
597+
replace!(ψ, tensors(ψ; at=lanel) => Γᵢ₋₁)
598+
replace!(ψ, tensors(ψ; at=laner) => Γᵢ)
550599

551600
if canonize
552601
canonize!(ψ; normalize)
553602
else
554-
normalize && normalize!(ψ, collect((sitel, siter)))
603+
normalize && normalize!(ψ, collect((lanel, laner)))
555604
end
556605

557606
return ψ

src/Lattice.jl

+18-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,24 @@ Graphs.dst(edge::Bond) = edge.dst
1414
Graphs.reverse(edge::Bond) = Bond(Graphs.dst(edge), Graphs.src(edge))
1515
Base.show(io::IO, edge::Bond) = write(io, "Bond: $(Graphs.src(edge)) - $(Graphs.dst(edge))")
1616

17-
Pair(e::Bond) = src(e) => dst(e)
18-
Tuple(e::Bond) = (src(e), dst(e))
17+
Pair(e::Bond) = Graphs.src(e) => Graphs.dst(e)
18+
Tuple(e::Bond) = (Graphs.src(e), Graphs.dst(e))
19+
20+
function Base.iterate(bond::Bond, state=0)
21+
if state == 0
22+
(Graphs.src(bond), 1)
23+
elseif state == 1
24+
(Graphs.dst(bond), 2)
25+
else
26+
nothing
27+
end
28+
end
29+
30+
Base.IteratorSize(::Type{Bond}) = Base.HasLength()
31+
Base.length(::Bond) = 2
32+
Base.IteratorEltype(::Type{Bond{L}}) where {L} = Base.HasEltype()
33+
Base.eltype(::Bond{L}) where {L} = L
34+
Base.isdone(::Bond, state) = state == 2
1935

2036
"""
2137
Lattice

0 commit comments

Comments
 (0)