From 7c83a188b9e8849d4913880e06227e99887abe11 Mon Sep 17 00:00:00 2001 From: Antao Almada Date: Wed, 21 Feb 2024 22:00:42 +0000 Subject: [PATCH] Add AggregateChecked and a vectorized SumChecked --- .../SumCheckedTests.cs | 87 +++++ .../AggregateChecked.cs | 314 ++++++++++++++++++ .../Operations/Addition.cs | 16 +- .../Operations/SumChecked.cs | 20 ++ .../Operators/AdditionOperators.cs | 2 +- .../Operators/SumOperator.cs | 16 + 6 files changed, 446 insertions(+), 9 deletions(-) create mode 100644 src/NetFabric.Numerics.Tensors.UnitTests/SumCheckedTests.cs create mode 100644 src/NetFabric.Numerics.Tensors/AggregateChecked.cs create mode 100644 src/NetFabric.Numerics.Tensors/Operations/SumChecked.cs diff --git a/src/NetFabric.Numerics.Tensors.UnitTests/SumCheckedTests.cs b/src/NetFabric.Numerics.Tensors.UnitTests/SumCheckedTests.cs new file mode 100644 index 0000000..3f951eb --- /dev/null +++ b/src/NetFabric.Numerics.Tensors.UnitTests/SumCheckedTests.cs @@ -0,0 +1,87 @@ +using System.Numerics.Tensors; + +namespace NetFabric.Numerics.Tensors.UnitTests; + +public class SumCheckedTests +{ + public static TheoryData SumCheckedOverflowData + => new() { + new[] { int.MaxValue, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, + new[] { 1, int.MaxValue, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, + new[] { 1, 1, int.MaxValue, 1, 1, 1, 1, 1, 1, 1, 1 }, + new[] { 1, 1, 1, int.MaxValue, 1, 1, 1, 1, 1, 1, 1 }, + new[] { 1, 1, 1, 1, int.MaxValue, 1, 1, 1, 1, 1, 1 }, + new[] { 1, 1, 1, 1, 1, int.MaxValue, 1, 1, 1, 1, 1 }, + new[] { 1, 1, 1, 1, 1, 1, int.MaxValue, 1, 1, 1, 1 }, + new[] { 1, 1, 1, 1, 1, 1, 1, int.MaxValue, 1, 1, 1 }, + }; + + [Theory] + [MemberData(nameof(SumCheckedOverflowData))] + public static void SumChecked_With_Overflow_Should_Throw(int[] source) + { + // arrange + + // act + void action() => TensorOperations.SumChecked(source); + + // assert + _ = Assert.Throws(action); + } + + public static TheoryData SumCheckedData + => new() { + { 0 }, { 1 }, { 2 }, { 3 }, { 4 }, { 5 }, { 6 }, { 7 }, { 8 }, { 9 }, { 10 }, { 100 }, + }; + + static void SumChecked_Should_Succeed(int count) + where T : struct, INumber + { + // arrange + var source = new T[count]; + var expected = T.Zero; + var random = new Random(42); + for (var index = 0; index < source.Length; index++) + { + var value = T.CreateChecked(random.Next(10)); + source[index] = value; + expected += value; + } + + // act + var result = TensorOperations.SumChecked(source); + + // assert + Assert.Equal(expected, result); + } + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Short_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Int_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Long_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Half_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Float_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); + + [Theory] + [MemberData(nameof(SumCheckedData))] + public void SumChecked_Double_Should_Succeed(int count) + => SumChecked_Should_Succeed(count); +} \ No newline at end of file diff --git a/src/NetFabric.Numerics.Tensors/AggregateChecked.cs b/src/NetFabric.Numerics.Tensors/AggregateChecked.cs new file mode 100644 index 0000000..18bc171 --- /dev/null +++ b/src/NetFabric.Numerics.Tensors/AggregateChecked.cs @@ -0,0 +1,314 @@ +namespace NetFabric.Numerics.Tensors; + +public static partial class Tensor +{ + /// + /// Aggregates the elements of a using the specified aggregation operator. + /// + /// The type of the elements in the source span. + /// The type of the aggregation operator that must implement the interface. + /// The span of elements to aggregate. + /// The result of the aggregation. + public static T AggregateChecked(ReadOnlySpan source) + where T : struct + where TAggregateOperator : struct, IAggregationOperator + => AggregateChecked, TAggregateOperator>(source); + + /// + /// Aggregates the elements of a using the specified transform and aggregation operators. + /// + /// The type of the elements in the source span. + /// The type of the elements after the transform operation. + /// The type of the result of the aggregation. + /// The type of the transform operator that must implement the interface. + /// The type of the aggregation operator that must implement the interface. + /// The span of elements to transform and aggregate. + /// The result of the aggregation. + /// The transform operator is applied to the source elements before the aggregation operator. + public static TResult AggregateChecked(ReadOnlySpan source) + where TSource : struct + where TTransformed : struct + where TResult : struct + where TTransformOperator : struct, IUnaryOperator + where TAggregateOperator : struct, IAggregationOperator + { + // initialize aggregate + var aggregate = TAggregateOperator.Seed; + var indexSource = nint.Zero; + + // aggregate using hardware acceleration if available + if (TTransformOperator.IsVectorizable && + TAggregateOperator.IsVectorizable && + Vector.IsHardwareAccelerated && + Vector.IsSupported && + (typeof(TTransformed) == typeof(short) || typeof(TTransformed) == typeof(int) || typeof(TTransformed) == typeof(long)) && + //Vector.IsSupported && + Vector.IsSupported) + { + // convert source span to vector span without copies + var sourceVectors = MemoryMarshal.Cast>(source); + + // check if there is at least one vector to aggregate + if (sourceVectors.Length > 0) + { + // initialize aggregate vector + var resultVector = new Vector(TAggregateOperator.Seed); + + // aggregate the source vectors into the aggregate vector + ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors); + var indexVector = nint.Zero; + for (; indexVector < sourceVectors.Length; indexVector++) + { + var transformedVector = TTransformOperator.Invoke(ref Unsafe.Add(ref sourceVectorsRef, indexVector)); + resultVector = TAggregateOperator.Invoke(ref resultVector, ref transformedVector); + } + + // aggregate the aggregate vector into the aggregate + for(var index = 0; index < Vector.Count; index++) + { + aggregate = TAggregateOperator.Invoke(aggregate, resultVector[index]); + } + + // skip the source elements already aggregated + indexSource = indexVector * Vector.Count; + } + } + + // aggregate the remaining elements in the source + ref var sourceRef = ref MemoryMarshal.GetReference(source); + var remaining = source.Length - (int)indexSource; + if (remaining >= 8) + { + var partial1 = TAggregateOperator.Seed; + var partial2 = TAggregateOperator.Seed; + var partial3 = TAggregateOperator.Seed; + for (; indexSource + 3 < source.Length; indexSource += 4) + { + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + partial1 = TAggregateOperator.Invoke(partial1, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + partial2 = TAggregateOperator.Invoke(partial2, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + partial3 = TAggregateOperator.Invoke(partial3, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 3))); + } + aggregate = TAggregateOperator.Invoke(aggregate, partial1); + aggregate = TAggregateOperator.Invoke(aggregate, partial2); + aggregate = TAggregateOperator.Invoke(aggregate, partial3); + + remaining = source.Length - (int)indexSource; + } + + switch (remaining) + { + case 7: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 4))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 5))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 6))); + break; + case 6: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 4))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 5))); + break; + case 5: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 4))); + break; + case 4: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 3))); + break; + case 3: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2))); + break; + case 2: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1))); + break; + case 1: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource))); + break; + case 0: + break; + default: + Throw.Exception("Should not happen!"); + break; + } + + return aggregate; + } + + /// + /// Aggregates the elements of two using the specified transform and aggregation operators. + /// + /// The type of the elements in the source spans. + /// The type of the transform operator that must implement the interface. + /// The type of the aggregation operator that must implement the interface. + /// The first span of elements to transform and aggregate. + /// The second span of elements to transform and aggregate. + /// The result of the aggregation. + /// The transform operator is applied to the source elements before the aggregation operator. + public static T AggregateChecked(ReadOnlySpan x, ReadOnlySpan y) + where T : struct + where TTransformOperator : struct, IBinaryOperator + where TAggregateOperator : struct, IAggregationOperator + => AggregateChecked(x, y); + + /// + /// Aggregates the elements of two using the specified transform and aggregation operators. + /// + /// The type of the elements in the first source span. + /// The type of the elements in the second source span. + /// The type of the elements after the transform operation. + /// The type of the result of the aggregation. + /// The type of the transform operator that must implement the interface. + /// The type of the aggregation operator that must implement the interface. + /// The first span of elements to transform and aggregate. + /// The second span of elements to transform and aggregate. + /// The result of the aggregation. + /// The transform operator is applied to the source elements before the aggregation operator. + public static TResult AggregateChecked(ReadOnlySpan x, ReadOnlySpan y) + where TSource1 : struct + where TSource2 : struct + where TTransformed : struct + where TResult : struct + where TTransformOperator : struct, IBinaryOperator + where TAggregateOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + Throw.ArgumentException(nameof(y), "source spans must have the same size."); + + // initialize aggregate + var aggregate = TAggregateOperator.Seed; + var indexSource = nint.Zero; + + // aggregate using hardware acceleration if available + if (TTransformOperator.IsVectorizable && + TAggregateOperator.IsVectorizable && + Vector.IsHardwareAccelerated && + Vector.IsSupported && + Vector.IsSupported && + Vector.IsSupported && + Vector.IsSupported) + { + // convert source span to vector span without copies + var xVectors = MemoryMarshal.Cast>(x); + var yVectors = MemoryMarshal.Cast>(y); + + // check if there is at least one vector to aggregate + if (xVectors.Length > 0) + { + // initialize aggregate vector + var resultVector = new Vector(TAggregateOperator.Seed); + + // aggregate the source vectors into the aggregate vector + ref var xVectorsRef = ref MemoryMarshal.GetReference(xVectors); + ref var yVectorsRef = ref MemoryMarshal.GetReference(yVectors); + var indexVector = nint.Zero; + for (; indexVector < xVectors.Length; indexVector++) + { + var transformedVector = TTransformOperator.Invoke(ref Unsafe.Add(ref xVectorsRef, indexVector), ref Unsafe.Add(ref yVectorsRef, indexVector)); + resultVector = TAggregateOperator.Invoke(ref resultVector, ref transformedVector); + } + + // aggregate the aggregate vector into the aggregate + for(var index = 0; index < Vector.Count; index++) + { + aggregate = TAggregateOperator.Invoke(aggregate, resultVector[index]); + } + + // skip the source elements already aggregated + indexSource = indexVector * Vector.Count; + } + } + + // aggregate the remaining elements in the source + ref var xRef = ref MemoryMarshal.GetReference(x); + ref var yRef = ref MemoryMarshal.GetReference(y); + var remaining = x.Length - (int)indexSource; + if (remaining >= 8) + { + var partial1 = TAggregateOperator.Seed; + var partial2 = TAggregateOperator.Seed; + var partial3 = TAggregateOperator.Seed; + for (; indexSource + 3 < x.Length; indexSource += 4) + { + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + partial1 = TAggregateOperator.Invoke(partial1, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + partial2 = TAggregateOperator.Invoke(partial2, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + partial3 = TAggregateOperator.Invoke(partial3, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 3), Unsafe.Add(ref yRef, indexSource + 3))); + } + aggregate = TAggregateOperator.Invoke(aggregate, partial1); + aggregate = TAggregateOperator.Invoke(aggregate, partial2); + aggregate = TAggregateOperator.Invoke(aggregate, partial3); + remaining = x.Length - (int)indexSource; + } + + switch(remaining) + { + case 7: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 3), Unsafe.Add(ref yRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 4), Unsafe.Add(ref yRef, indexSource + 4))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 5), Unsafe.Add(ref yRef, indexSource + 5))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 6), Unsafe.Add(ref yRef, indexSource + 6))); + break; + case 6: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 3), Unsafe.Add(ref yRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 4), Unsafe.Add(ref yRef, indexSource + 4))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 5), Unsafe.Add(ref yRef, indexSource + 5))); + break; + case 5: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 3), Unsafe.Add(ref yRef, indexSource + 3))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 4), Unsafe.Add(ref yRef, indexSource + 4))); + break; + case 4: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 3), Unsafe.Add(ref yRef, indexSource + 3))); + break; + case 3: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2))); + break; + case 2: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1))); + break; + case 1: + aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource))); + break; + case 0: + break; + default: + Throw.Exception("Should not happen!"); + break; + } + + return aggregate; + } + +} + diff --git a/src/NetFabric.Numerics.Tensors/Operations/Addition.cs b/src/NetFabric.Numerics.Tensors/Operations/Addition.cs index 64c090d..d65a36a 100644 --- a/src/NetFabric.Numerics.Tensors/Operations/Addition.cs +++ b/src/NetFabric.Numerics.Tensors/Operations/Addition.cs @@ -18,19 +18,19 @@ public static void Add(ReadOnlySpan left, ReadOnlySpan right, Span d where T : struct, IAdditionOperators => Tensor.Apply>(left, right, destination); - public static void CheckedAdd(ReadOnlySpan left, T right, Span destination) + public static void AddChecked(ReadOnlySpan left, T right, Span destination) where T : struct, IAdditionOperators - => Tensor.Apply>(left, right, destination); + => Tensor.Apply>(left, right, destination); - public static void CheckedAdd(ReadOnlySpan left, ValueTuple right, Span destination) + public static void AddChecked(ReadOnlySpan left, ValueTuple right, Span destination) where T : struct, IAdditionOperators - => Tensor.Apply>(left, right, destination); + => Tensor.Apply>(left, right, destination); - public static void CheckedAdd(ReadOnlySpan left, ValueTuple right, Span destination) + public static void AddChecked(ReadOnlySpan left, ValueTuple right, Span destination) where T : struct, IAdditionOperators - => Tensor.Apply>(left, right, destination); + => Tensor.Apply>(left, right, destination); - public static void CheckedAdd(ReadOnlySpan left, ReadOnlySpan right, Span destination) + public static void AddChecked(ReadOnlySpan left, ReadOnlySpan right, Span destination) where T : struct, IAdditionOperators - => Tensor.Apply>(left, right, destination); + => Tensor.Apply>(left, right, destination); } \ No newline at end of file diff --git a/src/NetFabric.Numerics.Tensors/Operations/SumChecked.cs b/src/NetFabric.Numerics.Tensors/Operations/SumChecked.cs new file mode 100644 index 0000000..1f00b00 --- /dev/null +++ b/src/NetFabric.Numerics.Tensors/Operations/SumChecked.cs @@ -0,0 +1,20 @@ +namespace NetFabric.Numerics.Tensors; + +public static partial class TensorOperations +{ + public static T SumChecked(ReadOnlySpan source) + where T : struct, IAdditionOperators, IAdditiveIdentity + => Tensor.AggregateChecked>(source); + + //public static ValueTuple SumChecked2D(ReadOnlySpan source) + // where T : struct, IAdditionOperators, IAdditiveIdentity + // => Tensor.AggregateChecked2D>(source); + + //public static ValueTuple SumChecked3D(ReadOnlySpan source) + // where T : struct, IAdditionOperators, IAdditiveIdentity + // => Tensor.AggregateChecked3D>(source); + + //public static ValueTuple SumChecked4D(ReadOnlySpan source) + // where T : struct, IAdditionOperators, IAdditiveIdentity + // => Tensor.AggregateChecked4D>(source); +} \ No newline at end of file diff --git a/src/NetFabric.Numerics.Tensors/Operators/AdditionOperators.cs b/src/NetFabric.Numerics.Tensors/Operators/AdditionOperators.cs index be7f77c..0d5fcd0 100644 --- a/src/NetFabric.Numerics.Tensors/Operators/AdditionOperators.cs +++ b/src/NetFabric.Numerics.Tensors/Operators/AdditionOperators.cs @@ -13,7 +13,7 @@ public static Vector Invoke(ref readonly Vector x, ref readonly Vector => x + y; } -readonly struct CheckedAddOperator +readonly struct AddCheckedOperator : IBinaryOperator where T : struct, IAdditionOperators { diff --git a/src/NetFabric.Numerics.Tensors/Operators/SumOperator.cs b/src/NetFabric.Numerics.Tensors/Operators/SumOperator.cs index 4c8ca03..7078191 100644 --- a/src/NetFabric.Numerics.Tensors/Operators/SumOperator.cs +++ b/src/NetFabric.Numerics.Tensors/Operators/SumOperator.cs @@ -11,6 +11,22 @@ public static T Seed public static T Invoke(T x, T y) => x + y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector Invoke(ref readonly Vector x, ref readonly Vector y) + => x + y; +} + +readonly struct SumCheckedOperator + : IAggregationOperator + where T : struct, IAdditiveIdentity, IAdditionOperators +{ + public static T Seed + => T.AdditiveIdentity; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T Invoke(T x, T y) + => checked(x + y); + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector Invoke(ref readonly Vector x, ref readonly Vector y) => x + y;