diff --git a/.editorconfig b/.editorconfig
index 63f9d967..37b7467e 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -20,7 +20,7 @@ csharp_style_var_for_built_in_types = false:none
csharp_style_var_when_type_is_apparent = false:none
dotnet_naming_rule.constants_rule.import_to_resharper = as_predefined
dotnet_naming_rule.constants_rule.severity = warning
-dotnet_naming_rule.constants_rule.style = all_upper_style
+dotnet_naming_rule.constants_rule.style = upper_camel_case_style
dotnet_naming_rule.constants_rule.symbols = constants_symbols
dotnet_naming_rule.private_constants_rule.import_to_resharper = as_predefined
dotnet_naming_rule.private_constants_rule.resharper_style = AaBb, NUM_ + AA_BB
diff --git a/src/Common/ArrayBuilder.cs b/src/Common/ArrayBuilder.cs
index 87bf5ef4..5c07433a 100644
--- a/src/Common/ArrayBuilder.cs
+++ b/src/Common/ArrayBuilder.cs
@@ -4,6 +4,7 @@
// Modified for generic types and not backed by a pool
+using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
@@ -63,12 +64,6 @@ public ref T this[int index]
}
}
- public override string ToString()
- {
- string s = Raw.Slice(0, _pos).ToString();
- return s;
- }
-
/// Returns the underlying storage of the builder.
public Span Raw => _array.AsSpan();
@@ -185,3 +180,208 @@ private void Grow(int additionalCapacityBeyondPos)
_array = newArray;
}
}
+
+internal ref struct PoolArrayBuilder
+{
+ private T[]? _array;
+ private Span _raw;
+ private int _pos;
+
+ public PoolArrayBuilder(Span initialBuffer)
+ {
+ _array = null;
+ _raw = initialBuffer;
+ _pos = 0;
+ }
+
+ public PoolArrayBuilder(int initialCapacity)
+ {
+ _array = ArrayPool.Shared.Rent(initialCapacity);
+ _raw = _array;
+ _pos = 0;
+ }
+
+ public int Length
+ {
+ get => _pos;
+ set
+ {
+ Debug.Assert(value >= 0);
+ Debug.Assert(value <= _raw.Length);
+ _pos = value;
+ }
+ }
+
+ public bool IsEmpty => 0 >= (uint)_pos;
+
+ public bool IsDefault => _raw.IsEmpty && _array is null && _pos == 0;
+
+ public int Capacity => _raw.Length;
+
+ public void EnsureCapacity(int capacity)
+ {
+ // This is not expected to be called this with negative capacity
+ Debug.Assert(capacity >= 0);
+
+ if ((uint)capacity > (uint)_raw.Length)
+ Grow(capacity - _pos);
+ }
+
+ ///
+ /// Get a pinnable reference to the builder.
+ /// Does not ensure there is a null char after
+ /// This overload is pattern matched in the C# 7.3+ compiler so you can omit
+ /// the explicit method call, and write eg "fixed (char* c = builder)"
+ ///
+ public ref T GetPinnableReference()
+ {
+ return ref MemoryMarshal.GetReference(_raw);
+ }
+
+ public ref T this[int index]
+ {
+ get
+ {
+ Debug.Assert(index < _pos);
+ return ref _raw[index];
+ }
+ }
+
+ /// Returns the underlying storage of the builder.
+ public Span Raw => _raw;
+
+ public ReadOnlySpan AsSpan() => _raw.Slice(0, _pos);
+ public ReadOnlySpan AsSpan(int start) => _raw.Slice(start, _pos - start);
+ public ReadOnlySpan AsSpan(int start, int length) => _raw.Slice(start, length);
+
+ public ArraySegment AsSegment() => new(_array ?? Array.Empty());
+ public ArraySegment AsSegment(int start) => new(_array ?? Array.Empty(), start, _pos - start);
+ public ArraySegment AsSegment(int start, int length) => new(_array ?? Array.Empty(), start, length);
+
+ public bool TryCopyTo(Span destination, out int written) {
+ if (_raw.Slice(0, _pos).TryCopyTo(destination))
+ {
+ written = _pos;
+ return true;
+ }
+
+ written = 0;
+ return false;
+ }
+
+ public void Insert(int index, in T value, int count)
+ {
+ if (_pos > _raw.Length - count)
+ {
+ Grow(count);
+ }
+
+ int remaining = _pos - index;
+ _raw.Slice(index, remaining).CopyTo(_raw.Slice(index + count));
+ _raw.Slice(index, count).Fill(value);
+ _pos += count;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Append(T c)
+ {
+ int pos = _pos;
+ if ((uint)pos < (uint)_raw.Length)
+ {
+ _raw[pos] = c;
+ _pos = pos + 1;
+ }
+ else
+ {
+ GrowAndAppend(c);
+ }
+ }
+
+ public void Append(in T c, int count)
+ {
+ if (_pos > _raw.Length - count)
+ {
+ Grow(count);
+ }
+
+ Span dst = _raw.Slice(_pos, count);
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = c;
+ }
+ _pos += count;
+ }
+
+ public void Append(ReadOnlySpan value)
+ {
+ int pos = _pos;
+ if (pos > _raw.Length - value.Length)
+ {
+ Grow(value.Length);
+ }
+
+ value.CopyTo(_raw.Slice(_pos));
+ _pos += value.Length;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public Span AppendSpan(int length)
+ {
+ int origPos = _pos;
+ if (origPos > _raw.Length - length)
+ {
+ Grow(length);
+ }
+
+ _pos = origPos + length;
+ return _raw.Slice(origPos, length);
+ }
+
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void GrowAndAppend(T c)
+ {
+ Grow(1);
+ Append(c);
+ }
+
+ ///
+ /// Resize the internal buffer either by doubling current buffer size or
+ /// by adding to
+ /// whichever is greater.
+ ///
+ ///
+ /// Number of chars requested beyond current position.
+ ///
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void Grow(int additionalCapacityBeyondPos)
+ {
+ Debug.Assert(additionalCapacityBeyondPos > 0);
+ Debug.Assert(_pos > _raw.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed.");
+
+ T[] newArray = ArrayPool.Shared.Rent(
+ (int)Math.Max((uint)(_pos + additionalCapacityBeyondPos), (uint)_raw.Length * 2)
+ );
+ _raw.Slice(0, _pos).CopyTo(newArray);
+ _raw = newArray;
+ if (_array is not null) {
+ ArrayPool.Shared.Return(_array);
+ }
+ _array = newArray;
+ }
+
+ public void Dispose() {
+ var array = _array;
+ _array = null;
+ _raw = default;
+ if (array is not null) {
+ ArrayPool.Shared.Return(array);
+ }
+ }
+
+ public T[] ToArray() {
+ var array = new T[_pos];
+ _raw.Slice(0, _pos).CopyTo(array.AsSpan(0, _pos));
+ Dispose();
+ return array;
+ }
+}
diff --git a/src/Common/AssemblyInfo.cs b/src/Common/AssemblyInfo.cs
index 8cfad412..4b93d270 100644
--- a/src/Common/AssemblyInfo.cs
+++ b/src/Common/AssemblyInfo.cs
@@ -1,5 +1,6 @@
using System.Runtime.CompilerServices;
+[assembly: CLSCompliant(true)]
[assembly: InternalsVisibleTo("SurrealDB.Abstractions")]
[assembly: InternalsVisibleTo("SurrealDB.Configuration")]
[assembly: InternalsVisibleTo("SurrealDB.Driver.Rest")]
diff --git a/src/Common/BitOperations.cs b/src/Common/BitOperations.cs
new file mode 100644
index 00000000..3ffd9620
--- /dev/null
+++ b/src/Common/BitOperations.cs
@@ -0,0 +1,699 @@
+// BitOperations without intrinsics for netstandard21 support
+#if !NETCOREAPP3_0_OR_GREATER
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+using SurrealDB.Common;
+
+// ReSharper disable once CheckNamespace
+namespace System.Numerics;
+
+public static class BitOperations
+{
+ // C# no-alloc optimization that directly wraps the data section of the dll (similar to string constants)
+ // https://github.com/dotnet/roslyn/pull/24621
+
+ private static ReadOnlySpan TrailingZeroCountDeBruijn => new byte[32]
+ {
+ 00, 01, 28, 02, 29, 14, 24, 03,
+ 30, 22, 20, 15, 25, 17, 04, 08,
+ 31, 27, 13, 23, 21, 19, 16, 07,
+ 26, 12, 18, 06, 11, 05, 10, 09
+ };
+
+ private static ReadOnlySpan Log2DeBruijn => new byte[32]
+ {
+ 00, 09, 01, 10, 13, 21, 02, 29,
+ 11, 14, 16, 18, 22, 25, 03, 30,
+ 08, 12, 20, 28, 15, 17, 24, 07,
+ 19, 27, 23, 06, 26, 05, 04, 31
+ };
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static bool IsPow2(int value) => (value & (value - 1)) == 0 && value > 0;
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static bool IsPow2(uint value) => (value & (value - 1)) == 0 && value != 0;
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static bool IsPow2(long value) => (value & (value - 1)) == 0 && value > 0;
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static bool IsPow2(ulong value) => (value & (value - 1)) == 0 && value != 0;
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static bool IsPow2(nint value) => (value & (value - 1)) == 0 && value > 0;
+
+ ///
+ /// Evaluate whether a given integral value is a power of 2.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static bool IsPow2(nuint value) => (value & (value - 1)) == 0 && value != 0;
+
+ /// Round the given integral value up to a power of 2.
+ /// The value.
+ ///
+ /// The smallest power of 2 which is greater than or equal to .
+ /// If is 0 or the result overflows, returns 0.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static uint RoundUpToPowerOf2(uint value)
+ {
+ // Based on https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+ --value;
+ value |= value >> 1;
+ value |= value >> 2;
+ value |= value >> 4;
+ value |= value >> 8;
+ value |= value >> 16;
+ return value + 1;
+ }
+
+ ///
+ /// Round the given integral value up to a power of 2.
+ ///
+ /// The value.
+ ///
+ /// The smallest power of 2 which is greater than or equal to .
+ /// If is 0 or the result overflows, returns 0.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static ulong RoundUpToPowerOf2(ulong value)
+ {
+ // Based on https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+ --value;
+ value |= value >> 1;
+ value |= value >> 2;
+ value |= value >> 4;
+ value |= value >> 8;
+ value |= value >> 16;
+ value |= value >> 32;
+ return value + 1;
+ }
+
+ ///
+ /// Round the given integral value up to a power of 2.
+ ///
+ /// The value.
+ ///
+ /// The smallest power of 2 which is greater than or equal to .
+ /// If is 0 or the result overflows, returns 0.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static nuint RoundUpToPowerOf2(nuint value)
+ {
+#if TARGET_64BIT
+ return (nuint)RoundUpToPowerOf2((ulong)value);
+#else
+ return (nuint)RoundUpToPowerOf2((uint)value);
+#endif
+ }
+
+ ///
+ /// Count the number of leading zero bits in a mask.
+ /// Similar in behavior to the x86 instruction LZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int LeadingZeroCount(uint value)
+ {
+ // Unguarded fallback contract is 0->31, BSR contract is 0->undefined
+ if (value == 0)
+ {
+ return 32;
+ }
+
+ return 31 ^ Log2SoftwareFallback(value);
+ }
+
+ ///
+ /// Count the number of leading zero bits in a mask.
+ /// Similar in behavior to the x86 instruction LZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int LeadingZeroCount(ulong value)
+ {
+ uint hi = (uint)(value >> 32);
+
+ if (hi == 0)
+ {
+ return 32 + LeadingZeroCount((uint)value);
+ }
+
+ return LeadingZeroCount(hi);
+ }
+
+ ///
+ /// Count the number of leading zero bits in a mask.
+ /// Similar in behavior to the x86 instruction LZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int LeadingZeroCount(nuint value)
+ {
+#if TARGET_64BIT
+ return LeadingZeroCount((ulong)value);
+#else
+ return LeadingZeroCount((uint)value);
+#endif
+ }
+
+ ///
+ /// Returns the integer (floor) log of the specified value, base 2.
+ /// Note that by convention, input value 0 returns 0 since log(0) is undefined.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int Log2(uint value)
+ {
+ // The 0->0 contract is fulfilled by setting the LSB to 1.
+ // Log(1) is 0, and setting the LSB for values > 1 does not change the log2 result.
+ value |= 1;
+
+ // Fallback contract is 0->0
+ return Log2SoftwareFallback(value);
+ }
+
+ ///
+ /// Returns the integer (floor) log of the specified value, base 2.
+ /// Note that by convention, input value 0 returns 0 since log(0) is undefined.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int Log2(ulong value)
+ {
+ value |= 1;
+
+ uint hi = (uint)(value >> 32);
+
+ if (hi == 0)
+ {
+ return Log2((uint)value);
+ }
+
+ return 32 + Log2(hi);
+ }
+
+ ///
+ /// Returns the integer (floor) log of the specified value, base 2.
+ /// Note that by convention, input value 0 returns 0 since log(0) is undefined.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int Log2(nuint value)
+ {
+#if TARGET_64BIT
+ return Log2((ulong)value);
+#else
+ return Log2((uint)value);
+#endif
+ }
+
+ ///
+ /// Returns the integer (floor) log of the specified value, base 2.
+ /// Note that by convention, input value 0 returns 0 since Log(0) is undefined.
+ /// Does not directly use any hardware intrinsics, nor does it incur branching.
+ ///
+ /// The value.
+ private static int Log2SoftwareFallback(uint value)
+ {
+ // No AggressiveInlining due to large method size
+ // Has conventional contract 0->0 (Log(0) is undefined)
+
+ // Fill trailing zeros with ones, eg 00010010 becomes 00011111
+ value |= value >> 01;
+ value |= value >> 02;
+ value |= value >> 04;
+ value |= value >> 08;
+ value |= value >> 16;
+
+ // uint.MaxValue >> 27 is always in range [0 - 31] so we use Unsafe.AddByteOffset to avoid bounds check
+ return Unsafe.AddByteOffset(
+ // Using deBruijn sequence, k=2, n=5 (2^5=32) : 0b_0000_0111_1100_0100_1010_1100_1101_1101u
+ ref MemoryMarshal.GetReference(Log2DeBruijn),
+ // uint|long -> IntPtr cast on 32-bit platforms does expensive overflow checks not needed here
+ (IntPtr)(int)((value * 0x07C4ACDDu) >> 27));
+ }
+
+ /// Returns the integer (ceiling) log of the specified value, base 2.
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static int Log2Ceiling(uint value)
+ {
+ int result = Log2(value);
+ if (PopCount(value) != 1)
+ {
+ result++;
+ }
+ return result;
+ }
+
+ /// Returns the integer (ceiling) log of the specified value, base 2.
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static int Log2Ceiling(ulong value)
+ {
+ int result = Log2(value);
+ if (PopCount(value) != 1)
+ {
+ result++;
+ }
+ return result;
+ }
+
+ ///
+ /// Returns the population count (number of bits set) of a mask.
+ /// Similar in behavior to the x86 instruction POPCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int PopCount(uint value)
+ {
+ const uint c1 = 0x_55555555u;
+ const uint c2 = 0x_33333333u;
+ const uint c3 = 0x_0F0F0F0Fu;
+ const uint c4 = 0x_01010101u;
+
+ value -= (value >> 1) & c1;
+ value = (value & c2) + ((value >> 2) & c2);
+ value = (((value + (value >> 4)) & c3) * c4) >> 24;
+
+ return (int)value;
+ }
+
+ ///
+ /// Returns the population count (number of bits set) of a mask.
+ /// Similar in behavior to the x86 instruction POPCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int PopCount(ulong value)
+ {
+#if TARGET_32BIT
+ return PopCount((uint)value) // lo
+ + PopCount((uint)(value >> 32)); // hi
+#else
+ const ulong c1 = 0x_55555555_55555555ul;
+ const ulong c2 = 0x_33333333_33333333ul;
+ const ulong c3 = 0x_0F0F0F0F_0F0F0F0Ful;
+ const ulong c4 = 0x_01010101_01010101ul;
+
+ value -= (value >> 1) & c1;
+ value = (value & c2) + ((value >> 2) & c2);
+ value = (((value + (value >> 4)) & c3) * c4) >> 56;
+
+ return (int)value;
+#endif
+ }
+
+ ///
+ /// Returns the population count (number of bits set) of a mask.
+ /// Similar in behavior to the x86 instruction POPCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int PopCount(nuint value)
+ {
+#if TARGET_64BIT
+ return PopCount((ulong)value);
+#else
+ return PopCount((uint)value);
+#endif
+ }
+
+ ///
+ /// Count the number of trailing zero bits in an integer value.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int TrailingZeroCount(int value)
+ => TrailingZeroCount((uint)value);
+
+ ///
+ /// Count the number of trailing zero bits in an integer value.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [CLSCompliant(false)]
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int TrailingZeroCount(uint value)
+ {
+ // Unguarded fallback contract is 0->0, BSF contract is 0->undefined
+ if (value == 0)
+ {
+ return 32;
+ }
+
+ // uint.MaxValue >> 27 is always in range [0 - 31] so we use Unsafe.AddByteOffset to avoid bounds check
+ return Unsafe.AddByteOffset(
+ // Using deBruijn sequence, k=2, n=5 (2^5=32) : 0b_0000_0111_0111_1100_1011_0101_0011_0001u
+ ref MemoryMarshal.GetReference(TrailingZeroCountDeBruijn),
+ // uint|long -> IntPtr cast on 32-bit platforms does expensive overflow checks not needed here
+ (IntPtr)(int)(((value & (uint)-(int)value) * 0x077CB531u) >> 27)); // Multi-cast mitigates redundant conv.u8
+ }
+
+ ///
+ /// Count the number of trailing zero bits in a mask.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int TrailingZeroCount(long value)
+ => TrailingZeroCount((ulong)value);
+
+ ///
+ /// Count the number of trailing zero bits in a mask.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int TrailingZeroCount(ulong value)
+ {
+ uint lo = (uint)value;
+
+ if (lo == 0)
+ {
+ return 32 + TrailingZeroCount((uint)(value >> 32));
+ }
+
+ return TrailingZeroCount(lo);
+ }
+
+ ///
+ /// Count the number of trailing zero bits in a mask.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int TrailingZeroCount(nint value)
+ => TrailingZeroCount((nuint)value);
+
+ ///
+ /// Count the number of trailing zero bits in a mask.
+ /// Similar in behavior to the x86 instruction TZCNT.
+ ///
+ /// The value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static int TrailingZeroCount(nuint value)
+ {
+#if TARGET_64BIT
+ return TrailingZeroCount((ulong)value);
+#else
+ return TrailingZeroCount((uint)value);
+#endif
+ }
+
+ ///
+ /// Rotates the specified value left by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROL.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static uint RotateLeft(uint value, int offset)
+ => (value << offset) | (value >> (32 - offset));
+
+ ///
+ /// Rotates the specified value left by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROL.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..63] is treated as congruent mod 64.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static ulong RotateLeft(ulong value, int offset)
+ => (value << offset) | (value >> (64 - offset));
+
+ ///
+ /// Rotates the specified value left by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROL.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32 on a 32-bit process,
+ /// and any value outside the range [0..63] is treated as congruent mod 64 on a 64-bit process.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static nuint RotateLeft(nuint value, int offset)
+ {
+#if TARGET_64BIT
+ return (nuint)RotateLeft((ulong)value, offset);
+#else
+ return (nuint)RotateLeft((uint)value, offset);
+#endif
+ }
+
+ ///
+ /// Rotates the specified value right by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROR.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static uint RotateRight(uint value, int offset)
+ => (value >> offset) | (value << (32 - offset));
+
+ ///
+ /// Rotates the specified value right by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROR.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..63] is treated as congruent mod 64.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static ulong RotateRight(ulong value, int offset)
+ => (value >> offset) | (value << (64 - offset));
+
+ ///
+ /// Rotates the specified value right by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROR.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32 on a 32-bit process,
+ /// and any value outside the range [0..63] is treated as congruent mod 64 on a 64-bit process.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [CLSCompliant(false)]
+ public static nuint RotateRight(nuint value, int offset)
+ {
+#if TARGET_64BIT
+ return (nuint)RotateRight((ulong)value, offset);
+#else
+ return (nuint)RotateRight((uint)value, offset);
+#endif
+ }
+
+ ///
+ /// Accumulates the CRC (Cyclic redundancy check) checksum.
+ ///
+ /// The base value to calculate checksum on
+ /// The data for which to compute the checksum
+ /// The CRC-checksum
+ [CLSCompliant(false)]
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static uint Crc32C(uint crc, byte data)
+ {
+ return Crc32Fallback.Crc32C(crc, data);
+ }
+
+ ///
+ /// Accumulates the CRC (Cyclic redundancy check) checksum.
+ ///
+ /// The base value to calculate checksum on
+ /// The data for which to compute the checksum
+ /// The CRC-checksum
+ [CLSCompliant(false)]
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static uint Crc32C(uint crc, ushort data)
+ {
+ return Crc32Fallback.Crc32C(crc, data);
+ }
+
+ ///
+ /// Accumulates the CRC (Cyclic redundancy check) checksum.
+ ///
+ /// The base value to calculate checksum on
+ /// The data for which to compute the checksum
+ /// The CRC-checksum
+ [CLSCompliant(false)]
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static uint Crc32C(uint crc, uint data)
+ {
+ return Crc32Fallback.Crc32C(crc, data);
+ }
+
+ ///
+ /// Accumulates the CRC (Cyclic redundancy check) checksum.
+ ///
+ /// The base value to calculate checksum on
+ /// The data for which to compute the checksum
+ /// The CRC-checksum
+ [CLSCompliant(false)]
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static uint Crc32C(uint crc, ulong data)
+ {
+ return Crc32Fallback.Crc32C(crc, data);
+ }
+
+ private static class Crc32Fallback
+ {
+ // Pre-computed CRC-32 transition table.
+ // While this implementation is based on the Castagnoli CRC-32 polynomial (CRC-32C),
+ // x32 + x28 + x27 + x26 + x25 + x23 + x22 + x20 + x19 + x18 + x14 + x13 + x11 + x10 + x9 + x8 + x6 + x0,
+ // this version uses reflected bit ordering, so 0x1EDC6F41 becomes 0x82F63B78u
+ private static readonly uint[] s_crcTable = Crc32ReflectedTable.Generate(0x82F63B78u);
+
+ internal static uint Crc32C(uint crc, byte data)
+ {
+ ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable);
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8);
+
+ return crc;
+ }
+
+ internal static uint Crc32C(uint crc, ushort data)
+ {
+ ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable);
+
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8);
+ data >>= 8;
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8);
+
+ return crc;
+ }
+
+ internal static uint Crc32C(uint crc, uint data)
+ {
+ ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable);
+ return Crc32CCore(ref lookupTable, crc, data);
+ }
+
+ internal static uint Crc32C(uint crc, ulong data)
+ {
+ ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable);
+
+ crc = Crc32CCore(ref lookupTable, crc, (uint)data);
+ data >>= 32;
+ crc = Crc32CCore(ref lookupTable, crc, (uint)data);
+
+ return crc;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint Crc32CCore(ref uint lookupTable, uint crc, uint data)
+ {
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8);
+ data >>= 8;
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8);
+ data >>= 8;
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8);
+ data >>= 8;
+ crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8);
+
+ return crc;
+ }
+ }
+
+ ///
+ /// Reset the lowest significant bit in the given value
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static uint ResetLowestSetBit(uint value)
+ {
+ // It's lowered to BLSR on x86
+ return value & (value - 1);
+ }
+
+ ///
+ /// Reset specific bit in the given value
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static uint ResetBit(uint value, int bitPos)
+ {
+ // TODO: Recognize BTR on x86 and LSL+BIC on ARM
+ return value & ~(uint)(1 << bitPos);
+ }
+}
+#endif
+
+
+internal static class Crc32ReflectedTable
+{
+ internal static uint[] Generate(uint reflectedPolynomial)
+ {
+ uint[] table = new uint[256];
+
+ for (int i = 0; i < 256; i++)
+ {
+ uint val = (uint)i;
+
+ for (int j = 0; j < 8; j++)
+ {
+ if ((val & 0b0000_0001) == 0)
+ {
+ val >>= 1;
+ }
+ else
+ {
+ val = (val >> 1) ^ reflectedPolynomial;
+ }
+ }
+
+ table[i] = val;
+ }
+
+ return table;
+ }
+}
diff --git a/src/Common/BufferStreamReader.cs b/src/Common/BufferStreamReader.cs
new file mode 100644
index 00000000..942c89e2
--- /dev/null
+++ b/src/Common/BufferStreamReader.cs
@@ -0,0 +1,113 @@
+using System.Buffers;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+/// Allows reading a stream efficiently
+public struct BufferStreamReader : IDisposable {
+ private Stream? _arbitraryStream;
+ private MemoryStream? _memoryStream;
+ private readonly int _bufferSize;
+ private byte[]? _poolArray;
+
+ private BufferStreamReader(Stream? arbitraryStream, MemoryStream? memoryStream, int bufferSize) {
+ _arbitraryStream = arbitraryStream;
+ _memoryStream = memoryStream;
+ _bufferSize = bufferSize;
+ _poolArray = null;
+ }
+
+ public Stream Stream => _memoryStream ?? _arbitraryStream!;
+
+ public BufferStreamReader(Stream stream, int bufferSize) {
+ ThrowArgIfStreamCantRead(stream);
+ this = stream switch {
+ // inhering from ms doesnt guarantee a good GetBuffer impl, such as RecyclableMemoryStream,
+ // therefore we have to check the exact runtime type.
+ MemoryStream ms when ms.GetType() == typeof(MemoryStream) => new(null, ms, bufferSize),
+ _ => new(stream, null, bufferSize)
+ };
+ }
+
+ /// Reads up to bytes from the underlying .
+ /// The expected number of bytes to read
+ /// The cancellation token
+ /// The context bound memory representing the bytes read.
+ /// The returned memory is invalid outside this instance. Do not reference the memory outside of the scope!
+ public ValueTask> ReadAsync(int expectedSize, CancellationToken ct = default) {
+ var memoryStream = _memoryStream;
+ var stream = _arbitraryStream;
+ ThrowIfNull(stream is null & memoryStream is null);
+
+ if (memoryStream is not null) {
+ if (memoryStream.TryReadToBuffer(expectedSize, out ReadOnlyMemory read)) {
+ return new(read);
+ }
+
+ // unable to access the memory stream buffer
+ // handle as regular stream
+ stream = memoryStream;
+ }
+
+ Debug.Assert(stream is not null);
+ // reserve the buffer
+ var buffer = _poolArray;
+ if (buffer is null) {
+ _poolArray = buffer = ArrayPool.Shared.Rent(_bufferSize);
+ }
+
+ // negative buffer size -> read as much as possible
+ expectedSize = expectedSize < 0 ? buffer.Length : expectedSize;
+
+ return new(stream.ReadToBufferAsync(buffer.AsMemory(0, Math.Min(buffer.Length, expectedSize)), ct));
+ }
+
+ ///
+ public ReadOnlySpan Read(int expectedSize) {
+ var memoryStream = _memoryStream;
+ var stream = _arbitraryStream;
+ ThrowIfNull(stream is null & memoryStream is null);
+
+ if (memoryStream is not null) {
+ if (memoryStream.TryReadToBuffer(expectedSize, out ReadOnlySpan read)) {
+ return read;
+ }
+
+ // unable to access the memory stream buffer
+ // handle as regular stream
+ stream = memoryStream;
+ }
+
+ Debug.Assert(stream is not null);
+ // reserve the buffer
+ var buffer = _poolArray;
+ if (buffer is null) {
+ _poolArray = buffer = ArrayPool.Shared.Rent(_bufferSize);
+ }
+
+ return stream.ReadToBuffer(buffer.AsSpan(0, Math.Min(buffer.Length, expectedSize)));
+ }
+
+
+ public void Dispose() {
+ var poolArray = _poolArray;
+ _poolArray = null;
+ if (poolArray is not null) {
+ ArrayPool.Shared.Return(poolArray);
+ }
+ }
+
+ private static void ThrowIfNull([DoesNotReturnIf(true)] bool isNull, [CallerArgumentExpression(nameof(isNull))] string expression = "") {
+ if (isNull) {
+ throw new InvalidOperationException($"The expression cannot be null. `{expression}`");
+ }
+ }
+
+ private static void ThrowArgIfStreamCantRead(Stream stream, [CallerArgumentExpression(nameof(stream))] string argName = "") {
+ if (!stream.CanRead) {
+ throw new ArgumentException("The stream must be readable", argName);
+ }
+ }
+}
diff --git a/src/Common/CallerArgumentExpressionAttribute.cs b/src/Common/CallerArgumentExpressionAttribute.cs
index 6f7f11b7..e853763a 100644
--- a/src/Common/CallerArgumentExpressionAttribute.cs
+++ b/src/Common/CallerArgumentExpressionAttribute.cs
@@ -1,5 +1,5 @@
// ReSharper disable CheckNamespace
-#if !(NET6_0 || NET_5_0 || NET5_0_OR_GREATER || NETCOREAPP3_0_OR_GREATER)
+#if !NETCOREAPP3_0_OR_GREATER
#pragma warning disable IDE0130
namespace System.Runtime.CompilerServices;
diff --git a/src/Common/DisposingCache.cs b/src/Common/DisposingCache.cs
new file mode 100644
index 00000000..85e11c8c
--- /dev/null
+++ b/src/Common/DisposingCache.cs
@@ -0,0 +1,101 @@
+using System.Collections.Concurrent;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+/// Thread-safe sliding cache that disposed values when evicting
+internal sealed class DisposingCache
+ where K : notnull
+ where V : IDisposable {
+ private int _evictLock;
+ private long _lastEvictedTicks; // timestamp of latest eviction operation.
+ private readonly long _evictionIntervalTicks; // min timespan needed to trigger a new evict operation.
+ private readonly long _slidingExpirationTicks; // max timespan allowed for cache entries to remain inactive.
+ private readonly ConcurrentDictionary _cache = new();
+
+ public DisposingCache(TimeSpan slidingExpiration, TimeSpan evictionInterval) {
+ _slidingExpirationTicks = slidingExpiration.Ticks;
+ _evictionIntervalTicks = evictionInterval.Ticks;
+ _lastEvictedTicks = DateTime.UtcNow.Ticks;
+ }
+
+ public V GetOrAdd(K key, Func valueFactory) {
+ CacheEntry entry = _cache.GetOrAdd(key, static (key, valueFactory) => new(valueFactory(key)), valueFactory);
+ EnsureSlidingEviction(entry);
+
+ return entry.Value;
+ }
+
+ public bool TryAdd(K key, V value) {
+ CacheEntry entry = new(value);
+ bool added = _cache.TryAdd(key, entry);
+ EnsureSlidingEviction(entry);
+
+ return added;
+ }
+
+ public bool TryRemove(K key, [MaybeNullWhen(false)] out V value) {
+ if (_cache.TryRemove(key, out var entry)) {
+ EnsureSlidingEviction(entry);
+ value = entry.Value;
+ return true;
+ }
+
+ value = default;
+ return false;
+ }
+
+ public bool TryGetValue(K key, [MaybeNullWhen(false)] out V value) {
+ if (_cache.TryRemove(key, out var entry)) {
+ EnsureSlidingEviction(entry);
+ value = entry.Value;
+ return true;
+ }
+
+ value = default;
+ return false;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private void EnsureSlidingEviction(CacheEntry entry) {
+ long utcNowTicks = DateTime.UtcNow.Ticks;
+ Volatile.Write(ref entry.LastUsedTicks, utcNowTicks);
+
+ if (utcNowTicks - Volatile.Read(ref _lastEvictedTicks) >= _evictionIntervalTicks) {
+ if (Interlocked.CompareExchange(ref _evictLock, 1, 0) == 0) {
+ if (utcNowTicks - _lastEvictedTicks >= _evictionIntervalTicks) {
+ EvictStaleCacheEntries(utcNowTicks);
+ Volatile.Write(ref _lastEvictedTicks, utcNowTicks);
+ }
+
+ Volatile.Write(ref _evictLock, 0);
+ }
+ }
+ }
+
+ public void Clear() {
+ _cache.Clear();
+ _lastEvictedTicks = DateTime.UtcNow.Ticks;
+ }
+
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private void EvictStaleCacheEntries(long utcNowTicks) {
+ foreach (KeyValuePair kvp in _cache) {
+ if (utcNowTicks - Volatile.Read(ref kvp.Value.LastUsedTicks) >= _slidingExpirationTicks) {
+ if (_cache.TryRemove(kvp.Key, out var entry)) {
+ entry.Value.Dispose();
+ }
+ }
+ }
+ }
+
+ private sealed class CacheEntry {
+ public readonly V Value;
+ public long LastUsedTicks;
+
+ public CacheEntry(V value) {
+ Value = value;
+ }
+ }
+}
diff --git a/src/Common/GCHelper.cs b/src/Common/GCHelper.cs
new file mode 100644
index 00000000..93d130d9
--- /dev/null
+++ b/src/Common/GCHelper.cs
@@ -0,0 +1,55 @@
+using System.Diagnostics;
+using System.Numerics;
+using System.Runtime.CompilerServices;
+
+// ReSharper disable once CheckNamespace
+namespace System.Buffers;
+
+internal static class GCHelper {
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static int SelectBucketIndex(int bufferSize)
+ {
+ // Buffers are bucketed so that a request between 2^(n-1) + 1 and 2^n is given a buffer of 2^n
+ // Bucket index is log2(bufferSize - 1) with the exception that buffers between 1 and 16 bytes
+ // are combined, and the index is slid down by 3 to compensate.
+ // Zero is a valid bufferSize, and it is assigned the highest bucket index so that zero-length
+ // buffers are not retained by the pool. The pool will return the Array.Empty singleton for these.
+ return BitOperations.Log2((uint)bufferSize - 1 | 15) - 3;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal static int GetMaxSizeForBucket(int binIndex)
+ {
+ int maxSize = 16 << binIndex;
+ Debug.Assert(maxSize >= 0);
+ return maxSize;
+ }
+
+ internal enum MemoryPressure
+ {
+ Low,
+ Medium,
+ High
+ }
+
+ internal static MemoryPressure GetMemoryPressure()
+ {
+#if NETCOREAPP3_0_OR_GREATER
+ const double HighPressureThreshold = .90; // Percent of GC memory pressure threshold we consider "high"
+ const double MediumPressureThreshold = .70; // Percent of GC memory pressure threshold we consider "medium"
+ GCMemoryInfo memoryInfo = GC.GetGCMemoryInfo();
+
+ if (memoryInfo.MemoryLoadBytes >= memoryInfo.HighMemoryLoadThresholdBytes * HighPressureThreshold)
+ {
+ return MemoryPressure.High;
+ }
+
+ if (memoryInfo.MemoryLoadBytes >= memoryInfo.HighMemoryLoadThresholdBytes * MediumPressureThreshold)
+ {
+ return MemoryPressure.Medium;
+ }
+#endif
+ return MemoryPressure.Low;
+ }
+}
diff --git a/src/Common/Gen2GCCallback.cs b/src/Common/Gen2GCCallback.cs
new file mode 100644
index 00000000..81a4eb5e
--- /dev/null
+++ b/src/Common/Gen2GCCallback.cs
@@ -0,0 +1,112 @@
+using System.Diagnostics;
+using System.Runtime.ConstrainedExecution;
+using System.Runtime.InteropServices;
+
+// Source: https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Gen2GcCallback.cs
+
+// ReSharper disable once CheckNamespace
+namespace System;
+
+///
+/// Schedules a callback roughly every gen 2 GC (you may see a Gen 0 an Gen 1 but only once)
+/// (We can fix this by capturing the Gen 2 count at startup and testing, but I mostly don't care)
+///
+internal sealed class Gen2GcCallback : CriticalFinalizerObject
+{
+ private readonly Func? _callback0;
+ private readonly Func? _callback1;
+ private GCHandle _weakTargetObj;
+
+ private Gen2GcCallback(Func callback)
+ {
+ _callback0 = callback;
+ }
+
+ private Gen2GcCallback(Func callback, object targetObj)
+ {
+ _callback1 = callback;
+ _weakTargetObj = GCHandle.Alloc(targetObj, GCHandleType.Weak);
+ }
+
+ ///
+ /// Schedule 'callback' to be called in the next GC. If the callback returns true it is
+ /// rescheduled for the next Gen 2 GC. Otherwise the callbacks stop.
+ ///
+ public static void Register(Func callback)
+ {
+ // Create a unreachable object that remembers the callback function and target object.
+ new Gen2GcCallback(callback);
+ }
+
+ ///
+ /// Schedule 'callback' to be called in the next GC. If the callback returns true it is
+ /// rescheduled for the next Gen 2 GC. Otherwise the callbacks stop.
+ ///
+ /// NOTE: This callback will be kept alive until either the callback function returns false,
+ /// or the target object dies.
+ ///
+ public static void Register(Func callback, object targetObj)
+ {
+ // Create a unreachable object that remembers the callback function and target object.
+ new Gen2GcCallback(callback, targetObj);
+ }
+
+ ~Gen2GcCallback()
+ {
+ if (_weakTargetObj.IsAllocated)
+ {
+ // Check to see if the target object is still alive.
+ object? targetObj = _weakTargetObj.Target;
+ if (targetObj == null)
+ {
+ // The target object is dead, so this callback object is no longer needed.
+ _weakTargetObj.Free();
+ return;
+ }
+
+ // Execute the callback method.
+ try
+ {
+ Debug.Assert(_callback1 != null);
+ if (!_callback1(targetObj))
+ {
+ // If the callback returns false, this callback object is no longer needed.
+ _weakTargetObj.Free();
+ return;
+ }
+ }
+ catch
+ {
+ // Ensure that we still get a chance to resurrect this object, even if the callback throws an exception.
+#if DEBUG
+ // Except in DEBUG, as we really shouldn't be hitting any exceptions here.
+ throw;
+#endif
+ }
+ }
+ else
+ {
+ // Execute the callback method.
+ try
+ {
+ Debug.Assert(_callback0 != null);
+ if (!_callback0())
+ {
+ // If the callback returns false, this callback object is no longer needed.
+ return;
+ }
+ }
+ catch
+ {
+ // Ensure that we still get a chance to resurrect this object, even if the callback throws an exception.
+#if DEBUG
+ // Except in DEBUG, as we really shouldn't be hitting any exceptions here.
+ throw;
+#endif
+ }
+ }
+
+ // Resurrect ourselves by re-registering for finalization.
+ GC.ReRegisterForFinalize(this);
+ }
+}
diff --git a/src/Common/IsExternalInit.cs b/src/Common/IsExternalInit.cs
index 4129a311..7f1adb48 100644
--- a/src/Common/IsExternalInit.cs
+++ b/src/Common/IsExternalInit.cs
@@ -1,5 +1,5 @@
// ReSharper disable CheckNamespace
-#if !(NET6_0 || NET_5_0 || NET5_0_OR_GREATER)
+#if !NET5_0_OR_GREATER
#pragma warning disable IDE0130
namespace System.Runtime.CompilerServices;
diff --git a/src/Common/LatentReadonly.cs b/src/Common/LatentReadonly.cs
new file mode 100644
index 00000000..1405a50e
--- /dev/null
+++ b/src/Common/LatentReadonly.cs
@@ -0,0 +1,34 @@
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+public abstract record LatentReadonly {
+ private bool _isReadonly;
+
+ protected void MakeReadonly() {
+ _isReadonly = true;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ protected bool IsReadonly() => _isReadonly;
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ protected void Set(out T field, in T value) {
+ ThrowIfReadonly();
+ field = value;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private void ThrowIfReadonly() {
+ if (_isReadonly) {
+ ThrowReadonly();
+ }
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowReadonly() {
+ throw new InvalidOperationException("The object is readonly and cannot be mutated.");
+ }
+}
diff --git a/src/Common/MemoryExtensions.cs b/src/Common/MemoryExtensions.cs
deleted file mode 100644
index e2106d58..00000000
--- a/src/Common/MemoryExtensions.cs
+++ /dev/null
@@ -1,7 +0,0 @@
-namespace SurrealDB.Common;
-
-internal static class MemoryExtensions {
- public static ReadOnlySpan SliceToMin(in this ReadOnlySpan span, int length) {
- return span.Length <= length ? span : span.Slice(0, length);
- }
-}
diff --git a/src/Common/MemoryHelper.cs b/src/Common/MemoryHelper.cs
new file mode 100644
index 00000000..09407f2e
--- /dev/null
+++ b/src/Common/MemoryHelper.cs
@@ -0,0 +1,41 @@
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+internal static class MemoryHelper {
+ ///
+ /// Returns a reference to the 0th element of . If the array is empty, returns a reference to where the 0th element
+ /// would have been stored. Such a reference may be used for pinning but must never be dereferenced.
+ ///
+ /// is .
+ ///
+ /// This method does not perform array variance checks. The caller must manually perform any array variance checks
+ /// if the caller wishes to write to the returned reference.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ref T GetArrayDataReference(T[] array) {
+#if !NET6_0_OR_GREATER
+ return ref Unsafe.As(ref Unsafe.As(array).Data);
+#else
+ return ref System.Runtime.InteropServices.MemoryMarshal.GetArrayDataReference(array);
+#endif
+ }
+
+
+ // CLR arrays are laid out in memory as follows (multidimensional array bounds are optional):
+ // [ sync block || pMethodTable || num components || MD array bounds || array data .. ]
+ // ^ ^ ^ ^ returned reference
+ // | | \-- ref Unsafe.As(array).Data
+ // \-- array \-- ref Unsafe.As(array).Data
+ // The BaseSize of an array includes all the fields before the array data,
+ // including the sync block and method table. The reference to RawData.Data
+ // points at the number of components, skipping over these two pointer-sized fields.
+ internal sealed class RawArrayData
+ {
+ public uint Length; // Array._numComponents padded to IntPtr
+#if TARGET_64BIT
+ public uint Padding;
+#endif
+ public byte Data;
+ }
+}
diff --git a/src/Common/NullableAttributes.cs b/src/Common/NullableAttributes.cs
new file mode 100644
index 00000000..bf53cdf6
--- /dev/null
+++ b/src/Common/NullableAttributes.cs
@@ -0,0 +1,66 @@
+// ReSharper disable CheckNamespace
+#if !NET5_0_OR_GREATER
+
+#pragma warning disable IDE0130
+namespace System.Diagnostics.CodeAnalysis;
+#pragma warning restore IDE0130
+
+/// Specifies that the method or property will ensure that the listed field and property members have not-null values when returning with the specified return value condition.
+[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)]
+internal sealed class MemberNotNullWhenAttribute : Attribute
+{
+ /// Initializes the attribute with the specified return value condition and a field or property member.
+ ///
+ /// The return value condition. If the method returns this value, the associated parameter will not be null.
+ ///
+ ///
+ /// The field or property member that is promised to be not-null.
+ ///
+ public MemberNotNullWhenAttribute(bool returnValue, string member)
+ {
+ ReturnValue = returnValue;
+ Members = new[] { member };
+ }
+
+ /// Initializes the attribute with the specified return value condition and list of field and property members.
+ ///
+ /// The return value condition. If the method returns this value, the associated parameter will not be null.
+ ///
+ ///
+ /// The list of field and property members that are promised to be not-null.
+ ///
+ public MemberNotNullWhenAttribute(bool returnValue, params string[] members)
+ {
+ ReturnValue = returnValue;
+ Members = members;
+ }
+
+ /// Gets the return value condition.
+ public bool ReturnValue { get; }
+
+ /// Gets field or property member names.
+ public string[] Members { get; }
+}
+
+/// Specifies that the method or property will ensure that the listed field and property members have not-null values.
+[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)]
+internal sealed class MemberNotNullAttribute : Attribute
+{
+ /// Initializes the attribute with a field or property member.
+ ///
+ /// The field or property member that is promised to be not-null.
+ ///
+ public MemberNotNullAttribute(string member) => Members = new[] { member };
+
+ /// Initializes the attribute with the list of field and property members.
+ ///
+ /// The list of field and property members that are promised to be not-null.
+ ///
+ public MemberNotNullAttribute(params string[] members) => Members = members;
+
+ /// Gets field or property member names.
+ public string[] Members { get; }
+}
+
+
+#endif
diff --git a/src/Common/StreamExtensions.cs b/src/Common/StreamExtensions.cs
new file mode 100644
index 00000000..9e299425
--- /dev/null
+++ b/src/Common/StreamExtensions.cs
@@ -0,0 +1,45 @@
+namespace SurrealDB.Common;
+
+internal static class StreamExtensions {
+ public static bool TryReadToBuffer(this MemoryStream memoryStream, int expectedSize, out ReadOnlyMemory read) {
+ if (!memoryStream.TryGetBuffer(out var buffer)) {
+ read = default;
+ return false;
+ }
+ // negative size -> read to end
+ expectedSize = expectedSize < 0 ? Int32.MaxValue : expectedSize;
+ // fake a read call
+ var pos = (int)memoryStream.Position;
+ var cap = (int)memoryStream.Length;
+ var len = Math.Min(expectedSize, cap - pos);
+ memoryStream.Position += len;
+ read = buffer.AsMemory(pos, len);
+ return true;
+ }
+
+ public static bool TryReadToBuffer(this MemoryStream memoryStream, int expectedSize, out ReadOnlySpan read) {
+ if (!memoryStream.TryGetBuffer(out var buffer)) {
+ read = default;
+ return false;
+ }
+ // negative size -> read to end
+ expectedSize = expectedSize < 0 ? Int32.MaxValue : expectedSize;
+ // fake a read call
+ var pos = (int)memoryStream.Position;
+ var cap = (int)memoryStream.Length;
+ var len = Math.Min(expectedSize, cap - pos);
+ memoryStream.Position += len;
+ read = buffer.AsSpan(pos, len);
+ return true;
+ }
+
+ public static async Task> ReadToBufferAsync(this Stream stream, Memory buffer, CancellationToken ct) {
+ int read = await stream.ReadAsync(buffer, ct);
+ return buffer.Slice(0, read);
+ }
+
+ public static ReadOnlySpan ReadToBuffer(this Stream stream, Span buffer) {
+ int read = stream.Read(buffer);
+ return buffer.Slice(0, read);
+ }
+}
diff --git a/src/Common/StringHelper.cs b/src/Common/StringHelper.cs
index 7952c0c0..4ec80848 100644
--- a/src/Common/StringHelper.cs
+++ b/src/Common/StringHelper.cs
@@ -16,7 +16,7 @@ public static bool IsEmpty([NotNullWhen(false)] this string? str) {
}
public static string Concat(ReadOnlySpan p0, ReadOnlySpan p1, ReadOnlySpan p2 = default, ReadOnlySpan p3 = default) {
-#if NET5_0_OR_GREATER || NETCOREAPP3_0_OR_GREATER
+#if NETCOREAPP3_0_OR_GREATER
return String.Concat(p0, p1, p2, p3);
#else
int cap = p0.Length + p1.Length + p2.Length + p3.Length;
diff --git a/src/Common/TaskExtensions.cs b/src/Common/TaskExtensions.cs
new file mode 100644
index 00000000..9661c73e
--- /dev/null
+++ b/src/Common/TaskExtensions.cs
@@ -0,0 +1,59 @@
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+/// Extension methods for Tasks
+public static class TaskExtensions {
+ /// The task is invariant.
+ /// The previous context is not restored upon completion.
+ /// Equivalent to Task.ConfigureAwait(false)
.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ConfiguredTaskAwaitable Inv(this Task t) => t.ConfigureAwait(false);
+
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ConfiguredTaskAwaitable Inv(this Task t) => t.ConfigureAwait(false);
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ConfiguredValueTaskAwaitable Inv(in this ValueTask t) => t.ConfigureAwait(false);
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static ConfiguredValueTaskAwaitable Inv(in this ValueTask t) => t.ConfigureAwait(false);
+
+ /// Creates a task awaiting the .
+ /// the handle is null
+ public static Task ToTask(this WaitHandle handle)
+ {
+ if (handle == null) {
+ throw new ArgumentNullException(nameof(handle));
+ }
+
+ TaskCompletionSource tcs = new();
+ RegisteredWaitHandle? shared = null;
+ RegisteredWaitHandle produced = ThreadPool.RegisterWaitForSingleObject(
+ handle,
+ (state, timedOut) =>
+ {
+ tcs.SetResult(null);
+
+ while (true)
+ {
+ RegisteredWaitHandle? consumed = Interlocked.CompareExchange(ref shared, null, null);
+ if (consumed is not null)
+ {
+ consumed.Unregister(null);
+ break;
+ }
+ }
+ },
+ state: null,
+ millisecondsTimeOutInterval: Timeout.Infinite,
+ executeOnlyOnce: true);
+
+ // Publish the RegisteredWaitHandle so that the callback can see it.
+ Interlocked.CompareExchange(ref shared, produced, null);
+
+ return tcs.Task;
+ }
+
+}
diff --git a/src/Common/TaskList.cs b/src/Common/TaskList.cs
deleted file mode 100644
index 10d13c9a..00000000
--- a/src/Common/TaskList.cs
+++ /dev/null
@@ -1,119 +0,0 @@
-using System.Collections;
-using System.Diagnostics;
-
-namespace SurrealDB.Common;
-
-public sealed class TaskList {
- private readonly object _lock = new();
- private readonly Node _root;
- private Node _tail;
- private int _len;
-
- public TaskList() {
- _root = _tail = new(Task.CompletedTask);
- }
-
- public void Add(Task task) {
- lock (_lock) {
- Debug.Assert(_tail.Next is null);
- Node tail = new(task) { Prev = _tail };
- _tail.Next = tail;
- _tail = tail;
- _len += 1;
- }
- }
-
- public void Trim() {
- lock (_lock) {
- Node? pos = _root;
- do {
- Node cur = pos;
- pos = pos.Next;
- Task task = cur.Task;
- if (task.IsCompleted) {
- Remove(cur);
- }
- } while (pos is not null);
- }
- }
-
- public ValueTask WhenAll() {
- return _len == 0 ? default : new(Task.WhenAll(Drain()));
- }
-
- public DrainIterator Drain() {
- return new DrainIterator(this);
- }
-
- ///
- /// Removes the node from the list. Requires _lock!
- ///
- private bool Remove(Node node) {
- if (Object.ReferenceEquals(_root, node)) {
- // Do not remove the root!
- return false;
- }
- Node? prev = node.Prev;
- Node? next = node.Next;
- if (Object.ReferenceEquals(_tail, node)) {
- _tail = prev!; // cannot be null because of _root
- }
- if (prev is not null) {
- prev.Next = next;
- }
- if (next is not null) {
- next.Prev = prev;
- }
-
- _len -= 1;
- return true;
- }
-
- private sealed class Node {
- public readonly Task Task;
- public Node? Next;
- public Node? Prev;
-
- public Node(Task task) {
- Task = task;
- }
- }
-
- public struct DrainIterator : IEnumerable, IEnumerator {
- private readonly TaskList _list;
-
- public DrainIterator(TaskList list) {
- _list = list;
- }
-
- public DrainIterator GetEnumerator() {
- return new(_list);
- }
-
- IEnumerator IEnumerable.GetEnumerator() {
- return GetEnumerator();
- }
-
- IEnumerator IEnumerable.GetEnumerator() {
- return GetEnumerator();
- }
-
- public bool MoveNext() {
- lock (_list._lock) {
- return _list.Remove(_list._tail);
- }
- }
-
- public void Reset() {
- throw new NotSupportedException("Cannot be reset");
- }
-
- public Task Current => _list._tail.Task;
-
- object IEnumerator.Current => Current;
-
- public void Dispose() {
- // not needed
- }
- }
-}
diff --git a/src/Common/ValidateReadonly.cs b/src/Common/ValidateReadonly.cs
new file mode 100644
index 00000000..9459e2c9
--- /dev/null
+++ b/src/Common/ValidateReadonly.cs
@@ -0,0 +1,75 @@
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+using System.Runtime.Serialization;
+using System.Text;
+
+namespace SurrealDB.Common;
+
+public abstract record ValidateReadonly : LatentReadonly {
+ /// Evaluates , if it yields any errors, throws a with the errors
+ protected void ValidateOrThrow(string? message = null) {
+ PoolArrayBuilder<(string, string)> errors = new();
+ foreach (var error in Validations()) {
+ errors.Append(error);
+ }
+
+ if (!errors.IsDefault) {
+ AggregatePropertyValidationException.Throw(message, errors.ToArray());
+ }
+ }
+
+ /// Performs a sequence of validation on properties.
+ /// If a validation fails, yields the PropertyName and the corresponding error message.
+ protected abstract IEnumerable<(string PropertyName, string Message)> Validations();
+}
+
+[Serializable]
+public class AggregatePropertyValidationException : Exception {
+ public AggregatePropertyValidationException() {
+ }
+
+ public AggregatePropertyValidationException(string? message) : base(message) {
+ }
+
+ public AggregatePropertyValidationException(string? message, Exception? inner) : base(message, inner) {
+ }
+
+ public AggregatePropertyValidationException(string? message, (string Property, string Error)[]? errors, Exception? inner) : base(message, inner) {
+ Errors = errors;
+ }
+
+ protected AggregatePropertyValidationException(
+ SerializationInfo info,
+ StreamingContext context) : base(info, context) {
+ }
+
+ public (string Property, string Error)[]? Errors { get; set; }
+
+ public override string ToString() {
+ ValueStringBuilder sb = new(stackalloc char[512]);
+ Span<(string PropertyName, string Message)>.Enumerator en = Errors.AsSpan().GetEnumerator();
+ sb.Append(Message ?? "Validation failed with the following errors:");
+ if (en.MoveNext()) {
+ sb.Append(Environment.NewLine);
+ sb.Append("- `");
+ sb.Append(en.Current.PropertyName);
+ sb.Append("`: ");
+ sb.Append(en.Current.Message);
+ }
+ while (en.MoveNext()) {
+ sb.Append(Environment.NewLine);
+ sb.Append("- `");
+ sb.Append(en.Current.PropertyName);
+ sb.Append("`: ");
+ sb.Append(en.Current.Message);
+ }
+
+ return sb.ToString();
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ public static void Throw(string? message, (string, string)[]? errors = null, Exception? inner = default) {
+ throw new AggregatePropertyValidationException(message, errors, inner);
+ }
+}
diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs
new file mode 100644
index 00000000..e18ca7fa
--- /dev/null
+++ b/src/Common/WebSocketExtensions.cs
@@ -0,0 +1,22 @@
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Net.WebSockets;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Common;
+
+internal static class WebSocketExtensions {
+ /// Throws a if the result is a close ack.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void ThrowIfClose(this WebSocketReceiveResult result) {
+ if (result.CloseStatus is not null) {
+ ThrowConnectionClosed();
+ }
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowConnectionClosed() {
+ throw new OperationCanceledException("Connection closed");
+ }
+}
diff --git a/src/Common/WsStream.cs b/src/Common/WsStream.cs
deleted file mode 100644
index 3a7473a7..00000000
--- a/src/Common/WsStream.cs
+++ /dev/null
@@ -1,155 +0,0 @@
-using System.Buffers;
-using System.Diagnostics.CodeAnalysis;
-using System.Net.WebSockets;
-
-namespace SurrealDB.Common;
-
-public sealed class WsStream : Stream {
- private readonly IDisposable _prefixOwner;
- ///
- /// The prefix is the memory already obtained to be consumed before queries the socket
- ///
- private readonly ReadOnlyMemory _prefix;
- private int _prefixConsumed;
-
- private readonly WebSocket _ws;
- public bool EndOfMessage { get; private set; } = false;
-
- public override bool CanRead => true;
- public override bool CanSeek => false;
- public override bool CanWrite => false;
- public override long Length => ThrowSeekDisallowed();
- public override long Position { get => ThrowSeekDisallowed(); set => ThrowSeekDisallowed(); }
-
- public WsStream(IDisposable prefixOwner, ReadOnlyMemory prefix, WebSocket ws) {
- _prefixOwner = prefixOwner;
- _prefix = prefix;
- _ws = ws;
- }
-
- public override void Flush() {
- // Readonly
- }
-
- public override int Read(byte[] buffer, int offset, int count) {
- return Read(buffer.AsSpan(offset, count));
- }
-
- ///
- /// Use , or if possible.
- ///
- public override int Read(Span buffer) {
- int read = 0;
- // consume the prefix
- ReadOnlySpan pref = ConsumePrefixAsSpan(buffer.Length);
- if (!pref.IsEmpty) {
- pref.CopyTo(buffer);
- buffer = buffer.Slice(pref.Length);
- read += pref.Length;
- }
-
- if (buffer.IsEmpty) {
- return read;
- }
-
- using IMemoryOwner o = MemoryPool.Shared.Rent(buffer.Length);
- Memory m = o.Memory.Slice(0, buffer.Length);
- buffer.CopyTo(m.Span);
- return read + ReadSync(m);
- }
-
- ///
- public int Read(Memory buffer) {
- int read = 0;
- // consume the prefix
- ReadOnlySpan pref = ConsumePrefixAsSpan(buffer.Length);
- if (!pref.IsEmpty) {
- pref.CopyTo(buffer.Span);
- buffer = buffer.Slice(pref.Length);
- read += pref.Length;
- }
-
- return read + ReadSync(buffer);
- }
-
- private int ReadSync(Memory buffer) {
- // This causes issues if the scheduler is exclusive.
- return ReadAsync(buffer).Result;
- }
-
- public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) {
- return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
- }
-
- public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) {
- var pref = ConsumePrefix(buffer.Length);
- if (!pref.IsEmpty) {
- pref.CopyTo(buffer);
- return pref.Length;
- }
-
- int read = 0;
- if (!EndOfMessage) {
- ValueWebSocketReceiveResult rsp = await _ws.ReceiveAsync(buffer, cancellationToken);
- read = rsp.Count;
-
- EndOfMessage = rsp.EndOfMessage;
- }
-
- return read;
- }
-
- private ReadOnlySpan ConsumePrefixAsSpan(int length) {
- var prefix = ConsumePrefix(length);
- return prefix.Span;
- }
-
- private ReadOnlyMemory ConsumePrefix(int length) {
- int len = _prefix.Length;
- int con = _prefixConsumed;
- if (con == len) {
- return default;
- }
- int rem = len - con;
- int inc = Math.Min(rem, length);
- _prefixConsumed = con + inc;
- return _prefix.Slice(con, inc);
- }
-
- private void DisposePrefix() {
- _prefixConsumed = _prefix.Length;
- _prefixOwner.Dispose();
- }
-
- public override long Seek(long offset, SeekOrigin origin) {
- return ThrowSeekDisallowed();
- }
-
- public override void SetLength(long value) {
- ThrowWriteDisallowed();
- }
-
- public override void Write(byte[] buffer, int offset, int count) {
- ThrowWriteDisallowed();
- }
-
- protected override void Dispose(bool disposing) {
- base.Dispose(disposing);
- DisposePrefix();
- }
-
- public override async ValueTask DisposeAsync() {
- await base.DisposeAsync();
- DisposePrefix();
- }
-
- [DoesNotReturn]
- private static void ThrowWriteDisallowed() {
- throw new InvalidOperationException("Cannot write a readonly stream");
- }
-
- [DoesNotReturn]
- private static long ThrowSeekDisallowed() {
- throw new InvalidOperationException("Cannot seek in the stream");
- }
-}
diff --git a/src/Driver/Rest/RestClientExtensions.cs b/src/Driver/Rest/RestClientExtensions.cs
index 350cdae1..45d9190e 100644
--- a/src/Driver/Rest/RestClientExtensions.cs
+++ b/src/Driver/Rest/RestClientExtensions.cs
@@ -5,7 +5,6 @@
using SurrealDB.Common;
using SurrealDB.Json;
-using SurrealDB.Models;
using SurrealDB.Models.Result;
using DriverResponse = SurrealDB.Models.Result.DriverResponse;
diff --git a/src/Driver/Rpc/DatabaseRpc.cs b/src/Driver/Rpc/DatabaseRpc.cs
index 29eca395..2b48e2f1 100644
--- a/src/Driver/Rpc/DatabaseRpc.cs
+++ b/src/Driver/Rpc/DatabaseRpc.cs
@@ -55,7 +55,7 @@ public async Task Open(CancellationToken ct = default) {
// Open connection
InvalidConfigException.ThrowIfNull(_config.RpcEndpoint);
- await _client.Open(_config.RpcEndpoint!, ct);
+ await _client.OpenAsync(_config.RpcEndpoint!, ct);
// Authenticate
if (_config.Username != null && _config.Password != null) {
@@ -70,7 +70,7 @@ public async Task Open(CancellationToken ct = default) {
public async Task Close(CancellationToken ct = default) {
_configured = false;
- await _client.Close(ct);
+ await _client.CloseAsync(ct);
}
///
diff --git a/src/Driver/Rpc/RpcClientExtensions.cs b/src/Driver/Rpc/RpcClientExtensions.cs
index 29dfaa8d..1788182b 100644
--- a/src/Driver/Rpc/RpcClientExtensions.cs
+++ b/src/Driver/Rpc/RpcClientExtensions.cs
@@ -1,9 +1,10 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using SurrealDB.Common;
using SurrealDB.Json;
-using SurrealDB.Models;
using SurrealDB.Models.Result;
using SurrealDB.Ws;
@@ -12,7 +13,6 @@
namespace SurrealDB.Driver.Rpc;
internal static class RpcClientExtensions {
-
internal static async Task ToSurreal(this Task rsp) => ToSurreal(await rsp);
internal static DriverResponse ToSurreal(this WsClient.Response rsp){
if (rsp.id is null) {
@@ -90,7 +90,7 @@ private static DriverResponse FromNestedStatus(in WsClient.Response rsp) {
return DriverResponse.FromOwned(builder.AsSegment());
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static void ThrowIdMissing() {
throw new InvalidOperationException("Response does not have an id.");
}
diff --git a/src/Json/Time/DateOnlyConv.cs b/src/Json/Time/DateOnlyConv.cs
index 394093a6..b9331309 100644
--- a/src/Json/Time/DateOnlyConv.cs
+++ b/src/Json/Time/DateOnlyConv.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -44,14 +46,14 @@ public static bool TryParse(string? s, out DateOnly value) {
public static string ToString(in DateOnly value) {
return $"{value.Year.ToString("D4")}-{value.Month.ToString("D2")}-{value.Day.ToString("D2")}";
}
-
- [DoesNotReturn]
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static DateOnly ThrowParseInvalid(string? s) {
throw new ParseException($"Unable to parse DateOnly from `{s}`");
}
- [DoesNotReturn]
- private DateOnly ThrowJsonTokenTypeInvalid() {
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static DateOnly ThrowJsonTokenTypeInvalid() {
throw new JsonException("Cannot deserialize a non string token as a DateOnly.");
}
-}
\ No newline at end of file
+}
diff --git a/src/Json/Time/DateTimeConv.cs b/src/Json/Time/DateTimeConv.cs
index 2491ee9c..3ff89051 100644
--- a/src/Json/Time/DateTimeConv.cs
+++ b/src/Json/Time/DateTimeConv.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -49,13 +51,13 @@ public static string ToString(in DateTime value) {
return value.ToString("O");
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static DateTime ThrowParseInvalid(string? s) {
throw new ParseException($"Unable to parse DateTime from `{s}`");
}
-
- [DoesNotReturn]
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static DateTime ThrowJsonTokenTypeInvalid() {
throw new JsonException("Cannot deserialize a non numeric non string token as a DateTime.");
}
diff --git a/src/Json/Time/DateTimeOffsetConv.cs b/src/Json/Time/DateTimeOffsetConv.cs
index 0b93f4aa..041bf5a1 100644
--- a/src/Json/Time/DateTimeOffsetConv.cs
+++ b/src/Json/Time/DateTimeOffsetConv.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -49,13 +51,13 @@ public static string ToString(in DateTimeOffset value) {
return value.ToString("O");
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static DateTimeOffset ThrowParseInvalid(string? s) {
throw new ParseException($"Unable to parse DateTimeOffset from `{s}`");
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static DateTimeOffset ThrowJsonTokenInvalid() {
throw new JsonException("Cannot deserialize a non numeric non string token as a DateTime.");
}
-}
\ No newline at end of file
+}
diff --git a/src/Json/Time/TimeOnlyConv.cs b/src/Json/Time/TimeOnlyConv.cs
index ff91dfb0..c3219b5c 100644
--- a/src/Json/Time/TimeOnlyConv.cs
+++ b/src/Json/Time/TimeOnlyConv.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -46,13 +48,13 @@ public static string ToString(in TimeOnly value) {
return $"{value.Hour.ToString("D2")}:{value.Minute.ToString("D2")}:{value.Second.ToString("D2")}.{value.FractionString()}";
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static TimeOnly ThrowParseInvalid(string? s) {
throw new ParseException($"Unable to parse TimeOnly from `{s}`");
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private TimeOnly ThrowJsonTokenTypeInvalid() {
throw new JsonException("Cannot deserialize a non string token as a TimeOnly.");
}
-}
\ No newline at end of file
+}
diff --git a/src/Json/Time/TimeSpanConv.cs b/src/Json/Time/TimeSpanConv.cs
index 238e3d46..0f101d4c 100644
--- a/src/Json/Time/TimeSpanConv.cs
+++ b/src/Json/Time/TimeSpanConv.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
@@ -57,14 +59,14 @@ public static string ToString(in TimeSpan value) {
return $"{value.Days}d{value.Hours}h{value.Minutes}m{value.Seconds}s{value.Milliseconds}ms";
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static TimeSpan ThrowParseInvalid(string? s) {
throw new ParseException($"Unable to parse TimeSpan from `{s}`");
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static TimeSpan ThrowJsonTokenTypeInvalid() {
throw new JsonException("Cannot deserialize a non numeric non string token as a TimeSpan.");
}
-}
\ No newline at end of file
+}
diff --git a/src/Models/Result/ResultContentException.cs b/src/Models/Result/ResultContentException.cs
index 590147d3..0b3bb4f5 100644
--- a/src/Models/Result/ResultContentException.cs
+++ b/src/Models/Result/ResultContentException.cs
@@ -1,4 +1,6 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
namespace SurrealDB.Models.Result;
@@ -16,15 +18,15 @@ public ResultContentException(string? message) : base(message) {
public ResultContentException(string? message, Exception? innerException) : base(message, innerException) {
}
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
public static ErrorResult ExpectedAnyError() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain any {nameof(ErrorResult)}");
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
public static OkResult ExpectedAnyOk() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain any {nameof(OkResult)}");
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
public static ErrorResult ExpectedSingleError() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain exactly one {nameof(ErrorResult)}");
- [DoesNotReturn]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
public static OkResult ExpectedSingleOk() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain exactly one {nameof(OkResult)}");
}
diff --git a/src/Models/Result/ResultValue.cs b/src/Models/Result/ResultValue.cs
index dbea56fd..14f8b179 100644
--- a/src/Models/Result/ResultValue.cs
+++ b/src/Models/Result/ResultValue.cs
@@ -259,7 +259,7 @@ public Kind GetKind() {
return Kind.None;
}
- [DoesNotReturn, DebuggerStepThrough,]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static ResultValue ThrowUnknownJsonValueKind(JsonElement json) {
throw new ArgumentOutOfRangeException(nameof(json), json.ValueKind, "Unknown value kind.");
}
@@ -386,7 +386,7 @@ public override string ToString() {
return Inner.ToString();
}
- [DoesNotReturn, DebuggerStepThrough,]
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
private static int ThrowInvalidCompareTypes() {
throw new InvalidOperationException("Cannot compare SurrealResult of different types, if one or more is not numeric..");
}
diff --git a/src/Models/Thing.cs b/src/Models/Thing.cs
index 813c51ef..9b2ec3af 100644
--- a/src/Models/Thing.cs
+++ b/src/Models/Thing.cs
@@ -1,5 +1,3 @@
-using SurrealDB.Json;
-
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.Runtime.CompilerServices;
@@ -7,8 +5,6 @@
using System.Text.Json;
using System.Text.Json.Serialization;
-using SurrealDB.Common;
-
namespace SurrealDB.Models;
///
diff --git a/src/Ws/Handler.cs b/src/Ws/Handler.cs
index cc36fe91..c869ad1f 100644
--- a/src/Ws/Handler.cs
+++ b/src/Ws/Handler.cs
@@ -6,11 +6,11 @@ internal interface IHandler : IDisposable {
public bool Persistent { get; }
- public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm);
+ public void Dispatch(WsHeaderWithMessage header);
}
internal sealed class ResponseHandler : IHandler {
- private readonly TaskCompletionSource<(WsTx.RspHeader, WsTx.NtyHeader, Stream)> _tcs = new();
+ private readonly TaskCompletionSource _tcs = new();
private readonly string _id;
private readonly CancellationToken _ct;
@@ -19,14 +19,15 @@ public ResponseHandler(string id, CancellationToken ct) {
_ct = ct;
}
- public Task<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> Task => _tcs!.Task;
+ public Task Task => _tcs!.Task;
public string Id => _id;
public bool Persistent => false;
- public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm) {
- _tcs.SetResult((rsp, nty, stm));
+ public void Dispatch(WsHeaderWithMessage header) {
+ _ct.ThrowIfCancellationRequested();
+ _tcs.SetResult(header);
}
public void Dispose() {
@@ -35,11 +36,11 @@ public void Dispose() {
}
-internal class NotificationHandler : IHandler, IAsyncEnumerable<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> {
- private readonly Ws _mediator;
+internal class NotificationHandler : IHandler, IAsyncEnumerable {
+ private WsReceiverDeflater _mediator;
private readonly CancellationToken _ct;
- private TaskCompletionSource<(WsTx.RspHeader, WsTx.NtyHeader, Stream)> _tcs = new();
- public NotificationHandler(Ws mediator, string id, CancellationToken ct) {
+ private TaskCompletionSource _tcs = new();
+ public NotificationHandler(WsReceiverDeflater mediator, string id, CancellationToken ct) {
_mediator = mediator;
Id = id;
_ct = ct;
@@ -48,8 +49,9 @@ public NotificationHandler(Ws mediator, string id, CancellationToken ct) {
public string Id { get; }
public bool Persistent => true;
- public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm) {
- _tcs.SetResult((rsp, nty, stm));
+ public void Dispatch(WsHeaderWithMessage header) {
+ _ct.ThrowIfCancellationRequested();
+ _tcs.SetResult(header);
_tcs = new();
}
@@ -57,9 +59,9 @@ public void Dispose() {
_tcs.TrySetCanceled();
}
- public async IAsyncEnumerator<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> GetAsyncEnumerator(CancellationToken cancellationToken = default) {
+ public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) {
while (!_ct.IsCancellationRequested) {
- (WsTx.RspHeader, WsTx.NtyHeader, Stream) res;
+ WsHeaderWithMessage res;
try {
res = await _tcs.Task;
} catch (OperationCanceledException) {
@@ -71,7 +73,7 @@ public void Dispose() {
// unregister before throwing
if (_ct.IsCancellationRequested) {
- _mediator.Unregister(this);
+ _mediator.Unregister(Id);
}
}
}
diff --git a/src/Ws/HeaderHelper.cs b/src/Ws/HeaderHelper.cs
new file mode 100644
index 00000000..8c9fc8f0
--- /dev/null
+++ b/src/Ws/HeaderHelper.cs
@@ -0,0 +1,296 @@
+using System.Text.Json;
+
+using SurrealDB.Common;
+using SurrealDB.Json;
+
+namespace SurrealDB.Ws;
+
+public static class HeaderHelper {
+ /// Generates a random base64 string of the length specified.
+ public static string GetRandomId(int length) {
+ Span buf = stackalloc byte[length];
+ ThreadRng.Shared.NextBytes(buf);
+ return Convert.ToBase64String(buf);
+ }
+
+ public static WsHeader Parse(ReadOnlySpan utf8) {
+ var (rsp, rspOff, rspErr) = RspHeader.Parse(utf8);
+ if (rspErr is null) {
+ return new(rsp, default, (int)rspOff);
+ }
+ var (nty, ntyOff, ntyErr) = NtyHeader.Parse(utf8);
+ if (ntyErr is null) {
+ return new(default, nty, (int)ntyOff);
+ }
+
+ return default;
+ }
+}
+
+public readonly record struct WsHeader(RspHeader Response, NtyHeader Notify, int BytesLength) {
+ public string? Id => (Response.IsDefault, Notify.IsDefault) switch {
+ (true, false) => Notify.id,
+ (false, true) => Response.id,
+ _ => null
+ };
+}
+
+public readonly record struct WsHeaderWithMessage(WsHeader Header, WsReceiverMessageReader Reader) : IDisposable {
+ public void Dispose() {
+ Reader.Dispose();
+ }
+}
+
+public readonly record struct NtyHeader(string? id, string? method, WsClient.Error err) {
+ public bool IsDefault => default == this;
+
+ ///
+ /// Parses the head including the result propertyname, excluding the result array.
+ ///
+ internal static (NtyHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) {
+ Fsm fsm = new() {
+ Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })),
+ State = Fsms.Start,
+ };
+ while (fsm.MoveNext()) {}
+
+ if (!fsm.Success) {
+ return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}");
+ }
+ return (new(fsm.Id, fsm.Method, fsm.Error), default, default);
+ }
+
+ private enum Fsms {
+ Start, // -> Prop
+ Prop, // -> PropId | PropAsync | PropMethod | ProsResult
+ PropId, // -> Prop | End
+ PropMethod, // -> Prop | End
+ PropError, // -> End
+ PropParams, // -> End
+ End
+ }
+
+ private ref struct Fsm {
+ public Fsms State;
+ public Utf8JsonReader Lexer;
+ public string? Err;
+ public bool Success;
+
+ public string? Name;
+ public string? Id;
+ public WsClient.Error Error;
+ public string? Method;
+
+ public bool MoveNext() {
+ return State switch {
+ Fsms.Start => Start(),
+ Fsms.Prop => Prop(),
+ Fsms.PropId => PropId(),
+ Fsms.PropMethod => PropMethod(),
+ Fsms.PropError => PropError(),
+ Fsms.PropParams => PropParams(),
+ Fsms.End => End(),
+ _ => false
+ };
+ }
+
+ private bool Start() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) {
+ Err = "Unable to read token StartObject";
+ return false;
+ }
+
+ State = Fsms.Prop;
+ return true;
+
+ }
+
+ private bool End() {
+ Success = !String.IsNullOrEmpty(Id) && !String.IsNullOrEmpty(Method);
+ return false;
+ }
+
+ private bool Prop() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) {
+ Err = "Unable to read PropertyName";
+ return false;
+ }
+
+ Name = Lexer.GetString();
+ if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropId;
+ return true;
+ }
+ if ("method".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropMethod;
+ return true;
+ }
+ if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropError;
+ return true;
+ }
+ if ("params".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropParams;
+ return true;
+ }
+
+ Err = $"Unknown PropertyName `{Name}`";
+ return false;
+ }
+
+ private bool PropId() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
+ Err = "Unable to read `id` property value";
+ return false;
+ }
+
+ State = Fsms.Prop;
+ Id = Lexer.GetString();
+ return true;
+ }
+
+ private bool PropError() {
+ Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared);
+ State = Fsms.End;
+ return true;
+ }
+
+ private bool PropMethod() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
+ Err = "Unable to read `method` property value";
+ return false;
+ }
+
+ State = Fsms.Prop;
+ Method = Lexer.GetString();
+ return true;
+ }
+
+ private bool PropParams() {
+ // Do not parse the result!
+ // The complete result is not present in the buffer!
+ // The result is returned as a unevaluated asynchronous stream!
+ State = Fsms.End;
+ return true;
+ }
+ }
+}
+
+public readonly record struct RspHeader(string? id, WsClient.Error err) {
+ public bool IsDefault => default == this;
+
+ ///
+ /// Parses the head including the result propertyname, excluding the result array.
+ ///
+ internal static (RspHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) {
+ Fsm fsm = new() {
+ Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })),
+ State = Fsms.Start,
+ };
+ while (fsm.MoveNext()) {}
+
+ if (!fsm.Success) {
+ return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}");
+ }
+ return (new(fsm.Id, fsm.Error), default, default);
+ }
+
+ private enum Fsms {
+ Start, // -> Prop
+ Prop, // -> PropId | PropError | ProsResult
+ PropId, // -> Prop | End
+ PropError, // -> End
+ PropResult, // -> End
+ End
+ }
+
+ private ref struct Fsm {
+ public Fsms State;
+ public Utf8JsonReader Lexer;
+ public string? Err;
+ public bool Success;
+
+ public string? Name;
+ public string? Id;
+ public WsClient.Error Error;
+
+ public bool MoveNext() {
+ return State switch {
+ Fsms.Start => Start(),
+ Fsms.Prop => Prop(),
+ Fsms.PropId => PropId(),
+ Fsms.PropError => PropError(),
+ Fsms.PropResult => PropResult(),
+ Fsms.End => End(),
+ _ => false
+ };
+ }
+
+ private bool Start() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) {
+ Err = "Unable to read token StartObject";
+ return false;
+ }
+
+ State = Fsms.Prop;
+ return true;
+
+ }
+
+ private bool End() {
+ Success = !String.IsNullOrEmpty(Id);
+ return false;
+ }
+
+ private bool Prop() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) {
+ Err = "Unable to read PropertyName";
+ return false;
+ }
+
+ Name = Lexer.GetString();
+ if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropId;
+ return true;
+ }
+ if ("result".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropResult;
+ return true;
+ }
+ if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
+ State = Fsms.PropError;
+ return true;
+ }
+
+ Err = $"Unknown PropertyName `{Name}`";
+ return false;
+ }
+
+ private bool PropId() {
+ if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
+ Err = "Unable to read `id` property value";
+ return false;
+ }
+
+ State = Fsms.Prop;
+ Id = Lexer.GetString();
+ return true;
+ }
+
+ private bool PropError() {
+ Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared);
+ State = Fsms.End;
+ return true;
+ }
+
+
+ private bool PropResult() {
+ // Do not parse the result!
+ // The complete result is not present in the buffer!
+ // The result is returned as a unevaluated asynchronous stream!
+ State = Fsms.End;
+ return true;
+ }
+ }
+}
+
diff --git a/src/Ws/Helpers/BoundedChannelPool.cs b/src/Ws/Helpers/BoundedChannelPool.cs
new file mode 100644
index 00000000..010e09e2
--- /dev/null
+++ b/src/Ws/Helpers/BoundedChannelPool.cs
@@ -0,0 +1,475 @@
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+using System.Threading.Channels;
+
+// ReSharper disable once CheckNamespace
+namespace System.Buffers;
+
+public abstract class BoundedChannelPool {
+ private static readonly TlsOverPerCoreLockedStacksBoundedChannelPool s_boundedShared = new();
+
+ public static BoundedChannelPool Shared => s_boundedShared;
+
+ public abstract BoundedChannel Rent(int minimumLength);
+
+ public abstract void Return(BoundedChannel channel);
+}
+
+public sealed class BoundedChannel : Channel, IDisposable {
+ private BoundedChannelPool? _owner;
+
+ public BoundedChannel(Channel wrapped, int capacity, BoundedChannelPool owner) {
+ Reader = wrapped.Reader;
+ Writer = wrapped.Writer;
+ Capacity = capacity;
+ _owner = owner;
+ }
+
+ public int Capacity { get; }
+
+ public void Dispose() {
+ var owner = _owner;
+ _owner = null;
+ if (owner is not null) {
+ owner.Return(this);
+ }
+ }
+}
+
+// Source: https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Buffers/TlsOverPerCoreLockedStacksArrayPool.cs
+// modified for use with channels
+public sealed class TlsOverPerCoreLockedStacksBoundedChannelPool : BoundedChannelPool {
+ /// The number of buckets (array sizes) in the pool, one for each array length, starting from length 16.
+ private const int NumBuckets = 27; // GCHelper.SelectBucketIndex(1024 * 1024 * 1024 + 1)
+ /// Maximum number of per-core stacks to use per array size.
+ private const int MaxPerCorePerArraySizeStacks = 64; // selected to avoid needing to worry about processor groups
+ /// The maximum number of buffers to store in a bucket's global queue.
+ private const int MaxBuffersPerArraySizePerCore = 8;
+
+ /// A per-thread array of arrays, to cache one array per array size per thread.
+ [ThreadStatic]
+ private static ThreadLocalArray[]? t_tlsBuckets;
+ /// Used to keep track of all thread local buckets for trimming if needed.
+ private readonly ConditionalWeakTable _allTlsBuckets = new ConditionalWeakTable();
+ ///
+ /// An array of per-core array stacks. The slots are lazily initialized to avoid creating
+ /// lots of overhead for unused array sizes.
+ ///
+ private readonly PerCoreLockedStacks?[] _buckets = new PerCoreLockedStacks[NumBuckets];
+ /// Whether the callback to trim arrays in response to memory pressure has been created.
+ private int _trimCallbackCreated;
+
+ /// Allocate a new PerCoreLockedStacks and try to store it into the array.
+ private PerCoreLockedStacks CreatePerCoreLockedStacks(int bucketIndex)
+ {
+ var inst = new PerCoreLockedStacks();
+ return Interlocked.CompareExchange(ref _buckets[bucketIndex], inst, null) ?? inst;
+ }
+
+ /// Gets an ID for the pool to use with events.
+ private int Id => GetHashCode();
+
+ public override BoundedChannel Rent(int minimumLength)
+ {
+ BoundedChannel? buffer;
+
+ // Get the bucket number for the array length. The result may be out of range of buckets,
+ // either for too large a value or for 0 and negative values.
+ int bucketIndex = GCHelper.SelectBucketIndex(minimumLength);
+
+ // First, try to get an array from TLS if possible.
+ ThreadLocalArray[]? tlsBuckets = t_tlsBuckets;
+ if (tlsBuckets is not null && (uint)bucketIndex < (uint)tlsBuckets.Length)
+ {
+ buffer = tlsBuckets[bucketIndex].Channel;
+ if (buffer is not null)
+ {
+ tlsBuckets[bucketIndex].Channel = null;
+ return buffer;
+ }
+ }
+
+ // Next, try to get an array from one of the per-core stacks.
+ PerCoreLockedStacks?[] perCoreBuckets = _buckets;
+ if ((uint)bucketIndex < (uint)perCoreBuckets.Length)
+ {
+ PerCoreLockedStacks? b = perCoreBuckets[bucketIndex];
+ if (b is not null)
+ {
+ buffer = b.TryPop();
+ if (buffer is not null)
+ {
+ return buffer;
+ }
+ }
+
+ // No buffer available. Ensure the length we'll allocate matches that of a bucket
+ // so we can later return it.
+ minimumLength = GCHelper.GetMaxSizeForBucket(bucketIndex);
+ }
+ else if (minimumLength <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(minimumLength));
+ }
+
+ // allocate a new bounded channel, that belongs to this instance
+ buffer = new(Channel.CreateBounded(minimumLength), minimumLength, this);
+
+ return buffer;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public override void Return(BoundedChannel? channel) {
+ if (channel is null) {
+ ThrowChannelNull();
+ }
+
+ if (channel.Reader.Completion.IsCompleted) {
+ ThrowChannelCompleted();
+ }
+
+ // Determine with what bucket this array length is associated
+ int bucketIndex = GCHelper.SelectBucketIndex(channel.Capacity);
+
+ // Make sure our TLS buckets are initialized. Technically we could avoid doing
+ // this if the array being returned is erroneous or too large for the pool, but the
+ // former condition is an error we don't need to optimize for, and the latter is incredibly
+ // rare, given a max size of 1B elements.
+ ThreadLocalArray[] tlsBuckets = t_tlsBuckets ?? InitializeTlsBucketsAndTrimming();
+
+ bool haveBucket = false;
+ bool returned = true;
+ if ((uint)bucketIndex < (uint)tlsBuckets.Length)
+ {
+ haveBucket = true;
+
+ // Check to see if the buffer is the correct size for this bucket.
+ if (channel.Capacity != GCHelper.GetMaxSizeForBucket(bucketIndex)) {
+ ThrowChannelNotOfPool();
+ }
+
+ // Store the array into the TLS bucket. If there's already an array in it,
+ // push that array down into the per-core stacks, preferring to keep the latest
+ // one in TLS for better locality.
+ ref ThreadLocalArray tla = ref tlsBuckets[bucketIndex];
+ BoundedChannel? prev = tla.Channel;
+ tla = new ThreadLocalArray(channel);
+ if (prev is not null)
+ {
+ PerCoreLockedStacks stackBucket = _buckets[bucketIndex] ?? CreatePerCoreLockedStacks(bucketIndex);
+ returned = stackBucket.TryPush(prev);
+ }
+ }
+ }
+
+ public bool Trim()
+ {
+ int currentMilliseconds = Environment.TickCount;
+ GCHelper.MemoryPressure pressure = GCHelper.GetMemoryPressure();
+
+ // Trim each of the per-core buckets.
+ PerCoreLockedStacks?[] perCoreBuckets = _buckets;
+ for (int i = 0; i < perCoreBuckets.Length; i++)
+ {
+ perCoreBuckets[i]?.Trim(currentMilliseconds, Id, pressure, GCHelper.GetMaxSizeForBucket(i));
+ }
+
+ // Trim each of the TLS buckets. Note that threads may be modifying their TLS slots concurrently with
+ // this trimming happening. We do not force synchronization with those operations, so we accept the fact
+ // that we may end up firing a trimming event even if an array wasn't trimmed, and potentially
+ // trim an array we didn't need to. Both of these should be rare occurrences.
+
+ // Under high pressure, release all thread locals.
+ if (pressure == GCHelper.MemoryPressure.High) {
+ foreach (KeyValuePair tlsBuckets in _allTlsBuckets)
+ {
+#if NET6_0_OR_GREATER
+ Array.Clear(tlsBuckets.Key);
+#else
+ tlsBuckets.Key.AsSpan().Clear();
+#endif
+ }
+ }
+ else
+ {
+ // Otherwise, release thread locals based on how long we've observed them to be stored. This time is
+ // approximate, with the time set not when the array is stored but when we see it during a Trim, so it
+ // takes at least two Trim calls (and thus two gen2 GCs) to drop an array, unless we're in high memory
+ // pressure. These values have been set arbitrarily; we could tune them in the future.
+ uint millisecondsThreshold = pressure switch
+ {
+ GCHelper.MemoryPressure.Medium => 15_000,
+ _ => 30_000,
+ };
+
+ foreach (KeyValuePair tlsBuckets in _allTlsBuckets)
+ {
+ ThreadLocalArray[] buckets = tlsBuckets.Key;
+ for (int i = 0; i < buckets.Length; i++)
+ {
+ if (buckets[i].Channel is null)
+ {
+ continue;
+ }
+
+ // We treat 0 to mean it hasn't yet been seen in a Trim call. In the very rare case where Trim records 0,
+ // it'll take an extra Trim call to remove the array.
+ int lastSeen = buckets[i].MillisecondsTimeStamp;
+ if (lastSeen == 0)
+ {
+ buckets[i].MillisecondsTimeStamp = currentMilliseconds;
+ }
+ else if ((currentMilliseconds - lastSeen) >= millisecondsThreshold)
+ {
+ // Time noticeably wrapped, or we've surpassed the threshold.
+ // Clear out the array, and log its being trimmed if desired.
+ Interlocked.Exchange(ref buckets[i].Channel, null);
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ private ThreadLocalArray[] InitializeTlsBucketsAndTrimming()
+ {
+ Debug.Assert(t_tlsBuckets is null, $"Non-null {nameof(t_tlsBuckets)}");
+
+ var tlsBuckets = new ThreadLocalArray[NumBuckets];
+ t_tlsBuckets = tlsBuckets;
+
+ _allTlsBuckets.Add(tlsBuckets, null);
+ if (Interlocked.Exchange(ref _trimCallbackCreated, 1) == 0)
+ {
+ Gen2GcCallback.Register(s => ((TlsOverPerCoreLockedStacksBoundedChannelPool)s).Trim(), this);
+ }
+
+ return tlsBuckets;
+ }
+
+ /// Stores a set of stacks of arrays, with one stack per core.
+ private sealed class PerCoreLockedStacks
+ {
+ /// Number of locked stacks to employ.
+ private static readonly int s_lockedStackCount = Math.Min(Environment.ProcessorCount, MaxPerCorePerArraySizeStacks);
+ /// The stacks.
+ private readonly LockedStack[] _perCoreStacks;
+
+ /// Initializes the stacks.
+ public PerCoreLockedStacks()
+ {
+ // Create the stacks. We create as many as there are processors, limited by our max.
+ var stacks = new LockedStack[s_lockedStackCount];
+ for (int i = 0; i < stacks.Length; i++)
+ {
+ stacks[i] = new LockedStack();
+ }
+ _perCoreStacks = stacks;
+ }
+
+ /// Try to push the array into the stacks. If each is full when it's tested, the array will be dropped.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool TryPush(BoundedChannel array)
+ {
+ // Try to push on to the associated stack first. If that fails,
+ // round-robin through the other stacks.
+ LockedStack[] stacks = _perCoreStacks;
+ int index = (int)((uint)Thread.GetCurrentProcessorId() % (uint)s_lockedStackCount); // mod by constant in tier 1
+ for (int i = 0; i < stacks.Length; i++)
+ {
+ if (stacks[index].TryPush(array)) return true;
+ if (++index == stacks.Length) index = 0;
+ }
+
+ return false;
+ }
+
+ /// Try to get an array from the stacks. If each is empty when it's tested, null will be returned.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public BoundedChannel? TryPop()
+ {
+ // Try to pop from the associated stack first. If that fails, round-robin through the other stacks.
+ BoundedChannel? arr;
+ LockedStack[] stacks = _perCoreStacks;
+ int index = (int)((uint)Thread.GetCurrentProcessorId() % (uint)s_lockedStackCount); // mod by constant in tier 1
+ for (int i = 0; i < stacks.Length; i++)
+ {
+ if ((arr = stacks[index].TryPop()) is not null) return arr;
+ if (++index == stacks.Length) index = 0;
+ }
+ return null;
+ }
+
+ public void Trim(int currentMilliseconds, int id, GCHelper.MemoryPressure pressure, int bucketSize)
+ {
+ LockedStack[] stacks = _perCoreStacks;
+ for (int i = 0; i < stacks.Length; i++)
+ {
+ stacks[i].Trim(currentMilliseconds, id, pressure, bucketSize);
+ }
+ }
+ }
+
+ /// Provides a simple, bounded stack of arrays, protected by a lock.
+ private sealed class LockedStack
+ {
+ /// The arrays in the stack.
+ private readonly BoundedChannel?[] _arrays = new BoundedChannel[MaxBuffersPerArraySizePerCore];
+ /// Number of arrays stored in .
+ private int _count;
+ /// Timestamp set by Trim when it sees this as 0.
+ private int _millisecondsTimestamp;
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public bool TryPush(BoundedChannel array)
+ {
+ bool enqueued = false;
+ Monitor.Enter(this);
+ BoundedChannel?[] arrays = _arrays;
+ int count = _count;
+ if ((uint)count < (uint)arrays.Length)
+ {
+ if (count == 0)
+ {
+ // Reset the time stamp now that we're transitioning from empty to non-empty.
+ // Trim will see this as 0 and initialize it to the current time when Trim is called.
+ _millisecondsTimestamp = 0;
+ }
+
+ arrays[count] = array;
+ _count = count + 1;
+ enqueued = true;
+ }
+ Monitor.Exit(this);
+ return enqueued;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public BoundedChannel? TryPop()
+ {
+ BoundedChannel? arr = null;
+ Monitor.Enter(this);
+ BoundedChannel?[] arrays = _arrays;
+ int count = _count - 1;
+ if ((uint)count < (uint)arrays.Length)
+ {
+ arr = arrays[count];
+ arrays[count] = null;
+ _count = count;
+ }
+ Monitor.Exit(this);
+ return arr;
+ }
+
+ public void Trim(int currentMilliseconds, int id, GCHelper.MemoryPressure pressure, int bucketSize)
+ {
+ const int StackTrimAfterMS = 60 * 1000; // Trim after 60 seconds for low/moderate pressure
+ const int StackHighTrimAfterMS = 10 * 1000; // Trim after 10 seconds for high pressure
+ const int StackLowTrimCount = 1; // Trim one item when pressure is low
+ const int StackMediumTrimCount = 2; // Trim two items when pressure is moderate
+ const int StackHighTrimCount = MaxBuffersPerArraySizePerCore; // Trim all items when pressure is high
+ const int StackLargeBucket = 16384; // If the bucket is larger than this we'll trim an extra when under high pressure
+ const int StackModerateTypeSize = 16; // If T is larger than this we'll trim an extra when under high pressure
+ const int StackLargeTypeSize = 32; // If T is larger than this we'll trim an extra (additional) when under high pressure
+
+ if (_count == 0)
+ {
+ return;
+ }
+
+ int trimMilliseconds = pressure == GCHelper.MemoryPressure.High ? StackHighTrimAfterMS : StackTrimAfterMS;
+
+ lock (this)
+ {
+ if (_count == 0)
+ {
+ return;
+ }
+
+ if (_millisecondsTimestamp == 0)
+ {
+ _millisecondsTimestamp = currentMilliseconds;
+ return;
+ }
+
+ if ((currentMilliseconds - _millisecondsTimestamp) <= trimMilliseconds)
+ {
+ return;
+ }
+
+ // We've elapsed enough time since the first item went into the stack.
+ // Drop the top item so it can be collected and make the stack look a little newer.
+
+ int trimCount = StackLowTrimCount;
+ switch (pressure)
+ {
+ case GCHelper.MemoryPressure.High:
+ trimCount = StackHighTrimCount;
+
+ // When pressure is high, aggressively trim larger arrays.
+ if (bucketSize > StackLargeBucket)
+ {
+ trimCount++;
+ }
+ if (Unsafe.SizeOf() > StackModerateTypeSize)
+ {
+ trimCount++;
+ }
+ if (Unsafe.SizeOf() > StackLargeTypeSize)
+ {
+ trimCount++;
+ }
+ break;
+
+ case GCHelper.MemoryPressure.Medium:
+ trimCount = StackMediumTrimCount;
+ break;
+ }
+
+ while (_count > 0 && trimCount-- > 0)
+ {
+ BoundedChannel? array = _arrays[--_count];
+ Debug.Assert(array is not null, "No nulls should have been present in slots < _count.");
+ _arrays[_count] = null;
+ }
+
+ _millisecondsTimestamp = _count > 0 ?
+ _millisecondsTimestamp + (trimMilliseconds / 4) : // Give the remaining items a bit more time
+ 0;
+ }
+ }
+ }
+
+ /// Wrapper for arrays stored in ThreadStatic buckets.
+ private struct ThreadLocalArray
+ {
+ /// The stored array.
+ public BoundedChannel? Channel;
+ /// Environment.TickCount timestamp for when this array was observed by Trim.
+ public int MillisecondsTimeStamp;
+
+ public ThreadLocalArray(BoundedChannel channel)
+ {
+ Channel = channel;
+ MillisecondsTimeStamp = 0;
+ }
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowChannelNull() {
+ throw new ArgumentNullException("channel");
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowChannelNotOfPool() {
+ throw new ArgumentException("The channel does not belong to the bool", "channel");
+ }
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowChannelCompleted() {
+ throw new ArgumentException("Cannot add a completed channel to the pool", "channel");
+ }
+}
+
diff --git a/src/Ws/Ws.cs b/src/Ws/Ws.cs
deleted file mode 100644
index c52f8839..00000000
--- a/src/Ws/Ws.cs
+++ /dev/null
@@ -1,131 +0,0 @@
-using System.Collections.Concurrent;
-using System.Diagnostics.CodeAnalysis;
-using System.Runtime.CompilerServices;
-
-using SurrealDB.Common;
-
-namespace SurrealDB.Ws;
-
-public sealed class Ws : IDisposable, IAsyncDisposable {
- private readonly CancellationTokenSource _cts = new();
- private readonly WsTx _tx = new();
- private readonly ConcurrentDictionary _handlers = new();
- private Task _recv = Task.CompletedTask;
-
- public bool Connected => _tx.Connected;
-
- public async Task Open(Uri remote, CancellationToken ct = default) {
- await _tx.Open(remote, ct);
- _recv = Task.Run(async () => await Receive(_cts.Token), _cts.Token);
- }
-
- public async Task Close(CancellationToken ct = default) {
- Task t1 = _tx.Close(ct);
- Task t2 = Task.Run(ClearHandlers, ct);
- _cts.Cancel();
-
- await t1;
- await t2;
- }
-
- ///
- /// Sends the request and awaits a response from the server
- ///
- public async Task<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> RequestOnce(string id, Stream request, CancellationToken ct = default) {
- ResponseHandler handler = new(id, ct);
- Register(handler);
- await _tx.Tw(request, ct);
- return await handler.Task;
- }
-
- ///
- /// Sends the request and awaits responses from the server until manually canceled using the cancellation token
- ///
- public async IAsyncEnumerable<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> RequestPersists(string id, Stream request, [EnumeratorCancellation] CancellationToken ct = default) {
- NotificationHandler handler = new(this, id, ct);
- Register(handler);
- await _tx.Tw(request, ct);
- await foreach (var res in handler) {
- yield return res;
- }
- }
-
- internal void Register(IHandler handler) {
- if (!_handlers.TryAdd(handler.Id, handler)) {
- ThrowDuplicateId(handler.Id);
- }
- }
-
- internal void Unregister(IHandler handler) {
- if (!_handlers.TryRemove(handler.Id, out var h)) {
- return;
- }
-
- h.Dispose();
- }
-
- private async Task Receive(CancellationToken stoppingToken) {
- while (!stoppingToken.IsCancellationRequested) {
- var (id, response, notify, stream) = await _tx.Tr(stoppingToken);
-
- stoppingToken.ThrowIfCancellationRequested();
-
- if (String.IsNullOrEmpty(id)) {
- continue; // Invalid response
- }
-
- if (!_handlers.TryGetValue(id, out IHandler? handler)) {
- // assume that unhandled responses belong to other clients
- // discard!
- await stream.DisposeAsync();
- continue;
- }
-
- handler.Handle(response, notify, stream);
-
- if (!handler.Persistent) {
- // persistent handlers are for notifications and are not removed automatically
- Unregister(handler);
- }
-
- if (stream is WsStream wsStream) {
- while (!wsStream.EndOfMessage) {
- await Task.Delay(13, stoppingToken);
- }
- }
- }
- }
-
- private void ClearHandlers() {
- foreach (var handler in _handlers.Values) {
- Unregister(handler);
- }
- }
-
- public void Dispose() {
- try {
- Close().Wait();
- _tx.Dispose();
- } catch (OperationCanceledException) {
- // expected
- } catch (AggregateException) {
- // wrapping OperationCanceledException
- }
- }
-
- public async ValueTask DisposeAsync() {
- try {
- await Close();
- _tx.Dispose();
- } catch (OperationCanceledException) {
- // expected
- } catch (AggregateException) {
- // wrapping OperationCanceledException for async
- }
- }
-
- [DoesNotReturn]
- private static void ThrowDuplicateId(string id) {
- throw new ArgumentOutOfRangeException(nameof(id), $"A request with the Id `{id}` is already registered");
- }
-}
diff --git a/src/Ws/Ws.csproj b/src/Ws/Ws.csproj
index dd553b1f..ba51b4d5 100644
--- a/src/Ws/Ws.csproj
+++ b/src/Ws/Ws.csproj
@@ -2,6 +2,7 @@
SurrealDB.Ws
+ true
@@ -20,6 +21,6 @@
+
-
diff --git a/src/Ws/WsClient.cs b/src/Ws/WsClient.cs
index 661584d2..d37b9edc 100644
--- a/src/Ws/WsClient.cs
+++ b/src/Ws/WsClient.cs
@@ -1,6 +1,10 @@
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Net.WebSockets;
+using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Threading.Channels;
using Microsoft.IO;
@@ -9,83 +13,109 @@
namespace SurrealDB.Ws;
-///
-/// The client used to connect to the Surreal server via JSON RPC.
-///
-public sealed class WsClient : IDisposable, IAsyncDisposable {
- private static readonly Lazy s_manager = new(static () => new());
+/// The client used to connect to the Surreal server via JSON RPC.
+public sealed class WsClient : IDisposable {
// Do not get any funny ideas and fill this fucker up.
- public static readonly List EmptyList = new();
+ private static readonly List s_emptyList = new();
- private readonly Ws _ws = new();
+ private readonly ClientWebSocket _ws = new();
+ private readonly RecyclableMemoryStreamManager _memoryManager;
+ private readonly WsTransmitter _transmitter;
+ private readonly WsReceiverDeflater _deflater;
+ private readonly WsReceiverInflater _inflater;
- ///
- /// Indicates whether the client is connected or not.
- ///
- public bool Connected => _ws.Connected;
+ private readonly int _idBytes;
- ///
- /// Generates a random base64 string of the length specified.
- ///
- public static string GetRandomId(int length) {
- Span buf = stackalloc byte[length];
- ThreadRng.Shared.NextBytes(buf);
- return Convert.ToBase64String(buf);
+ public WsClient()
+ : this(WsClientOptions.Default) {
}
- ///
- /// Opens the connection to the Surreal server.
- ///
- public async Task Open(Uri url, CancellationToken ct = default) {
+ public WsClient(WsClientOptions options) {
+ options.ValidateAndMakeReadonly();
+ _memoryManager = options.MemoryManager;
+ _transmitter = new(_ws, _memoryManager.BlockSize);
+ var tx = Channel.CreateBounded(options.TxChannelCapacity);
+ _deflater = new(tx.Reader, options.ReceiveHeaderBytesMax, options.RequestExpiration, TimeSpan.FromSeconds(1));
+ _inflater = new(_ws, tx.Writer, _memoryManager, _memoryManager.BlockSize, options.MessageChannelCapacity);
+
+ _idBytes = options.IdBytes;
+ }
+
+ /// Indicates whether the client is connected or not.
+ public bool Connected => _ws.State == WebSocketState.Open;
+
+ public WebSocketState State => _ws.State;
+
+ /// Opens the connection to the Surreal server.
+ public async Task OpenAsync(Uri url, CancellationToken ct = default) {
ThrowIfConnected();
- await _ws.Open(url, ct);
+ await _ws.ConnectAsync(url, ct).Inv();
+ _deflater.Open();
+ _inflater.Open();
}
///
/// Closes the connection to the Surreal server.
///
- public async Task Close(CancellationToken ct = default) {
- await _ws.Close(ct);
+ public async Task CloseAsync(CancellationToken ct = default) {
+ ThrowIfDisconnected();
+ await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client connection closed orderly", ct).Inv();
+ await _deflater.CloseAsync().Inv();
+ await _inflater.CloseAsync().Inv();
}
///
public void Dispose() {
+ _deflater.Dispose();
+ _inflater.Dispose();
_ws.Dispose();
}
- ///
- public ValueTask DisposeAsync() {
- return _ws.DisposeAsync();
- }
-
///
/// Sends the specified request to the Surreal server, and returns the response.
///
public async Task Send(Request req, CancellationToken ct = default) {
ThrowIfDisconnected();
- req.id ??= GetRandomId(6);
- req.parameters ??= EmptyList;
-
- await using RecyclableMemoryStream stream = new(s_manager.Value);
+ req.id ??= HeaderHelper.GetRandomId(_idBytes);
+ req.parameters ??= s_emptyList;
- await JsonSerializer.SerializeAsync(stream, req, SerializerOptions.Shared, ct);
- // Now Position = Length = EndOfMessage
- // Write the buffer to the websocket
- stream.Position = 0;
- var (rsp, nty, stm) = await _ws.RequestOnce(req.id, stream, ct);
- if (!nty.IsDefault) {
- ThrowExpectRspGotNty();
+ // listen for the response
+ ResponseHandler handler = new(req.id, ct);
+ if (!_deflater.RegisterOrGet(handler)) {
+ return default;
}
-
- if (rsp.IsDefault) {
- ThrowRspDefault();
+ // send request
+ var stream = await SerializeAsync(req, ct).Inv();
+ await _transmitter.SendAsync(stream, ct).Inv();
+ // await response, dispose message when done
+ using var response = await handler.Task.Inv();
+ // validate header
+ var responseHeader = response.Header.Response;
+ if (!response.Header.Notify.IsDefault) {
+ ThrowExpectResponseGotNotify();
+ }
+ if (responseHeader.IsDefault) {
+ ThrowInvalidResponse();
}
- var bdy = await JsonSerializer.DeserializeAsync(stm, SerializerOptions.Shared, ct);
- if (bdy is null) {
- ThrowRspDefault();
+ // position stream beyond header and deserialize message body
+ response.Reader.Position = response.Header.BytesLength;
+ // deserialize body
+ var body = await JsonSerializer.DeserializeAsync(response.Reader, SerializerOptions.Shared, ct).Inv();
+ if (body is null) {
+ ThrowInvalidResponse();
}
- return new(rsp.id, rsp.err, ExtractResult(bdy));
+
+ return new(responseHeader.id, responseHeader.err, ExtractResult(body));
+ }
+
+ private async Task SerializeAsync(Request req, CancellationToken ct) {
+ RecyclableMemoryStream stream = new(_memoryManager);
+
+ await JsonSerializer.SerializeAsync(stream, req, SerializerOptions.Shared, ct).Inv();
+ // position = Length = EndOfMessage -> position = 0
+ stream.Position = 0;
+ return stream;
}
private static JsonElement ExtractResult(JsonDocument root) {
@@ -104,13 +134,13 @@ private void ThrowIfConnected() {
}
}
- [DoesNotReturn]
- private static void ThrowExpectRspGotNty() {
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowExpectResponseGotNotify() {
throw new InvalidOperationException("Expected a response, got a notification");
}
- [DoesNotReturn]
- private static void ThrowRspDefault() {
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowInvalidResponse() {
throw new InvalidOperationException("Invalid response");
}
@@ -140,6 +170,4 @@ public record struct Notify(
string? method,
[property: JsonPropertyName("params"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault),]
List? parameters);
-
-
}
diff --git a/src/Ws/WsClientOptions.cs b/src/Ws/WsClientOptions.cs
new file mode 100644
index 00000000..8f63f56c
--- /dev/null
+++ b/src/Ws/WsClientOptions.cs
@@ -0,0 +1,109 @@
+using Microsoft.IO;
+
+using SurrealDB.Common;
+
+namespace SurrealDB.Ws;
+
+public sealed record WsClientOptions : ValidateReadonly {
+ public const int MaxArraySize = 0X7FFFFFC7;
+
+ /// The maximum number of pending messages in the client inbound channel. Default 1024
+ /// This does not refer to the number of simultaneous queries, but the number of unread messages, the size of the "inbox".
+ /// A message may consist of multiple blocks. Only the message counts towards this number.
+ public int TxChannelCapacity {
+ get => _txChannelCapacity;
+ set => Set(out _txChannelCapacity, in value);
+ }
+
+ /// The maximum number of pending blocks in a single message channel. Default 64
+ /// The number of blocks of a message that can be received by the client, before they are consumed,
+ /// by reading from the .
+ /// A block have up to bytes.
+ public int MessageChannelCapacity {
+ get => _messageChannelCapacity;
+ set => Set(out _messageChannelCapacity, in value);
+ }
+
+ /// The maximum number of bytes a received header can consist of. Default 4 * 1024bytes
+ /// The client receives a message with a and the message.
+ /// This is the length the socket "peeks" at the beginning of the network stream, in oder to fully deserialize the or .
+ /// The entire header must be contained within the peeked memory.
+ /// The length is bound to .
+ /// Longer lengths introduce additional overhead.
+ public int ReceiveHeaderBytesMax {
+ get => _receiveHeaderBytesMax;
+ set => Set(out _receiveHeaderBytesMax, in value);
+ }
+ /// The number of bytes the id consists of.
+ /// The id is base64 encoded, therefore 6 bytes = 4 characters. Use values in steps of 6.
+ public int IdBytes {
+ get => _idBytes;
+ set => Set(out _idBytes, in value);
+ }
+ /// Defines the resize behaviour of the streams used for handling messages.
+ public RecyclableMemoryStreamManager MemoryManager {
+ get => _memoryManager!; // Validated not null
+ set => Set(out _memoryManager, in value);
+ }
+
+ /// The maximum time a request is awaited, before a is thrown.
+ /// Limited by the internal cache eviction timeout (1s) & pressure/traffic.
+ public TimeSpan RequestExpiration {
+ get => _requestExpiration;
+ set => Set(out _requestExpiration, value);
+ }
+
+ private RecyclableMemoryStreamManager? _memoryManager;
+ private int _idBytes = 6;
+ private int _receiveHeaderBytesMax = 4 * 1024;
+ private int _txChannelCapacity = 1024;
+ private int _messageChannelCapacity = 64;
+ private TimeSpan _requestExpiration = TimeSpan.FromSeconds(10);
+
+ public void ValidateAndMakeReadonly() {
+ if (!IsReadonly()) {
+ ValidateOrThrow();
+ MakeReadonly();
+ }
+ }
+
+ protected override IEnumerable<(string PropertyName, string Message)> Validations() {
+ if (TxChannelCapacity <= 0) {
+ yield return (nameof(TxChannelCapacity), "cannot be less then or equal to zero");
+ }
+ if (TxChannelCapacity > MaxArraySize) {
+ yield return (nameof(TxChannelCapacity), "cannot be greater then MaxArraySize");
+ }
+
+ if (MessageChannelCapacity <= 0) {
+ yield return (nameof(MessageChannelCapacity), "cannot be less then or equal to zero");
+ }
+ if (MessageChannelCapacity > MaxArraySize) {
+ yield return (nameof(MessageChannelCapacity), "cannot be greater then MaxArraySize");
+ }
+
+ if (ReceiveHeaderBytesMax <= 0) {
+ yield return (nameof(ReceiveHeaderBytesMax), "cannot be less then or equal to zero");
+ }
+
+ if (ReceiveHeaderBytesMax > (_memoryManager?.BlockSize ?? 0)) {
+ yield return (nameof(ReceiveHeaderBytesMax), "cannot be greater then MemoryManager.BlockSize");
+ }
+
+ if (_memoryManager is null) {
+ yield return (nameof(MemoryManager), "cannot be null");
+ }
+
+ if (RequestExpiration <= TimeSpan.Zero) {
+ yield return (nameof(RequestExpiration), "expiration time cannot be less then or equal to zero");
+ }
+ }
+
+ public static WsClientOptions Default { get; } = CreateDefault();
+
+ private static WsClientOptions CreateDefault() {
+ WsClientOptions o = new() { MemoryManager = new(), };
+ o.ValidateAndMakeReadonly();
+ return o;
+ }
+}
diff --git a/src/Ws/WsReceiverDeflater.cs b/src/Ws/WsReceiverDeflater.cs
new file mode 100644
index 00000000..96ba78aa
--- /dev/null
+++ b/src/Ws/WsReceiverDeflater.cs
@@ -0,0 +1,155 @@
+using System.Buffers;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Threading.Channels;
+
+using SurrealDB.Common;
+
+namespace SurrealDB.Ws;
+
+/// Listens for s and dispatches them by their headers to different s.
+internal sealed class WsReceiverDeflater : IDisposable {
+ private readonly ChannelReader _channel;
+ private readonly DisposingCache _handlers;
+ private CancellationTokenSource? _cts;
+ private Task? _execute;
+ private readonly int _maxHeaderBytes;
+
+ public WsReceiverDeflater(ChannelReader channel, int maxHeaderBytes, TimeSpan cacheSlidingExpiration, TimeSpan cacheEvictionInterval) {
+ _channel = channel;
+ _maxHeaderBytes = maxHeaderBytes;
+ _handlers = new(cacheSlidingExpiration, cacheEvictionInterval);
+ }
+
+ [MemberNotNullWhen(true, nameof(_cts)), MemberNotNullWhen(true, nameof(_execute))]
+ public bool Connected => _cts is not null & _execute is not null;
+
+ private async Task Execute(CancellationToken ct) {
+ Debug.Assert(ct.CanBeCanceled);
+
+ while (!ct.IsCancellationRequested) {
+ await Consume(ct).Inv();
+ ct.ThrowIfCancellationRequested();
+ }
+ }
+
+ private async Task Consume(CancellationToken ct) {
+ var log = WsReceiverDeflaterEventSource.Log;
+
+ ct.ThrowIfCancellationRequested();
+
+ // log that we are waiting for a message from the channel
+ log.MessageAwaiting();
+
+ var message = await ReadAsync(ct).Inv();
+ // log that a message has been retrieved from the channel
+ log.MessageReceived(message.Header.Id);
+
+ // find the handler
+ string? id = message.Header.Id;
+ if (id is null || !_handlers.TryGetValue(id, out var handler)) {
+ // invalid format, or no registered -> discard message
+ message.Dispose();
+ // log that the message has been discarded
+ log.MessageDiscarded(id);
+ return;
+ }
+ Debug.Assert(id == handler.Id);
+
+ // dispatch the message to the handler
+ try {
+ handler.Dispatch(message);
+ } catch (Exception ex) {
+ // handler is canceled -> unregister
+ Unregister(id);
+ // log that the dispatch has resulted in a exception
+ log.HandlerUnregisteredAfterException(id, ex);
+ }
+
+ if (!handler.Persistent) {
+ Unregister(id);
+ // log that the handler has been unregistered
+ log.HandlerUnregisterdFleeting(id);
+ }
+ }
+
+ public void Unregister(string id) {
+ if (_handlers.TryRemove(id, out var handler))
+ {
+ handler.Dispose();
+ }
+ }
+
+ public bool RegisterOrGet(IHandler handler) {
+ return _handlers.TryAdd(handler.Id, handler);
+ }
+
+ private async Task ReadAsync(CancellationToken ct) {
+ var message = await _channel.ReadAsync(ct).Inv();
+
+ // receive the first part of the message
+ var bytes = ArrayPool.Shared.Rent(_maxHeaderBytes);
+ int read = await message.ReadAsync(bytes, ct).Inv();
+ // peek instead of reading
+ message.Position = 0;
+ // parse the header portion of the stream, without reading the `result` property.
+ // the header is a sum-type of all possible headers.
+ var header = HeaderHelper.Parse(bytes.AsSpan(0, read));
+ ArrayPool.Shared.Return(bytes);
+ return new(header, message);
+ }
+
+ public void Open() {
+ var log = WsReceiverDeflaterEventSource.Log;
+
+ ThrowIfConnected();
+ _cts = new();
+ _execute = Execute(_cts.Token);
+
+ log.Opened();
+ }
+
+ public async Task CloseAsync() {
+ var log = WsReceiverDeflaterEventSource.Log;
+
+ ThrowIfDisconnected();
+ Task task;
+ _cts.Cancel();
+ _cts.Dispose(); // not relly needed here
+ _cts = null;
+ task = _execute;
+ _execute = null;
+
+ log.CloseBegin();
+
+ try {
+ await task.Inv();
+ } catch (OperationCanceledException) {
+ // expected on close using cts
+ }
+
+ log.CloseFinish();
+ }
+
+ public void Dispose() {
+ var log = WsReceiverDeflaterEventSource.Log;
+
+ _cts?.Cancel();
+ _cts?.Dispose();
+
+ log.Disposed();
+ }
+
+ [MemberNotNull(nameof(_cts)), MemberNotNull(nameof(_execute))]
+ private void ThrowIfDisconnected() {
+ if (!Connected) {
+ throw new InvalidOperationException("The connection is not open.");
+ }
+ }
+
+ private void ThrowIfConnected() {
+ if (Connected) {
+ throw new InvalidOperationException("The connection is already open");
+ }
+ }
+}
diff --git a/src/Ws/WsReceiverDeflaterEventSource.cs b/src/Ws/WsReceiverDeflaterEventSource.cs
new file mode 100644
index 00000000..7b2d04d5
--- /dev/null
+++ b/src/Ws/WsReceiverDeflaterEventSource.cs
@@ -0,0 +1,104 @@
+using System.Diagnostics.Tracing;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Ws;
+
+[EventSource(Guid = "03c50b03-e245-46e5-a99a-6eaa28990a41", Name = "WsReceiverDeflaterEventSource")]
+public sealed class WsReceiverDeflaterEventSource : EventSource
+{
+ private WsReceiverDeflaterEventSource() { }
+
+ public static WsReceiverDeflaterEventSource Log { get; } = new();
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void MessageReceived(string? messageId) {
+ if (IsEnabled()) {
+ MessageReceivedCore(messageId);
+ }
+ }
+
+ [Event(1, Level = EventLevel.Verbose, Message = "Message (Id = {0}) pulled from channel")]
+ private void MessageReceivedCore(string? messageId) => WriteEvent(1, messageId);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void MessageDiscarded(string? messageId) {
+ if (IsEnabled()) {
+ MessageDiscardedCore(messageId);
+ }
+ }
+
+ [Event(2, Level = EventLevel.Warning, Message = "No handler registered for the message (Id = {0})")]
+ private void MessageDiscardedCore(string? messageId) => WriteEvent(2, messageId);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void HandlerUnregisteredAfterException(string handlerId, Exception ex) {
+ if (IsEnabled()) {
+ HandlerUnregisteredAfterExceptionCore(handlerId, ex.ToString());
+ }
+ }
+
+ [Event(3, Level = EventLevel.Error, Message = "The handler (Id = {0}) threw an exception during dispatch, and was unregistered. ERROR: {1}")]
+ private unsafe void HandlerUnregisteredAfterExceptionCore(string handlerId, string ex) {
+ WriteEvent(3, handlerId, ex);
+ }
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void HandlerUnregisterdFleeting(string handlerId) {
+ if (IsEnabled()) {
+ HandlerUnregisterdFleetingCore(handlerId);
+ }
+ }
+
+ [Event(4, Level = EventLevel.Verbose, Message = "The handler (Id = {0}) is fleeting and was unregistered after dispatch")]
+ private void HandlerUnregisterdFleetingCore(string handlerId) => WriteEvent(4, handlerId);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void MessageAwaiting() {
+ if (IsEnabled()) {
+ MessageAwaitingCore();
+ }
+ }
+
+ [Event(5, Level = EventLevel.Verbose, Message = "Waiting for message to pull from channel")]
+ private void MessageAwaitingCore() => WriteEvent(5);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Opened() {
+ if (IsEnabled()) {
+ OpenedCore();
+ }
+ }
+
+ [Event(6, Level = EventLevel.Informational, Message = "Opened and is now pulling from the channel")]
+ private void OpenedCore() => WriteEvent(6);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void CloseBegin() {
+ if (IsEnabled()) {
+ CloseBeginCore();
+ }
+ }
+
+ [Event(7, Level = EventLevel.Informational, Message = "Closed and stopped pulling from the channel")]
+ private void CloseBeginCore() => WriteEvent(7);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void CloseFinish() {
+ if (IsEnabled()) {
+ CloseFinishCore();
+ }
+ }
+
+ [Event(8, Level = EventLevel.Informational, Message = "Closing has finished")]
+ private void CloseFinishCore() => WriteEvent(8);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Disposed() {
+ if (IsEnabled()) {
+ DisposedCore();
+ }
+ }
+
+ [Event(9, Level = EventLevel.Informational, Message = "Disposed")]
+ private void DisposedCore() => WriteEvent(9);
+}
diff --git a/src/Ws/WsReceiverInflater.cs b/src/Ws/WsReceiverInflater.cs
new file mode 100644
index 00000000..ba54e5f9
--- /dev/null
+++ b/src/Ws/WsReceiverInflater.cs
@@ -0,0 +1,136 @@
+using System.Buffers;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Net.WebSockets;
+using System.Threading.Channels;
+
+using Microsoft.IO;
+
+using SurrealDB.Common;
+
+namespace SurrealDB.Ws;
+
+/// Receives messages from a websocket server and passes them to a channel
+public sealed class WsReceiverInflater : IDisposable {
+ private readonly ClientWebSocket _socket;
+ private readonly ChannelWriter _channel;
+ private readonly RecyclableMemoryStreamManager _memoryManager;
+ private CancellationTokenSource? _cts;
+ private Task? _execute;
+
+ private readonly int _blockSize;
+ private readonly int _messageSize;
+
+ public WsReceiverInflater(ClientWebSocket socket, ChannelWriter channel, RecyclableMemoryStreamManager memoryManager, int blockSize, int messageSize) {
+ _socket = socket;
+ _channel = channel;
+ _memoryManager = memoryManager;
+ _blockSize = blockSize;
+ _messageSize = messageSize;
+ }
+
+ private async Task Execute(CancellationToken ct) {
+ Debug.Assert(ct.CanBeCanceled);
+ while (!ct.IsCancellationRequested) {
+ var buffer = ArrayPool.Shared.Rent(_blockSize);
+ await Produce(buffer, ct).Inv();
+ ArrayPool.Shared.Return(buffer);
+ ct.ThrowIfCancellationRequested();
+ }
+ }
+
+ private async Task Produce(byte[] buffer, CancellationToken ct) {
+ var log = WsReceiverInflaterEventSource.Log;
+
+ // log that we are waiting for the socket
+ log.SocketWaiting();
+ // receive the first part
+ var result = await _socket.ReceiveAsync(buffer, ct).Inv();
+ // log that we have received a message from the socket
+ log.SockedReceived(result);
+
+ ct.ThrowIfCancellationRequested();
+ // create a new message with a RecyclableMemoryStream
+ // use buffer instead of the build the builtin IBufferWriter, bc of thread safely issues related to locking
+ WsReceiverMessageReader msg = new(_memoryManager, _messageSize);
+ // begin adding the message to the output
+ var push = _channel.WriteAsync(msg, ct);
+
+ await msg.AppendResultAsync(buffer, result, ct).Inv();
+
+ while (!result.EndOfMessage && !ct.IsCancellationRequested) {
+ // receive more parts
+ result = await _socket.ReceiveAsync(buffer, ct).Inv();
+ // log that we have received a message from the socket
+ log.SockedReceived(result);
+ await msg.AppendResultAsync(buffer, result, ct).Inv();
+ }
+
+ // log that we have completely received the message
+ log.MessageReceiveFinished();
+ // finish adding the message to the output
+ await push.Inv();
+ // log that the message has been pushed to the channel
+ log.MessagePushed();
+ }
+
+
+ [MemberNotNullWhen(true, nameof(_cts)), MemberNotNullWhen(true, nameof(_execute))]
+ public bool Connected => _cts is not null & _execute is not null;
+
+ public void Open() {
+ var log = WsReceiverInflaterEventSource.Log;
+
+ ThrowIfConnected();
+ _cts = new();
+ _execute = Execute(_cts.Token);
+
+ log.Opened();
+ }
+
+ public async Task CloseAsync() {
+ var log = WsReceiverInflaterEventSource.Log;
+
+ ThrowIfDisconnected();
+ var task = _execute;
+ _cts.Cancel();
+ _cts.Dispose(); // not relly needed here
+ _cts = null;
+ _execute = null;
+
+ log.CloseBegin();
+
+ try {
+ await task.Inv();
+ } catch (OperationCanceledException) {
+ // expected on close using cts
+ } catch (WebSocketException) {
+ // expected on abort
+ }
+
+ log.CloseFinish();
+ }
+
+ [MemberNotNull(nameof(_cts)), MemberNotNull(nameof(_execute))]
+ private void ThrowIfDisconnected() {
+ if (!Connected) {
+ throw new InvalidOperationException("The connection is not open.");
+ }
+ }
+
+ private void ThrowIfConnected() {
+ if (Connected) {
+ throw new InvalidOperationException("The connection is already open");
+ }
+ }
+
+ public void Dispose() {
+ var log = WsReceiverInflaterEventSource.Log;
+
+ _cts?.Cancel();
+ _cts?.Dispose();
+ _channel.TryComplete();
+
+ log.Disposed();
+ }
+}
diff --git a/src/Ws/WsReceiverInflaterEventSource.cs b/src/Ws/WsReceiverInflaterEventSource.cs
new file mode 100644
index 00000000..4415dc16
--- /dev/null
+++ b/src/Ws/WsReceiverInflaterEventSource.cs
@@ -0,0 +1,100 @@
+using System.Diagnostics.Tracing;
+using System.Net.WebSockets;
+using System.Runtime.CompilerServices;
+
+namespace SurrealDB.Ws;
+
+[EventSource(Guid = "91a1c84b-f0aa-43c8-ad21-6ff518a8fa01", Name = "WsReceiverInflaterEventSource")]
+public sealed class WsReceiverInflaterEventSource : EventSource {
+ private WsReceiverInflaterEventSource() { }
+
+ public static WsReceiverInflaterEventSource Log { get; } = new();
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void SocketWaiting() {
+ if (IsEnabled()) {
+ SocketWaitingCore();
+ }
+ }
+
+ [Event(1, Level = EventLevel.Verbose, Message = "Waiting to receive a block from the socket")]
+ private void SocketWaitingCore() => WriteEvent(1);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void SockedReceived(WebSocketReceiveResult result) {
+ if (IsEnabled()) {
+ SockedReceivedCore(result.Count, result.EndOfMessage, result.CloseStatus is not null);
+ }
+ }
+
+ [ThreadStatic]
+ private static object[]? _socketReceivedArgs;
+ [Event(2, Level = EventLevel.Verbose, Message = "Received a block from the socket (Count = {0}, EndOfMessage = {1}, Closed = {2})")]
+ private unsafe void SockedReceivedCore(int count, bool endOfMessage, bool closed) {
+ _socketReceivedArgs ??= new object[3];
+ _socketReceivedArgs[0] = count;
+ _socketReceivedArgs[1] = endOfMessage;
+ _socketReceivedArgs[2] = closed;
+ WriteEvent(2, _socketReceivedArgs);
+ }
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void MessageReceiveFinished() {
+ if (IsEnabled()) {
+ MessageReceiveFinishedCore();
+ }
+ }
+
+ [Event(3, Level = EventLevel.Verbose, Message = "Finished receiving the message from the socket")]
+ private void MessageReceiveFinishedCore() => WriteEvent(3);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void MessagePushed() {
+ if (IsEnabled()) {
+ MessagePushedCore();
+ }
+ }
+
+ [Event(4, Level = EventLevel.Informational, Message = "Pushed the message to the channel")]
+ private void MessagePushedCore() => WriteEvent(4);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Opened() {
+ if (IsEnabled()) {
+ OpenedCore();
+ }
+ }
+
+ [Event(5, Level = EventLevel.Informational, Message = "Opened and is now pushing to the channel")]
+ private void OpenedCore() => WriteEvent(5);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void CloseBegin() {
+ if (IsEnabled()) {
+ CloseBeginCore();
+ }
+ }
+
+ [Event(6, Level = EventLevel.Informational, Message = "Closed and stopped pushing to the channel")]
+ private void CloseBeginCore() => WriteEvent(6);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void CloseFinish() {
+ if (IsEnabled()) {
+ CloseFinishCore();
+ }
+ }
+
+ [Event(7, Level = EventLevel.Informational, Message = "Closing has finished")]
+ private void CloseFinishCore() => WriteEvent(7);
+
+ [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Disposed() {
+ if (IsEnabled()) {
+ DisposedCore();
+ }
+ }
+
+ [Event(8, Level = EventLevel.Informational, Message = "Disposed")]
+ private void DisposedCore() => WriteEvent(8);
+}
diff --git a/src/Ws/WsReceiverMessageReader.cs b/src/Ws/WsReceiverMessageReader.cs
new file mode 100644
index 00000000..f59b3ab5
--- /dev/null
+++ b/src/Ws/WsReceiverMessageReader.cs
@@ -0,0 +1,204 @@
+using System.Buffers;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Net.WebSockets;
+using System.Runtime.CompilerServices;
+
+using Microsoft.IO;
+
+using SurrealDB.Common;
+
+namespace SurrealDB.Ws;
+
+public sealed class WsReceiverMessageReader : Stream {
+ private readonly BoundedChannel _channel;
+ private readonly RecyclableMemoryStream _stream;
+ private int _endOfMessage;
+
+ internal WsReceiverMessageReader(RecyclableMemoryStreamManager memoryManager, int channelCapacity) {
+ _stream = new(memoryManager);
+ _channel = BoundedChannelPool.Shared.Rent(channelCapacity);
+ _endOfMessage = 0;
+ }
+
+ public bool HasReceivedEndOfMessage => Interlocked.CompareExchange(ref _endOfMessage, 0, 0) != 0;
+
+ protected override void Dispose(bool disposing) {
+ if (!disposing) {
+ return;
+ }
+
+ _stream.Dispose();
+ _channel.Dispose();
+ }
+
+ private async ValueTask SetReceivedAsync(WebSocketReceiveResult result, CancellationToken ct) {
+ await _channel.Writer.WriteAsync(result, ct).Inv();
+ if (result.EndOfMessage) {
+ Interlocked.Exchange(ref _endOfMessage, 1);
+ }
+ }
+
+ private ValueTask ReceiveAsync(CancellationToken ct) {
+ return _channel.Reader.ReadAsync(ct);
+ }
+
+ private WebSocketReceiveResult Receive(CancellationToken ct) {
+ var t = ReceiveAsync(ct);
+ return t.IsCompleted ? t.Result : t.AsTask().Result;
+ }
+
+ internal ValueTask AppendResultAsync(ReadOnlyMemory buffer, WebSocketReceiveResult result, CancellationToken ct) {
+ ReadOnlySpan span = buffer.Span.Slice(0, result.Count);
+ lock (_stream) {
+ var pos = _stream.Position;
+ _stream.Write(span);
+ _stream.Position = pos;
+ }
+ ct.ThrowIfCancellationRequested();
+
+ return SetReceivedAsync(result, ct);
+ }
+
+#region Stream members
+
+ public override bool CanRead => true;
+ public override bool CanSeek => true;
+ public override bool CanWrite => false;
+ public override long Length {
+ get {
+ lock (_stream) {
+ return _stream.Length;
+ }
+ }
+ }
+
+ public override long Position {
+ get {
+ lock (_stream) {
+ return _stream.Position;
+ }
+ }
+ set {
+ lock (_stream) {
+ _stream.Position = value;
+ }
+ }
+ }
+ public override void Flush() {
+ ThrowCantWrite();
+ }
+
+ public override int Read(byte[] buffer, int offset, int count) {
+ return Read(buffer.AsSpan(offset, count));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public override int Read(Span buffer) {
+ return Read(buffer, default);
+ }
+
+ public int Read(Span buffer, CancellationToken ct) {
+ if (0 >= (uint)buffer.Length || ct.IsCancellationRequested) {
+ return 0;
+ }
+
+ int read;
+ lock (_stream) {
+ // attempt to read from present buffer
+ read = _stream.Read(buffer);
+ }
+ ct.ThrowIfCancellationRequested();
+
+ if (read == buffer.Length || HasReceivedEndOfMessage) {
+ return read;
+ }
+
+ WebSocketReceiveResult result;
+ do {
+ result = Receive(ct);
+ int inc;
+ lock (_stream) {
+ inc = _stream.Read(buffer.Slice(read));
+ }
+
+ ct.ThrowIfCancellationRequested();
+ Debug.Assert(inc == result.Count);
+ read += inc;
+ } while (!result.EndOfMessage);
+
+ return read;
+ }
+
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) {
+ return ReadAsync(buffer.AsMemory(offset, count), ct).AsTask();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public override ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) {
+ if (0 >= (uint)buffer.Length || ct.IsCancellationRequested) {
+ return new(0);
+ }
+
+ int read;
+ lock (_stream) {
+ // attempt to read from present buffer
+ read = _stream.Read(buffer.Span);
+ }
+ ct.ThrowIfCancellationRequested();
+
+ if (read == buffer.Length || HasReceivedEndOfMessage) {
+ return new(read);
+ }
+
+ return new(ReadFromChannelAsync(buffer, read, ct));
+ }
+
+ private async Task ReadFromChannelAsync(Memory buffer, int read, CancellationToken ct) {
+ WebSocketReceiveResult? result;
+ do {
+ result = await ReceiveAsync(ct).Inv();
+ int inc;
+ lock (_stream) {
+ inc = _stream.Read(buffer.Span.Slice(read));
+ }
+
+ ct.ThrowIfCancellationRequested();
+
+ Debug.Assert(inc == result.Count);
+ read += inc;
+ } while (!result.EndOfMessage);
+
+ return read;
+ }
+
+ public override int ReadByte() {
+ Span buffer = stackalloc byte[1];
+ int read = Read(buffer);
+ return read == 0 ? -1 : buffer[0];
+ }
+
+ public override long Seek(long offset, SeekOrigin origin) {
+ lock (_stream) {
+ return _stream.Seek(offset, origin);
+ }
+ }
+
+ public override void SetLength(long value) {
+ lock (_stream) {
+ _stream.SetLength(value);
+ }
+ }
+
+ public override void Write(byte[] buffer, int offset, int count) {
+ ThrowCantWrite();
+ }
+
+
+#endregion
+
+ [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)]
+ private static void ThrowCantWrite() {
+ throw new NotSupportedException("The stream does not support writing");
+ }
+}
diff --git a/src/Ws/WsTransmitter.cs b/src/Ws/WsTransmitter.cs
new file mode 100644
index 00000000..a083c5d3
--- /dev/null
+++ b/src/Ws/WsTransmitter.cs
@@ -0,0 +1,37 @@
+using System.Net.WebSockets;
+
+using SurrealDB.Common;
+
+namespace SurrealDB.Ws;
+
+/// Sends messages from a channel to a websocket server.
+public sealed class WsTransmitter {
+ private readonly ClientWebSocket _ws;
+ private readonly int _blockSize;
+
+ public WsTransmitter(ClientWebSocket ws, int blockSize) {
+ _ws = ws;
+ _blockSize = blockSize;
+ }
+
+ public async Task SendAsync(Stream stream, CancellationToken ct) {
+ // reader is disposed by the consumer
+ using BufferStreamReader reader = new(stream, _blockSize);
+ await Consume(reader, ct).Inv();
+ }
+
+
+ private async Task Consume(BufferStreamReader reader, CancellationToken ct) {
+ bool isFinalBlock = false;
+ while (!isFinalBlock && !ct.IsCancellationRequested) {
+ var rom = await reader.ReadAsync(_blockSize, ct).Inv();
+ isFinalBlock = rom.Length != _blockSize;
+ await _ws.SendAsync(rom, WebSocketMessageType.Text, isFinalBlock, ct).Inv();
+ }
+
+ if (!isFinalBlock) {
+ // ensure that the message is always terminated
+ await _ws.SendAsync(default, WebSocketMessageType.Text, true, ct).Inv();
+ }
+ }
+}
diff --git a/src/Ws/WsTx.cs b/src/Ws/WsTx.cs
deleted file mode 100644
index 03893a29..00000000
--- a/src/Ws/WsTx.cs
+++ /dev/null
@@ -1,404 +0,0 @@
-using System.Buffers;
-using System.Diagnostics;
-using System.Diagnostics.CodeAnalysis;
-using System.Net.WebSockets;
-using System.Text.Json;
-
-using SurrealDB.Common;
-using SurrealDB.Json;
-
-namespace SurrealDB.Ws;
-
-public sealed class WsTx : IDisposable {
- private readonly ClientWebSocket _ws = new();
-
- public static int DefaultBufferSize => 16 * 1024;
-
- ///
- /// Indicates whether the client is connected or not.
- ///
- public bool Connected => _ws.State == WebSocketState.Open;
-
- public async Task Open(Uri remote, CancellationToken ct = default) {
- ThrowIfConnected();
- await _ws.ConnectAsync(remote, ct);
- }
-
- public async Task Close(CancellationToken ct = default) {
- if (_ws.State == WebSocketState.Closed) {
- return;
- }
-
- try {
- await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client disconnect", ct);
- } catch (OperationCanceledException) {
- if (ct.IsCancellationRequested) {
- // Catch any canceled exception that is generated during the close,
- // but still throw for cancellations that we requested.
- throw;
- }
- }
- }
-
- public void Dispose() {
- _ws.Dispose();
- }
-
- ///
- /// Receives a response stream from the socket.
- /// Parses the header.
- /// The body contains the result array including the end object token `[...]}`.
- ///
- public async Task<(string? id, RspHeader rsp, NtyHeader nty, Stream body)> Tr(CancellationToken ct) {
- ThrowIfDisconnected();
- // this method assumes that the header size never exceeds DefaultBufferSize!
- IMemoryOwner owner = MemoryPool.Shared.Rent(DefaultBufferSize);
- var r = await _ws.ReceiveAsync(owner.Memory, ct);
-
- if (r.Count <= 0) {
- return (default, default, default, default!);
- }
-
- // parse the header
- var (rsp, nty, off) = ParseHeader(owner.Memory.Span.Slice(0, r.Count));
- string? id = rsp.IsDefault ? nty.id : rsp.id;
- if (String.IsNullOrEmpty(id)) {
- ThrowHeaderId();
- }
- // returns a stream over the remainder of the body
- Stream body = CreateBody(r, owner, owner.Memory.Slice(off));
- return (id, rsp, nty, body);
- }
-
- private static (RspHeader rsp, NtyHeader nty, int off) ParseHeader(ReadOnlySpan utf8) {
- var (rsp, rspOff, rspErr) = RspHeader.Parse(utf8);
- if (rspErr is null) {
- return (rsp, default, (int)rspOff);
- }
- var (nty, ntyOff, ntyErr) = NtyHeader.Parse(utf8);
- if (ntyErr is null) {
- return (default, nty, (int)ntyOff);
- }
-
- throw new JsonException($"Failed to parse RspHeader or NotifyHeader: {rspErr} \n--AND--\n {ntyErr}", null, 0, Math.Max(rspOff, ntyOff));
- }
-
- private Stream CreateBody(ValueWebSocketReceiveResult res, IDisposable owner, ReadOnlyMemory rem) {
- // check if rsp is already completely in the buffer
- if (res.EndOfMessage) {
- // create a rented stream from the remainder.
- MemoryStream s = RentedMemoryStream.FromMemory(owner, rem, true, true);
- s.SetLength(res.Count);
- return s;
- }
-
- // the rsp is not recv completely!
- // create a stream wrapping the websocket
- // with the recv portion as a prefix
- Debug.Assert(res.Count == rem.Length);
- return new WsStream(owner, rem, _ws);
- }
-
- ///
- /// Sends the stream over the socket.
- ///
- ///
- /// Fast if used with a with exposed buffer!
- ///
- public async Task Tw(Stream req, CancellationToken ct) {
- ThrowIfDisconnected();
- if (req is MemoryStream ms && ms.TryGetBuffer(out ArraySegment raw)) {
- // We can obtain the raw buffer from the request, send it
- await _ws.SendAsync(raw, WebSocketMessageType.Text, true, ct);
- return;
- }
-
- using IMemoryOwner owner = MemoryPool.Shared.Rent(DefaultBufferSize);
- bool end = false;
- while (!end && !ct.IsCancellationRequested) {
- int read = await req.ReadAsync(owner.Memory, ct);
- end = read != owner.Memory.Length;
- ReadOnlyMemory used = owner.Memory.Slice(0, read);
- await _ws.SendAsync(used, WebSocketMessageType.Text, end, ct);
-
- ThrowIfDisconnected();
- ct.ThrowIfCancellationRequested();
- }
- Debug.Assert(end, "Unfinished message sent!");
- }
-
- [DoesNotReturn]
- private static void ThrowHeaderId() {
- throw new InvalidOperationException("Header has no associated id!");
- }
-
- private void ThrowIfDisconnected() {
- if (!Connected) {
- throw new InvalidOperationException("The connection is not open.");
- }
- }
-
- private void ThrowIfConnected() {
- if (Connected) {
- throw new InvalidOperationException("The connection is already open");
- }
- }
-
- [DoesNotReturn]
- private static void ThrowParseHead(string err, long off) {
- throw new JsonException(err, default, default, off);
- }
-
- public readonly record struct NtyHeader(string? id, string? method, WsClient.Error err) {
- public bool IsDefault => default == this;
-
- ///
- /// Parses the head including the result propertyname, excluding the result array.
- ///
- internal static (NtyHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) {
- Fsm fsm = new() {
- Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })),
- State = Fsms.Start,
- };
- while (fsm.MoveNext()) {}
-
- if (!fsm.Success) {
- return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}");
- }
- return (new(fsm.Id, fsm.Method, fsm.Error), default, default);
- }
-
- private enum Fsms {
- Start, // -> Prop
- Prop, // -> PropId | PropAsync | PropMethod | ProsResult
- PropId, // -> Prop | End
- PropMethod, // -> Prop | End
- PropError, // -> End
- PropParams, // -> End
- End
- }
-
- private ref struct Fsm {
- public Fsms State;
- public Utf8JsonReader Lexer;
- public string? Err;
- public bool Success;
-
- public string? Name;
- public string? Id;
- public WsClient.Error Error;
- public string? Method;
-
- public bool MoveNext() {
- return State switch {
- Fsms.Start => Start(),
- Fsms.Prop => Prop(),
- Fsms.PropId => PropId(),
- Fsms.PropMethod => PropMethod(),
- Fsms.PropError => PropError(),
- Fsms.PropParams => PropParams(),
- Fsms.End => End(),
- _ => false
- };
- }
-
- private bool Start() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) {
- Err = "Unable to read token StartObject";
- return false;
- }
-
- State = Fsms.Prop;
- return true;
-
- }
-
- private bool End() {
- Success = !String.IsNullOrEmpty(Id) && !String.IsNullOrEmpty(Method);
- return false;
- }
-
- private bool Prop() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) {
- Err = "Unable to read PropertyName";
- return false;
- }
-
- Name = Lexer.GetString();
- if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropId;
- return true;
- }
- if ("method".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropMethod;
- return true;
- }
- if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropError;
- return true;
- }
- if ("params".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropParams;
- return true;
- }
-
- Err = $"Unknown PropertyName `{Name}`";
- return false;
- }
-
- private bool PropId() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
- Err = "Unable to read `id` property value";
- return false;
- }
-
- State = Fsms.Prop;
- Id = Lexer.GetString();
- return true;
- }
-
- private bool PropError() {
- Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared);
- State = Fsms.End;
- return true;
- }
-
- private bool PropMethod() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
- Err = "Unable to read `method` property value";
- return false;
- }
-
- State = Fsms.Prop;
- Method = Lexer.GetString();
- return true;
- }
-
- private bool PropParams() {
- // Do not parse the result!
- // The complete result is not present in the buffer!
- // The result is returned as a unevaluated asynchronous stream!
- State = Fsms.End;
- return true;
- }
- }
- }
-
- public readonly record struct RspHeader(string? id, WsClient.Error err) {
- public bool IsDefault => default == this;
-
- ///
- /// Parses the head including the result propertyname, excluding the result array.
- ///
- internal static (RspHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) {
- Fsm fsm = new() {
- Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })),
- State = Fsms.Start,
- };
- while (fsm.MoveNext()) {}
-
- if (!fsm.Success) {
- return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}");
- }
- return (new(fsm.Id, fsm.Error), default, default);
- }
-
- private enum Fsms {
- Start, // -> Prop
- Prop, // -> PropId | PropError | ProsResult
- PropId, // -> Prop | End
- PropError, // -> End
- PropResult, // -> End
- End
- }
-
- private ref struct Fsm {
- public Fsms State;
- public Utf8JsonReader Lexer;
- public string? Err;
- public bool Success;
-
- public string? Name;
- public string? Id;
- public WsClient.Error Error;
-
- public bool MoveNext() {
- return State switch {
- Fsms.Start => Start(),
- Fsms.Prop => Prop(),
- Fsms.PropId => PropId(),
- Fsms.PropError => PropError(),
- Fsms.PropResult => PropResult(),
- Fsms.End => End(),
- _ => false
- };
- }
-
- private bool Start() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) {
- Err = "Unable to read token StartObject";
- return false;
- }
-
- State = Fsms.Prop;
- return true;
-
- }
-
- private bool End() {
- Success = !String.IsNullOrEmpty(Id);
- return false;
- }
-
- private bool Prop() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) {
- Err = "Unable to read PropertyName";
- return false;
- }
-
- Name = Lexer.GetString();
- if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropId;
- return true;
- }
- if ("result".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropResult;
- return true;
- }
- if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) {
- State = Fsms.PropError;
- return true;
- }
-
- Err = $"Unknown PropertyName `{Name}`";
- return false;
- }
-
- private bool PropId() {
- if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) {
- Err = "Unable to read `id` property value";
- return false;
- }
-
- State = Fsms.Prop;
- Id = Lexer.GetString();
- return true;
- }
-
- private bool PropError() {
- Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared);
- State = Fsms.End;
- return true;
- }
-
-
- private bool PropResult() {
- // Do not parse the result!
- // The complete result is not present in the buffer!
- // The result is returned as a unevaluated asynchronous stream!
- State = Fsms.End;
- return true;
- }
- }
- }
-}
diff --git a/tests/Driver.Tests/DatabaseTests.cs b/tests/Driver.Tests/DatabaseTests.cs
index fc44e8c8..439d530d 100644
--- a/tests/Driver.Tests/DatabaseTests.cs
+++ b/tests/Driver.Tests/DatabaseTests.cs
@@ -10,98 +10,81 @@ public sealed class RestDatabaseTest : DatabaseTestDriver {
[Collection("SurrealDBRequired")]
public abstract class DatabaseTestDriver
- : DriverBase
where T : IDatabase, IDisposable, new() {
- protected override async Task Run(T db) {
- db.GetConfig().Should().BeEquivalentTo(TestHelper.Default);
-
- var useResp = await db.Use(TestHelper.Database, TestHelper.Namespace);
- TestHelper.AssertOk(useResp);
- var infoResp = await db.Info();
- TestHelper.AssertOk(infoResp);
-
- var signInStatus = await db.Signin(new RootAuth(TestHelper.User, TestHelper.Pass));
-
- TestHelper.AssertOk(signInStatus);
-
- (string id1, string id2) = ("id1", "id2");
- var res1 = await db.Create(
- "person",
- new {
- Title = "Founder & CEO",
- Name = new { First = "Tobie", Last = "Morgan Hitchcock", },
- Marketing = true,
- Identifier = ThreadRng.Shared.Next(),
- }
- );
-
- TestHelper.AssertOk(res1);
+ [Fact]
+ public async Task TestSuite() => await DbHandle.WithDatabase(
+ async db => {
+ db.GetConfig().Should().BeEquivalentTo(TestHelper.Default);
- var res2 = await db.Create(
- "person",
- new {
- Title = "Contributor",
- Name = new { First = "Prophet", Last = "Lamb", },
- Marketing = false,
- Identifier = ThreadRng.Shared.Next(),
- }
- );
+ var useResp = await db.Use(TestHelper.Database, TestHelper.Namespace);
+ TestHelper.AssertOk(useResp);
+ var infoResp = await db.Info();
+ TestHelper.AssertOk(infoResp);
- TestHelper.AssertOk(res2);
+ var signInStatus = await db.Signin(new RootAuth(TestHelper.User, TestHelper.Pass));
- Thing thing2 = ("person", id2);
- TestHelper.AssertOk(await db.Update(thing2, new { Marketing = false, }));
+ TestHelper.AssertOk(signInStatus);
- TestHelper.AssertOk(await db.Select(thing2));
+ (string id1, string id2) = ("id1", "id2");
+ var res1 = await db.Create(
+ "person",
+ new {
+ Title = "Founder & CEO",
+ Name = new { First = "Tobie", Last = "Morgan Hitchcock", },
+ Marketing = true,
+ Identifier = ThreadRng.Shared.Next(),
+ }
+ );
- TestHelper.AssertOk(await db.Delete(thing2));
+ TestHelper.AssertOk(res1);
- Thing thing1 = ("person", id1);
- TestHelper.AssertOk(
- await db.Change(
- thing1,
+ var res2 = await db.Create(
+ "person",
new {
- Title = "Founder & CEO",
- Name = new { First = "Tobie", Last = "Hitchcock Morgan", },
+ Title = "Contributor",
+ Name = new { First = "Prophet", Last = "Lamb", },
Marketing = false,
Identifier = ThreadRng.Shared.Next(),
}
- )
- );
+ );
- string newTitle = "Founder & CEO & Ruler of the known free World";
- var modifyResp = await db.Modify(thing1, new[] {
- Patch.Replace("/Title", newTitle),
- });
- TestHelper.AssertOk(modifyResp);
+ TestHelper.AssertOk(res2);
- TestHelper.AssertOk(await db.Let("tbl", "person"));
+ Thing thing2 = ("person", id2);
+ TestHelper.AssertOk(await db.Update(thing2, new { Marketing = false, }));
- var queryResp = await db.Query(
- "SELECT $props FROM $tbl WHERE title = $title",
- new Dictionary { ["props"] = "title, identifier", ["tbl"] = "person", ["title"] = newTitle, }
- );
+ TestHelper.AssertOk(await db.Select(thing2));
- TestHelper.AssertOk(queryResp);
+ TestHelper.AssertOk(await db.Delete(thing2));
- await db.Close();
- }
-}
+ Thing thing1 = ("person", id1);
+ TestHelper.AssertOk(
+ await db.Change(
+ thing1,
+ new {
+ Title = "Founder & CEO",
+ Name = new { First = "Tobie", Last = "Hitchcock Morgan", },
+ Marketing = false,
+ Identifier = ThreadRng.Shared.Next(),
+ }
+ )
+ );
-///
-/// The test driver executes the testsuite on the client.
-///
-[Collection("SurrealDBRequired")]
-public abstract class DriverBase
- where T : IDatabase, IDisposable, new() {
+ string newTitle = "Founder & CEO & Ruler of the known free World";
+ var modifyResp = await db.Modify(thing1, new[] { Patch.Replace("/Title", newTitle), });
+ TestHelper.AssertOk(modifyResp);
+ TestHelper.AssertOk(await db.Let("tbl", "person"));
- [Fact]
- public async Task TestSuite() {
- using var handle = await DbHandle.Create();
- await Run(handle.Database);
- }
+ var queryResp = await db.Query(
+ "SELECT $props FROM $tbl WHERE title = $title",
+ new Dictionary {
+ ["props"] = "title, identifier", ["tbl"] = "person", ["title"] = newTitle,
+ }
+ );
- protected abstract Task Run(T db);
+ TestHelper.AssertOk(queryResp);
+ }
+ );
}
diff --git a/tests/Driver.Tests/Queries/GeneralQueryTests.cs b/tests/Driver.Tests/Queries/GeneralQueryTests.cs
index 06852cb4..ccd9ac50 100644
--- a/tests/Driver.Tests/Queries/GeneralQueryTests.cs
+++ b/tests/Driver.Tests/Queries/GeneralQueryTests.cs
@@ -304,7 +304,7 @@ public async Task SimultaneousDatabaseOperations() => await DbHandle.WithData
async db => {
var taskCount = 50;
var tasks = Enumerable.Range(0, taskCount).Select(i => DbTask(i, db));
- await Task.WhenAll(tasks).ConfigureAwait(false);
+ await Task.WhenAll(tasks).Inv();
}
);
@@ -314,17 +314,17 @@ private async Task DbTask(int i, T db) {
var expectedResult = new TestObject(i, i);
Thing thing = Thing.From("object", expectedResult.Key);
- var createResponse = await db.Create(thing, expectedResult).ConfigureAwait(false);
+ var createResponse = await db.Create(thing, expectedResult).Inv();
AssertResponse(createResponse, expectedResult);
Logger.WriteLine($"Create {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}");
- var selectResponse = await db.Select(thing).ConfigureAwait(false);
+ var selectResponse = await db.Select(thing).Inv();
AssertResponse(selectResponse, expectedResult);
Logger.WriteLine($"Select {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}");
string sql = "SELECT * FROM $record";
Dictionary param = new() { ["record"] = thing };
- var queryResponse = await db.Query(sql, param).ConfigureAwait(false);
+ var queryResponse = await db.Query(sql, param).Inv();
AssertResponse(queryResponse, expectedResult);
Logger.WriteLine($"Query {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}");
diff --git a/tests/Driver.Tests/Roundtrip/RoundTripTests.cs b/tests/Driver.Tests/Roundtrip/RoundTripTests.cs
index 0d827678..51a2d74a 100644
--- a/tests/Driver.Tests/Roundtrip/RoundTripTests.cs
+++ b/tests/Driver.Tests/Roundtrip/RoundTripTests.cs
@@ -1,5 +1,3 @@
-using System.Collections;
-
using SurrealDB.Json;
using SurrealDB.Models.Result;
diff --git a/tests/Shared/ConsoleOutEventListener.cs b/tests/Shared/ConsoleOutEventListener.cs
new file mode 100644
index 00000000..117cee60
--- /dev/null
+++ b/tests/Shared/ConsoleOutEventListener.cs
@@ -0,0 +1,16 @@
+using System.Diagnostics.Tracing;
+
+namespace SurrealDB.Shared.Tests;
+
+public class ConsoleOutEventListener : EventListener {
+
+ protected override void OnEventWritten(EventWrittenEventArgs eventData) {
+ string message = $"{eventData.TimeStamp:T} - {eventData.EventSource.Name} - {eventData.EventName} - {eventData.OSThreadId}: {eventData.Message}";
+ if (eventData.Payload is null) {
+ Console.WriteLine(message);
+ } else {
+ Console.WriteLine(message, eventData.Payload.ToArray());
+ }
+ base.OnEventWritten(eventData);
+ }
+}
diff --git a/tests/Shared/DbHandle.cs b/tests/Shared/DbHandle.cs
index a6ca36b6..0581ca00 100644
--- a/tests/Shared/DbHandle.cs
+++ b/tests/Shared/DbHandle.cs
@@ -1,4 +1,7 @@
+using System.Diagnostics.Tracing;
+
using SurrealDB.Abstractions;
+using SurrealDB.Ws;
namespace SurrealDB.Shared.Tests;
@@ -20,22 +23,28 @@ public static async Task> Create() {
[DebuggerStepThrough]
public static async Task WithDatabase(Func action) {
+ // enable console logging for events
+ using TestEventListener l = new();
+ l.EnableEvents(WsReceiverInflaterEventSource.Log, EventLevel.LogAlways);
+ l.EnableEvents(WsReceiverDeflaterEventSource.Log, EventLevel.LogAlways);
+ // connect to the database
using DbHandle db = await Create();
+ // execute test methods
await action(db.Database);
}
public T Database { get; }
- ~DbHandle() {
- Dispose();
- }
-
public void Dispose() {
Process? p = _process;
_process = null;
if (p is not null) {
- Database.Dispose();
- p.Kill();
+ DisposeActual(p);
}
}
+
+ private void DisposeActual(Process p) {
+ Database.Dispose();
+ p.Kill();
+ }
}
diff --git a/tests/Shared/TestEventListener.cs b/tests/Shared/TestEventListener.cs
new file mode 100644
index 00000000..eaf5ab89
--- /dev/null
+++ b/tests/Shared/TestEventListener.cs
@@ -0,0 +1,51 @@
+using System.Diagnostics.Tracing;
+using System.Runtime.ExceptionServices;
+using System.Text;
+using System.Text.RegularExpressions;
+
+namespace SurrealDB.Shared.Tests;
+
+public class TestEventListener : EventListener {
+ private StreamWriter? _writer = new(File.OpenWrite($"./{nameof(TestEventListener)}_{DateTimeOffset.UtcNow:s}.log"));
+
+ public TestEventListener() {
+ AppDomain.CurrentDomain.FirstChanceException += FirstChanceException;
+ }
+
+ private void FirstChanceException(object? sender, FirstChanceExceptionEventArgs e) {
+ ValueStringBuilder sb = new(stackalloc char[512]);
+ sb.Append(DateTime.UtcNow.ToString("T"));
+ sb.Append(" Error: ");
+ sb.Append(e.Exception.ToString());
+
+ WriteLine(sb.ToString());
+ }
+
+ protected override void OnEventWritten(EventWrittenEventArgs eventData) {
+ string message = $"{eventData.TimeStamp:T} {eventData.EventSource.Name}: {eventData.Message} (E:{eventData.EventName}, T:{eventData.OSThreadId})";
+ if (eventData.Payload is null) {
+ WriteLine(message);
+ } else {
+ object?[] parameters = eventData.Payload.ToArray();
+ WriteLine(message, parameters);
+ }
+ base.OnEventWritten(eventData);
+ }
+
+ private void WriteLine(string message) {
+ _writer?.WriteLine(message);
+ }
+
+ private void WriteLine(string format, object?[] args) {
+ _writer?.WriteLine(format, args);
+ }
+
+ public override void Dispose() {
+ var writer = Interlocked.Exchange(ref _writer, null);
+ if (writer is not null) {
+ writer.Dispose();
+ AppDomain.CurrentDomain.FirstChanceException -= FirstChanceException;
+ }
+ base.Dispose();
+ }
+}