Skip to content

Commit 2227b4b

Browse files
committed
Fix type-inference issues in contract!
1 parent c3f6ea4 commit 2227b4b

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/Numerics.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,19 @@ function contract(a::Tensor, b::Tensor; kwargs...)
5757
end
5858

5959
function allocate_result(
60-
::typeof(contract), a::Tensor, b::Tensor; fillzero=false, dims=((inds(a), inds(b))), out=nothing
60+
::typeof(contract), a::Tensor, b::Tensor; fillzero=false, dims=((vinds(a), vinds(b))), out=nothing
6161
)
62-
ia = collect(inds(a))
63-
ib = collect(inds(b))
62+
ia = collect(vinds(a))
63+
ib = collect(vinds(b))
6464
i = (dims, ia, ib)
6565

6666
ic = if isnothing(out)
67-
Tuple(setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : (i,)))
67+
setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : [i])
6868
else
6969
out
7070
end
7171

72-
data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero)
72+
data = OMEinsum.get_output_array((parent(a), parent(b)), Int[size(i in ia ? a : b, i) for i in ic]; fillzero)
7373
return Tensor(data, ic)
7474
end
7575

@@ -88,12 +88,12 @@ function contract(a::Tensor; kwargs...)
8888
return contract!(c, a)
8989
end
9090

91-
function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=nonunique(inds(a)), out=nothing)
92-
ia = inds(a)
91+
function allocate_result(::typeof(contract), a::Tensor; fillzero=false, dims=nonunique(vinds(a)), out=nothing)
92+
ia = vinds(a)
9393
i = (dims, ia)
9494

9595
ic::Vector{Symbol} = if isnothing(out)
96-
setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))
96+
setdiff(ia, i isa Base.AbstractVecOrTuple ? i : [i])
9797
else
9898
out
9999
end
@@ -114,11 +114,11 @@ contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs
114114
Perform a binary tensor contraction operation between `a` and `b` and store the result in `c`.
115115
"""
116116
function contract!(c::Tensor, a::Tensor, b::Tensor)
117-
ixs = (inds(a), inds(b))
118-
iy = inds(c)
117+
ixs = (vinds(a), vinds(b))
118+
iy = vinds(c)
119119
xs = (parent(a), parent(b))
120120
y = parent(c)
121-
size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...)
121+
size_dict = merge!(Dict{Symbol,Int}.([vinds(a) .=> size(a), vinds(b) .=> size(b)])...)
122122

123123
einsum!(ixs, iy, xs, y, true, false, size_dict)
124124
return c
@@ -130,9 +130,9 @@ end
130130
Perform a unary tensor contraction operation on `a` and store the result in `c`.
131131
"""
132132
function contract!(y::Tensor, x::Tensor)
133-
ixs = (inds(x),)
134-
iy = inds(y)
135-
size_dict = Dict{Symbol,Int}(inds(x) .=> size(x))
133+
ixs = (vinds(x),)
134+
iy = vinds(y)
135+
size_dict = Dict{Symbol,Int}(vinds(x) .=> size(x))
136136

137137
einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict)
138138
return y

0 commit comments

Comments
 (0)