Skip to content

Commit 44c236a

Browse files
authored
Merge branch 'master' into feature/normalize-for-canonical
2 parents e67918c + 08dd131 commit 44c236a

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

src/Ansatz.jl

+47-2
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function inds(kwargs::NamedTuple{(:bond,)}, tn::AbstractAnsatz)
165165
return only(inds(tensor1) inds(tensor2))
166166
end
167167

168-
# TODO fix this properly when we do the mapping
168+
# TODO fix this properly when we do the mapping
169169
function tensors(kwargs::NamedTuple{(:at,),Tuple{L}}, tn::AbstractAnsatz) where {L<:Lane}
170170
hassite(tn, Site(kwargs.at)) && return tensors(tn; at=Site(kwargs.at))
171171
hassite(tn, Site(kwargs.at; dual=true)) && return tensors(tn; at=Site(kwargs.at; dual=true))
@@ -379,7 +379,8 @@ Compute the expectation value of an observable on a [`AbstractAnsatz`](@ref) Ten
379379
function expect::AbstractAnsatz, observable; bra=adjoint(ψ))
380380
@assert socket(ψ) == State() "ψ must be a state"
381381
@assert socket(bra) == State(; dual=true) "bra must be a dual state"
382-
contract(merge(ψ, observable, bra))
382+
383+
return expect(form(ψ), ψ, observable; bra)
383384
end
384385

385386
function expect::AbstractAnsatz, observables::AbstractVecOrTuple; bra=adjoint(ψ))
@@ -388,6 +389,50 @@ function expect(ψ::AbstractAnsatz, observables::AbstractVecOrTuple; bra=adjoint
388389
end
389390
end
390391

392+
function expect(::NonCanonical, ψ::AbstractAnsatz, observable; bra=adjoint(ψ))
393+
return contract(merge(ψ, observable, bra))
394+
end
395+
396+
# TODO: Try to find a better way to do this
397+
function expect(::MixedCanonical, ψ::AbstractAnsatz, observable; bra=adjoint(ψ))
398+
return contract(merge(ψ, observable, bra))
399+
end
400+
401+
function expect(::Canonical, ψ::Tenet.AbstractAnsatz, observable; bra=adjoint(ψ))
402+
obs_sites = unique(id.(sites(observable)))
403+
404+
ket_Λ = []
405+
bra_Λ = []
406+
ket_tensors = []
407+
bra_tensors = []
408+
for i in obs_sites
409+
replace!(observable, inds(observable; at=Site(i)) => Symbol(:input, i))
410+
replace!(observable, inds(observable; at=Site(i; dual=true)) => Symbol(:output, i))
411+
replace!(ψ, inds(ψ; at=Site(i)) => Symbol(:input, i))
412+
replace!(bra, inds(bra; at=Site(i; dual=true)) => Symbol(:output, i))
413+
414+
replace!(bra, inds(bra; bond=(Lane(i), Lane(i + 1))) => inds(ψ; bond=(Lane(i), Lane(i + 1))))
415+
replace!(bra, inds(bra; bond=(Lane(i - 1), Lane(i))) => inds(ψ; bond=(Lane(i - 1), Lane(i))))
416+
417+
push!(ket_Λ, tensors(ψ; bond=(Lane(i - 1), Lane(i))))
418+
push!(bra_Λ, tensors(bra; bond=(Lane(i - 1), Lane(i))))
419+
420+
push!(ket_tensors, tensors(ψ; at=Site(i)))
421+
push!(bra_tensors, tensors(bra; at=Site(i; dual=true)))
422+
end
423+
424+
push!(ket_Λ, tensors(ψ; bond=(Lane(obs_sites[end]), Lane(obs_sites[end] + 1))))
425+
push!(bra_Λ, tensors(bra; bond=(Lane(obs_sites[end]), Lane(obs_sites[end] + 1))))
426+
427+
t = contract(
428+
contract(ket_Λ..., ket_tensors...; dims=[]),
429+
contract(bra_Λ..., bra_tensors...; dims=[]),
430+
tensors(Quantum(observable))[1],
431+
)
432+
433+
return t
434+
end
435+
391436
"""
392437
evolve!(ψ::AbstractAnsatz, gate; threshold = nothing, maxdim = nothing, normalize = false)
393438

src/TensorNetwork.jl

+2
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,8 @@ function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor})
507507
tn = TensorNetwork(tn)
508508
old_tensor, new_tensor = pair
509509

510+
old_tensor tn || throw(ArgumentError("Old tensor not found in TensorNetwork"))
511+
510512
old_tensor === new_tensor && return tn
511513

512514
issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match"))

test/unit/MPS_test.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using LinearAlgebra
1414
@test nsites(ψ; set=:outputs) == 3
1515
@test issetequal(sites(ψ), [site"1", site"2", site"3"])
1616
@test boundary(ψ) == Open()
17-
@test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) == nothing
17+
@test inds(ψ; at=lane"1", dir=:left) == inds(ψ; at=lane"3", dir=:right) == nothing
1818
end
1919

2020
@testset "case 2" begin

test/unit/TensorNetwork_test.jl

+3
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ end
484484

485485
old_tensor = t_lm
486486

487+
# Test that it throws an error if the old tensor is not in the tensor network
488+
@test_throws ArgumentError replace!(tn, Tensor(ones(2, 2), (:i, :j)) => t_ij)
489+
487490
@test_throws ArgumentError begin
488491
new_tensor = Tensor(rand(2, 2), (:a, :b))
489492
replace!(tn, old_tensor => new_tensor)

0 commit comments

Comments
 (0)