Description
Background and motivation
This is a refinement of #113068, which was approved a few weeks ago. During the API review process we had several things brought up around clarity/ease of use. In the end, the proposal was approved, but after various discussions between @tannergooding and myself we decided to refine it to resolve the issues that were brought up.
Copied from the original issue:
We have exposed lots of functions we can use with Tensor. These currently all operate on the entire backing memory of the Tensor. The problem with this is that many times you want to do the operation on whole dimensions, rather than the whole data, and we have no easy way of providing this currently. For a very simple example say your Tensor is 2d, think of it like an Excel table. In Excel we can sum the whole table, but we can also sum each row, for example. This is the functionality we need to add to Tensor.
API Proposal
namespace System.Numerics.Tensors;
- public interface IReadOnlyTensor<TSelf, T> : IEnumerable<T>
- where TSelf : IReadOnlyTensor<TSelf, T>
- {
- // This is being moved to the Tensor itself and renamed for clarity.
- //used to easily get a sub dimension of a tensor. Like all the Rows of a tensor for example.
- TSelf SliceAlongDimension(int dimension, nint index);
- }
// Renamed from DimensionCollection to TensorDimensionView for clarity
// Also has been moved out of being a nested class so that it can be shared between all Tensor types.
+ public readonly ref struct TensorDimensionView<T>
+ {
+ internal TensorDimensionView(TensorSpan<T> tensor, int dimension);
+ public nint Count { get; }
// This is what is replace SliceAlongDimension
+ public TensorSpan<T> GetSlice(int index);
+ public Enumerator GetEnumerator();
+ public ref struct Enumerator
+ #if NET9_0_OR_GREATER
+ : IEnumerator<TensorSpan<T>>
+ #endif
+ {
+ internal Enumerator(TensorSpan<T> tensor, int dimension);
+ public bool MoveNext();
+ public void Reset();
+ public void Dispose();
+ public TensorSpan<T> Current;
+ #if NET9_0_OR_GREATER
// This will always just throw but needs to be here.
+ object? IEnumerator.Current { get; }
+ #endif
+ }
+ }
+ public readonly ref struct ReadOnlyTensorDimensionView<T>
+ {
+ internal ReadOnlyTensorDimensionView(ReadOnlyTensorSpan<T> tensor, int dimension);
+ public nint Count { get; }
// This is what is replacing SliceAlongDimensions
+ public ReadOnlyTensorSpan<T> GetSlice(int index);
+ public Enumerator GetEnumerator();
+ public ref struct Enumerator
+ #if NET9_0_OR_GREATER
+ : IEnumerator<ReadOnlyTensorSpan<T>>
+ #endif
+ {
+ internal Enumerator(ReadOnlyTensorSpan<T> tensor, int dimension);
+ public bool MoveNext();
+ public void Reset();
+ public void Dispose();
+ public ReadOnlyTensorSpan<T> Current;
+ #if NET9_0_OR_GREATER
// This will always just throw but needs to be here.
+ object? IEnumerator.Current { get; }
+ #endif
+ }
+ }
public sealed class Tensor<T>
: ITensor<Tensor<T>, T>
{
+ public TensorDimensionView<T> GetDimension(int dimension) => new TensorDimensionView<T>(this, dimension);
- public DimensionCollection Dimensions {get;}
- public sealed class DimensionCollection
- : System.Collections.Generic.ICollection<Tensor<T>>,
- System.Collections.Generic.IEnumerable<Tensor<T>>,
- System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
- System.Collections.ICollection
- {
- public int Count {get;}
- public void CopyTo(Tensor<T>[] array , int index)
- public Enumerator GetEnumerator();
- //All else explicitly implemented and match what Dictionary does
- public struct Enumerator : IEnumerator<Tensor<T>>
- {
- internal Enumerator(Tensor<T> tensor, int dimension);
- public bool MoveNext();
- public void Reset();
- public void Dispose();
- Tensor<T> IEnumerator<Tensor<T>>.Current;
- object? IEnumerator.Current;
- }
- }
}
public readonly ref struct ReadOnlyTensorSpan<T>
{
+ public ReadOnlyTensorDimensionView<T> GetDimension(int dimension) => new ReadOnlyTensorDimensionView<T>(this, dimension);
- public DimensionCollection Dimensions {get;}
- public ref struct DimensionCollection
- : System.Collections.Generic.ICollection<Tensor<T>>,
- System.Collections.Generic.IEnumerable<Tensor<T>>,
- System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
- System.Collections.ICollection
- {
- public int Count {get;}
- public void CopyTo(Tensor<T>[] array , int index)
- public Enumerator GetEnumerator();
- //All else explicitly implemented and match what Dictionary does
- public ref struct Enumerator : IEnumerator<Tensor<T>>
- {
- internal Enumerator(Tensor<T> tensor, int dimension);
- public bool MoveNext();
- public ref readonly ReadOnlyTensorSpan<T> Current;
- }
- }
}
public ref struct TensorSpan<T>
{
+ public TensorDimensionView<T> GetDimension(int dimension) => new TensorDimensionView<T>(this, dimension);
- public DimensionCollection Dimensions {get;}
- public ref struct DimensionCollection
- : System.Collections.Generic.ICollection<Tensor<T>>,
- System.Collections.Generic.IEnumerable<Tensor<T>>,
- System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
- System.Collections.ICollection
- {
- public int Count {get;}
- public void CopyTo(Tensor<T>[] array , int index)
- public Enumerator GetEnumerator();
- //All else explicitly implemented and match what Dictionary does
- public ref struct Enumerator : IEnumerator<Tensor<T>>
- {
- internal Enumerator(Tensor<T> tensor, int dimension);
- public bool MoveNext();
- public ref TensorSpan<T> Current;
- }
- }
}
API Usage
// Create a 2 x 2 x 2 tensor with 8 elements
Tensor tensor = Tensor.Create<int>([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);
// Loop through the tensor in dimension 0. This will give us a 2 x 2 tensor 2 times.
foreach (var slice in tensor.GetDimension(0))
{
// do stuff with the 2 x 2 slice
// The first tensor will be [1, 2, 3, 4] and the second will be [5, 6, 7, 8]
}
// Loop through the tensor in dimension 1. This will give us a 2 tensor 4 times.
foreach (var slice in tensor.GetDimension(1))
{
// do stuff with the 2 slice
// The first tensor will be [1, 2], then [ 3, 4], then [5, 6], and last is [7, 8]
}
// Loop through the tensor in dimension 2. This will give us a tensor with a single element 8 times.
foreach (var slice in tensor.GetDimension(2))
{
// Do stuff with the single element tensor
// Values will be a Tensor with values 1, 2, 3, 4, 5, 6, 7, 8
}
// Create a 2 x 4 tensor with 8 elements
tensor = Tensor.Create<int>([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
// Loop through the tensor in dimension 0. This will give us a 4 tensor 2 times.
foreach (var slice in tensor.GetDimension(0))
{
// do stuff with the 4 slice
// The first tensor will be [1, 2, 3, 4] and the second will be [5, 6, 7, 8]
}
// Loop through the tensor in dimension 1. This will give us a tensor with a single element 8 times.
foreach (var slice in tensor.GetDimension(1))
{
// Do stuff with the single element tensor
// Values will be a Tensor with values 1, 2, 3, 4, 5, 6, 7, 8
}
// Since we moved SliceAlongDimension this is what it would like like now.
// Create a 2 x 4 tensor with 8 elements
tensor = Tensor.Create<int>([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
// Returns a 4 tensor of values [1, 2, 3, 4]
tensor.GetDimension(0).GetSlice(0)
// Returns a 4 tensor of values [5, 6, 7, 8]
tensor.GetDimension(0).GetSlice(1)
Alternative Designs
No response
Risks
Low risk because it's a preview type and its all new features.