@@ -42,10 +42,12 @@ Return the indices of the tensor in the order of the dimensions.
42
42
"""
43
43
inds (t:: Tensor ) = Tuple (t. inds)
44
44
45
+ # WARN internal use only because it can mutate `Tensor`
46
+ vinds (t:: Tensor ) = t. inds
47
+
45
48
function Base. copy (t:: Tensor{T,N,<:SubArray{T,N}} ) where {T,N}
46
49
data = copy (t. data)
47
- inds = t. inds
48
- return Tensor (data, inds)
50
+ return Tensor (data, vinds (t))
49
51
end
50
52
51
53
"""
@@ -55,7 +57,7 @@ Return a uninitialize tensor of the same size, eltype and [`inds`](@ref) as `ten
55
57
"""
56
58
Base. similar (t:: Tensor ; inds= inds (t)) = Tensor (similar (parent (t)), inds)
57
59
Base. similar (t:: Tensor , S:: Type ; inds= inds (t)) = Tensor (similar (parent (t), S), inds)
58
- function Base. similar (t:: Tensor{T,N} , S:: Type , dims:: Base.Dims{N} ; inds= inds (t)) where {T,N}
60
+ function Base. similar (t:: Tensor{T,N} , S:: Type , dims:: Base.Dims{N} ; inds= vinds (t)) where {T,N}
59
61
return Tensor (similar (parent (t), S, dims), inds)
60
62
end
61
63
function Base. similar (t:: Tensor , :: Type , dims:: Base.Dims{N} ; kwargs... ) where {N}
71
73
72
74
Return a tensor of the same size, eltype and [`inds`](@ref) as `tensor` but filled with zeros.
73
75
"""
74
- Base. zero (t:: Tensor ) = Tensor (zero (parent (t)), inds (t))
76
+ Base. zero (t:: Tensor ) = Tensor (zero (parent (t)), vinds (t))
75
77
76
78
function __find_index_permutation (a, b)
77
79
inds_b = collect (Union{Missing,Symbol}, b)
@@ -107,8 +109,8 @@ Base.isequal(a::Tensor{A,0}, b::Tensor{B,0}) where {A,B} = isequal(only(a), only
107
109
Base. isapprox (a:: AbstractArray , b:: Tensor ) = false
108
110
Base. isapprox (a:: Tensor , b:: AbstractArray ) = false
109
111
function Base. isapprox (a:: Tensor , b:: Tensor ; kwargs... )
110
- issetequal (inds (a), inds (b)) || return false
111
- perm = __find_index_permutation (inds (a), inds (b))
112
+ issetequal (vinds (a), vinds (b)) || return false
113
+ perm = __find_index_permutation (vinds (a), vinds (b))
112
114
return all (eachindex (IndexCartesian (), a)) do i
113
115
j = CartesianIndex (Tuple (permute! (collect (Tuple (i)), invperm (perm))))
114
116
isapprox (a[i], b[j]; kwargs... )
@@ -150,7 +152,7 @@ parenttype(::T) where {T<:Tensor} = parenttype(T)
150
152
Return the location of the dimension of `tensor` corresponding to the given index `i`.
151
153
"""
152
154
dim (:: Tensor , i:: Number ) = i
153
- dim (t:: Tensor , i:: Symbol ) = first ( findall ( == (i), inds (t) ))
155
+ dim (t:: Tensor , i:: Symbol ) = findfirst ( == (i), vinds (t ))
154
156
155
157
# Iteration interface
156
158
Base. IteratorSize (T:: Type{Tensor} ) = Iterators. IteratorSize (parenttype (T))
@@ -240,10 +242,10 @@ Return a view of the tensor where the index for dimension `dim` equals `i`.
240
242
241
243
See also: [`selectdim`](@ref)
242
244
"""
243
- Base. selectdim (t:: Tensor , d:: Integer , i) = Tensor (selectdim (parent (t), d, i), inds (t))
245
+ Base. selectdim (t:: Tensor , d:: Integer , i) = Tensor (selectdim (parent (t), d, i), vinds (t))
244
246
function Base. selectdim (t:: Tensor , d:: Integer , i:: Integer )
245
247
data = selectdim (parent (t), d, i)
246
- indices = [label for (i, label) in enumerate (inds (t)) if i != d]
248
+ indices = [label for (i, label) in enumerate (vinds (t)) if i != d]
247
249
return Tensor (data, indices)
248
250
end
249
251
0 commit comments