Skip to content

[API Proposal]: Tensor Operators #112781

Open
@michaelgsharp

Description

@michaelgsharp

Background and motivation

Tensors are often used to do multi-dimensional math in an efficient manner. We need to add these operators to our Tensor class to enable this behavior. The language team is working on adding 2 features to enable this behavior. First, the ability to restrict which types get certain operators, and second, the ability for users to implement their own compound operators, such as +=.

Note: this API is currently blocked waiting for 2 language features that are supposed to be implemented for .NET 10. All the following API are based on their current in progress work, but if anything major changes we will need to adjust accordingly.

API Proposal

namespace System.Numerics.Tensors;

public static partial class Tensor
{
    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IMultiplyOperators<T, T, T>
    {
        public static TSelf operator *(TSelf left, T right) { ... }
        public static TSelf operator *(TSelf left, TSelf right) { ... }
        public static TSelf operator *(T left, TSelf right) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IAdditionOperators<T, T, T>
    {
        public static TSelf operator +(TSelf left, T right) { ... }
        public static TSelf operator +(TSelf left, TSelf right) { ... }
        public static TSelf operator +(T left, TSelf right) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IModulusOperators<T, T, T>
    {
        public static TSelf operator %(TSelf left, T right) { ... }
        public static TSelf operator %(TSelf left, TSelf right) { ... }
        public static TSelf operator %(T left, TSelf right) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IDivisionOperators<T, T, T>
    {
        public static TSelf operator /(TSelf left, T right) { ... }
        public static TSelf operator /(TSelf left, TSelf right) { ... }
        public static TSelf operator /(T left, TSelf right) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : ISubtractionOperators<T, T, T>
    {
        public static TSelf operator -(TSelf left, T right) { ... }
        public static TSelf operator -(TSelf left, TSelf right) { ... }
        public static TSelf operator -(T left, TSelf right) { ... }
    }

       // REVIEW: no == for now

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IUnaryNegationOperators<T, T>
    {
        public static TSelf operator -(TSelf tensor) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IUnaryPlusOperators<T, T>
    {
        public static TSelf operator +(TSelf tensor) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>
        where T : IBitwiseOperators<T, T, T>
    {
        public static TSelf operator ~(TSelf tensor) { ... }

        public static TSelf operator &(TSelf left, T right) { ... }
        public static TSelf operator &(TSelf left, TSelf right) { ... }
        public static TSelf operator &(T left, TSelf right) { ... }

        public static TSelf operator |(TSelf left, T right) { ... }
        public static TSelf operator |(TSelf left, TSelf right) { ... }
        public static TSelf operator |(T left, TSelf right) { ... }

        public static TSelf operator ^(TSelf left, T right) { ... }
        public static TSelf operator ^(TSelf left, TSelf right) { ... }
        public static TSelf operator ^(T left, TSelf right) { ... }
    }

    extension<TSelf, T>(TSelf)
    where TSelf: IReadOnlyTensor<TSelf, T>
    where T : IShiftOperators<T, T, int>
    {
        public static TSelf operator <<(this TSelf tensor, int shiftAmount) { ... }
        public static TSelf operator >>(this TSelf tensor, int shiftAmount) { ... }
        public static TSelf operator >>>(this TSelf tensor, int shiftAmount) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IMultiplyOperators<T, T, T>
    {
        // In place updates need no allocation
        public static void operator *=(this TSelf tensor, T other) { ... }
        public static void operator *=(this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: IReadOnlyTensor<TSelf, T>, struct
        where T : IMultiplyOperators<T, T, T>
    {
        // In place updates need no allocation
        public static void operator *=(ref this TSelf tensor, T other) { ... }
        public static void operator *=(ref this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IAdditionOperators<T, T, T>
    {
        public static void operator +=(this TSelf tensor, TSelf other) { ... }
        public static void operator +=(this TSelf tensor, T other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, struct
        where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>
    {
        public static void operator +=(ref this TSelf tensor, TSelf other) { ... }
        public static void operator +=(ref this TSelf tensor, T other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IIncrementOperators<T>
    {
        public static void operator ++(this TSelf tensor) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IModulusOperators<T, T, T>
    {
        public static void operator %=(this TSelf tensor, T other) { ... }
        public static void operator %=(this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, struct
        where T : IModulusOperators<T, T, T>
    {
        public static void operator %=(ref this TSelf tensor, T other) { ... }
        public static void operator %=(ref this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IDivisionOperators<T, T, T>
    {
        public static void operator /=(this TSelf tensor, T other) { ... }
        public static void operator /=(this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, struct
        where T : IDivisionOperators<T, T, T>
    {
        public static void operator /=(ref this TSelf tensor, T other) { ... }
        public static void operator /=(ref this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : ISubtractionOperators<T, T, T>
    {
        public static void operator -=(this TSelf tensor, T other) { ... }
        public static void operator -=(this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, struct
        where T : ISubtractionOperators<T, T, T>
    {
        public static void operator -=(ref this TSelf tensor, T other) { ... }
        public static void operator -=(ref this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
        where TSelf: ITensor<TSelf, T>, class
        where T : IDecrementOperator<T>
    {
        public static void operator --(this TSelf tensor) { ... }
    }

    extension<TSelf, T>(TSelf)
    where TSelf: ITensor<TSelf, T>, class
    where T : IShiftOperators<T, T, int>
    {
        public static void operator <<=(this TSelf tensor, int shiftAmount) { ... }
        public static void operator >>=(this TSelf tensor, int shiftAmount) { ... }
        public static void operator >>>=(this TSelf tensor, int shiftAmount) { ... }
    }

    extension<TSelf, T>(TSelf)
    where TSelf: ITensor<TSelf, T>, struct
    where T : IShiftOperators<T, T, int>
    {
        public static void operator <<=(ref this TSelf tensor, int shiftAmount) { ... }
        public static void operator >>=(ref this TSelf tensor, int shiftAmount) { ... }
        public static void operator >>>=(ref this TSelf tensor, int shiftAmount) { ... }
    }

    extension<TSelf, T>(TSelf)
    where TSelf: ITensor<TSelf, T>, class
    where T : IBitwiseOperators<T, T, T>
    {
        public static void operator &=(thisTSelf tensor, T other) { ... }
        public static void operator &=(this TSelf tensor, TSelf other) { ... }

        public static void operator |=(this TSelf tensor, T other) { ... }
        public static void operator |=(this TSelf tensor, TSelf other) { ... }

        public static void operator ^=(this TSelf tensor, T other) { ... }
        public static void operator ^=(this TSelf tensor, TSelf other) { ... }
    }

    extension<TSelf, T>(TSelf)
    where TSelf: ITensor<TSelf, T>, struct
    where T : IBitwiseOperators<T, T, T>
    {
        public static void operator &=(ref this TSelf tensor, T other) { ... }
        public static void operator &=(ref this TSelf tensor, TSelf other) { ... }

        public static void operator |=(ref this TSelf tensor, T other) { ... }
        public static void operator |=(ref this TSelf tensor, TSelf other) { ... }

        public static void operator ^=(ref this TSelf tensor, T other) { ... }
        public static void operator ^=(ref this TSelf tensor, TSelf other) { ... }
    }
}

API Usage

Tensor<int> tensor = Tensor.Create<int>([1, 2, 3, 4], [2, 2]);

// After this the values inside should be [3, 4, 5, 6]
tensor += 2;

// after this the values inside should [9, 12, 15, 18] 
tensor *= 3;

// Allocates a new tensor with values [10, 13, 16, 19]
Tensor<int> tensor2 = tensor + 1;

Alternative Designs

None. We needed this functionality and so the language team figured out the best way to incorporate it in C# and we are just adopting it.

Risks

Low risk because it's a preview type and its all new features. The main risk will be around the language team and any delays they may have.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions