Skip to content

[API Proposal]: Tensor Operations Per Dimension #113068

@michaelgsharp

Description

@michaelgsharp

Background and motivation

We have exposed lots of functions we can use with Tensor<T>. These currently all operate on the entire backing memory of the Tensor. The problem with this is that many times you want to do the operation on whole dimensions, rather than the whole data, and we have no easy way of providing this currently. For a very simple example say your Tensor<T> is 2d, think of it like an Excel table. In Excel we can sum the whole table, but we can also sum each row, for example. This is the functionality we need to add to Tensor<T>.

The approach we have taken means we don't have to add overloads to every operation that could do this type of behavior. Currently we are planning on these methods needing this behavior.

CosineSimilarity
StdDev
Average
Max
MaxMagnitude
MaxMagnitudeNumber
MaxNumber
Min
MinMagnitude
MinMagnitudeNumber
MinNumber
Norm
Product
SoftMax
Sum
SumOfSquares
IndexOfMax
IndexOfMaxMagnitude
IndexOfMin
IndexOfMinMagnitude

API Proposal

namespace System.Numerics.Tensors;

public interface IReadOnlyTensor<TSelf, T> : IEnumerable<T>
    where TSelf : IReadOnlyTensor<TSelf, T>
{
    //used to easily get a sub dimension of a tensor. Like all the Rows of a tensor for example.
    TSelf SliceAlongDimension(int dimension, nint index);
}

public static partial class Tensor
{
        // These are for things like Sum that take in a tensor and return a single T.
    public delegate T TensorAggregationPredicate<T>(in ReadOnlyTensorSpan<T> tensor);
    public static Tensor<T> AggregateMany<T>(this Tensor<T> tensor, TensorAggregationPredicate<T> func, int dimension = -1);

    // These are for things like CosineSimilarity that take in a tensor and return a new tensor as well.
    public delegate Tensor<T> TensorSelectionPredicate<T>(in ReadOnlyTensorSpan<T> tensor);
    public static Tensor<T> SelectMany<T>(this Tensor<T> tensor, TensorSelectionPredicate<T> func, int dimension = -1);

    public static Tensor<T> ToTensor<T>(this IEnumerable<Tensor<T>> source);
}


public sealed class Tensor<T>
    : ITensor<Tensor<T>, T>
{
    public DimensionCollection Dimensions {get;}
    
    public sealed class DimensionCollection
        : System.Collections.Generic.ICollection<Tensor<T>>,
          System.Collections.Generic.IEnumerable<Tensor<T>>,
          System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
          System.Collections.ICollection
    {
        public int Count {get;}
        public void CopyTo(Tensor<T>[] array , int index)
        public Enumerator GetEnumerator();

        //All else explicitly implemented and match what Dictionary does   

        public struct Enumerator : IEnumerator<Tensor<T>>
        {
            
            internal Enumerator(Tensor<T> tensor, int dimension);

            public bool MoveNext();

            public void Reset();

            public void Dispose();

            Tensor<T> IEnumerator<Tensor<T>>.Current;

            object? IEnumerator.Current;
        }     
    }
}


public readonly ref struct ReadOnlyTensorSpan<T>
{
    public DimensionCollection Dimensions {get;}
    
    public ref struct DimensionCollection
        : System.Collections.Generic.ICollection<Tensor<T>>,
          System.Collections.Generic.IEnumerable<Tensor<T>>,
          System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
          System.Collections.ICollection
    {
        public int Count {get;}
        public void CopyTo(Tensor<T>[] array , int index)
        public Enumerator GetEnumerator();

        //All else explicitly implemented and match what Dictionary does   

        public ref struct Enumerator : IEnumerator<Tensor<T>>
        {
            
            internal Enumerator(Tensor<T> tensor, int dimension);

            public bool MoveNext();

            public ref readonly ReadOnlyTensorSpan<T> Current;
        }     
    }
}

public ref struct TensorSpan<T>
{
    public DimensionCollection Dimensions {get;}
    
    public ref struct DimensionCollection
        : System.Collections.Generic.ICollection<Tensor<T>>,
          System.Collections.Generic.IEnumerable<Tensor<T>>,
          System.Collections.Generic.IReadOnlyCollection<Tensor<T>>,
          System.Collections.ICollection
    {
        public int Count {get;}
        public void CopyTo(Tensor<T>[] array , int index)
        public Enumerator GetEnumerator();

        //All else explicitly implemented and match what Dictionary does   

        public ref struct Enumerator : IEnumerator<Tensor<T>>
        {
            
            internal Enumerator(Tensor<T> tensor, int dimension);

            public bool MoveNext();

            public ref TensorSpan<T> Current;
        }     
    }   
}

API Usage

Tensor<int> t1 = Tensor.Create([1,2,3,4,5,6,7,8,9], [3,3]);

// Will be [1,2,3]
Tensor<int> slice =  t1.SliceAlongDimension(0,0);

// Will be [7, 8, 9]
slice =  t1.SliceAlongDimension(0,2);
for (int i = 0; i < 3; i++)
{
    // Get each slice
    t1.SliceAlongDimension(0,i);
    // Can now do something which each sub slice.
}

Tensor<int> tensor = Tensor.Create([1, 2, 3, 4], [2, 2]);

// Sum tensor will be [3, 7]
Tensor<int> sum = t.AggregateMany(t => Sum(t), 0);

// CosineTensor would be [1,1,1]
Tensor<float> tensor1 = Create<float>([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3], [], false);
Tensor<float> cosineTensor = tensor1.SelectMany(t => CosineSimilarity<T>(t, t), 0);

Alternative Designs

We could explicitly overload every method that has this behavior. That would mean an additional 20 overloads just for now. As time goes on this number would grow. If a user wanted to do this with a method we hadn't overloaded yet, they would either have to do it themselves (which we have had happen recently), or they would have to wait. Our current approach allows us to avoid both of these problems.

Risks

Low risk because it's a preview type and its all new features.

Metadata

Metadata

Assignees

Labels

api-approvedAPI was approved in API review, it can be implementedarea-System.Numerics.Tensorsin-prThere is an active PR which will close this issue when it is merged

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions