Skip to content

Commit fdbc56e

Browse files
authored
Fix | SqlVector: Explicitly perform little-endian multibyte writes (dotnet#3861)
1 parent 34b400e commit fdbc56e

2 files changed

Lines changed: 70 additions & 27 deletions

File tree

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlVector.cs

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -39,13 +39,12 @@ namespace Microsoft.Data.SqlTypes;
3939

4040
private SqlVector(int length)
4141
{
42-
if (length < 0)
42+
(_elementType, _elementSize, int maxElements) = GetTypeFieldsOrThrow();
43+
if (length < 0 || length > maxElements)
4344
{
4445
throw ADP.InvalidArraySize(nameof(length));
4546
}
4647

47-
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
48-
4948
IsNull = true;
5049

5150
Length = length;
@@ -61,7 +60,11 @@ private SqlVector(int length)
6160
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlTypes/SqlVector.xml' path='docs/members[@name="SqlVector"]/ctor1/*' />
6261
public SqlVector(ReadOnlyMemory<T> memory)
6362
{
64-
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
63+
(_elementType, _elementSize, int maxElements) = GetTypeFieldsOrThrow();
64+
if (memory.Length > maxElements)
65+
{
66+
throw ADP.InvalidArraySize(nameof(memory));
67+
}
6568

6669
IsNull = false;
6770

@@ -74,7 +77,7 @@ public SqlVector(ReadOnlyMemory<T> memory)
7477

7578
internal SqlVector(byte[] tdsBytes)
7679
{
77-
(_elementType, _elementSize) = GetTypeFieldsOrThrow();
80+
(_elementType, _elementSize, _) = GetTypeFieldsOrThrow();
7881

7982
(Length, _size) = GetCountsOrThrow(tdsBytes);
8083

@@ -125,10 +128,11 @@ internal string GetString()
125128

126129
#region Helpers
127130

128-
private (byte, byte) GetTypeFieldsOrThrow()
131+
private static (byte, byte, int) GetTypeFieldsOrThrow()
129132
{
130133
byte elementType;
131134
byte elementSize;
135+
int maxSize;
132136

133137
if (typeof(T) == typeof(float))
134138
{
@@ -139,12 +143,17 @@ internal string GetString()
139143
{
140144
throw SQL.VectorTypeNotSupported(typeof(T).FullName);
141145
}
146+
// The size of a vector (including its header) must not exceed the maximum size of a TDS packet.
147+
// Calculate the maximum number of elements to simplify the validation of input sizes in constructors.
148+
maxSize = (TdsEnums.MAXSIZE - TdsEnums.VECTOR_HEADER_SIZE) / elementSize;
142149

143-
return (elementType, elementSize);
150+
return (elementType, elementSize, maxSize);
144151
}
145152

146153
private byte[] MakeTdsBytes(ReadOnlyMemory<T> values)
147154
{
155+
Debug.Assert(Length <= ushort.MaxValue);
156+
148157
//Refer to TDS section 2.2.5.5.7 for vector header format
149158
// +------------------------+-----------------+----------------------+------------------+----------------------------+--------------+
150159
// | Field | Size (bytes) | Example Value | Description |
@@ -158,32 +167,42 @@ private byte[] MakeTdsBytes(ReadOnlyMemory<T> values)
158167
// +------------------------+-----------------+----------------------+--------------------------------------------------------------+
159168

160169
byte[] result = new byte[_size];
170+
ReadOnlySpan<T> valueSpan = values.Span;
161171

162172
// Header Bytes
163173
result[0] = VecHeaderMagicNo;
164174
result[1] = VecVersionNo;
165-
result[2] = (byte)(Length & 0xFF);
166-
result[3] = (byte)((Length >> 8) & 0xFF);
175+
BinaryPrimitives.WriteUInt16LittleEndian(result.AsSpan(2), (ushort)Length);
167176
result[4] = _elementType;
168177
result[5] = 0x00;
169178
result[6] = 0x00;
170179
result[7] = 0x00;
171180

172-
#if NETFRAMEWORK
173-
// Copy data via marshaling.
174-
if (MemoryMarshal.TryGetArray(values, out ArraySegment<T> segment))
181+
// If .NET is running on a little-endian architecture, cast directly to a byte array and proceed.
182+
// This optimisation relies upon the base type of the vector transporting values in a format and
183+
// endianness which is identical to the client. This is true for all little-endian clients reading
184+
// float32-based vectors.
185+
if (BitConverter.IsLittleEndian)
175186
{
176-
Buffer.BlockCopy(segment.Array, segment.Offset * _elementSize, result, TdsEnums.VECTOR_HEADER_SIZE, segment.Count * _elementSize);
187+
ReadOnlySpan<byte> valuesAsBytes = MemoryMarshal.AsBytes(valueSpan);
188+
189+
valuesAsBytes.CopyTo(result.AsSpan(TdsEnums.VECTOR_HEADER_SIZE));
177190
}
178191
else
179192
{
180-
Buffer.BlockCopy(values.ToArray(), 0, result, TdsEnums.VECTOR_HEADER_SIZE, values.Length * _elementSize);
193+
if (typeof(T) == typeof(float))
194+
{
195+
for (int i = 0, currPosition = TdsEnums.VECTOR_HEADER_SIZE; i < values.Length; i++, currPosition += _elementSize)
196+
{
197+
#if NET
198+
BinaryPrimitives.WriteSingleLittleEndian(result.AsSpan(currPosition), (float)(object)valueSpan[i]);
199+
#else
200+
BinaryPrimitives.WriteInt32LittleEndian(result.AsSpan(currPosition), BitConverterCompatible.SingleToInt32Bits((float)(object)valueSpan[i]));
201+
#endif
202+
}
203+
}
181204
}
182-
#else
183-
// Fast span-based copy.
184-
var byteSpan = MemoryMarshal.AsBytes(values.Span);
185-
byteSpan.CopyTo(result.AsSpan(TdsEnums.VECTOR_HEADER_SIZE));
186-
#endif
205+
187206
return result;
188207
}
189208

@@ -227,15 +246,32 @@ private T[] MakeArray()
227246
return Array.Empty<T>();
228247
}
229248

230-
#if NETFRAMEWORK
231249
// Allocate array and copy bytes into it
232250
T[] result = new T[Length];
233-
Buffer.BlockCopy(_tdsBytes, 8, result, 0, _elementSize * Length);
251+
252+
// See the comment in MakeTdsBytes for more information on this optimisation.
253+
if (BitConverter.IsLittleEndian)
254+
{
255+
Span<byte> valuesAsBytes = MemoryMarshal.AsBytes(result.AsSpan());
256+
257+
_tdsBytes.AsSpan(TdsEnums.VECTOR_HEADER_SIZE).CopyTo(valuesAsBytes);
258+
}
259+
else
260+
{
261+
if (typeof(T) == typeof(float))
262+
{
263+
for (int i = 0, currPosition = TdsEnums.VECTOR_HEADER_SIZE; i < Length; i++, currPosition += _elementSize)
264+
{
265+
#if NET
266+
result[i] = (T)(object)BinaryPrimitives.ReadSingleLittleEndian(_tdsBytes.AsSpan(currPosition));
267+
#else
268+
result[i] = (T)(object)BitConverterCompatible.Int32BitsToSingle(BinaryPrimitives.ReadInt32LittleEndian(_tdsBytes.AsSpan(currPosition)));
269+
#endif
270+
}
271+
}
272+
}
273+
234274
return result;
235-
#else
236-
ReadOnlySpan<byte> dataSpan = _tdsBytes.AsSpan(8, _elementSize * Length);
237-
return MemoryMarshal.Cast<byte, T>(dataSpan).ToArray();
238-
#endif
239275
}
240276

241277
#endregion

src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlTypes/SqlVectorTest.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -28,6 +28,13 @@ public void Construct_Length_Negative()
2828
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(-1));
2929
}
3030

31+
[Fact]
32+
public void Construct_Length_Exceeds_8000()
33+
{
34+
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(1999));
35+
Assert.Throws<ArgumentOutOfRangeException>(() => SqlVector<float>.CreateNull(int.MaxValue / 2));
36+
}
37+
3138
[Fact]
3239
public void Construct_Length()
3340
{

0 commit comments

Comments
 (0)