Skip to content

Commit 20ec49d

Browse files
authored
Add vinds for internal use of inds without Tuple specialization (#335)
* Add `vinds` for internal use of `inds` without `Tuple` specialization * Fix type-inference in `dim(::Tensor, ::Symbol)` * Update `dim` test * Fix type-inference issues in `contract!`
1 parent 6753076 commit 20ec49d

File tree

4 files changed

+27
-25
lines changed

4 files changed

+27
-25
lines changed

src/Numerics.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,19 @@ function contract(a::Tensor, b::Tensor; kwargs...)
5757
end
5858

5959
function allocate_result(
60-
::typeof(contract), a::Tensor, b::Tensor; fillzero=false, dims=((inds(a), inds(b))), out=nothing
60+
::typeof(contract), a::Tensor, b::Tensor; fillzero=false, dims=((vinds(a), vinds(b))), out=nothing
6161
)
62-
ia = collect(inds(a))
63-
ib = collect(inds(b))
62+
ia = collect(vinds(a))
63+
ib = collect(vinds(b))
6464
i = (dims, ia, ib)
6565

6666
ic = if isnothing(out)
67-
Tuple(setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : (i,)))
67+
setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : [i])
6868
else
6969
out
7070
end
7171

72-
data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero)
72+
data = OMEinsum.get_output_array((parent(a), parent(b)), Int[size(i in ia ? a : b, i) for i in ic]; fillzero)
7373
return Tensor(data, ic)
7474
end
7575

@@ -88,12 +88,12 @@ function contract(a::Tensor; kwargs...)
8888
return contract!(c, a)
8989
end
9090

91-
function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=nonunique(inds(a)), out=nothing)
92-
ia = inds(a)
91+
function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=nonunique(vinds(a)), out=nothing)
92+
ia = vinds(a)
9393
i = (dims, ia)
9494

9595
ic::Vector{Symbol} = if isnothing(out)
96-
setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))
96+
setdiff(ia, i isa Base.AbstractVecOrTuple ? i : [i])
9797
else
9898
out
9999
end
@@ -114,11 +114,11 @@ contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs
114114
Perform a binary tensor contraction operation between `a` and `b` and store the result in `c`.
115115
"""
116116
function contract!(c::Tensor, a::Tensor, b::Tensor)
117-
ixs = (inds(a), inds(b))
118-
iy = inds(c)
117+
ixs = (vinds(a), vinds(b))
118+
iy = vinds(c)
119119
xs = (parent(a), parent(b))
120120
y = parent(c)
121-
size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...)
121+
size_dict = merge!(Dict{Symbol,Int}.([vinds(a) .=> size(a), vinds(b) .=> size(b)])...)
122122

123123
einsum!(ixs, iy, xs, y, true, false, size_dict)
124124
return c
@@ -130,9 +130,9 @@ end
130130
Perform a unary tensor contraction operation on `a` and store the result in `c`.
131131
"""
132132
function contract!(y::Tensor, x::Tensor)
133-
ixs = (inds(x),)
134-
iy = inds(y)
135-
size_dict = Dict{Symbol,Int}(inds(x) .=> size(x))
133+
ixs = (vinds(x),)
134+
iy = vinds(y)
135+
size_dict = Dict{Symbol,Int}(vinds(x) .=> size(x))
136136

137137
einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict)
138138
return y

src/Tensor.jl

+11-9
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ Return the indices of the tensor in the order of the dimensions.
4242
"""
4343
inds(t::Tensor) = Tuple(t.inds)
4444

45+
# WARN internal use only because it can mutate `Tensor`
46+
vinds(t::Tensor) = t.inds
47+
4548
function Base.copy(t::Tensor{T,N,<:SubArray{T,N}}) where {T,N}
4649
data = copy(t.data)
47-
inds = t.inds
48-
return Tensor(data, inds)
50+
return Tensor(data, vinds(t))
4951
end
5052

5153
"""
@@ -55,7 +57,7 @@ Return a uninitialize tensor of the same size, eltype and [`inds`](@ref) as `ten
5557
"""
5658
Base.similar(t::Tensor; inds=inds(t)) = Tensor(similar(parent(t)), inds)
5759
Base.similar(t::Tensor, S::Type; inds=inds(t)) = Tensor(similar(parent(t), S), inds)
58-
function Base.similar(t::Tensor{T,N}, S::Type, dims::Base.Dims{N}; inds=inds(t)) where {T,N}
60+
function Base.similar(t::Tensor{T,N}, S::Type, dims::Base.Dims{N}; inds=vinds(t)) where {T,N}
5961
return Tensor(similar(parent(t), S, dims), inds)
6062
end
6163
function Base.similar(t::Tensor, ::Type, dims::Base.Dims{N}; kwargs...) where {N}
@@ -71,7 +73,7 @@ end
7173
7274
Return a tensor of the same size, eltype and [`inds`](@ref) as `tensor` but filled with zeros.
7375
"""
74-
Base.zero(t::Tensor) = Tensor(zero(parent(t)), inds(t))
76+
Base.zero(t::Tensor) = Tensor(zero(parent(t)), vinds(t))
7577

7678
function __find_index_permutation(a, b)
7779
inds_b = collect(Union{Missing,Symbol}, b)
@@ -107,8 +109,8 @@ Base.isequal(a::Tensor{A,0}, b::Tensor{B,0}) where {A,B} = isequal(only(a), only
107109
Base.isapprox(a::AbstractArray, b::Tensor) = false
108110
Base.isapprox(a::Tensor, b::AbstractArray) = false
109111
function Base.isapprox(a::Tensor, b::Tensor; kwargs...)
110-
issetequal(inds(a), inds(b)) || return false
111-
perm = __find_index_permutation(inds(a), inds(b))
112+
issetequal(vinds(a), vinds(b)) || return false
113+
perm = __find_index_permutation(vinds(a), vinds(b))
112114
return all(eachindex(IndexCartesian(), a)) do i
113115
j = CartesianIndex(Tuple(permute!(collect(Tuple(i)), invperm(perm))))
114116
isapprox(a[i], b[j]; kwargs...)
@@ -150,7 +152,7 @@ parenttype(::T) where {T<:Tensor} = parenttype(T)
150152
Return the location of the dimension of `tensor` corresponding to the given index `i`.
151153
"""
152154
dim(::Tensor, i::Number) = i
153-
dim(t::Tensor, i::Symbol) = first(findall(==(i), inds(t)))
155+
dim(t::Tensor, i::Symbol) = findfirst(==(i), vinds(t))
154156

155157
# Iteration interface
156158
Base.IteratorSize(T::Type{Tensor}) = Iterators.IteratorSize(parenttype(T))
@@ -240,10 +242,10 @@ Return a view of the tensor where the index for dimension `dim` equals `i`.
240242
241243
See also: [`selectdim`](@ref)
242244
"""
243-
Base.selectdim(t::Tensor, d::Integer, i) = Tensor(selectdim(parent(t), d, i), inds(t))
245+
Base.selectdim(t::Tensor, d::Integer, i) = Tensor(selectdim(parent(t), d, i), vinds(t))
244246
function Base.selectdim(t::Tensor, d::Integer, i::Integer)
245247
data = selectdim(parent(t), d, i)
246-
indices = [label for (i, label) in enumerate(inds(t)) if i != d]
248+
indices = [label for (i, label) in enumerate(vinds(t)) if i != d]
247249
return Tensor(data, indices)
248250
end
249251

src/TensorNetwork.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ tensors(tn::AbstractTensorNetwork; kwargs...) = tensors(sort_nt(values(kwargs)),
209209
function tensors(::@NamedTuple{}, tn::AbstractTensorNetwork)
210210
tn = TensorNetwork(tn)
211211
get!(tn.sorted_tensors) do
212-
sort!(collect(keys(tn.tensormap)); by=sort collect inds)
212+
sort!(collect(keys(tn.tensormap)); by=sort collect vinds)
213213
end
214214
end
215215

test/unit/Tensor_test.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
@test dim(tensor, label) == i
168168
end
169169

170-
@test_throws BoundsError dim(tensor, :_)
170+
@test isnothing(dim(tensor, :_))
171171
end
172172

173173
@testset "Broadcasting" begin

0 commit comments

Comments
 (0)