Skip to content

Commit d2f6624

Browse files
committed
Use TensorPrimitives in FloatSpan
instead of questionable AVX usage.
1 parent 187ac5a commit d2f6624

2 files changed

Lines changed: 18 additions & 288 deletions

File tree

engine/Sandbox.Engine/Sandbox.Engine.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="10.0.0" />
4545
<PackageReference Include="Svg.Skia" Version="2.0.0.4" />
4646
<PackageReference Include="Azure.Messaging.WebPubSub.Client" Version="1.0.0" />
47+
<PackageReference Include="System.Numerics.Tensors" Version="10.0.0" />
4748
</ItemGroup>
4849

4950
<ItemGroup>
Lines changed: 17 additions & 288 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
using System.Numerics;
1+
using System;
22
using System.Runtime.CompilerServices;
3-
using System.Runtime.Intrinsics;
4-
using System.Runtime.Intrinsics.X86;
3+
using System.Numerics.Tensors;
54

65
namespace Sandbox;
76

87
/// <summary>
9-
/// Allows easy SIMD/AVX2 fast math on a span of floats
8+
/// Provides vectorized operations over a span of floats.
109
/// </summary>
1110
public ref struct FloatSpan
1211
{
@@ -17,333 +16,63 @@ public FloatSpan( Span<float> span )
1716
_span = span;
1817
}
1918

20-
/// <summary>
21-
/// Uses SIMD/AVX2 to find the maximum value in a span of floats.
22-
/// </summary>
2319
[MethodImpl( MethodImplOptions.AggressiveInlining )]
2420
public float Max()
2521
{
26-
if ( _span.IsEmpty ) return 0.0f;
27-
28-
int i = 0;
29-
float max = float.MinValue;
30-
31-
if ( Avx.IsSupported )
32-
{
33-
var maxVector = Vector256.Create( float.MinValue );
34-
35-
// Get a pointer to the span data
36-
unsafe
37-
{
38-
fixed ( float* ptr = _span )
39-
{
40-
for ( ; i <= _span.Length - 8; i += 8 )
41-
{
42-
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
43-
maxVector = Avx.Max( maxVector, v );
44-
}
45-
}
46-
}
47-
48-
// Reduce maxVector to a single float
49-
max = Math.Max( max, maxVector.GetElement( 0 ) );
50-
max = Math.Max( max, maxVector.GetElement( 1 ) );
51-
max = Math.Max( max, maxVector.GetElement( 2 ) );
52-
max = Math.Max( max, maxVector.GetElement( 3 ) );
53-
max = Math.Max( max, maxVector.GetElement( 4 ) );
54-
max = Math.Max( max, maxVector.GetElement( 5 ) );
55-
max = Math.Max( max, maxVector.GetElement( 6 ) );
56-
max = Math.Max( max, maxVector.GetElement( 7 ) );
57-
}
58-
59-
// Handle remaining elements
60-
for ( ; i < _span.Length; i++ )
61-
max = Math.Max( max, _span[i] );
62-
63-
return max;
22+
return _span.IsEmpty ? 0.0f : TensorPrimitives.Max( _span );
6423
}
6524

66-
/// <summary>
67-
/// Uses SIMD/AVX2 to find the minimum value in a span of floats.
68-
/// </summary>
6925
[MethodImpl( MethodImplOptions.AggressiveInlining )]
7026
public float Min()
7127
{
72-
if ( _span.IsEmpty ) return 0.0f;
73-
74-
int i = 0;
75-
float max = float.MaxValue;
76-
77-
if ( Avx.IsSupported )
78-
{
79-
var maxVector = Vector256.Create( float.MaxValue );
80-
81-
// Get a pointer to the span data
82-
unsafe
83-
{
84-
fixed ( float* ptr = _span )
85-
{
86-
for ( ; i <= _span.Length - 8; i += 8 )
87-
{
88-
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
89-
maxVector = Avx.Min( maxVector, v );
90-
}
91-
}
92-
}
93-
94-
max = Math.Min( max, maxVector.GetElement( 0 ) );
95-
max = Math.Min( max, maxVector.GetElement( 1 ) );
96-
max = Math.Min( max, maxVector.GetElement( 2 ) );
97-
max = Math.Min( max, maxVector.GetElement( 3 ) );
98-
max = Math.Min( max, maxVector.GetElement( 4 ) );
99-
max = Math.Min( max, maxVector.GetElement( 5 ) );
100-
max = Math.Min( max, maxVector.GetElement( 6 ) );
101-
max = Math.Min( max, maxVector.GetElement( 7 ) );
102-
}
103-
104-
// Handle remaining elements
105-
for ( ; i < _span.Length; i++ )
106-
max = Math.Min( max, _span[i] );
107-
108-
return max;
28+
return _span.IsEmpty ? 0.0f : TensorPrimitives.Min( _span );
10929
}
11030

11131
[MethodImpl( MethodImplOptions.AggressiveInlining )]
11232
public float Average()
11333
{
114-
if ( _span.IsEmpty ) return 0.0f;
115-
116-
int i = 0;
117-
float sum = 0f;
118-
float len = _span.Length;
119-
120-
if ( Avx.IsSupported )
121-
{
122-
var sumVector = Vector256<float>.Zero;
123-
124-
unsafe
125-
{
126-
fixed ( float* ptr = _span )
127-
{
128-
// Sum using AVX2
129-
for ( ; i <= len - 8; i += 8 )
130-
{
131-
var v = Avx.LoadVector256( ptr + i );
132-
sumVector = Avx.Add( sumVector, v );
133-
}
134-
}
135-
}
136-
137-
// Reduce sumVector to a single float
138-
sum += sumVector.GetElement( 0 );
139-
sum += sumVector.GetElement( 1 );
140-
sum += sumVector.GetElement( 2 );
141-
sum += sumVector.GetElement( 3 );
142-
sum += sumVector.GetElement( 4 );
143-
sum += sumVector.GetElement( 5 );
144-
sum += sumVector.GetElement( 6 );
145-
sum += sumVector.GetElement( 7 );
146-
}
147-
148-
// Handle remaining elements
149-
for ( ; i < len; i++ )
150-
sum += _span[i];
151-
152-
return sum / len;
34+
return _span.IsEmpty ? 0.0f : TensorPrimitives.Average( _span );
15335
}
15436

15537
[MethodImpl( MethodImplOptions.AggressiveInlining )]
15638
public float Sum()
15739
{
158-
if ( _span.IsEmpty ) return 0.0f;
159-
160-
int i = 0;
161-
162-
float sum = 0f;
163-
164-
if ( Avx.IsSupported )
165-
{
166-
var sumVector = Vector256<float>.Zero;
167-
168-
unsafe
169-
{
170-
fixed ( float* ptr = _span )
171-
{
172-
for ( ; i <= _span.Length - 8; i += 8 )
173-
{
174-
var v = Avx.LoadVector256( ptr + i );
175-
sumVector = Avx.Add( sumVector, v );
176-
}
177-
}
178-
}
179-
180-
// Reduce sumVector using horizontal adds
181-
var temp = Avx.HorizontalAdd( sumVector, sumVector );
182-
temp = Avx.HorizontalAdd( temp, temp );
183-
temp = Avx.HorizontalAdd( temp, temp );
184-
185-
sum += temp.GetElement( 0 );
186-
187-
}
188-
189-
for ( ; i < _span.Length; i++ )
190-
sum += _span[i];
191-
192-
return sum;
40+
return _span.IsEmpty ? 0.0f : TensorPrimitives.Sum( _span );
19341
}
19442

19543
[MethodImpl( MethodImplOptions.AggressiveInlining )]
19644
public void Set( float value )
19745
{
198-
int i = 0;
199-
200-
if ( Avx.IsSupported )
201-
{
202-
unsafe
203-
{
204-
fixed ( float* ptr = _span )
205-
{
206-
var v = Vector256.Create( value );
207-
for ( ; i <= _span.Length - 8; i += 8 )
208-
{
209-
Avx.Store( ptr + i, v );
210-
}
211-
}
212-
}
213-
}
214-
215-
for ( ; i < _span.Length; i++ )
216-
_span[i] = value;
46+
_span.Fill( value );
21747
}
21848

21949
[MethodImpl( MethodImplOptions.AggressiveInlining )]
220-
public readonly void Set( in Span<float> values )
50+
public readonly void Set( ReadOnlySpan<float> values )
22151
{
222-
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
223-
224-
unsafe
225-
{
226-
var size = _span.Length * sizeof( float );
227-
228-
fixed ( float* srcPtr = values, dstPtr = _span )
229-
{
230-
NativeLowLevel.Copy( (IntPtr)srcPtr, (IntPtr)dstPtr, (uint)size );
231-
}
232-
}
52+
values.CopyTo( _span );
23353
}
23454

23555
[MethodImpl( MethodImplOptions.AggressiveInlining )]
236-
public readonly void CopyScaled( in Span<float> values, float scale )
56+
public readonly void CopyScaled( ReadOnlySpan<float> values, float scale )
23757
{
238-
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
239-
240-
int i = 0;
241-
242-
if ( Avx.IsSupported )
243-
{
244-
var scaleVector = Vector256.Create( scale );
245-
246-
unsafe
247-
{
248-
fixed ( float* srcPtr = values, dstPtr = _span )
249-
{
250-
for ( ; i <= _span.Length - 8; i += 8 )
251-
{
252-
var v = Avx.LoadVector256( srcPtr + i );
253-
v = Avx.Multiply( v, scaleVector );
254-
Avx.Store( dstPtr + i, v );
255-
}
256-
}
257-
}
258-
}
259-
260-
for ( ; i < _span.Length; i++ )
261-
_span[i] = values[i] * scale;
58+
TensorPrimitives.Multiply( values, scale, _span );
26259
}
26360

26461
[MethodImpl( MethodImplOptions.AggressiveInlining )]
265-
public readonly void Add( in Span<float> values )
62+
public readonly void Add( ReadOnlySpan<float> values )
26663
{
267-
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
268-
269-
int i = 0;
270-
271-
if ( Avx.IsSupported )
272-
{
273-
unsafe
274-
{
275-
fixed ( float* srcPtr = values, dstPtr = _span )
276-
{
277-
for ( ; i <= _span.Length - 8; i += 8 )
278-
{
279-
var v = Avx.LoadVector256( srcPtr + i );
280-
var dst = Avx.LoadVector256( dstPtr + i );
281-
dst = Avx.Add( dst, v );
282-
Avx.Store( dstPtr + i, dst );
283-
}
284-
}
285-
}
286-
}
287-
288-
for ( ; i < _span.Length; i++ )
289-
_span[i] += values[i];
64+
TensorPrimitives.Add( _span, values, _span );
29065
}
29166

29267
[MethodImpl( MethodImplOptions.AggressiveInlining )]
293-
public readonly void AddScaled( in Span<float> values, float scale )
68+
public readonly void AddScaled( ReadOnlySpan<float> values, float scale )
29469
{
295-
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
296-
297-
int i = 0;
298-
299-
if ( Avx.IsSupported )
300-
{
301-
var scaleVector = Vector256.Create( scale );
302-
303-
unsafe
304-
{
305-
fixed ( float* srcPtr = values, dstPtr = _span )
306-
{
307-
for ( ; i <= _span.Length - 8; i += 8 )
308-
{
309-
var v = Avx.LoadVector256( srcPtr + i );
310-
v = Avx.Multiply( v, scaleVector );
311-
var dst = Avx.LoadVector256( dstPtr + i );
312-
dst = Avx.Add( dst, v );
313-
Avx.Store( dstPtr + i, dst );
314-
}
315-
}
316-
}
317-
}
318-
319-
for ( ; i < _span.Length; i++ )
320-
_span[i] += values[i] * scale;
70+
TensorPrimitives.MultiplyAdd( values, scale, _span, _span );
32171
}
32272

32373
[MethodImpl( MethodImplOptions.AggressiveInlining )]
32474
public readonly void Scale( float scale )
32575
{
326-
int i = 0;
327-
328-
if ( Avx.IsSupported )
329-
{
330-
var scaleVector = Vector256.Create( scale );
331-
332-
unsafe
333-
{
334-
fixed ( float* dstPtr = _span )
335-
{
336-
for ( ; i <= _span.Length - 8; i += 8 )
337-
{
338-
var v = Avx.LoadVector256( dstPtr + i );
339-
v = Avx.Multiply( v, scaleVector );
340-
Avx.Store( dstPtr + i, v );
341-
}
342-
}
343-
}
344-
}
345-
346-
for ( ; i < _span.Length; i++ )
347-
_span[i] *= scale;
76+
TensorPrimitives.Multiply( _span, scale, _span );
34877
}
34978
}

0 commit comments

Comments
 (0)