diff --git a/.editorconfig b/.editorconfig index 63f9d967..37b7467e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -20,7 +20,7 @@ csharp_style_var_for_built_in_types = false:none csharp_style_var_when_type_is_apparent = false:none dotnet_naming_rule.constants_rule.import_to_resharper = as_predefined dotnet_naming_rule.constants_rule.severity = warning -dotnet_naming_rule.constants_rule.style = all_upper_style +dotnet_naming_rule.constants_rule.style = upper_camel_case_style dotnet_naming_rule.constants_rule.symbols = constants_symbols dotnet_naming_rule.private_constants_rule.import_to_resharper = as_predefined dotnet_naming_rule.private_constants_rule.resharper_style = AaBb, NUM_ + AA_BB diff --git a/src/Common/ArrayBuilder.cs b/src/Common/ArrayBuilder.cs index 87bf5ef4..5c07433a 100644 --- a/src/Common/ArrayBuilder.cs +++ b/src/Common/ArrayBuilder.cs @@ -4,6 +4,7 @@ // Modified for generic types and not backed by a pool +using System.Buffers; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -63,12 +64,6 @@ public ref T this[int index] } } - public override string ToString() - { - string s = Raw.Slice(0, _pos).ToString(); - return s; - } - /// Returns the underlying storage of the builder. public Span Raw => _array.AsSpan(); @@ -185,3 +180,208 @@ private void Grow(int additionalCapacityBeyondPos) _array = newArray; } } + +internal ref struct PoolArrayBuilder +{ + private T[]? _array; + private Span _raw; + private int _pos; + + public PoolArrayBuilder(Span initialBuffer) + { + _array = null; + _raw = initialBuffer; + _pos = 0; + } + + public PoolArrayBuilder(int initialCapacity) + { + _array = ArrayPool.Shared.Rent(initialCapacity); + _raw = _array; + _pos = 0; + } + + public int Length + { + get => _pos; + set + { + Debug.Assert(value >= 0); + Debug.Assert(value <= _raw.Length); + _pos = value; + } + } + + public bool IsEmpty => 0 >= (uint)_pos; + + public bool IsDefault => _raw.IsEmpty && _array is null && _pos == 0; + + public int Capacity => _raw.Length; + + public void EnsureCapacity(int capacity) + { + // This is not expected to be called this with negative capacity + Debug.Assert(capacity >= 0); + + if ((uint)capacity > (uint)_raw.Length) + Grow(capacity - _pos); + } + + /// + /// Get a pinnable reference to the builder. + /// Does not ensure there is a null char after + /// This overload is pattern matched in the C# 7.3+ compiler so you can omit + /// the explicit method call, and write eg "fixed (char* c = builder)" + /// + public ref T GetPinnableReference() + { + return ref MemoryMarshal.GetReference(_raw); + } + + public ref T this[int index] + { + get + { + Debug.Assert(index < _pos); + return ref _raw[index]; + } + } + + /// Returns the underlying storage of the builder. + public Span Raw => _raw; + + public ReadOnlySpan AsSpan() => _raw.Slice(0, _pos); + public ReadOnlySpan AsSpan(int start) => _raw.Slice(start, _pos - start); + public ReadOnlySpan AsSpan(int start, int length) => _raw.Slice(start, length); + + public ArraySegment AsSegment() => new(_array ?? Array.Empty()); + public ArraySegment AsSegment(int start) => new(_array ?? Array.Empty(), start, _pos - start); + public ArraySegment AsSegment(int start, int length) => new(_array ?? Array.Empty(), start, length); + + public bool TryCopyTo(Span destination, out int written) { + if (_raw.Slice(0, _pos).TryCopyTo(destination)) + { + written = _pos; + return true; + } + + written = 0; + return false; + } + + public void Insert(int index, in T value, int count) + { + if (_pos > _raw.Length - count) + { + Grow(count); + } + + int remaining = _pos - index; + _raw.Slice(index, remaining).CopyTo(_raw.Slice(index + count)); + _raw.Slice(index, count).Fill(value); + _pos += count; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(T c) + { + int pos = _pos; + if ((uint)pos < (uint)_raw.Length) + { + _raw[pos] = c; + _pos = pos + 1; + } + else + { + GrowAndAppend(c); + } + } + + public void Append(in T c, int count) + { + if (_pos > _raw.Length - count) + { + Grow(count); + } + + Span dst = _raw.Slice(_pos, count); + for (int i = 0; i < dst.Length; i++) + { + dst[i] = c; + } + _pos += count; + } + + public void Append(ReadOnlySpan value) + { + int pos = _pos; + if (pos > _raw.Length - value.Length) + { + Grow(value.Length); + } + + value.CopyTo(_raw.Slice(_pos)); + _pos += value.Length; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AppendSpan(int length) + { + int origPos = _pos; + if (origPos > _raw.Length - length) + { + Grow(length); + } + + _pos = origPos + length; + return _raw.Slice(origPos, length); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowAndAppend(T c) + { + Grow(1); + Append(c); + } + + /// + /// Resize the internal buffer either by doubling current buffer size or + /// by adding to + /// whichever is greater. + /// + /// + /// Number of chars requested beyond current position. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void Grow(int additionalCapacityBeyondPos) + { + Debug.Assert(additionalCapacityBeyondPos > 0); + Debug.Assert(_pos > _raw.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); + + T[] newArray = ArrayPool.Shared.Rent( + (int)Math.Max((uint)(_pos + additionalCapacityBeyondPos), (uint)_raw.Length * 2) + ); + _raw.Slice(0, _pos).CopyTo(newArray); + _raw = newArray; + if (_array is not null) { + ArrayPool.Shared.Return(_array); + } + _array = newArray; + } + + public void Dispose() { + var array = _array; + _array = null; + _raw = default; + if (array is not null) { + ArrayPool.Shared.Return(array); + } + } + + public T[] ToArray() { + var array = new T[_pos]; + _raw.Slice(0, _pos).CopyTo(array.AsSpan(0, _pos)); + Dispose(); + return array; + } +} diff --git a/src/Common/AssemblyInfo.cs b/src/Common/AssemblyInfo.cs index 8cfad412..4b93d270 100644 --- a/src/Common/AssemblyInfo.cs +++ b/src/Common/AssemblyInfo.cs @@ -1,5 +1,6 @@ using System.Runtime.CompilerServices; +[assembly: CLSCompliant(true)] [assembly: InternalsVisibleTo("SurrealDB.Abstractions")] [assembly: InternalsVisibleTo("SurrealDB.Configuration")] [assembly: InternalsVisibleTo("SurrealDB.Driver.Rest")] diff --git a/src/Common/BitOperations.cs b/src/Common/BitOperations.cs new file mode 100644 index 00000000..3ffd9620 --- /dev/null +++ b/src/Common/BitOperations.cs @@ -0,0 +1,699 @@ +// BitOperations without intrinsics for netstandard21 support +#if !NETCOREAPP3_0_OR_GREATER +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using SurrealDB.Common; + +// ReSharper disable once CheckNamespace +namespace System.Numerics; + +public static class BitOperations +{ + // C# no-alloc optimization that directly wraps the data section of the dll (similar to string constants) + // https://github.com/dotnet/roslyn/pull/24621 + + private static ReadOnlySpan TrailingZeroCountDeBruijn => new byte[32] + { + 00, 01, 28, 02, 29, 14, 24, 03, + 30, 22, 20, 15, 25, 17, 04, 08, + 31, 27, 13, 23, 21, 19, 16, 07, + 26, 12, 18, 06, 11, 05, 10, 09 + }; + + private static ReadOnlySpan Log2DeBruijn => new byte[32] + { + 00, 09, 01, 10, 13, 21, 02, 29, + 11, 14, 16, 18, 22, 25, 03, 30, + 08, 12, 20, 28, 15, 17, 24, 07, + 19, 27, 23, 06, 26, 05, 04, 31 + }; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsPow2(int value) => (value & (value - 1)) == 0 && value > 0; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static bool IsPow2(uint value) => (value & (value - 1)) == 0 && value != 0; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsPow2(long value) => (value & (value - 1)) == 0 && value > 0; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static bool IsPow2(ulong value) => (value & (value - 1)) == 0 && value != 0; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsPow2(nint value) => (value & (value - 1)) == 0 && value > 0; + + /// + /// Evaluate whether a given integral value is a power of 2. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static bool IsPow2(nuint value) => (value & (value - 1)) == 0 && value != 0; + + /// Round the given integral value up to a power of 2. + /// The value. + /// + /// The smallest power of 2 which is greater than or equal to . + /// If is 0 or the result overflows, returns 0. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static uint RoundUpToPowerOf2(uint value) + { + // Based on https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + --value; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + return value + 1; + } + + /// + /// Round the given integral value up to a power of 2. + /// + /// The value. + /// + /// The smallest power of 2 which is greater than or equal to . + /// If is 0 or the result overflows, returns 0. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static ulong RoundUpToPowerOf2(ulong value) + { + // Based on https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + --value; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return value + 1; + } + + /// + /// Round the given integral value up to a power of 2. + /// + /// The value. + /// + /// The smallest power of 2 which is greater than or equal to . + /// If is 0 or the result overflows, returns 0. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static nuint RoundUpToPowerOf2(nuint value) + { +#if TARGET_64BIT + return (nuint)RoundUpToPowerOf2((ulong)value); +#else + return (nuint)RoundUpToPowerOf2((uint)value); +#endif + } + + /// + /// Count the number of leading zero bits in a mask. + /// Similar in behavior to the x86 instruction LZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int LeadingZeroCount(uint value) + { + // Unguarded fallback contract is 0->31, BSR contract is 0->undefined + if (value == 0) + { + return 32; + } + + return 31 ^ Log2SoftwareFallback(value); + } + + /// + /// Count the number of leading zero bits in a mask. + /// Similar in behavior to the x86 instruction LZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int LeadingZeroCount(ulong value) + { + uint hi = (uint)(value >> 32); + + if (hi == 0) + { + return 32 + LeadingZeroCount((uint)value); + } + + return LeadingZeroCount(hi); + } + + /// + /// Count the number of leading zero bits in a mask. + /// Similar in behavior to the x86 instruction LZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int LeadingZeroCount(nuint value) + { +#if TARGET_64BIT + return LeadingZeroCount((ulong)value); +#else + return LeadingZeroCount((uint)value); +#endif + } + + /// + /// Returns the integer (floor) log of the specified value, base 2. + /// Note that by convention, input value 0 returns 0 since log(0) is undefined. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int Log2(uint value) + { + // The 0->0 contract is fulfilled by setting the LSB to 1. + // Log(1) is 0, and setting the LSB for values > 1 does not change the log2 result. + value |= 1; + + // Fallback contract is 0->0 + return Log2SoftwareFallback(value); + } + + /// + /// Returns the integer (floor) log of the specified value, base 2. + /// Note that by convention, input value 0 returns 0 since log(0) is undefined. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int Log2(ulong value) + { + value |= 1; + + uint hi = (uint)(value >> 32); + + if (hi == 0) + { + return Log2((uint)value); + } + + return 32 + Log2(hi); + } + + /// + /// Returns the integer (floor) log of the specified value, base 2. + /// Note that by convention, input value 0 returns 0 since log(0) is undefined. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int Log2(nuint value) + { +#if TARGET_64BIT + return Log2((ulong)value); +#else + return Log2((uint)value); +#endif + } + + /// + /// Returns the integer (floor) log of the specified value, base 2. + /// Note that by convention, input value 0 returns 0 since Log(0) is undefined. + /// Does not directly use any hardware intrinsics, nor does it incur branching. + /// + /// The value. + private static int Log2SoftwareFallback(uint value) + { + // No AggressiveInlining due to large method size + // Has conventional contract 0->0 (Log(0) is undefined) + + // Fill trailing zeros with ones, eg 00010010 becomes 00011111 + value |= value >> 01; + value |= value >> 02; + value |= value >> 04; + value |= value >> 08; + value |= value >> 16; + + // uint.MaxValue >> 27 is always in range [0 - 31] so we use Unsafe.AddByteOffset to avoid bounds check + return Unsafe.AddByteOffset( + // Using deBruijn sequence, k=2, n=5 (2^5=32) : 0b_0000_0111_1100_0100_1010_1100_1101_1101u + ref MemoryMarshal.GetReference(Log2DeBruijn), + // uint|long -> IntPtr cast on 32-bit platforms does expensive overflow checks not needed here + (IntPtr)(int)((value * 0x07C4ACDDu) >> 27)); + } + + /// Returns the integer (ceiling) log of the specified value, base 2. + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int Log2Ceiling(uint value) + { + int result = Log2(value); + if (PopCount(value) != 1) + { + result++; + } + return result; + } + + /// Returns the integer (ceiling) log of the specified value, base 2. + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int Log2Ceiling(ulong value) + { + int result = Log2(value); + if (PopCount(value) != 1) + { + result++; + } + return result; + } + + /// + /// Returns the population count (number of bits set) of a mask. + /// Similar in behavior to the x86 instruction POPCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int PopCount(uint value) + { + const uint c1 = 0x_55555555u; + const uint c2 = 0x_33333333u; + const uint c3 = 0x_0F0F0F0Fu; + const uint c4 = 0x_01010101u; + + value -= (value >> 1) & c1; + value = (value & c2) + ((value >> 2) & c2); + value = (((value + (value >> 4)) & c3) * c4) >> 24; + + return (int)value; + } + + /// + /// Returns the population count (number of bits set) of a mask. + /// Similar in behavior to the x86 instruction POPCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int PopCount(ulong value) + { +#if TARGET_32BIT + return PopCount((uint)value) // lo + + PopCount((uint)(value >> 32)); // hi +#else + const ulong c1 = 0x_55555555_55555555ul; + const ulong c2 = 0x_33333333_33333333ul; + const ulong c3 = 0x_0F0F0F0F_0F0F0F0Ful; + const ulong c4 = 0x_01010101_01010101ul; + + value -= (value >> 1) & c1; + value = (value & c2) + ((value >> 2) & c2); + value = (((value + (value >> 4)) & c3) * c4) >> 56; + + return (int)value; +#endif + } + + /// + /// Returns the population count (number of bits set) of a mask. + /// Similar in behavior to the x86 instruction POPCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int PopCount(nuint value) + { +#if TARGET_64BIT + return PopCount((ulong)value); +#else + return PopCount((uint)value); +#endif + } + + /// + /// Count the number of trailing zero bits in an integer value. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int TrailingZeroCount(int value) + => TrailingZeroCount((uint)value); + + /// + /// Count the number of trailing zero bits in an integer value. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [CLSCompliant(false)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int TrailingZeroCount(uint value) + { + // Unguarded fallback contract is 0->0, BSF contract is 0->undefined + if (value == 0) + { + return 32; + } + + // uint.MaxValue >> 27 is always in range [0 - 31] so we use Unsafe.AddByteOffset to avoid bounds check + return Unsafe.AddByteOffset( + // Using deBruijn sequence, k=2, n=5 (2^5=32) : 0b_0000_0111_0111_1100_1011_0101_0011_0001u + ref MemoryMarshal.GetReference(TrailingZeroCountDeBruijn), + // uint|long -> IntPtr cast on 32-bit platforms does expensive overflow checks not needed here + (IntPtr)(int)(((value & (uint)-(int)value) * 0x077CB531u) >> 27)); // Multi-cast mitigates redundant conv.u8 + } + + /// + /// Count the number of trailing zero bits in a mask. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int TrailingZeroCount(long value) + => TrailingZeroCount((ulong)value); + + /// + /// Count the number of trailing zero bits in a mask. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int TrailingZeroCount(ulong value) + { + uint lo = (uint)value; + + if (lo == 0) + { + return 32 + TrailingZeroCount((uint)(value >> 32)); + } + + return TrailingZeroCount(lo); + } + + /// + /// Count the number of trailing zero bits in a mask. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int TrailingZeroCount(nint value) + => TrailingZeroCount((nuint)value); + + /// + /// Count the number of trailing zero bits in a mask. + /// Similar in behavior to the x86 instruction TZCNT. + /// + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static int TrailingZeroCount(nuint value) + { +#if TARGET_64BIT + return TrailingZeroCount((ulong)value); +#else + return TrailingZeroCount((uint)value); +#endif + } + + /// + /// Rotates the specified value left by the specified number of bits. + /// Similar in behavior to the x86 instruction ROL. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static uint RotateLeft(uint value, int offset) + => (value << offset) | (value >> (32 - offset)); + + /// + /// Rotates the specified value left by the specified number of bits. + /// Similar in behavior to the x86 instruction ROL. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..63] is treated as congruent mod 64. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static ulong RotateLeft(ulong value, int offset) + => (value << offset) | (value >> (64 - offset)); + + /// + /// Rotates the specified value left by the specified number of bits. + /// Similar in behavior to the x86 instruction ROL. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32 on a 32-bit process, + /// and any value outside the range [0..63] is treated as congruent mod 64 on a 64-bit process. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static nuint RotateLeft(nuint value, int offset) + { +#if TARGET_64BIT + return (nuint)RotateLeft((ulong)value, offset); +#else + return (nuint)RotateLeft((uint)value, offset); +#endif + } + + /// + /// Rotates the specified value right by the specified number of bits. + /// Similar in behavior to the x86 instruction ROR. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static uint RotateRight(uint value, int offset) + => (value >> offset) | (value << (32 - offset)); + + /// + /// Rotates the specified value right by the specified number of bits. + /// Similar in behavior to the x86 instruction ROR. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..63] is treated as congruent mod 64. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static ulong RotateRight(ulong value, int offset) + => (value >> offset) | (value << (64 - offset)); + + /// + /// Rotates the specified value right by the specified number of bits. + /// Similar in behavior to the x86 instruction ROR. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32 on a 32-bit process, + /// and any value outside the range [0..63] is treated as congruent mod 64 on a 64-bit process. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [CLSCompliant(false)] + public static nuint RotateRight(nuint value, int offset) + { +#if TARGET_64BIT + return (nuint)RotateRight((ulong)value, offset); +#else + return (nuint)RotateRight((uint)value, offset); +#endif + } + + /// + /// Accumulates the CRC (Cyclic redundancy check) checksum. + /// + /// The base value to calculate checksum on + /// The data for which to compute the checksum + /// The CRC-checksum + [CLSCompliant(false)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Crc32C(uint crc, byte data) + { + return Crc32Fallback.Crc32C(crc, data); + } + + /// + /// Accumulates the CRC (Cyclic redundancy check) checksum. + /// + /// The base value to calculate checksum on + /// The data for which to compute the checksum + /// The CRC-checksum + [CLSCompliant(false)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Crc32C(uint crc, ushort data) + { + return Crc32Fallback.Crc32C(crc, data); + } + + /// + /// Accumulates the CRC (Cyclic redundancy check) checksum. + /// + /// The base value to calculate checksum on + /// The data for which to compute the checksum + /// The CRC-checksum + [CLSCompliant(false)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Crc32C(uint crc, uint data) + { + return Crc32Fallback.Crc32C(crc, data); + } + + /// + /// Accumulates the CRC (Cyclic redundancy check) checksum. + /// + /// The base value to calculate checksum on + /// The data for which to compute the checksum + /// The CRC-checksum + [CLSCompliant(false)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint Crc32C(uint crc, ulong data) + { + return Crc32Fallback.Crc32C(crc, data); + } + + private static class Crc32Fallback + { + // Pre-computed CRC-32 transition table. + // While this implementation is based on the Castagnoli CRC-32 polynomial (CRC-32C), + // x32 + x28 + x27 + x26 + x25 + x23 + x22 + x20 + x19 + x18 + x14 + x13 + x11 + x10 + x9 + x8 + x6 + x0, + // this version uses reflected bit ordering, so 0x1EDC6F41 becomes 0x82F63B78u + private static readonly uint[] s_crcTable = Crc32ReflectedTable.Generate(0x82F63B78u); + + internal static uint Crc32C(uint crc, byte data) + { + ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable); + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8); + + return crc; + } + + internal static uint Crc32C(uint crc, ushort data) + { + ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable); + + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8); + data >>= 8; + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8); + + return crc; + } + + internal static uint Crc32C(uint crc, uint data) + { + ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable); + return Crc32CCore(ref lookupTable, crc, data); + } + + internal static uint Crc32C(uint crc, ulong data) + { + ref uint lookupTable = ref MemoryHelper.GetArrayDataReference(s_crcTable); + + crc = Crc32CCore(ref lookupTable, crc, (uint)data); + data >>= 32; + crc = Crc32CCore(ref lookupTable, crc, (uint)data); + + return crc; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint Crc32CCore(ref uint lookupTable, uint crc, uint data) + { + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8); + data >>= 8; + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8); + data >>= 8; + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ (byte)data)) ^ (crc >> 8); + data >>= 8; + crc = Unsafe.Add(ref lookupTable, (nint)(byte)(crc ^ data)) ^ (crc >> 8); + + return crc; + } + } + + /// + /// Reset the lowest significant bit in the given value + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static uint ResetLowestSetBit(uint value) + { + // It's lowered to BLSR on x86 + return value & (value - 1); + } + + /// + /// Reset specific bit in the given value + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static uint ResetBit(uint value, int bitPos) + { + // TODO: Recognize BTR on x86 and LSL+BIC on ARM + return value & ~(uint)(1 << bitPos); + } +} +#endif + + +internal static class Crc32ReflectedTable +{ + internal static uint[] Generate(uint reflectedPolynomial) + { + uint[] table = new uint[256]; + + for (int i = 0; i < 256; i++) + { + uint val = (uint)i; + + for (int j = 0; j < 8; j++) + { + if ((val & 0b0000_0001) == 0) + { + val >>= 1; + } + else + { + val = (val >> 1) ^ reflectedPolynomial; + } + } + + table[i] = val; + } + + return table; + } +} diff --git a/src/Common/BufferStreamReader.cs b/src/Common/BufferStreamReader.cs new file mode 100644 index 00000000..942c89e2 --- /dev/null +++ b/src/Common/BufferStreamReader.cs @@ -0,0 +1,113 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +/// Allows reading a stream efficiently +public struct BufferStreamReader : IDisposable { + private Stream? _arbitraryStream; + private MemoryStream? _memoryStream; + private readonly int _bufferSize; + private byte[]? _poolArray; + + private BufferStreamReader(Stream? arbitraryStream, MemoryStream? memoryStream, int bufferSize) { + _arbitraryStream = arbitraryStream; + _memoryStream = memoryStream; + _bufferSize = bufferSize; + _poolArray = null; + } + + public Stream Stream => _memoryStream ?? _arbitraryStream!; + + public BufferStreamReader(Stream stream, int bufferSize) { + ThrowArgIfStreamCantRead(stream); + this = stream switch { + // inhering from ms doesnt guarantee a good GetBuffer impl, such as RecyclableMemoryStream, + // therefore we have to check the exact runtime type. + MemoryStream ms when ms.GetType() == typeof(MemoryStream) => new(null, ms, bufferSize), + _ => new(stream, null, bufferSize) + }; + } + + /// Reads up to bytes from the underlying . + /// The expected number of bytes to read + /// The cancellation token + /// The context bound memory representing the bytes read. + /// The returned memory is invalid outside this instance. Do not reference the memory outside of the scope! + public ValueTask> ReadAsync(int expectedSize, CancellationToken ct = default) { + var memoryStream = _memoryStream; + var stream = _arbitraryStream; + ThrowIfNull(stream is null & memoryStream is null); + + if (memoryStream is not null) { + if (memoryStream.TryReadToBuffer(expectedSize, out ReadOnlyMemory read)) { + return new(read); + } + + // unable to access the memory stream buffer + // handle as regular stream + stream = memoryStream; + } + + Debug.Assert(stream is not null); + // reserve the buffer + var buffer = _poolArray; + if (buffer is null) { + _poolArray = buffer = ArrayPool.Shared.Rent(_bufferSize); + } + + // negative buffer size -> read as much as possible + expectedSize = expectedSize < 0 ? buffer.Length : expectedSize; + + return new(stream.ReadToBufferAsync(buffer.AsMemory(0, Math.Min(buffer.Length, expectedSize)), ct)); + } + + /// + public ReadOnlySpan Read(int expectedSize) { + var memoryStream = _memoryStream; + var stream = _arbitraryStream; + ThrowIfNull(stream is null & memoryStream is null); + + if (memoryStream is not null) { + if (memoryStream.TryReadToBuffer(expectedSize, out ReadOnlySpan read)) { + return read; + } + + // unable to access the memory stream buffer + // handle as regular stream + stream = memoryStream; + } + + Debug.Assert(stream is not null); + // reserve the buffer + var buffer = _poolArray; + if (buffer is null) { + _poolArray = buffer = ArrayPool.Shared.Rent(_bufferSize); + } + + return stream.ReadToBuffer(buffer.AsSpan(0, Math.Min(buffer.Length, expectedSize))); + } + + + public void Dispose() { + var poolArray = _poolArray; + _poolArray = null; + if (poolArray is not null) { + ArrayPool.Shared.Return(poolArray); + } + } + + private static void ThrowIfNull([DoesNotReturnIf(true)] bool isNull, [CallerArgumentExpression(nameof(isNull))] string expression = "") { + if (isNull) { + throw new InvalidOperationException($"The expression cannot be null. `{expression}`"); + } + } + + private static void ThrowArgIfStreamCantRead(Stream stream, [CallerArgumentExpression(nameof(stream))] string argName = "") { + if (!stream.CanRead) { + throw new ArgumentException("The stream must be readable", argName); + } + } +} diff --git a/src/Common/CallerArgumentExpressionAttribute.cs b/src/Common/CallerArgumentExpressionAttribute.cs index 6f7f11b7..e853763a 100644 --- a/src/Common/CallerArgumentExpressionAttribute.cs +++ b/src/Common/CallerArgumentExpressionAttribute.cs @@ -1,5 +1,5 @@ // ReSharper disable CheckNamespace -#if !(NET6_0 || NET_5_0 || NET5_0_OR_GREATER || NETCOREAPP3_0_OR_GREATER) +#if !NETCOREAPP3_0_OR_GREATER #pragma warning disable IDE0130 namespace System.Runtime.CompilerServices; diff --git a/src/Common/DisposingCache.cs b/src/Common/DisposingCache.cs new file mode 100644 index 00000000..85e11c8c --- /dev/null +++ b/src/Common/DisposingCache.cs @@ -0,0 +1,101 @@ +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +/// Thread-safe sliding cache that disposed values when evicting +internal sealed class DisposingCache + where K : notnull + where V : IDisposable { + private int _evictLock; + private long _lastEvictedTicks; // timestamp of latest eviction operation. + private readonly long _evictionIntervalTicks; // min timespan needed to trigger a new evict operation. + private readonly long _slidingExpirationTicks; // max timespan allowed for cache entries to remain inactive. + private readonly ConcurrentDictionary _cache = new(); + + public DisposingCache(TimeSpan slidingExpiration, TimeSpan evictionInterval) { + _slidingExpirationTicks = slidingExpiration.Ticks; + _evictionIntervalTicks = evictionInterval.Ticks; + _lastEvictedTicks = DateTime.UtcNow.Ticks; + } + + public V GetOrAdd(K key, Func valueFactory) { + CacheEntry entry = _cache.GetOrAdd(key, static (key, valueFactory) => new(valueFactory(key)), valueFactory); + EnsureSlidingEviction(entry); + + return entry.Value; + } + + public bool TryAdd(K key, V value) { + CacheEntry entry = new(value); + bool added = _cache.TryAdd(key, entry); + EnsureSlidingEviction(entry); + + return added; + } + + public bool TryRemove(K key, [MaybeNullWhen(false)] out V value) { + if (_cache.TryRemove(key, out var entry)) { + EnsureSlidingEviction(entry); + value = entry.Value; + return true; + } + + value = default; + return false; + } + + public bool TryGetValue(K key, [MaybeNullWhen(false)] out V value) { + if (_cache.TryRemove(key, out var entry)) { + EnsureSlidingEviction(entry); + value = entry.Value; + return true; + } + + value = default; + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EnsureSlidingEviction(CacheEntry entry) { + long utcNowTicks = DateTime.UtcNow.Ticks; + Volatile.Write(ref entry.LastUsedTicks, utcNowTicks); + + if (utcNowTicks - Volatile.Read(ref _lastEvictedTicks) >= _evictionIntervalTicks) { + if (Interlocked.CompareExchange(ref _evictLock, 1, 0) == 0) { + if (utcNowTicks - _lastEvictedTicks >= _evictionIntervalTicks) { + EvictStaleCacheEntries(utcNowTicks); + Volatile.Write(ref _lastEvictedTicks, utcNowTicks); + } + + Volatile.Write(ref _evictLock, 0); + } + } + } + + public void Clear() { + _cache.Clear(); + _lastEvictedTicks = DateTime.UtcNow.Ticks; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void EvictStaleCacheEntries(long utcNowTicks) { + foreach (KeyValuePair kvp in _cache) { + if (utcNowTicks - Volatile.Read(ref kvp.Value.LastUsedTicks) >= _slidingExpirationTicks) { + if (_cache.TryRemove(kvp.Key, out var entry)) { + entry.Value.Dispose(); + } + } + } + } + + private sealed class CacheEntry { + public readonly V Value; + public long LastUsedTicks; + + public CacheEntry(V value) { + Value = value; + } + } +} diff --git a/src/Common/GCHelper.cs b/src/Common/GCHelper.cs new file mode 100644 index 00000000..93d130d9 --- /dev/null +++ b/src/Common/GCHelper.cs @@ -0,0 +1,55 @@ +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; + +// ReSharper disable once CheckNamespace +namespace System.Buffers; + +internal static class GCHelper { + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int SelectBucketIndex(int bufferSize) + { + // Buffers are bucketed so that a request between 2^(n-1) + 1 and 2^n is given a buffer of 2^n + // Bucket index is log2(bufferSize - 1) with the exception that buffers between 1 and 16 bytes + // are combined, and the index is slid down by 3 to compensate. + // Zero is a valid bufferSize, and it is assigned the highest bucket index so that zero-length + // buffers are not retained by the pool. The pool will return the Array.Empty singleton for these. + return BitOperations.Log2((uint)bufferSize - 1 | 15) - 3; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int GetMaxSizeForBucket(int binIndex) + { + int maxSize = 16 << binIndex; + Debug.Assert(maxSize >= 0); + return maxSize; + } + + internal enum MemoryPressure + { + Low, + Medium, + High + } + + internal static MemoryPressure GetMemoryPressure() + { +#if NETCOREAPP3_0_OR_GREATER + const double HighPressureThreshold = .90; // Percent of GC memory pressure threshold we consider "high" + const double MediumPressureThreshold = .70; // Percent of GC memory pressure threshold we consider "medium" + GCMemoryInfo memoryInfo = GC.GetGCMemoryInfo(); + + if (memoryInfo.MemoryLoadBytes >= memoryInfo.HighMemoryLoadThresholdBytes * HighPressureThreshold) + { + return MemoryPressure.High; + } + + if (memoryInfo.MemoryLoadBytes >= memoryInfo.HighMemoryLoadThresholdBytes * MediumPressureThreshold) + { + return MemoryPressure.Medium; + } +#endif + return MemoryPressure.Low; + } +} diff --git a/src/Common/Gen2GCCallback.cs b/src/Common/Gen2GCCallback.cs new file mode 100644 index 00000000..81a4eb5e --- /dev/null +++ b/src/Common/Gen2GCCallback.cs @@ -0,0 +1,112 @@ +using System.Diagnostics; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; + +// Source: https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Gen2GcCallback.cs + +// ReSharper disable once CheckNamespace +namespace System; + +/// +/// Schedules a callback roughly every gen 2 GC (you may see a Gen 0 an Gen 1 but only once) +/// (We can fix this by capturing the Gen 2 count at startup and testing, but I mostly don't care) +/// +internal sealed class Gen2GcCallback : CriticalFinalizerObject +{ + private readonly Func? _callback0; + private readonly Func? _callback1; + private GCHandle _weakTargetObj; + + private Gen2GcCallback(Func callback) + { + _callback0 = callback; + } + + private Gen2GcCallback(Func callback, object targetObj) + { + _callback1 = callback; + _weakTargetObj = GCHandle.Alloc(targetObj, GCHandleType.Weak); + } + + /// + /// Schedule 'callback' to be called in the next GC. If the callback returns true it is + /// rescheduled for the next Gen 2 GC. Otherwise the callbacks stop. + /// + public static void Register(Func callback) + { + // Create a unreachable object that remembers the callback function and target object. + new Gen2GcCallback(callback); + } + + /// + /// Schedule 'callback' to be called in the next GC. If the callback returns true it is + /// rescheduled for the next Gen 2 GC. Otherwise the callbacks stop. + /// + /// NOTE: This callback will be kept alive until either the callback function returns false, + /// or the target object dies. + /// + public static void Register(Func callback, object targetObj) + { + // Create a unreachable object that remembers the callback function and target object. + new Gen2GcCallback(callback, targetObj); + } + + ~Gen2GcCallback() + { + if (_weakTargetObj.IsAllocated) + { + // Check to see if the target object is still alive. + object? targetObj = _weakTargetObj.Target; + if (targetObj == null) + { + // The target object is dead, so this callback object is no longer needed. + _weakTargetObj.Free(); + return; + } + + // Execute the callback method. + try + { + Debug.Assert(_callback1 != null); + if (!_callback1(targetObj)) + { + // If the callback returns false, this callback object is no longer needed. + _weakTargetObj.Free(); + return; + } + } + catch + { + // Ensure that we still get a chance to resurrect this object, even if the callback throws an exception. +#if DEBUG + // Except in DEBUG, as we really shouldn't be hitting any exceptions here. + throw; +#endif + } + } + else + { + // Execute the callback method. + try + { + Debug.Assert(_callback0 != null); + if (!_callback0()) + { + // If the callback returns false, this callback object is no longer needed. + return; + } + } + catch + { + // Ensure that we still get a chance to resurrect this object, even if the callback throws an exception. +#if DEBUG + // Except in DEBUG, as we really shouldn't be hitting any exceptions here. + throw; +#endif + } + } + + // Resurrect ourselves by re-registering for finalization. + GC.ReRegisterForFinalize(this); + } +} diff --git a/src/Common/IsExternalInit.cs b/src/Common/IsExternalInit.cs index 4129a311..7f1adb48 100644 --- a/src/Common/IsExternalInit.cs +++ b/src/Common/IsExternalInit.cs @@ -1,5 +1,5 @@ // ReSharper disable CheckNamespace -#if !(NET6_0 || NET_5_0 || NET5_0_OR_GREATER) +#if !NET5_0_OR_GREATER #pragma warning disable IDE0130 namespace System.Runtime.CompilerServices; diff --git a/src/Common/LatentReadonly.cs b/src/Common/LatentReadonly.cs new file mode 100644 index 00000000..1405a50e --- /dev/null +++ b/src/Common/LatentReadonly.cs @@ -0,0 +1,34 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +public abstract record LatentReadonly { + private bool _isReadonly; + + protected void MakeReadonly() { + _isReadonly = true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected bool IsReadonly() => _isReadonly; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected void Set(out T field, in T value) { + ThrowIfReadonly(); + field = value; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ThrowIfReadonly() { + if (_isReadonly) { + ThrowReadonly(); + } + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowReadonly() { + throw new InvalidOperationException("The object is readonly and cannot be mutated."); + } +} diff --git a/src/Common/MemoryExtensions.cs b/src/Common/MemoryExtensions.cs deleted file mode 100644 index e2106d58..00000000 --- a/src/Common/MemoryExtensions.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace SurrealDB.Common; - -internal static class MemoryExtensions { - public static ReadOnlySpan SliceToMin(in this ReadOnlySpan span, int length) { - return span.Length <= length ? span : span.Slice(0, length); - } -} diff --git a/src/Common/MemoryHelper.cs b/src/Common/MemoryHelper.cs new file mode 100644 index 00000000..09407f2e --- /dev/null +++ b/src/Common/MemoryHelper.cs @@ -0,0 +1,41 @@ +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +internal static class MemoryHelper { + /// + /// Returns a reference to the 0th element of . If the array is empty, returns a reference to where the 0th element + /// would have been stored. Such a reference may be used for pinning but must never be dereferenced. + /// + /// is . + /// + /// This method does not perform array variance checks. The caller must manually perform any array variance checks + /// if the caller wishes to write to the returned reference. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ref T GetArrayDataReference(T[] array) { +#if !NET6_0_OR_GREATER + return ref Unsafe.As(ref Unsafe.As(array).Data); +#else + return ref System.Runtime.InteropServices.MemoryMarshal.GetArrayDataReference(array); +#endif + } + + + // CLR arrays are laid out in memory as follows (multidimensional array bounds are optional): + // [ sync block || pMethodTable || num components || MD array bounds || array data .. ] + // ^ ^ ^ ^ returned reference + // | | \-- ref Unsafe.As(array).Data + // \-- array \-- ref Unsafe.As(array).Data + // The BaseSize of an array includes all the fields before the array data, + // including the sync block and method table. The reference to RawData.Data + // points at the number of components, skipping over these two pointer-sized fields. + internal sealed class RawArrayData + { + public uint Length; // Array._numComponents padded to IntPtr +#if TARGET_64BIT + public uint Padding; +#endif + public byte Data; + } +} diff --git a/src/Common/NullableAttributes.cs b/src/Common/NullableAttributes.cs new file mode 100644 index 00000000..bf53cdf6 --- /dev/null +++ b/src/Common/NullableAttributes.cs @@ -0,0 +1,66 @@ +// ReSharper disable CheckNamespace +#if !NET5_0_OR_GREATER + +#pragma warning disable IDE0130 +namespace System.Diagnostics.CodeAnalysis; +#pragma warning restore IDE0130 + +/// Specifies that the method or property will ensure that the listed field and property members have not-null values when returning with the specified return value condition. +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] +internal sealed class MemberNotNullWhenAttribute : Attribute +{ + /// Initializes the attribute with the specified return value condition and a field or property member. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, string member) + { + ReturnValue = returnValue; + Members = new[] { member }; + } + + /// Initializes the attribute with the specified return value condition and list of field and property members. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, params string[] members) + { + ReturnValue = returnValue; + Members = members; + } + + /// Gets the return value condition. + public bool ReturnValue { get; } + + /// Gets field or property member names. + public string[] Members { get; } +} + +/// Specifies that the method or property will ensure that the listed field and property members have not-null values. +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] +internal sealed class MemberNotNullAttribute : Attribute +{ + /// Initializes the attribute with a field or property member. + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullAttribute(string member) => Members = new[] { member }; + + /// Initializes the attribute with the list of field and property members. + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullAttribute(params string[] members) => Members = members; + + /// Gets field or property member names. + public string[] Members { get; } +} + + +#endif diff --git a/src/Common/StreamExtensions.cs b/src/Common/StreamExtensions.cs new file mode 100644 index 00000000..9e299425 --- /dev/null +++ b/src/Common/StreamExtensions.cs @@ -0,0 +1,45 @@ +namespace SurrealDB.Common; + +internal static class StreamExtensions { + public static bool TryReadToBuffer(this MemoryStream memoryStream, int expectedSize, out ReadOnlyMemory read) { + if (!memoryStream.TryGetBuffer(out var buffer)) { + read = default; + return false; + } + // negative size -> read to end + expectedSize = expectedSize < 0 ? Int32.MaxValue : expectedSize; + // fake a read call + var pos = (int)memoryStream.Position; + var cap = (int)memoryStream.Length; + var len = Math.Min(expectedSize, cap - pos); + memoryStream.Position += len; + read = buffer.AsMemory(pos, len); + return true; + } + + public static bool TryReadToBuffer(this MemoryStream memoryStream, int expectedSize, out ReadOnlySpan read) { + if (!memoryStream.TryGetBuffer(out var buffer)) { + read = default; + return false; + } + // negative size -> read to end + expectedSize = expectedSize < 0 ? Int32.MaxValue : expectedSize; + // fake a read call + var pos = (int)memoryStream.Position; + var cap = (int)memoryStream.Length; + var len = Math.Min(expectedSize, cap - pos); + memoryStream.Position += len; + read = buffer.AsSpan(pos, len); + return true; + } + + public static async Task> ReadToBufferAsync(this Stream stream, Memory buffer, CancellationToken ct) { + int read = await stream.ReadAsync(buffer, ct); + return buffer.Slice(0, read); + } + + public static ReadOnlySpan ReadToBuffer(this Stream stream, Span buffer) { + int read = stream.Read(buffer); + return buffer.Slice(0, read); + } +} diff --git a/src/Common/StringHelper.cs b/src/Common/StringHelper.cs index 7952c0c0..4ec80848 100644 --- a/src/Common/StringHelper.cs +++ b/src/Common/StringHelper.cs @@ -16,7 +16,7 @@ public static bool IsEmpty([NotNullWhen(false)] this string? str) { } public static string Concat(ReadOnlySpan p0, ReadOnlySpan p1, ReadOnlySpan p2 = default, ReadOnlySpan p3 = default) { -#if NET5_0_OR_GREATER || NETCOREAPP3_0_OR_GREATER +#if NETCOREAPP3_0_OR_GREATER return String.Concat(p0, p1, p2, p3); #else int cap = p0.Length + p1.Length + p2.Length + p3.Length; diff --git a/src/Common/TaskExtensions.cs b/src/Common/TaskExtensions.cs new file mode 100644 index 00000000..9661c73e --- /dev/null +++ b/src/Common/TaskExtensions.cs @@ -0,0 +1,59 @@ +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +/// Extension methods for Tasks +public static class TaskExtensions { + /// The task is invariant. + /// The previous context is not restored upon completion. + /// Equivalent to Task.ConfigureAwait(false). + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ConfiguredTaskAwaitable Inv(this Task t) => t.ConfigureAwait(false); + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ConfiguredTaskAwaitable Inv(this Task t) => t.ConfigureAwait(false); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ConfiguredValueTaskAwaitable Inv(in this ValueTask t) => t.ConfigureAwait(false); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ConfiguredValueTaskAwaitable Inv(in this ValueTask t) => t.ConfigureAwait(false); + + /// Creates a task awaiting the . + /// the handle is null + public static Task ToTask(this WaitHandle handle) + { + if (handle == null) { + throw new ArgumentNullException(nameof(handle)); + } + + TaskCompletionSource tcs = new(); + RegisteredWaitHandle? shared = null; + RegisteredWaitHandle produced = ThreadPool.RegisterWaitForSingleObject( + handle, + (state, timedOut) => + { + tcs.SetResult(null); + + while (true) + { + RegisteredWaitHandle? consumed = Interlocked.CompareExchange(ref shared, null, null); + if (consumed is not null) + { + consumed.Unregister(null); + break; + } + } + }, + state: null, + millisecondsTimeOutInterval: Timeout.Infinite, + executeOnlyOnce: true); + + // Publish the RegisteredWaitHandle so that the callback can see it. + Interlocked.CompareExchange(ref shared, produced, null); + + return tcs.Task; + } + +} diff --git a/src/Common/TaskList.cs b/src/Common/TaskList.cs deleted file mode 100644 index 10d13c9a..00000000 --- a/src/Common/TaskList.cs +++ /dev/null @@ -1,119 +0,0 @@ -using System.Collections; -using System.Diagnostics; - -namespace SurrealDB.Common; - -public sealed class TaskList { - private readonly object _lock = new(); - private readonly Node _root; - private Node _tail; - private int _len; - - public TaskList() { - _root = _tail = new(Task.CompletedTask); - } - - public void Add(Task task) { - lock (_lock) { - Debug.Assert(_tail.Next is null); - Node tail = new(task) { Prev = _tail }; - _tail.Next = tail; - _tail = tail; - _len += 1; - } - } - - public void Trim() { - lock (_lock) { - Node? pos = _root; - do { - Node cur = pos; - pos = pos.Next; - Task task = cur.Task; - if (task.IsCompleted) { - Remove(cur); - } - } while (pos is not null); - } - } - - public ValueTask WhenAll() { - return _len == 0 ? default : new(Task.WhenAll(Drain())); - } - - public DrainIterator Drain() { - return new DrainIterator(this); - } - - /// - /// Removes the node from the list. Requires _lock! - /// - private bool Remove(Node node) { - if (Object.ReferenceEquals(_root, node)) { - // Do not remove the root! - return false; - } - Node? prev = node.Prev; - Node? next = node.Next; - if (Object.ReferenceEquals(_tail, node)) { - _tail = prev!; // cannot be null because of _root - } - if (prev is not null) { - prev.Next = next; - } - if (next is not null) { - next.Prev = prev; - } - - _len -= 1; - return true; - } - - private sealed class Node { - public readonly Task Task; - public Node? Next; - public Node? Prev; - - public Node(Task task) { - Task = task; - } - } - - public struct DrainIterator : IEnumerable, IEnumerator { - private readonly TaskList _list; - - public DrainIterator(TaskList list) { - _list = list; - } - - public DrainIterator GetEnumerator() { - return new(_list); - } - - IEnumerator IEnumerable.GetEnumerator() { - return GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() { - return GetEnumerator(); - } - - public bool MoveNext() { - lock (_list._lock) { - return _list.Remove(_list._tail); - } - } - - public void Reset() { - throw new NotSupportedException("Cannot be reset"); - } - - public Task Current => _list._tail.Task; - - object IEnumerator.Current => Current; - - public void Dispose() { - // not needed - } - } -} diff --git a/src/Common/ValidateReadonly.cs b/src/Common/ValidateReadonly.cs new file mode 100644 index 00000000..9459e2c9 --- /dev/null +++ b/src/Common/ValidateReadonly.cs @@ -0,0 +1,75 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.Serialization; +using System.Text; + +namespace SurrealDB.Common; + +public abstract record ValidateReadonly : LatentReadonly { + /// Evaluates , if it yields any errors, throws a with the errors + protected void ValidateOrThrow(string? message = null) { + PoolArrayBuilder<(string, string)> errors = new(); + foreach (var error in Validations()) { + errors.Append(error); + } + + if (!errors.IsDefault) { + AggregatePropertyValidationException.Throw(message, errors.ToArray()); + } + } + + /// Performs a sequence of validation on properties. + /// If a validation fails, yields the PropertyName and the corresponding error message. + protected abstract IEnumerable<(string PropertyName, string Message)> Validations(); +} + +[Serializable] +public class AggregatePropertyValidationException : Exception { + public AggregatePropertyValidationException() { + } + + public AggregatePropertyValidationException(string? message) : base(message) { + } + + public AggregatePropertyValidationException(string? message, Exception? inner) : base(message, inner) { + } + + public AggregatePropertyValidationException(string? message, (string Property, string Error)[]? errors, Exception? inner) : base(message, inner) { + Errors = errors; + } + + protected AggregatePropertyValidationException( + SerializationInfo info, + StreamingContext context) : base(info, context) { + } + + public (string Property, string Error)[]? Errors { get; set; } + + public override string ToString() { + ValueStringBuilder sb = new(stackalloc char[512]); + Span<(string PropertyName, string Message)>.Enumerator en = Errors.AsSpan().GetEnumerator(); + sb.Append(Message ?? "Validation failed with the following errors:"); + if (en.MoveNext()) { + sb.Append(Environment.NewLine); + sb.Append("- `"); + sb.Append(en.Current.PropertyName); + sb.Append("`: "); + sb.Append(en.Current.Message); + } + while (en.MoveNext()) { + sb.Append(Environment.NewLine); + sb.Append("- `"); + sb.Append(en.Current.PropertyName); + sb.Append("`: "); + sb.Append(en.Current.Message); + } + + return sb.ToString(); + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + public static void Throw(string? message, (string, string)[]? errors = null, Exception? inner = default) { + throw new AggregatePropertyValidationException(message, errors, inner); + } +} diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs new file mode 100644 index 00000000..e18ca7fa --- /dev/null +++ b/src/Common/WebSocketExtensions.cs @@ -0,0 +1,22 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Common; + +internal static class WebSocketExtensions { + /// Throws a if the result is a close ack. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ThrowIfClose(this WebSocketReceiveResult result) { + if (result.CloseStatus is not null) { + ThrowConnectionClosed(); + } + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowConnectionClosed() { + throw new OperationCanceledException("Connection closed"); + } +} diff --git a/src/Common/WsStream.cs b/src/Common/WsStream.cs deleted file mode 100644 index 3a7473a7..00000000 --- a/src/Common/WsStream.cs +++ /dev/null @@ -1,155 +0,0 @@ -using System.Buffers; -using System.Diagnostics.CodeAnalysis; -using System.Net.WebSockets; - -namespace SurrealDB.Common; - -public sealed class WsStream : Stream { - private readonly IDisposable _prefixOwner; - /// - /// The prefix is the memory already obtained to be consumed before queries the socket - /// - private readonly ReadOnlyMemory _prefix; - private int _prefixConsumed; - - private readonly WebSocket _ws; - public bool EndOfMessage { get; private set; } = false; - - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => false; - public override long Length => ThrowSeekDisallowed(); - public override long Position { get => ThrowSeekDisallowed(); set => ThrowSeekDisallowed(); } - - public WsStream(IDisposable prefixOwner, ReadOnlyMemory prefix, WebSocket ws) { - _prefixOwner = prefixOwner; - _prefix = prefix; - _ws = ws; - } - - public override void Flush() { - // Readonly - } - - public override int Read(byte[] buffer, int offset, int count) { - return Read(buffer.AsSpan(offset, count)); - } - - /// - /// Use , or if possible. - /// - public override int Read(Span buffer) { - int read = 0; - // consume the prefix - ReadOnlySpan pref = ConsumePrefixAsSpan(buffer.Length); - if (!pref.IsEmpty) { - pref.CopyTo(buffer); - buffer = buffer.Slice(pref.Length); - read += pref.Length; - } - - if (buffer.IsEmpty) { - return read; - } - - using IMemoryOwner o = MemoryPool.Shared.Rent(buffer.Length); - Memory m = o.Memory.Slice(0, buffer.Length); - buffer.CopyTo(m.Span); - return read + ReadSync(m); - } - - /// - public int Read(Memory buffer) { - int read = 0; - // consume the prefix - ReadOnlySpan pref = ConsumePrefixAsSpan(buffer.Length); - if (!pref.IsEmpty) { - pref.CopyTo(buffer.Span); - buffer = buffer.Slice(pref.Length); - read += pref.Length; - } - - return read + ReadSync(buffer); - } - - private int ReadSync(Memory buffer) { - // This causes issues if the scheduler is exclusive. - return ReadAsync(buffer).Result; - } - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); - } - - public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - var pref = ConsumePrefix(buffer.Length); - if (!pref.IsEmpty) { - pref.CopyTo(buffer); - return pref.Length; - } - - int read = 0; - if (!EndOfMessage) { - ValueWebSocketReceiveResult rsp = await _ws.ReceiveAsync(buffer, cancellationToken); - read = rsp.Count; - - EndOfMessage = rsp.EndOfMessage; - } - - return read; - } - - private ReadOnlySpan ConsumePrefixAsSpan(int length) { - var prefix = ConsumePrefix(length); - return prefix.Span; - } - - private ReadOnlyMemory ConsumePrefix(int length) { - int len = _prefix.Length; - int con = _prefixConsumed; - if (con == len) { - return default; - } - int rem = len - con; - int inc = Math.Min(rem, length); - _prefixConsumed = con + inc; - return _prefix.Slice(con, inc); - } - - private void DisposePrefix() { - _prefixConsumed = _prefix.Length; - _prefixOwner.Dispose(); - } - - public override long Seek(long offset, SeekOrigin origin) { - return ThrowSeekDisallowed(); - } - - public override void SetLength(long value) { - ThrowWriteDisallowed(); - } - - public override void Write(byte[] buffer, int offset, int count) { - ThrowWriteDisallowed(); - } - - protected override void Dispose(bool disposing) { - base.Dispose(disposing); - DisposePrefix(); - } - - public override async ValueTask DisposeAsync() { - await base.DisposeAsync(); - DisposePrefix(); - } - - [DoesNotReturn] - private static void ThrowWriteDisallowed() { - throw new InvalidOperationException("Cannot write a readonly stream"); - } - - [DoesNotReturn] - private static long ThrowSeekDisallowed() { - throw new InvalidOperationException("Cannot seek in the stream"); - } -} diff --git a/src/Driver/Rest/RestClientExtensions.cs b/src/Driver/Rest/RestClientExtensions.cs index 350cdae1..45d9190e 100644 --- a/src/Driver/Rest/RestClientExtensions.cs +++ b/src/Driver/Rest/RestClientExtensions.cs @@ -5,7 +5,6 @@ using SurrealDB.Common; using SurrealDB.Json; -using SurrealDB.Models; using SurrealDB.Models.Result; using DriverResponse = SurrealDB.Models.Result.DriverResponse; diff --git a/src/Driver/Rpc/DatabaseRpc.cs b/src/Driver/Rpc/DatabaseRpc.cs index 29eca395..2b48e2f1 100644 --- a/src/Driver/Rpc/DatabaseRpc.cs +++ b/src/Driver/Rpc/DatabaseRpc.cs @@ -55,7 +55,7 @@ public async Task Open(CancellationToken ct = default) { // Open connection InvalidConfigException.ThrowIfNull(_config.RpcEndpoint); - await _client.Open(_config.RpcEndpoint!, ct); + await _client.OpenAsync(_config.RpcEndpoint!, ct); // Authenticate if (_config.Username != null && _config.Password != null) { @@ -70,7 +70,7 @@ public async Task Open(CancellationToken ct = default) { public async Task Close(CancellationToken ct = default) { _configured = false; - await _client.Close(ct); + await _client.CloseAsync(ct); } /// diff --git a/src/Driver/Rpc/RpcClientExtensions.cs b/src/Driver/Rpc/RpcClientExtensions.cs index 29dfaa8d..1788182b 100644 --- a/src/Driver/Rpc/RpcClientExtensions.cs +++ b/src/Driver/Rpc/RpcClientExtensions.cs @@ -1,9 +1,10 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using SurrealDB.Common; using SurrealDB.Json; -using SurrealDB.Models; using SurrealDB.Models.Result; using SurrealDB.Ws; @@ -12,7 +13,6 @@ namespace SurrealDB.Driver.Rpc; internal static class RpcClientExtensions { - internal static async Task ToSurreal(this Task rsp) => ToSurreal(await rsp); internal static DriverResponse ToSurreal(this WsClient.Response rsp){ if (rsp.id is null) { @@ -90,7 +90,7 @@ private static DriverResponse FromNestedStatus(in WsClient.Response rsp) { return DriverResponse.FromOwned(builder.AsSegment()); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static void ThrowIdMissing() { throw new InvalidOperationException("Response does not have an id."); } diff --git a/src/Json/Time/DateOnlyConv.cs b/src/Json/Time/DateOnlyConv.cs index 394093a6..b9331309 100644 --- a/src/Json/Time/DateOnlyConv.cs +++ b/src/Json/Time/DateOnlyConv.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -44,14 +46,14 @@ public static bool TryParse(string? s, out DateOnly value) { public static string ToString(in DateOnly value) { return $"{value.Year.ToString("D4")}-{value.Month.ToString("D2")}-{value.Day.ToString("D2")}"; } - - [DoesNotReturn] + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static DateOnly ThrowParseInvalid(string? s) { throw new ParseException($"Unable to parse DateOnly from `{s}`"); } - [DoesNotReturn] - private DateOnly ThrowJsonTokenTypeInvalid() { + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static DateOnly ThrowJsonTokenTypeInvalid() { throw new JsonException("Cannot deserialize a non string token as a DateOnly."); } -} \ No newline at end of file +} diff --git a/src/Json/Time/DateTimeConv.cs b/src/Json/Time/DateTimeConv.cs index 2491ee9c..3ff89051 100644 --- a/src/Json/Time/DateTimeConv.cs +++ b/src/Json/Time/DateTimeConv.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -49,13 +51,13 @@ public static string ToString(in DateTime value) { return value.ToString("O"); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static DateTime ThrowParseInvalid(string? s) { throw new ParseException($"Unable to parse DateTime from `{s}`"); } - - [DoesNotReturn] + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static DateTime ThrowJsonTokenTypeInvalid() { throw new JsonException("Cannot deserialize a non numeric non string token as a DateTime."); } diff --git a/src/Json/Time/DateTimeOffsetConv.cs b/src/Json/Time/DateTimeOffsetConv.cs index 0b93f4aa..041bf5a1 100644 --- a/src/Json/Time/DateTimeOffsetConv.cs +++ b/src/Json/Time/DateTimeOffsetConv.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -49,13 +51,13 @@ public static string ToString(in DateTimeOffset value) { return value.ToString("O"); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static DateTimeOffset ThrowParseInvalid(string? s) { throw new ParseException($"Unable to parse DateTimeOffset from `{s}`"); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static DateTimeOffset ThrowJsonTokenInvalid() { throw new JsonException("Cannot deserialize a non numeric non string token as a DateTime."); } -} \ No newline at end of file +} diff --git a/src/Json/Time/TimeOnlyConv.cs b/src/Json/Time/TimeOnlyConv.cs index ff91dfb0..c3219b5c 100644 --- a/src/Json/Time/TimeOnlyConv.cs +++ b/src/Json/Time/TimeOnlyConv.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -46,13 +48,13 @@ public static string ToString(in TimeOnly value) { return $"{value.Hour.ToString("D2")}:{value.Minute.ToString("D2")}:{value.Second.ToString("D2")}.{value.FractionString()}"; } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static TimeOnly ThrowParseInvalid(string? s) { throw new ParseException($"Unable to parse TimeOnly from `{s}`"); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private TimeOnly ThrowJsonTokenTypeInvalid() { throw new JsonException("Cannot deserialize a non string token as a TimeOnly."); } -} \ No newline at end of file +} diff --git a/src/Json/Time/TimeSpanConv.cs b/src/Json/Time/TimeSpanConv.cs index 238e3d46..0f101d4c 100644 --- a/src/Json/Time/TimeSpanConv.cs +++ b/src/Json/Time/TimeSpanConv.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; @@ -57,14 +59,14 @@ public static string ToString(in TimeSpan value) { return $"{value.Days}d{value.Hours}h{value.Minutes}m{value.Seconds}s{value.Milliseconds}ms"; } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static TimeSpan ThrowParseInvalid(string? s) { throw new ParseException($"Unable to parse TimeSpan from `{s}`"); } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static TimeSpan ThrowJsonTokenTypeInvalid() { throw new JsonException("Cannot deserialize a non numeric non string token as a TimeSpan."); } -} \ No newline at end of file +} diff --git a/src/Models/Result/ResultContentException.cs b/src/Models/Result/ResultContentException.cs index 590147d3..0b3bb4f5 100644 --- a/src/Models/Result/ResultContentException.cs +++ b/src/Models/Result/ResultContentException.cs @@ -1,4 +1,6 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Runtime.Serialization; namespace SurrealDB.Models.Result; @@ -16,15 +18,15 @@ public ResultContentException(string? message) : base(message) { public ResultContentException(string? message, Exception? innerException) : base(message, innerException) { } - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] public static ErrorResult ExpectedAnyError() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain any {nameof(ErrorResult)}"); - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] public static OkResult ExpectedAnyOk() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain any {nameof(OkResult)}"); - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] public static ErrorResult ExpectedSingleError() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain exactly one {nameof(ErrorResult)}"); - [DoesNotReturn] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] public static OkResult ExpectedSingleOk() => throw new ResultContentException($"The {nameof(Result.DriverResponse)} does not contain exactly one {nameof(OkResult)}"); } diff --git a/src/Models/Result/ResultValue.cs b/src/Models/Result/ResultValue.cs index dbea56fd..14f8b179 100644 --- a/src/Models/Result/ResultValue.cs +++ b/src/Models/Result/ResultValue.cs @@ -259,7 +259,7 @@ public Kind GetKind() { return Kind.None; } - [DoesNotReturn, DebuggerStepThrough,] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static ResultValue ThrowUnknownJsonValueKind(JsonElement json) { throw new ArgumentOutOfRangeException(nameof(json), json.ValueKind, "Unknown value kind."); } @@ -386,7 +386,7 @@ public override string ToString() { return Inner.ToString(); } - [DoesNotReturn, DebuggerStepThrough,] + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] private static int ThrowInvalidCompareTypes() { throw new InvalidOperationException("Cannot compare SurrealResult of different types, if one or more is not numeric.."); } diff --git a/src/Models/Thing.cs b/src/Models/Thing.cs index 813c51ef..9b2ec3af 100644 --- a/src/Models/Thing.cs +++ b/src/Models/Thing.cs @@ -1,5 +1,3 @@ -using SurrealDB.Json; - using System.Diagnostics; using System.Diagnostics.Contracts; using System.Runtime.CompilerServices; @@ -7,8 +5,6 @@ using System.Text.Json; using System.Text.Json.Serialization; -using SurrealDB.Common; - namespace SurrealDB.Models; /// diff --git a/src/Ws/Handler.cs b/src/Ws/Handler.cs index cc36fe91..c869ad1f 100644 --- a/src/Ws/Handler.cs +++ b/src/Ws/Handler.cs @@ -6,11 +6,11 @@ internal interface IHandler : IDisposable { public bool Persistent { get; } - public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm); + public void Dispatch(WsHeaderWithMessage header); } internal sealed class ResponseHandler : IHandler { - private readonly TaskCompletionSource<(WsTx.RspHeader, WsTx.NtyHeader, Stream)> _tcs = new(); + private readonly TaskCompletionSource _tcs = new(); private readonly string _id; private readonly CancellationToken _ct; @@ -19,14 +19,15 @@ public ResponseHandler(string id, CancellationToken ct) { _ct = ct; } - public Task<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> Task => _tcs!.Task; + public Task Task => _tcs!.Task; public string Id => _id; public bool Persistent => false; - public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm) { - _tcs.SetResult((rsp, nty, stm)); + public void Dispatch(WsHeaderWithMessage header) { + _ct.ThrowIfCancellationRequested(); + _tcs.SetResult(header); } public void Dispose() { @@ -35,11 +36,11 @@ public void Dispose() { } -internal class NotificationHandler : IHandler, IAsyncEnumerable<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> { - private readonly Ws _mediator; +internal class NotificationHandler : IHandler, IAsyncEnumerable { + private WsReceiverDeflater _mediator; private readonly CancellationToken _ct; - private TaskCompletionSource<(WsTx.RspHeader, WsTx.NtyHeader, Stream)> _tcs = new(); - public NotificationHandler(Ws mediator, string id, CancellationToken ct) { + private TaskCompletionSource _tcs = new(); + public NotificationHandler(WsReceiverDeflater mediator, string id, CancellationToken ct) { _mediator = mediator; Id = id; _ct = ct; @@ -48,8 +49,9 @@ public NotificationHandler(Ws mediator, string id, CancellationToken ct) { public string Id { get; } public bool Persistent => true; - public void Handle(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm) { - _tcs.SetResult((rsp, nty, stm)); + public void Dispatch(WsHeaderWithMessage header) { + _ct.ThrowIfCancellationRequested(); + _tcs.SetResult(header); _tcs = new(); } @@ -57,9 +59,9 @@ public void Dispose() { _tcs.TrySetCanceled(); } - public async IAsyncEnumerator<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> GetAsyncEnumerator(CancellationToken cancellationToken = default) { + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { while (!_ct.IsCancellationRequested) { - (WsTx.RspHeader, WsTx.NtyHeader, Stream) res; + WsHeaderWithMessage res; try { res = await _tcs.Task; } catch (OperationCanceledException) { @@ -71,7 +73,7 @@ public void Dispose() { // unregister before throwing if (_ct.IsCancellationRequested) { - _mediator.Unregister(this); + _mediator.Unregister(Id); } } } diff --git a/src/Ws/HeaderHelper.cs b/src/Ws/HeaderHelper.cs new file mode 100644 index 00000000..8c9fc8f0 --- /dev/null +++ b/src/Ws/HeaderHelper.cs @@ -0,0 +1,296 @@ +using System.Text.Json; + +using SurrealDB.Common; +using SurrealDB.Json; + +namespace SurrealDB.Ws; + +public static class HeaderHelper { + /// Generates a random base64 string of the length specified. + public static string GetRandomId(int length) { + Span buf = stackalloc byte[length]; + ThreadRng.Shared.NextBytes(buf); + return Convert.ToBase64String(buf); + } + + public static WsHeader Parse(ReadOnlySpan utf8) { + var (rsp, rspOff, rspErr) = RspHeader.Parse(utf8); + if (rspErr is null) { + return new(rsp, default, (int)rspOff); + } + var (nty, ntyOff, ntyErr) = NtyHeader.Parse(utf8); + if (ntyErr is null) { + return new(default, nty, (int)ntyOff); + } + + return default; + } +} + +public readonly record struct WsHeader(RspHeader Response, NtyHeader Notify, int BytesLength) { + public string? Id => (Response.IsDefault, Notify.IsDefault) switch { + (true, false) => Notify.id, + (false, true) => Response.id, + _ => null + }; +} + +public readonly record struct WsHeaderWithMessage(WsHeader Header, WsReceiverMessageReader Reader) : IDisposable { + public void Dispose() { + Reader.Dispose(); + } +} + +public readonly record struct NtyHeader(string? id, string? method, WsClient.Error err) { + public bool IsDefault => default == this; + + /// + /// Parses the head including the result propertyname, excluding the result array. + /// + internal static (NtyHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) { + Fsm fsm = new() { + Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })), + State = Fsms.Start, + }; + while (fsm.MoveNext()) {} + + if (!fsm.Success) { + return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}"); + } + return (new(fsm.Id, fsm.Method, fsm.Error), default, default); + } + + private enum Fsms { + Start, // -> Prop + Prop, // -> PropId | PropAsync | PropMethod | ProsResult + PropId, // -> Prop | End + PropMethod, // -> Prop | End + PropError, // -> End + PropParams, // -> End + End + } + + private ref struct Fsm { + public Fsms State; + public Utf8JsonReader Lexer; + public string? Err; + public bool Success; + + public string? Name; + public string? Id; + public WsClient.Error Error; + public string? Method; + + public bool MoveNext() { + return State switch { + Fsms.Start => Start(), + Fsms.Prop => Prop(), + Fsms.PropId => PropId(), + Fsms.PropMethod => PropMethod(), + Fsms.PropError => PropError(), + Fsms.PropParams => PropParams(), + Fsms.End => End(), + _ => false + }; + } + + private bool Start() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) { + Err = "Unable to read token StartObject"; + return false; + } + + State = Fsms.Prop; + return true; + + } + + private bool End() { + Success = !String.IsNullOrEmpty(Id) && !String.IsNullOrEmpty(Method); + return false; + } + + private bool Prop() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) { + Err = "Unable to read PropertyName"; + return false; + } + + Name = Lexer.GetString(); + if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropId; + return true; + } + if ("method".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropMethod; + return true; + } + if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropError; + return true; + } + if ("params".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropParams; + return true; + } + + Err = $"Unknown PropertyName `{Name}`"; + return false; + } + + private bool PropId() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { + Err = "Unable to read `id` property value"; + return false; + } + + State = Fsms.Prop; + Id = Lexer.GetString(); + return true; + } + + private bool PropError() { + Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared); + State = Fsms.End; + return true; + } + + private bool PropMethod() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { + Err = "Unable to read `method` property value"; + return false; + } + + State = Fsms.Prop; + Method = Lexer.GetString(); + return true; + } + + private bool PropParams() { + // Do not parse the result! + // The complete result is not present in the buffer! + // The result is returned as a unevaluated asynchronous stream! + State = Fsms.End; + return true; + } + } +} + +public readonly record struct RspHeader(string? id, WsClient.Error err) { + public bool IsDefault => default == this; + + /// + /// Parses the head including the result propertyname, excluding the result array. + /// + internal static (RspHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) { + Fsm fsm = new() { + Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })), + State = Fsms.Start, + }; + while (fsm.MoveNext()) {} + + if (!fsm.Success) { + return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}"); + } + return (new(fsm.Id, fsm.Error), default, default); + } + + private enum Fsms { + Start, // -> Prop + Prop, // -> PropId | PropError | ProsResult + PropId, // -> Prop | End + PropError, // -> End + PropResult, // -> End + End + } + + private ref struct Fsm { + public Fsms State; + public Utf8JsonReader Lexer; + public string? Err; + public bool Success; + + public string? Name; + public string? Id; + public WsClient.Error Error; + + public bool MoveNext() { + return State switch { + Fsms.Start => Start(), + Fsms.Prop => Prop(), + Fsms.PropId => PropId(), + Fsms.PropError => PropError(), + Fsms.PropResult => PropResult(), + Fsms.End => End(), + _ => false + }; + } + + private bool Start() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) { + Err = "Unable to read token StartObject"; + return false; + } + + State = Fsms.Prop; + return true; + + } + + private bool End() { + Success = !String.IsNullOrEmpty(Id); + return false; + } + + private bool Prop() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) { + Err = "Unable to read PropertyName"; + return false; + } + + Name = Lexer.GetString(); + if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropId; + return true; + } + if ("result".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropResult; + return true; + } + if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) { + State = Fsms.PropError; + return true; + } + + Err = $"Unknown PropertyName `{Name}`"; + return false; + } + + private bool PropId() { + if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { + Err = "Unable to read `id` property value"; + return false; + } + + State = Fsms.Prop; + Id = Lexer.GetString(); + return true; + } + + private bool PropError() { + Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared); + State = Fsms.End; + return true; + } + + + private bool PropResult() { + // Do not parse the result! + // The complete result is not present in the buffer! + // The result is returned as a unevaluated asynchronous stream! + State = Fsms.End; + return true; + } + } +} + diff --git a/src/Ws/Helpers/BoundedChannelPool.cs b/src/Ws/Helpers/BoundedChannelPool.cs new file mode 100644 index 00000000..010e09e2 --- /dev/null +++ b/src/Ws/Helpers/BoundedChannelPool.cs @@ -0,0 +1,475 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading.Channels; + +// ReSharper disable once CheckNamespace +namespace System.Buffers; + +public abstract class BoundedChannelPool { + private static readonly TlsOverPerCoreLockedStacksBoundedChannelPool s_boundedShared = new(); + + public static BoundedChannelPool Shared => s_boundedShared; + + public abstract BoundedChannel Rent(int minimumLength); + + public abstract void Return(BoundedChannel channel); +} + +public sealed class BoundedChannel : Channel, IDisposable { + private BoundedChannelPool? _owner; + + public BoundedChannel(Channel wrapped, int capacity, BoundedChannelPool owner) { + Reader = wrapped.Reader; + Writer = wrapped.Writer; + Capacity = capacity; + _owner = owner; + } + + public int Capacity { get; } + + public void Dispose() { + var owner = _owner; + _owner = null; + if (owner is not null) { + owner.Return(this); + } + } +} + +// Source: https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Buffers/TlsOverPerCoreLockedStacksArrayPool.cs +// modified for use with channels +public sealed class TlsOverPerCoreLockedStacksBoundedChannelPool : BoundedChannelPool { + /// The number of buckets (array sizes) in the pool, one for each array length, starting from length 16. + private const int NumBuckets = 27; // GCHelper.SelectBucketIndex(1024 * 1024 * 1024 + 1) + /// Maximum number of per-core stacks to use per array size. + private const int MaxPerCorePerArraySizeStacks = 64; // selected to avoid needing to worry about processor groups + /// The maximum number of buffers to store in a bucket's global queue. + private const int MaxBuffersPerArraySizePerCore = 8; + + /// A per-thread array of arrays, to cache one array per array size per thread. + [ThreadStatic] + private static ThreadLocalArray[]? t_tlsBuckets; + /// Used to keep track of all thread local buckets for trimming if needed. + private readonly ConditionalWeakTable _allTlsBuckets = new ConditionalWeakTable(); + /// + /// An array of per-core array stacks. The slots are lazily initialized to avoid creating + /// lots of overhead for unused array sizes. + /// + private readonly PerCoreLockedStacks?[] _buckets = new PerCoreLockedStacks[NumBuckets]; + /// Whether the callback to trim arrays in response to memory pressure has been created. + private int _trimCallbackCreated; + + /// Allocate a new PerCoreLockedStacks and try to store it into the array. + private PerCoreLockedStacks CreatePerCoreLockedStacks(int bucketIndex) + { + var inst = new PerCoreLockedStacks(); + return Interlocked.CompareExchange(ref _buckets[bucketIndex], inst, null) ?? inst; + } + + /// Gets an ID for the pool to use with events. + private int Id => GetHashCode(); + + public override BoundedChannel Rent(int minimumLength) + { + BoundedChannel? buffer; + + // Get the bucket number for the array length. The result may be out of range of buckets, + // either for too large a value or for 0 and negative values. + int bucketIndex = GCHelper.SelectBucketIndex(minimumLength); + + // First, try to get an array from TLS if possible. + ThreadLocalArray[]? tlsBuckets = t_tlsBuckets; + if (tlsBuckets is not null && (uint)bucketIndex < (uint)tlsBuckets.Length) + { + buffer = tlsBuckets[bucketIndex].Channel; + if (buffer is not null) + { + tlsBuckets[bucketIndex].Channel = null; + return buffer; + } + } + + // Next, try to get an array from one of the per-core stacks. + PerCoreLockedStacks?[] perCoreBuckets = _buckets; + if ((uint)bucketIndex < (uint)perCoreBuckets.Length) + { + PerCoreLockedStacks? b = perCoreBuckets[bucketIndex]; + if (b is not null) + { + buffer = b.TryPop(); + if (buffer is not null) + { + return buffer; + } + } + + // No buffer available. Ensure the length we'll allocate matches that of a bucket + // so we can later return it. + minimumLength = GCHelper.GetMaxSizeForBucket(bucketIndex); + } + else if (minimumLength <= 0) + { + throw new ArgumentOutOfRangeException(nameof(minimumLength)); + } + + // allocate a new bounded channel, that belongs to this instance + buffer = new(Channel.CreateBounded(minimumLength), minimumLength, this); + + return buffer; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override void Return(BoundedChannel? channel) { + if (channel is null) { + ThrowChannelNull(); + } + + if (channel.Reader.Completion.IsCompleted) { + ThrowChannelCompleted(); + } + + // Determine with what bucket this array length is associated + int bucketIndex = GCHelper.SelectBucketIndex(channel.Capacity); + + // Make sure our TLS buckets are initialized. Technically we could avoid doing + // this if the array being returned is erroneous or too large for the pool, but the + // former condition is an error we don't need to optimize for, and the latter is incredibly + // rare, given a max size of 1B elements. + ThreadLocalArray[] tlsBuckets = t_tlsBuckets ?? InitializeTlsBucketsAndTrimming(); + + bool haveBucket = false; + bool returned = true; + if ((uint)bucketIndex < (uint)tlsBuckets.Length) + { + haveBucket = true; + + // Check to see if the buffer is the correct size for this bucket. + if (channel.Capacity != GCHelper.GetMaxSizeForBucket(bucketIndex)) { + ThrowChannelNotOfPool(); + } + + // Store the array into the TLS bucket. If there's already an array in it, + // push that array down into the per-core stacks, preferring to keep the latest + // one in TLS for better locality. + ref ThreadLocalArray tla = ref tlsBuckets[bucketIndex]; + BoundedChannel? prev = tla.Channel; + tla = new ThreadLocalArray(channel); + if (prev is not null) + { + PerCoreLockedStacks stackBucket = _buckets[bucketIndex] ?? CreatePerCoreLockedStacks(bucketIndex); + returned = stackBucket.TryPush(prev); + } + } + } + + public bool Trim() + { + int currentMilliseconds = Environment.TickCount; + GCHelper.MemoryPressure pressure = GCHelper.GetMemoryPressure(); + + // Trim each of the per-core buckets. + PerCoreLockedStacks?[] perCoreBuckets = _buckets; + for (int i = 0; i < perCoreBuckets.Length; i++) + { + perCoreBuckets[i]?.Trim(currentMilliseconds, Id, pressure, GCHelper.GetMaxSizeForBucket(i)); + } + + // Trim each of the TLS buckets. Note that threads may be modifying their TLS slots concurrently with + // this trimming happening. We do not force synchronization with those operations, so we accept the fact + // that we may end up firing a trimming event even if an array wasn't trimmed, and potentially + // trim an array we didn't need to. Both of these should be rare occurrences. + + // Under high pressure, release all thread locals. + if (pressure == GCHelper.MemoryPressure.High) { + foreach (KeyValuePair tlsBuckets in _allTlsBuckets) + { +#if NET6_0_OR_GREATER + Array.Clear(tlsBuckets.Key); +#else + tlsBuckets.Key.AsSpan().Clear(); +#endif + } + } + else + { + // Otherwise, release thread locals based on how long we've observed them to be stored. This time is + // approximate, with the time set not when the array is stored but when we see it during a Trim, so it + // takes at least two Trim calls (and thus two gen2 GCs) to drop an array, unless we're in high memory + // pressure. These values have been set arbitrarily; we could tune them in the future. + uint millisecondsThreshold = pressure switch + { + GCHelper.MemoryPressure.Medium => 15_000, + _ => 30_000, + }; + + foreach (KeyValuePair tlsBuckets in _allTlsBuckets) + { + ThreadLocalArray[] buckets = tlsBuckets.Key; + for (int i = 0; i < buckets.Length; i++) + { + if (buckets[i].Channel is null) + { + continue; + } + + // We treat 0 to mean it hasn't yet been seen in a Trim call. In the very rare case where Trim records 0, + // it'll take an extra Trim call to remove the array. + int lastSeen = buckets[i].MillisecondsTimeStamp; + if (lastSeen == 0) + { + buckets[i].MillisecondsTimeStamp = currentMilliseconds; + } + else if ((currentMilliseconds - lastSeen) >= millisecondsThreshold) + { + // Time noticeably wrapped, or we've surpassed the threshold. + // Clear out the array, and log its being trimmed if desired. + Interlocked.Exchange(ref buckets[i].Channel, null); + } + } + } + } + + return true; + } + + private ThreadLocalArray[] InitializeTlsBucketsAndTrimming() + { + Debug.Assert(t_tlsBuckets is null, $"Non-null {nameof(t_tlsBuckets)}"); + + var tlsBuckets = new ThreadLocalArray[NumBuckets]; + t_tlsBuckets = tlsBuckets; + + _allTlsBuckets.Add(tlsBuckets, null); + if (Interlocked.Exchange(ref _trimCallbackCreated, 1) == 0) + { + Gen2GcCallback.Register(s => ((TlsOverPerCoreLockedStacksBoundedChannelPool)s).Trim(), this); + } + + return tlsBuckets; + } + + /// Stores a set of stacks of arrays, with one stack per core. + private sealed class PerCoreLockedStacks + { + /// Number of locked stacks to employ. + private static readonly int s_lockedStackCount = Math.Min(Environment.ProcessorCount, MaxPerCorePerArraySizeStacks); + /// The stacks. + private readonly LockedStack[] _perCoreStacks; + + /// Initializes the stacks. + public PerCoreLockedStacks() + { + // Create the stacks. We create as many as there are processors, limited by our max. + var stacks = new LockedStack[s_lockedStackCount]; + for (int i = 0; i < stacks.Length; i++) + { + stacks[i] = new LockedStack(); + } + _perCoreStacks = stacks; + } + + /// Try to push the array into the stacks. If each is full when it's tested, the array will be dropped. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryPush(BoundedChannel array) + { + // Try to push on to the associated stack first. If that fails, + // round-robin through the other stacks. + LockedStack[] stacks = _perCoreStacks; + int index = (int)((uint)Thread.GetCurrentProcessorId() % (uint)s_lockedStackCount); // mod by constant in tier 1 + for (int i = 0; i < stacks.Length; i++) + { + if (stacks[index].TryPush(array)) return true; + if (++index == stacks.Length) index = 0; + } + + return false; + } + + /// Try to get an array from the stacks. If each is empty when it's tested, null will be returned. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public BoundedChannel? TryPop() + { + // Try to pop from the associated stack first. If that fails, round-robin through the other stacks. + BoundedChannel? arr; + LockedStack[] stacks = _perCoreStacks; + int index = (int)((uint)Thread.GetCurrentProcessorId() % (uint)s_lockedStackCount); // mod by constant in tier 1 + for (int i = 0; i < stacks.Length; i++) + { + if ((arr = stacks[index].TryPop()) is not null) return arr; + if (++index == stacks.Length) index = 0; + } + return null; + } + + public void Trim(int currentMilliseconds, int id, GCHelper.MemoryPressure pressure, int bucketSize) + { + LockedStack[] stacks = _perCoreStacks; + for (int i = 0; i < stacks.Length; i++) + { + stacks[i].Trim(currentMilliseconds, id, pressure, bucketSize); + } + } + } + + /// Provides a simple, bounded stack of arrays, protected by a lock. + private sealed class LockedStack + { + /// The arrays in the stack. + private readonly BoundedChannel?[] _arrays = new BoundedChannel[MaxBuffersPerArraySizePerCore]; + /// Number of arrays stored in . + private int _count; + /// Timestamp set by Trim when it sees this as 0. + private int _millisecondsTimestamp; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryPush(BoundedChannel array) + { + bool enqueued = false; + Monitor.Enter(this); + BoundedChannel?[] arrays = _arrays; + int count = _count; + if ((uint)count < (uint)arrays.Length) + { + if (count == 0) + { + // Reset the time stamp now that we're transitioning from empty to non-empty. + // Trim will see this as 0 and initialize it to the current time when Trim is called. + _millisecondsTimestamp = 0; + } + + arrays[count] = array; + _count = count + 1; + enqueued = true; + } + Monitor.Exit(this); + return enqueued; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public BoundedChannel? TryPop() + { + BoundedChannel? arr = null; + Monitor.Enter(this); + BoundedChannel?[] arrays = _arrays; + int count = _count - 1; + if ((uint)count < (uint)arrays.Length) + { + arr = arrays[count]; + arrays[count] = null; + _count = count; + } + Monitor.Exit(this); + return arr; + } + + public void Trim(int currentMilliseconds, int id, GCHelper.MemoryPressure pressure, int bucketSize) + { + const int StackTrimAfterMS = 60 * 1000; // Trim after 60 seconds for low/moderate pressure + const int StackHighTrimAfterMS = 10 * 1000; // Trim after 10 seconds for high pressure + const int StackLowTrimCount = 1; // Trim one item when pressure is low + const int StackMediumTrimCount = 2; // Trim two items when pressure is moderate + const int StackHighTrimCount = MaxBuffersPerArraySizePerCore; // Trim all items when pressure is high + const int StackLargeBucket = 16384; // If the bucket is larger than this we'll trim an extra when under high pressure + const int StackModerateTypeSize = 16; // If T is larger than this we'll trim an extra when under high pressure + const int StackLargeTypeSize = 32; // If T is larger than this we'll trim an extra (additional) when under high pressure + + if (_count == 0) + { + return; + } + + int trimMilliseconds = pressure == GCHelper.MemoryPressure.High ? StackHighTrimAfterMS : StackTrimAfterMS; + + lock (this) + { + if (_count == 0) + { + return; + } + + if (_millisecondsTimestamp == 0) + { + _millisecondsTimestamp = currentMilliseconds; + return; + } + + if ((currentMilliseconds - _millisecondsTimestamp) <= trimMilliseconds) + { + return; + } + + // We've elapsed enough time since the first item went into the stack. + // Drop the top item so it can be collected and make the stack look a little newer. + + int trimCount = StackLowTrimCount; + switch (pressure) + { + case GCHelper.MemoryPressure.High: + trimCount = StackHighTrimCount; + + // When pressure is high, aggressively trim larger arrays. + if (bucketSize > StackLargeBucket) + { + trimCount++; + } + if (Unsafe.SizeOf() > StackModerateTypeSize) + { + trimCount++; + } + if (Unsafe.SizeOf() > StackLargeTypeSize) + { + trimCount++; + } + break; + + case GCHelper.MemoryPressure.Medium: + trimCount = StackMediumTrimCount; + break; + } + + while (_count > 0 && trimCount-- > 0) + { + BoundedChannel? array = _arrays[--_count]; + Debug.Assert(array is not null, "No nulls should have been present in slots < _count."); + _arrays[_count] = null; + } + + _millisecondsTimestamp = _count > 0 ? + _millisecondsTimestamp + (trimMilliseconds / 4) : // Give the remaining items a bit more time + 0; + } + } + } + + /// Wrapper for arrays stored in ThreadStatic buckets. + private struct ThreadLocalArray + { + /// The stored array. + public BoundedChannel? Channel; + /// Environment.TickCount timestamp for when this array was observed by Trim. + public int MillisecondsTimeStamp; + + public ThreadLocalArray(BoundedChannel channel) + { + Channel = channel; + MillisecondsTimeStamp = 0; + } + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowChannelNull() { + throw new ArgumentNullException("channel"); + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowChannelNotOfPool() { + throw new ArgumentException("The channel does not belong to the bool", "channel"); + } + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowChannelCompleted() { + throw new ArgumentException("Cannot add a completed channel to the pool", "channel"); + } +} + diff --git a/src/Ws/Ws.cs b/src/Ws/Ws.cs deleted file mode 100644 index c52f8839..00000000 --- a/src/Ws/Ws.cs +++ /dev/null @@ -1,131 +0,0 @@ -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; - -using SurrealDB.Common; - -namespace SurrealDB.Ws; - -public sealed class Ws : IDisposable, IAsyncDisposable { - private readonly CancellationTokenSource _cts = new(); - private readonly WsTx _tx = new(); - private readonly ConcurrentDictionary _handlers = new(); - private Task _recv = Task.CompletedTask; - - public bool Connected => _tx.Connected; - - public async Task Open(Uri remote, CancellationToken ct = default) { - await _tx.Open(remote, ct); - _recv = Task.Run(async () => await Receive(_cts.Token), _cts.Token); - } - - public async Task Close(CancellationToken ct = default) { - Task t1 = _tx.Close(ct); - Task t2 = Task.Run(ClearHandlers, ct); - _cts.Cancel(); - - await t1; - await t2; - } - - /// - /// Sends the request and awaits a response from the server - /// - public async Task<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> RequestOnce(string id, Stream request, CancellationToken ct = default) { - ResponseHandler handler = new(id, ct); - Register(handler); - await _tx.Tw(request, ct); - return await handler.Task; - } - - /// - /// Sends the request and awaits responses from the server until manually canceled using the cancellation token - /// - public async IAsyncEnumerable<(WsTx.RspHeader rsp, WsTx.NtyHeader nty, Stream stm)> RequestPersists(string id, Stream request, [EnumeratorCancellation] CancellationToken ct = default) { - NotificationHandler handler = new(this, id, ct); - Register(handler); - await _tx.Tw(request, ct); - await foreach (var res in handler) { - yield return res; - } - } - - internal void Register(IHandler handler) { - if (!_handlers.TryAdd(handler.Id, handler)) { - ThrowDuplicateId(handler.Id); - } - } - - internal void Unregister(IHandler handler) { - if (!_handlers.TryRemove(handler.Id, out var h)) { - return; - } - - h.Dispose(); - } - - private async Task Receive(CancellationToken stoppingToken) { - while (!stoppingToken.IsCancellationRequested) { - var (id, response, notify, stream) = await _tx.Tr(stoppingToken); - - stoppingToken.ThrowIfCancellationRequested(); - - if (String.IsNullOrEmpty(id)) { - continue; // Invalid response - } - - if (!_handlers.TryGetValue(id, out IHandler? handler)) { - // assume that unhandled responses belong to other clients - // discard! - await stream.DisposeAsync(); - continue; - } - - handler.Handle(response, notify, stream); - - if (!handler.Persistent) { - // persistent handlers are for notifications and are not removed automatically - Unregister(handler); - } - - if (stream is WsStream wsStream) { - while (!wsStream.EndOfMessage) { - await Task.Delay(13, stoppingToken); - } - } - } - } - - private void ClearHandlers() { - foreach (var handler in _handlers.Values) { - Unregister(handler); - } - } - - public void Dispose() { - try { - Close().Wait(); - _tx.Dispose(); - } catch (OperationCanceledException) { - // expected - } catch (AggregateException) { - // wrapping OperationCanceledException - } - } - - public async ValueTask DisposeAsync() { - try { - await Close(); - _tx.Dispose(); - } catch (OperationCanceledException) { - // expected - } catch (AggregateException) { - // wrapping OperationCanceledException for async - } - } - - [DoesNotReturn] - private static void ThrowDuplicateId(string id) { - throw new ArgumentOutOfRangeException(nameof(id), $"A request with the Id `{id}` is already registered"); - } -} diff --git a/src/Ws/Ws.csproj b/src/Ws/Ws.csproj index dd553b1f..ba51b4d5 100644 --- a/src/Ws/Ws.csproj +++ b/src/Ws/Ws.csproj @@ -2,6 +2,7 @@ SurrealDB.Ws + true @@ -20,6 +21,6 @@ + - diff --git a/src/Ws/WsClient.cs b/src/Ws/WsClient.cs index 661584d2..d37b9edc 100644 --- a/src/Ws/WsClient.cs +++ b/src/Ws/WsClient.cs @@ -1,6 +1,10 @@ +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading.Channels; using Microsoft.IO; @@ -9,83 +13,109 @@ namespace SurrealDB.Ws; -/// -/// The client used to connect to the Surreal server via JSON RPC. -/// -public sealed class WsClient : IDisposable, IAsyncDisposable { - private static readonly Lazy s_manager = new(static () => new()); +/// The client used to connect to the Surreal server via JSON RPC. +public sealed class WsClient : IDisposable { // Do not get any funny ideas and fill this fucker up. - public static readonly List EmptyList = new(); + private static readonly List s_emptyList = new(); - private readonly Ws _ws = new(); + private readonly ClientWebSocket _ws = new(); + private readonly RecyclableMemoryStreamManager _memoryManager; + private readonly WsTransmitter _transmitter; + private readonly WsReceiverDeflater _deflater; + private readonly WsReceiverInflater _inflater; - /// - /// Indicates whether the client is connected or not. - /// - public bool Connected => _ws.Connected; + private readonly int _idBytes; - /// - /// Generates a random base64 string of the length specified. - /// - public static string GetRandomId(int length) { - Span buf = stackalloc byte[length]; - ThreadRng.Shared.NextBytes(buf); - return Convert.ToBase64String(buf); + public WsClient() + : this(WsClientOptions.Default) { } - /// - /// Opens the connection to the Surreal server. - /// - public async Task Open(Uri url, CancellationToken ct = default) { + public WsClient(WsClientOptions options) { + options.ValidateAndMakeReadonly(); + _memoryManager = options.MemoryManager; + _transmitter = new(_ws, _memoryManager.BlockSize); + var tx = Channel.CreateBounded(options.TxChannelCapacity); + _deflater = new(tx.Reader, options.ReceiveHeaderBytesMax, options.RequestExpiration, TimeSpan.FromSeconds(1)); + _inflater = new(_ws, tx.Writer, _memoryManager, _memoryManager.BlockSize, options.MessageChannelCapacity); + + _idBytes = options.IdBytes; + } + + /// Indicates whether the client is connected or not. + public bool Connected => _ws.State == WebSocketState.Open; + + public WebSocketState State => _ws.State; + + /// Opens the connection to the Surreal server. + public async Task OpenAsync(Uri url, CancellationToken ct = default) { ThrowIfConnected(); - await _ws.Open(url, ct); + await _ws.ConnectAsync(url, ct).Inv(); + _deflater.Open(); + _inflater.Open(); } /// /// Closes the connection to the Surreal server. /// - public async Task Close(CancellationToken ct = default) { - await _ws.Close(ct); + public async Task CloseAsync(CancellationToken ct = default) { + ThrowIfDisconnected(); + await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client connection closed orderly", ct).Inv(); + await _deflater.CloseAsync().Inv(); + await _inflater.CloseAsync().Inv(); } /// public void Dispose() { + _deflater.Dispose(); + _inflater.Dispose(); _ws.Dispose(); } - /// - public ValueTask DisposeAsync() { - return _ws.DisposeAsync(); - } - /// /// Sends the specified request to the Surreal server, and returns the response. /// public async Task Send(Request req, CancellationToken ct = default) { ThrowIfDisconnected(); - req.id ??= GetRandomId(6); - req.parameters ??= EmptyList; - - await using RecyclableMemoryStream stream = new(s_manager.Value); + req.id ??= HeaderHelper.GetRandomId(_idBytes); + req.parameters ??= s_emptyList; - await JsonSerializer.SerializeAsync(stream, req, SerializerOptions.Shared, ct); - // Now Position = Length = EndOfMessage - // Write the buffer to the websocket - stream.Position = 0; - var (rsp, nty, stm) = await _ws.RequestOnce(req.id, stream, ct); - if (!nty.IsDefault) { - ThrowExpectRspGotNty(); + // listen for the response + ResponseHandler handler = new(req.id, ct); + if (!_deflater.RegisterOrGet(handler)) { + return default; } - - if (rsp.IsDefault) { - ThrowRspDefault(); + // send request + var stream = await SerializeAsync(req, ct).Inv(); + await _transmitter.SendAsync(stream, ct).Inv(); + // await response, dispose message when done + using var response = await handler.Task.Inv(); + // validate header + var responseHeader = response.Header.Response; + if (!response.Header.Notify.IsDefault) { + ThrowExpectResponseGotNotify(); + } + if (responseHeader.IsDefault) { + ThrowInvalidResponse(); } - var bdy = await JsonSerializer.DeserializeAsync(stm, SerializerOptions.Shared, ct); - if (bdy is null) { - ThrowRspDefault(); + // position stream beyond header and deserialize message body + response.Reader.Position = response.Header.BytesLength; + // deserialize body + var body = await JsonSerializer.DeserializeAsync(response.Reader, SerializerOptions.Shared, ct).Inv(); + if (body is null) { + ThrowInvalidResponse(); } - return new(rsp.id, rsp.err, ExtractResult(bdy)); + + return new(responseHeader.id, responseHeader.err, ExtractResult(body)); + } + + private async Task SerializeAsync(Request req, CancellationToken ct) { + RecyclableMemoryStream stream = new(_memoryManager); + + await JsonSerializer.SerializeAsync(stream, req, SerializerOptions.Shared, ct).Inv(); + // position = Length = EndOfMessage -> position = 0 + stream.Position = 0; + return stream; } private static JsonElement ExtractResult(JsonDocument root) { @@ -104,13 +134,13 @@ private void ThrowIfConnected() { } } - [DoesNotReturn] - private static void ThrowExpectRspGotNty() { + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowExpectResponseGotNotify() { throw new InvalidOperationException("Expected a response, got a notification"); } - [DoesNotReturn] - private static void ThrowRspDefault() { + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowInvalidResponse() { throw new InvalidOperationException("Invalid response"); } @@ -140,6 +170,4 @@ public record struct Notify( string? method, [property: JsonPropertyName("params"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault),] List? parameters); - - } diff --git a/src/Ws/WsClientOptions.cs b/src/Ws/WsClientOptions.cs new file mode 100644 index 00000000..8f63f56c --- /dev/null +++ b/src/Ws/WsClientOptions.cs @@ -0,0 +1,109 @@ +using Microsoft.IO; + +using SurrealDB.Common; + +namespace SurrealDB.Ws; + +public sealed record WsClientOptions : ValidateReadonly { + public const int MaxArraySize = 0X7FFFFFC7; + + /// The maximum number of pending messages in the client inbound channel. Default 1024 + /// This does not refer to the number of simultaneous queries, but the number of unread messages, the size of the "inbox". + /// A message may consist of multiple blocks. Only the message counts towards this number. + public int TxChannelCapacity { + get => _txChannelCapacity; + set => Set(out _txChannelCapacity, in value); + } + + /// The maximum number of pending blocks in a single message channel. Default 64 + /// The number of blocks of a message that can be received by the client, before they are consumed, + /// by reading from the . + /// A block have up to bytes. + public int MessageChannelCapacity { + get => _messageChannelCapacity; + set => Set(out _messageChannelCapacity, in value); + } + + /// The maximum number of bytes a received header can consist of. Default 4 * 1024bytes + /// The client receives a message with a and the message. + /// This is the length the socket "peeks" at the beginning of the network stream, in oder to fully deserialize the or . + /// The entire header must be contained within the peeked memory. + /// The length is bound to . + /// Longer lengths introduce additional overhead. + public int ReceiveHeaderBytesMax { + get => _receiveHeaderBytesMax; + set => Set(out _receiveHeaderBytesMax, in value); + } + /// The number of bytes the id consists of. + /// The id is base64 encoded, therefore 6 bytes = 4 characters. Use values in steps of 6. + public int IdBytes { + get => _idBytes; + set => Set(out _idBytes, in value); + } + /// Defines the resize behaviour of the streams used for handling messages. + public RecyclableMemoryStreamManager MemoryManager { + get => _memoryManager!; // Validated not null + set => Set(out _memoryManager, in value); + } + + /// The maximum time a request is awaited, before a is thrown. + /// Limited by the internal cache eviction timeout (1s) & pressure/traffic. + public TimeSpan RequestExpiration { + get => _requestExpiration; + set => Set(out _requestExpiration, value); + } + + private RecyclableMemoryStreamManager? _memoryManager; + private int _idBytes = 6; + private int _receiveHeaderBytesMax = 4 * 1024; + private int _txChannelCapacity = 1024; + private int _messageChannelCapacity = 64; + private TimeSpan _requestExpiration = TimeSpan.FromSeconds(10); + + public void ValidateAndMakeReadonly() { + if (!IsReadonly()) { + ValidateOrThrow(); + MakeReadonly(); + } + } + + protected override IEnumerable<(string PropertyName, string Message)> Validations() { + if (TxChannelCapacity <= 0) { + yield return (nameof(TxChannelCapacity), "cannot be less then or equal to zero"); + } + if (TxChannelCapacity > MaxArraySize) { + yield return (nameof(TxChannelCapacity), "cannot be greater then MaxArraySize"); + } + + if (MessageChannelCapacity <= 0) { + yield return (nameof(MessageChannelCapacity), "cannot be less then or equal to zero"); + } + if (MessageChannelCapacity > MaxArraySize) { + yield return (nameof(MessageChannelCapacity), "cannot be greater then MaxArraySize"); + } + + if (ReceiveHeaderBytesMax <= 0) { + yield return (nameof(ReceiveHeaderBytesMax), "cannot be less then or equal to zero"); + } + + if (ReceiveHeaderBytesMax > (_memoryManager?.BlockSize ?? 0)) { + yield return (nameof(ReceiveHeaderBytesMax), "cannot be greater then MemoryManager.BlockSize"); + } + + if (_memoryManager is null) { + yield return (nameof(MemoryManager), "cannot be null"); + } + + if (RequestExpiration <= TimeSpan.Zero) { + yield return (nameof(RequestExpiration), "expiration time cannot be less then or equal to zero"); + } + } + + public static WsClientOptions Default { get; } = CreateDefault(); + + private static WsClientOptions CreateDefault() { + WsClientOptions o = new() { MemoryManager = new(), }; + o.ValidateAndMakeReadonly(); + return o; + } +} diff --git a/src/Ws/WsReceiverDeflater.cs b/src/Ws/WsReceiverDeflater.cs new file mode 100644 index 00000000..96ba78aa --- /dev/null +++ b/src/Ws/WsReceiverDeflater.cs @@ -0,0 +1,155 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Channels; + +using SurrealDB.Common; + +namespace SurrealDB.Ws; + +/// Listens for s and dispatches them by their headers to different s. +internal sealed class WsReceiverDeflater : IDisposable { + private readonly ChannelReader _channel; + private readonly DisposingCache _handlers; + private CancellationTokenSource? _cts; + private Task? _execute; + private readonly int _maxHeaderBytes; + + public WsReceiverDeflater(ChannelReader channel, int maxHeaderBytes, TimeSpan cacheSlidingExpiration, TimeSpan cacheEvictionInterval) { + _channel = channel; + _maxHeaderBytes = maxHeaderBytes; + _handlers = new(cacheSlidingExpiration, cacheEvictionInterval); + } + + [MemberNotNullWhen(true, nameof(_cts)), MemberNotNullWhen(true, nameof(_execute))] + public bool Connected => _cts is not null & _execute is not null; + + private async Task Execute(CancellationToken ct) { + Debug.Assert(ct.CanBeCanceled); + + while (!ct.IsCancellationRequested) { + await Consume(ct).Inv(); + ct.ThrowIfCancellationRequested(); + } + } + + private async Task Consume(CancellationToken ct) { + var log = WsReceiverDeflaterEventSource.Log; + + ct.ThrowIfCancellationRequested(); + + // log that we are waiting for a message from the channel + log.MessageAwaiting(); + + var message = await ReadAsync(ct).Inv(); + // log that a message has been retrieved from the channel + log.MessageReceived(message.Header.Id); + + // find the handler + string? id = message.Header.Id; + if (id is null || !_handlers.TryGetValue(id, out var handler)) { + // invalid format, or no registered -> discard message + message.Dispose(); + // log that the message has been discarded + log.MessageDiscarded(id); + return; + } + Debug.Assert(id == handler.Id); + + // dispatch the message to the handler + try { + handler.Dispatch(message); + } catch (Exception ex) { + // handler is canceled -> unregister + Unregister(id); + // log that the dispatch has resulted in a exception + log.HandlerUnregisteredAfterException(id, ex); + } + + if (!handler.Persistent) { + Unregister(id); + // log that the handler has been unregistered + log.HandlerUnregisterdFleeting(id); + } + } + + public void Unregister(string id) { + if (_handlers.TryRemove(id, out var handler)) + { + handler.Dispose(); + } + } + + public bool RegisterOrGet(IHandler handler) { + return _handlers.TryAdd(handler.Id, handler); + } + + private async Task ReadAsync(CancellationToken ct) { + var message = await _channel.ReadAsync(ct).Inv(); + + // receive the first part of the message + var bytes = ArrayPool.Shared.Rent(_maxHeaderBytes); + int read = await message.ReadAsync(bytes, ct).Inv(); + // peek instead of reading + message.Position = 0; + // parse the header portion of the stream, without reading the `result` property. + // the header is a sum-type of all possible headers. + var header = HeaderHelper.Parse(bytes.AsSpan(0, read)); + ArrayPool.Shared.Return(bytes); + return new(header, message); + } + + public void Open() { + var log = WsReceiverDeflaterEventSource.Log; + + ThrowIfConnected(); + _cts = new(); + _execute = Execute(_cts.Token); + + log.Opened(); + } + + public async Task CloseAsync() { + var log = WsReceiverDeflaterEventSource.Log; + + ThrowIfDisconnected(); + Task task; + _cts.Cancel(); + _cts.Dispose(); // not relly needed here + _cts = null; + task = _execute; + _execute = null; + + log.CloseBegin(); + + try { + await task.Inv(); + } catch (OperationCanceledException) { + // expected on close using cts + } + + log.CloseFinish(); + } + + public void Dispose() { + var log = WsReceiverDeflaterEventSource.Log; + + _cts?.Cancel(); + _cts?.Dispose(); + + log.Disposed(); + } + + [MemberNotNull(nameof(_cts)), MemberNotNull(nameof(_execute))] + private void ThrowIfDisconnected() { + if (!Connected) { + throw new InvalidOperationException("The connection is not open."); + } + } + + private void ThrowIfConnected() { + if (Connected) { + throw new InvalidOperationException("The connection is already open"); + } + } +} diff --git a/src/Ws/WsReceiverDeflaterEventSource.cs b/src/Ws/WsReceiverDeflaterEventSource.cs new file mode 100644 index 00000000..7b2d04d5 --- /dev/null +++ b/src/Ws/WsReceiverDeflaterEventSource.cs @@ -0,0 +1,104 @@ +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Ws; + +[EventSource(Guid = "03c50b03-e245-46e5-a99a-6eaa28990a41", Name = "WsReceiverDeflaterEventSource")] +public sealed class WsReceiverDeflaterEventSource : EventSource +{ + private WsReceiverDeflaterEventSource() { } + + public static WsReceiverDeflaterEventSource Log { get; } = new(); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MessageReceived(string? messageId) { + if (IsEnabled()) { + MessageReceivedCore(messageId); + } + } + + [Event(1, Level = EventLevel.Verbose, Message = "Message (Id = {0}) pulled from channel")] + private void MessageReceivedCore(string? messageId) => WriteEvent(1, messageId); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MessageDiscarded(string? messageId) { + if (IsEnabled()) { + MessageDiscardedCore(messageId); + } + } + + [Event(2, Level = EventLevel.Warning, Message = "No handler registered for the message (Id = {0})")] + private void MessageDiscardedCore(string? messageId) => WriteEvent(2, messageId); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void HandlerUnregisteredAfterException(string handlerId, Exception ex) { + if (IsEnabled()) { + HandlerUnregisteredAfterExceptionCore(handlerId, ex.ToString()); + } + } + + [Event(3, Level = EventLevel.Error, Message = "The handler (Id = {0}) threw an exception during dispatch, and was unregistered. ERROR: {1}")] + private unsafe void HandlerUnregisteredAfterExceptionCore(string handlerId, string ex) { + WriteEvent(3, handlerId, ex); + } + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void HandlerUnregisterdFleeting(string handlerId) { + if (IsEnabled()) { + HandlerUnregisterdFleetingCore(handlerId); + } + } + + [Event(4, Level = EventLevel.Verbose, Message = "The handler (Id = {0}) is fleeting and was unregistered after dispatch")] + private void HandlerUnregisterdFleetingCore(string handlerId) => WriteEvent(4, handlerId); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MessageAwaiting() { + if (IsEnabled()) { + MessageAwaitingCore(); + } + } + + [Event(5, Level = EventLevel.Verbose, Message = "Waiting for message to pull from channel")] + private void MessageAwaitingCore() => WriteEvent(5); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Opened() { + if (IsEnabled()) { + OpenedCore(); + } + } + + [Event(6, Level = EventLevel.Informational, Message = "Opened and is now pulling from the channel")] + private void OpenedCore() => WriteEvent(6); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void CloseBegin() { + if (IsEnabled()) { + CloseBeginCore(); + } + } + + [Event(7, Level = EventLevel.Informational, Message = "Closed and stopped pulling from the channel")] + private void CloseBeginCore() => WriteEvent(7); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void CloseFinish() { + if (IsEnabled()) { + CloseFinishCore(); + } + } + + [Event(8, Level = EventLevel.Informational, Message = "Closing has finished")] + private void CloseFinishCore() => WriteEvent(8); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Disposed() { + if (IsEnabled()) { + DisposedCore(); + } + } + + [Event(9, Level = EventLevel.Informational, Message = "Disposed")] + private void DisposedCore() => WriteEvent(9); +} diff --git a/src/Ws/WsReceiverInflater.cs b/src/Ws/WsReceiverInflater.cs new file mode 100644 index 00000000..ba54e5f9 --- /dev/null +++ b/src/Ws/WsReceiverInflater.cs @@ -0,0 +1,136 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Threading.Channels; + +using Microsoft.IO; + +using SurrealDB.Common; + +namespace SurrealDB.Ws; + +/// Receives messages from a websocket server and passes them to a channel +public sealed class WsReceiverInflater : IDisposable { + private readonly ClientWebSocket _socket; + private readonly ChannelWriter _channel; + private readonly RecyclableMemoryStreamManager _memoryManager; + private CancellationTokenSource? _cts; + private Task? _execute; + + private readonly int _blockSize; + private readonly int _messageSize; + + public WsReceiverInflater(ClientWebSocket socket, ChannelWriter channel, RecyclableMemoryStreamManager memoryManager, int blockSize, int messageSize) { + _socket = socket; + _channel = channel; + _memoryManager = memoryManager; + _blockSize = blockSize; + _messageSize = messageSize; + } + + private async Task Execute(CancellationToken ct) { + Debug.Assert(ct.CanBeCanceled); + while (!ct.IsCancellationRequested) { + var buffer = ArrayPool.Shared.Rent(_blockSize); + await Produce(buffer, ct).Inv(); + ArrayPool.Shared.Return(buffer); + ct.ThrowIfCancellationRequested(); + } + } + + private async Task Produce(byte[] buffer, CancellationToken ct) { + var log = WsReceiverInflaterEventSource.Log; + + // log that we are waiting for the socket + log.SocketWaiting(); + // receive the first part + var result = await _socket.ReceiveAsync(buffer, ct).Inv(); + // log that we have received a message from the socket + log.SockedReceived(result); + + ct.ThrowIfCancellationRequested(); + // create a new message with a RecyclableMemoryStream + // use buffer instead of the build the builtin IBufferWriter, bc of thread safely issues related to locking + WsReceiverMessageReader msg = new(_memoryManager, _messageSize); + // begin adding the message to the output + var push = _channel.WriteAsync(msg, ct); + + await msg.AppendResultAsync(buffer, result, ct).Inv(); + + while (!result.EndOfMessage && !ct.IsCancellationRequested) { + // receive more parts + result = await _socket.ReceiveAsync(buffer, ct).Inv(); + // log that we have received a message from the socket + log.SockedReceived(result); + await msg.AppendResultAsync(buffer, result, ct).Inv(); + } + + // log that we have completely received the message + log.MessageReceiveFinished(); + // finish adding the message to the output + await push.Inv(); + // log that the message has been pushed to the channel + log.MessagePushed(); + } + + + [MemberNotNullWhen(true, nameof(_cts)), MemberNotNullWhen(true, nameof(_execute))] + public bool Connected => _cts is not null & _execute is not null; + + public void Open() { + var log = WsReceiverInflaterEventSource.Log; + + ThrowIfConnected(); + _cts = new(); + _execute = Execute(_cts.Token); + + log.Opened(); + } + + public async Task CloseAsync() { + var log = WsReceiverInflaterEventSource.Log; + + ThrowIfDisconnected(); + var task = _execute; + _cts.Cancel(); + _cts.Dispose(); // not relly needed here + _cts = null; + _execute = null; + + log.CloseBegin(); + + try { + await task.Inv(); + } catch (OperationCanceledException) { + // expected on close using cts + } catch (WebSocketException) { + // expected on abort + } + + log.CloseFinish(); + } + + [MemberNotNull(nameof(_cts)), MemberNotNull(nameof(_execute))] + private void ThrowIfDisconnected() { + if (!Connected) { + throw new InvalidOperationException("The connection is not open."); + } + } + + private void ThrowIfConnected() { + if (Connected) { + throw new InvalidOperationException("The connection is already open"); + } + } + + public void Dispose() { + var log = WsReceiverInflaterEventSource.Log; + + _cts?.Cancel(); + _cts?.Dispose(); + _channel.TryComplete(); + + log.Disposed(); + } +} diff --git a/src/Ws/WsReceiverInflaterEventSource.cs b/src/Ws/WsReceiverInflaterEventSource.cs new file mode 100644 index 00000000..4415dc16 --- /dev/null +++ b/src/Ws/WsReceiverInflaterEventSource.cs @@ -0,0 +1,100 @@ +using System.Diagnostics.Tracing; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; + +namespace SurrealDB.Ws; + +[EventSource(Guid = "91a1c84b-f0aa-43c8-ad21-6ff518a8fa01", Name = "WsReceiverInflaterEventSource")] +public sealed class WsReceiverInflaterEventSource : EventSource { + private WsReceiverInflaterEventSource() { } + + public static WsReceiverInflaterEventSource Log { get; } = new(); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void SocketWaiting() { + if (IsEnabled()) { + SocketWaitingCore(); + } + } + + [Event(1, Level = EventLevel.Verbose, Message = "Waiting to receive a block from the socket")] + private void SocketWaitingCore() => WriteEvent(1); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void SockedReceived(WebSocketReceiveResult result) { + if (IsEnabled()) { + SockedReceivedCore(result.Count, result.EndOfMessage, result.CloseStatus is not null); + } + } + + [ThreadStatic] + private static object[]? _socketReceivedArgs; + [Event(2, Level = EventLevel.Verbose, Message = "Received a block from the socket (Count = {0}, EndOfMessage = {1}, Closed = {2})")] + private unsafe void SockedReceivedCore(int count, bool endOfMessage, bool closed) { + _socketReceivedArgs ??= new object[3]; + _socketReceivedArgs[0] = count; + _socketReceivedArgs[1] = endOfMessage; + _socketReceivedArgs[2] = closed; + WriteEvent(2, _socketReceivedArgs); + } + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MessageReceiveFinished() { + if (IsEnabled()) { + MessageReceiveFinishedCore(); + } + } + + [Event(3, Level = EventLevel.Verbose, Message = "Finished receiving the message from the socket")] + private void MessageReceiveFinishedCore() => WriteEvent(3); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MessagePushed() { + if (IsEnabled()) { + MessagePushedCore(); + } + } + + [Event(4, Level = EventLevel.Informational, Message = "Pushed the message to the channel")] + private void MessagePushedCore() => WriteEvent(4); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Opened() { + if (IsEnabled()) { + OpenedCore(); + } + } + + [Event(5, Level = EventLevel.Informational, Message = "Opened and is now pushing to the channel")] + private void OpenedCore() => WriteEvent(5); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void CloseBegin() { + if (IsEnabled()) { + CloseBeginCore(); + } + } + + [Event(6, Level = EventLevel.Informational, Message = "Closed and stopped pushing to the channel")] + private void CloseBeginCore() => WriteEvent(6); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void CloseFinish() { + if (IsEnabled()) { + CloseFinishCore(); + } + } + + [Event(7, Level = EventLevel.Informational, Message = "Closing has finished")] + private void CloseFinishCore() => WriteEvent(7); + + [NonEvent, MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Disposed() { + if (IsEnabled()) { + DisposedCore(); + } + } + + [Event(8, Level = EventLevel.Informational, Message = "Disposed")] + private void DisposedCore() => WriteEvent(8); +} diff --git a/src/Ws/WsReceiverMessageReader.cs b/src/Ws/WsReceiverMessageReader.cs new file mode 100644 index 00000000..f59b3ab5 --- /dev/null +++ b/src/Ws/WsReceiverMessageReader.cs @@ -0,0 +1,204 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; + +using Microsoft.IO; + +using SurrealDB.Common; + +namespace SurrealDB.Ws; + +public sealed class WsReceiverMessageReader : Stream { + private readonly BoundedChannel _channel; + private readonly RecyclableMemoryStream _stream; + private int _endOfMessage; + + internal WsReceiverMessageReader(RecyclableMemoryStreamManager memoryManager, int channelCapacity) { + _stream = new(memoryManager); + _channel = BoundedChannelPool.Shared.Rent(channelCapacity); + _endOfMessage = 0; + } + + public bool HasReceivedEndOfMessage => Interlocked.CompareExchange(ref _endOfMessage, 0, 0) != 0; + + protected override void Dispose(bool disposing) { + if (!disposing) { + return; + } + + _stream.Dispose(); + _channel.Dispose(); + } + + private async ValueTask SetReceivedAsync(WebSocketReceiveResult result, CancellationToken ct) { + await _channel.Writer.WriteAsync(result, ct).Inv(); + if (result.EndOfMessage) { + Interlocked.Exchange(ref _endOfMessage, 1); + } + } + + private ValueTask ReceiveAsync(CancellationToken ct) { + return _channel.Reader.ReadAsync(ct); + } + + private WebSocketReceiveResult Receive(CancellationToken ct) { + var t = ReceiveAsync(ct); + return t.IsCompleted ? t.Result : t.AsTask().Result; + } + + internal ValueTask AppendResultAsync(ReadOnlyMemory buffer, WebSocketReceiveResult result, CancellationToken ct) { + ReadOnlySpan span = buffer.Span.Slice(0, result.Count); + lock (_stream) { + var pos = _stream.Position; + _stream.Write(span); + _stream.Position = pos; + } + ct.ThrowIfCancellationRequested(); + + return SetReceivedAsync(result, ct); + } + +#region Stream members + + public override bool CanRead => true; + public override bool CanSeek => true; + public override bool CanWrite => false; + public override long Length { + get { + lock (_stream) { + return _stream.Length; + } + } + } + + public override long Position { + get { + lock (_stream) { + return _stream.Position; + } + } + set { + lock (_stream) { + _stream.Position = value; + } + } + } + public override void Flush() { + ThrowCantWrite(); + } + + public override int Read(byte[] buffer, int offset, int count) { + return Read(buffer.AsSpan(offset, count)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override int Read(Span buffer) { + return Read(buffer, default); + } + + public int Read(Span buffer, CancellationToken ct) { + if (0 >= (uint)buffer.Length || ct.IsCancellationRequested) { + return 0; + } + + int read; + lock (_stream) { + // attempt to read from present buffer + read = _stream.Read(buffer); + } + ct.ThrowIfCancellationRequested(); + + if (read == buffer.Length || HasReceivedEndOfMessage) { + return read; + } + + WebSocketReceiveResult result; + do { + result = Receive(ct); + int inc; + lock (_stream) { + inc = _stream.Read(buffer.Slice(read)); + } + + ct.ThrowIfCancellationRequested(); + Debug.Assert(inc == result.Count); + read += inc; + } while (!result.EndOfMessage); + + return read; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) { + return ReadAsync(buffer.AsMemory(offset, count), ct).AsTask(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public override ValueTask ReadAsync(Memory buffer, CancellationToken ct = default) { + if (0 >= (uint)buffer.Length || ct.IsCancellationRequested) { + return new(0); + } + + int read; + lock (_stream) { + // attempt to read from present buffer + read = _stream.Read(buffer.Span); + } + ct.ThrowIfCancellationRequested(); + + if (read == buffer.Length || HasReceivedEndOfMessage) { + return new(read); + } + + return new(ReadFromChannelAsync(buffer, read, ct)); + } + + private async Task ReadFromChannelAsync(Memory buffer, int read, CancellationToken ct) { + WebSocketReceiveResult? result; + do { + result = await ReceiveAsync(ct).Inv(); + int inc; + lock (_stream) { + inc = _stream.Read(buffer.Span.Slice(read)); + } + + ct.ThrowIfCancellationRequested(); + + Debug.Assert(inc == result.Count); + read += inc; + } while (!result.EndOfMessage); + + return read; + } + + public override int ReadByte() { + Span buffer = stackalloc byte[1]; + int read = Read(buffer); + return read == 0 ? -1 : buffer[0]; + } + + public override long Seek(long offset, SeekOrigin origin) { + lock (_stream) { + return _stream.Seek(offset, origin); + } + } + + public override void SetLength(long value) { + lock (_stream) { + _stream.SetLength(value); + } + } + + public override void Write(byte[] buffer, int offset, int count) { + ThrowCantWrite(); + } + + +#endregion + + [DoesNotReturn, DebuggerStepThrough, MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowCantWrite() { + throw new NotSupportedException("The stream does not support writing"); + } +} diff --git a/src/Ws/WsTransmitter.cs b/src/Ws/WsTransmitter.cs new file mode 100644 index 00000000..a083c5d3 --- /dev/null +++ b/src/Ws/WsTransmitter.cs @@ -0,0 +1,37 @@ +using System.Net.WebSockets; + +using SurrealDB.Common; + +namespace SurrealDB.Ws; + +/// Sends messages from a channel to a websocket server. +public sealed class WsTransmitter { + private readonly ClientWebSocket _ws; + private readonly int _blockSize; + + public WsTransmitter(ClientWebSocket ws, int blockSize) { + _ws = ws; + _blockSize = blockSize; + } + + public async Task SendAsync(Stream stream, CancellationToken ct) { + // reader is disposed by the consumer + using BufferStreamReader reader = new(stream, _blockSize); + await Consume(reader, ct).Inv(); + } + + + private async Task Consume(BufferStreamReader reader, CancellationToken ct) { + bool isFinalBlock = false; + while (!isFinalBlock && !ct.IsCancellationRequested) { + var rom = await reader.ReadAsync(_blockSize, ct).Inv(); + isFinalBlock = rom.Length != _blockSize; + await _ws.SendAsync(rom, WebSocketMessageType.Text, isFinalBlock, ct).Inv(); + } + + if (!isFinalBlock) { + // ensure that the message is always terminated + await _ws.SendAsync(default, WebSocketMessageType.Text, true, ct).Inv(); + } + } +} diff --git a/src/Ws/WsTx.cs b/src/Ws/WsTx.cs deleted file mode 100644 index 03893a29..00000000 --- a/src/Ws/WsTx.cs +++ /dev/null @@ -1,404 +0,0 @@ -using System.Buffers; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Net.WebSockets; -using System.Text.Json; - -using SurrealDB.Common; -using SurrealDB.Json; - -namespace SurrealDB.Ws; - -public sealed class WsTx : IDisposable { - private readonly ClientWebSocket _ws = new(); - - public static int DefaultBufferSize => 16 * 1024; - - /// - /// Indicates whether the client is connected or not. - /// - public bool Connected => _ws.State == WebSocketState.Open; - - public async Task Open(Uri remote, CancellationToken ct = default) { - ThrowIfConnected(); - await _ws.ConnectAsync(remote, ct); - } - - public async Task Close(CancellationToken ct = default) { - if (_ws.State == WebSocketState.Closed) { - return; - } - - try { - await _ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "client disconnect", ct); - } catch (OperationCanceledException) { - if (ct.IsCancellationRequested) { - // Catch any canceled exception that is generated during the close, - // but still throw for cancellations that we requested. - throw; - } - } - } - - public void Dispose() { - _ws.Dispose(); - } - - /// - /// Receives a response stream from the socket. - /// Parses the header. - /// The body contains the result array including the end object token `[...]}`. - /// - public async Task<(string? id, RspHeader rsp, NtyHeader nty, Stream body)> Tr(CancellationToken ct) { - ThrowIfDisconnected(); - // this method assumes that the header size never exceeds DefaultBufferSize! - IMemoryOwner owner = MemoryPool.Shared.Rent(DefaultBufferSize); - var r = await _ws.ReceiveAsync(owner.Memory, ct); - - if (r.Count <= 0) { - return (default, default, default, default!); - } - - // parse the header - var (rsp, nty, off) = ParseHeader(owner.Memory.Span.Slice(0, r.Count)); - string? id = rsp.IsDefault ? nty.id : rsp.id; - if (String.IsNullOrEmpty(id)) { - ThrowHeaderId(); - } - // returns a stream over the remainder of the body - Stream body = CreateBody(r, owner, owner.Memory.Slice(off)); - return (id, rsp, nty, body); - } - - private static (RspHeader rsp, NtyHeader nty, int off) ParseHeader(ReadOnlySpan utf8) { - var (rsp, rspOff, rspErr) = RspHeader.Parse(utf8); - if (rspErr is null) { - return (rsp, default, (int)rspOff); - } - var (nty, ntyOff, ntyErr) = NtyHeader.Parse(utf8); - if (ntyErr is null) { - return (default, nty, (int)ntyOff); - } - - throw new JsonException($"Failed to parse RspHeader or NotifyHeader: {rspErr} \n--AND--\n {ntyErr}", null, 0, Math.Max(rspOff, ntyOff)); - } - - private Stream CreateBody(ValueWebSocketReceiveResult res, IDisposable owner, ReadOnlyMemory rem) { - // check if rsp is already completely in the buffer - if (res.EndOfMessage) { - // create a rented stream from the remainder. - MemoryStream s = RentedMemoryStream.FromMemory(owner, rem, true, true); - s.SetLength(res.Count); - return s; - } - - // the rsp is not recv completely! - // create a stream wrapping the websocket - // with the recv portion as a prefix - Debug.Assert(res.Count == rem.Length); - return new WsStream(owner, rem, _ws); - } - - /// - /// Sends the stream over the socket. - /// - /// - /// Fast if used with a with exposed buffer! - /// - public async Task Tw(Stream req, CancellationToken ct) { - ThrowIfDisconnected(); - if (req is MemoryStream ms && ms.TryGetBuffer(out ArraySegment raw)) { - // We can obtain the raw buffer from the request, send it - await _ws.SendAsync(raw, WebSocketMessageType.Text, true, ct); - return; - } - - using IMemoryOwner owner = MemoryPool.Shared.Rent(DefaultBufferSize); - bool end = false; - while (!end && !ct.IsCancellationRequested) { - int read = await req.ReadAsync(owner.Memory, ct); - end = read != owner.Memory.Length; - ReadOnlyMemory used = owner.Memory.Slice(0, read); - await _ws.SendAsync(used, WebSocketMessageType.Text, end, ct); - - ThrowIfDisconnected(); - ct.ThrowIfCancellationRequested(); - } - Debug.Assert(end, "Unfinished message sent!"); - } - - [DoesNotReturn] - private static void ThrowHeaderId() { - throw new InvalidOperationException("Header has no associated id!"); - } - - private void ThrowIfDisconnected() { - if (!Connected) { - throw new InvalidOperationException("The connection is not open."); - } - } - - private void ThrowIfConnected() { - if (Connected) { - throw new InvalidOperationException("The connection is already open"); - } - } - - [DoesNotReturn] - private static void ThrowParseHead(string err, long off) { - throw new JsonException(err, default, default, off); - } - - public readonly record struct NtyHeader(string? id, string? method, WsClient.Error err) { - public bool IsDefault => default == this; - - /// - /// Parses the head including the result propertyname, excluding the result array. - /// - internal static (NtyHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) { - Fsm fsm = new() { - Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })), - State = Fsms.Start, - }; - while (fsm.MoveNext()) {} - - if (!fsm.Success) { - return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}"); - } - return (new(fsm.Id, fsm.Method, fsm.Error), default, default); - } - - private enum Fsms { - Start, // -> Prop - Prop, // -> PropId | PropAsync | PropMethod | ProsResult - PropId, // -> Prop | End - PropMethod, // -> Prop | End - PropError, // -> End - PropParams, // -> End - End - } - - private ref struct Fsm { - public Fsms State; - public Utf8JsonReader Lexer; - public string? Err; - public bool Success; - - public string? Name; - public string? Id; - public WsClient.Error Error; - public string? Method; - - public bool MoveNext() { - return State switch { - Fsms.Start => Start(), - Fsms.Prop => Prop(), - Fsms.PropId => PropId(), - Fsms.PropMethod => PropMethod(), - Fsms.PropError => PropError(), - Fsms.PropParams => PropParams(), - Fsms.End => End(), - _ => false - }; - } - - private bool Start() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) { - Err = "Unable to read token StartObject"; - return false; - } - - State = Fsms.Prop; - return true; - - } - - private bool End() { - Success = !String.IsNullOrEmpty(Id) && !String.IsNullOrEmpty(Method); - return false; - } - - private bool Prop() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) { - Err = "Unable to read PropertyName"; - return false; - } - - Name = Lexer.GetString(); - if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropId; - return true; - } - if ("method".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropMethod; - return true; - } - if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropError; - return true; - } - if ("params".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropParams; - return true; - } - - Err = $"Unknown PropertyName `{Name}`"; - return false; - } - - private bool PropId() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { - Err = "Unable to read `id` property value"; - return false; - } - - State = Fsms.Prop; - Id = Lexer.GetString(); - return true; - } - - private bool PropError() { - Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared); - State = Fsms.End; - return true; - } - - private bool PropMethod() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { - Err = "Unable to read `method` property value"; - return false; - } - - State = Fsms.Prop; - Method = Lexer.GetString(); - return true; - } - - private bool PropParams() { - // Do not parse the result! - // The complete result is not present in the buffer! - // The result is returned as a unevaluated asynchronous stream! - State = Fsms.End; - return true; - } - } - } - - public readonly record struct RspHeader(string? id, WsClient.Error err) { - public bool IsDefault => default == this; - - /// - /// Parses the head including the result propertyname, excluding the result array. - /// - internal static (RspHeader head, long off, string? err) Parse(in ReadOnlySpan utf8) { - Fsm fsm = new() { - Lexer = new(utf8, false, new JsonReaderState(new() { CommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true })), - State = Fsms.Start, - }; - while (fsm.MoveNext()) {} - - if (!fsm.Success) { - return (default, fsm.Lexer.BytesConsumed, $"Error while parsing {nameof(RspHeader)} at {fsm.Lexer.TokenStartIndex}: {fsm.Err}"); - } - return (new(fsm.Id, fsm.Error), default, default); - } - - private enum Fsms { - Start, // -> Prop - Prop, // -> PropId | PropError | ProsResult - PropId, // -> Prop | End - PropError, // -> End - PropResult, // -> End - End - } - - private ref struct Fsm { - public Fsms State; - public Utf8JsonReader Lexer; - public string? Err; - public bool Success; - - public string? Name; - public string? Id; - public WsClient.Error Error; - - public bool MoveNext() { - return State switch { - Fsms.Start => Start(), - Fsms.Prop => Prop(), - Fsms.PropId => PropId(), - Fsms.PropError => PropError(), - Fsms.PropResult => PropResult(), - Fsms.End => End(), - _ => false - }; - } - - private bool Start() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.StartObject) { - Err = "Unable to read token StartObject"; - return false; - } - - State = Fsms.Prop; - return true; - - } - - private bool End() { - Success = !String.IsNullOrEmpty(Id); - return false; - } - - private bool Prop() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.PropertyName) { - Err = "Unable to read PropertyName"; - return false; - } - - Name = Lexer.GetString(); - if ("id".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropId; - return true; - } - if ("result".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropResult; - return true; - } - if ("error".Equals(Name, StringComparison.OrdinalIgnoreCase)) { - State = Fsms.PropError; - return true; - } - - Err = $"Unknown PropertyName `{Name}`"; - return false; - } - - private bool PropId() { - if (!Lexer.Read() || Lexer.TokenType != JsonTokenType.String) { - Err = "Unable to read `id` property value"; - return false; - } - - State = Fsms.Prop; - Id = Lexer.GetString(); - return true; - } - - private bool PropError() { - Error = JsonSerializer.Deserialize(ref Lexer, SerializerOptions.Shared); - State = Fsms.End; - return true; - } - - - private bool PropResult() { - // Do not parse the result! - // The complete result is not present in the buffer! - // The result is returned as a unevaluated asynchronous stream! - State = Fsms.End; - return true; - } - } - } -} diff --git a/tests/Driver.Tests/DatabaseTests.cs b/tests/Driver.Tests/DatabaseTests.cs index fc44e8c8..439d530d 100644 --- a/tests/Driver.Tests/DatabaseTests.cs +++ b/tests/Driver.Tests/DatabaseTests.cs @@ -10,98 +10,81 @@ public sealed class RestDatabaseTest : DatabaseTestDriver { [Collection("SurrealDBRequired")] public abstract class DatabaseTestDriver - : DriverBase where T : IDatabase, IDisposable, new() { - protected override async Task Run(T db) { - db.GetConfig().Should().BeEquivalentTo(TestHelper.Default); - - var useResp = await db.Use(TestHelper.Database, TestHelper.Namespace); - TestHelper.AssertOk(useResp); - var infoResp = await db.Info(); - TestHelper.AssertOk(infoResp); - - var signInStatus = await db.Signin(new RootAuth(TestHelper.User, TestHelper.Pass)); - - TestHelper.AssertOk(signInStatus); - - (string id1, string id2) = ("id1", "id2"); - var res1 = await db.Create( - "person", - new { - Title = "Founder & CEO", - Name = new { First = "Tobie", Last = "Morgan Hitchcock", }, - Marketing = true, - Identifier = ThreadRng.Shared.Next(), - } - ); - - TestHelper.AssertOk(res1); + [Fact] + public async Task TestSuite() => await DbHandle.WithDatabase( + async db => { + db.GetConfig().Should().BeEquivalentTo(TestHelper.Default); - var res2 = await db.Create( - "person", - new { - Title = "Contributor", - Name = new { First = "Prophet", Last = "Lamb", }, - Marketing = false, - Identifier = ThreadRng.Shared.Next(), - } - ); + var useResp = await db.Use(TestHelper.Database, TestHelper.Namespace); + TestHelper.AssertOk(useResp); + var infoResp = await db.Info(); + TestHelper.AssertOk(infoResp); - TestHelper.AssertOk(res2); + var signInStatus = await db.Signin(new RootAuth(TestHelper.User, TestHelper.Pass)); - Thing thing2 = ("person", id2); - TestHelper.AssertOk(await db.Update(thing2, new { Marketing = false, })); + TestHelper.AssertOk(signInStatus); - TestHelper.AssertOk(await db.Select(thing2)); + (string id1, string id2) = ("id1", "id2"); + var res1 = await db.Create( + "person", + new { + Title = "Founder & CEO", + Name = new { First = "Tobie", Last = "Morgan Hitchcock", }, + Marketing = true, + Identifier = ThreadRng.Shared.Next(), + } + ); - TestHelper.AssertOk(await db.Delete(thing2)); + TestHelper.AssertOk(res1); - Thing thing1 = ("person", id1); - TestHelper.AssertOk( - await db.Change( - thing1, + var res2 = await db.Create( + "person", new { - Title = "Founder & CEO", - Name = new { First = "Tobie", Last = "Hitchcock Morgan", }, + Title = "Contributor", + Name = new { First = "Prophet", Last = "Lamb", }, Marketing = false, Identifier = ThreadRng.Shared.Next(), } - ) - ); + ); - string newTitle = "Founder & CEO & Ruler of the known free World"; - var modifyResp = await db.Modify(thing1, new[] { - Patch.Replace("/Title", newTitle), - }); - TestHelper.AssertOk(modifyResp); + TestHelper.AssertOk(res2); - TestHelper.AssertOk(await db.Let("tbl", "person")); + Thing thing2 = ("person", id2); + TestHelper.AssertOk(await db.Update(thing2, new { Marketing = false, })); - var queryResp = await db.Query( - "SELECT $props FROM $tbl WHERE title = $title", - new Dictionary { ["props"] = "title, identifier", ["tbl"] = "person", ["title"] = newTitle, } - ); + TestHelper.AssertOk(await db.Select(thing2)); - TestHelper.AssertOk(queryResp); + TestHelper.AssertOk(await db.Delete(thing2)); - await db.Close(); - } -} + Thing thing1 = ("person", id1); + TestHelper.AssertOk( + await db.Change( + thing1, + new { + Title = "Founder & CEO", + Name = new { First = "Tobie", Last = "Hitchcock Morgan", }, + Marketing = false, + Identifier = ThreadRng.Shared.Next(), + } + ) + ); -/// -/// The test driver executes the testsuite on the client. -/// -[Collection("SurrealDBRequired")] -public abstract class DriverBase - where T : IDatabase, IDisposable, new() { + string newTitle = "Founder & CEO & Ruler of the known free World"; + var modifyResp = await db.Modify(thing1, new[] { Patch.Replace("/Title", newTitle), }); + TestHelper.AssertOk(modifyResp); + TestHelper.AssertOk(await db.Let("tbl", "person")); - [Fact] - public async Task TestSuite() { - using var handle = await DbHandle.Create(); - await Run(handle.Database); - } + var queryResp = await db.Query( + "SELECT $props FROM $tbl WHERE title = $title", + new Dictionary { + ["props"] = "title, identifier", ["tbl"] = "person", ["title"] = newTitle, + } + ); - protected abstract Task Run(T db); + TestHelper.AssertOk(queryResp); + } + ); } diff --git a/tests/Driver.Tests/Queries/GeneralQueryTests.cs b/tests/Driver.Tests/Queries/GeneralQueryTests.cs index 06852cb4..ccd9ac50 100644 --- a/tests/Driver.Tests/Queries/GeneralQueryTests.cs +++ b/tests/Driver.Tests/Queries/GeneralQueryTests.cs @@ -304,7 +304,7 @@ public async Task SimultaneousDatabaseOperations() => await DbHandle.WithData async db => { var taskCount = 50; var tasks = Enumerable.Range(0, taskCount).Select(i => DbTask(i, db)); - await Task.WhenAll(tasks).ConfigureAwait(false); + await Task.WhenAll(tasks).Inv(); } ); @@ -314,17 +314,17 @@ private async Task DbTask(int i, T db) { var expectedResult = new TestObject(i, i); Thing thing = Thing.From("object", expectedResult.Key); - var createResponse = await db.Create(thing, expectedResult).ConfigureAwait(false); + var createResponse = await db.Create(thing, expectedResult).Inv(); AssertResponse(createResponse, expectedResult); Logger.WriteLine($"Create {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}"); - var selectResponse = await db.Select(thing).ConfigureAwait(false); + var selectResponse = await db.Select(thing).Inv(); AssertResponse(selectResponse, expectedResult); Logger.WriteLine($"Select {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}"); string sql = "SELECT * FROM $record"; Dictionary param = new() { ["record"] = thing }; - var queryResponse = await db.Query(sql, param).ConfigureAwait(false); + var queryResponse = await db.Query(sql, param).Inv(); AssertResponse(queryResponse, expectedResult); Logger.WriteLine($"Query {i} - Thread ID {Thread.CurrentThread.ManagedThreadId}"); diff --git a/tests/Driver.Tests/Roundtrip/RoundTripTests.cs b/tests/Driver.Tests/Roundtrip/RoundTripTests.cs index 0d827678..51a2d74a 100644 --- a/tests/Driver.Tests/Roundtrip/RoundTripTests.cs +++ b/tests/Driver.Tests/Roundtrip/RoundTripTests.cs @@ -1,5 +1,3 @@ -using System.Collections; - using SurrealDB.Json; using SurrealDB.Models.Result; diff --git a/tests/Shared/ConsoleOutEventListener.cs b/tests/Shared/ConsoleOutEventListener.cs new file mode 100644 index 00000000..117cee60 --- /dev/null +++ b/tests/Shared/ConsoleOutEventListener.cs @@ -0,0 +1,16 @@ +using System.Diagnostics.Tracing; + +namespace SurrealDB.Shared.Tests; + +public class ConsoleOutEventListener : EventListener { + + protected override void OnEventWritten(EventWrittenEventArgs eventData) { + string message = $"{eventData.TimeStamp:T} - {eventData.EventSource.Name} - {eventData.EventName} - {eventData.OSThreadId}: {eventData.Message}"; + if (eventData.Payload is null) { + Console.WriteLine(message); + } else { + Console.WriteLine(message, eventData.Payload.ToArray()); + } + base.OnEventWritten(eventData); + } +} diff --git a/tests/Shared/DbHandle.cs b/tests/Shared/DbHandle.cs index a6ca36b6..0581ca00 100644 --- a/tests/Shared/DbHandle.cs +++ b/tests/Shared/DbHandle.cs @@ -1,4 +1,7 @@ +using System.Diagnostics.Tracing; + using SurrealDB.Abstractions; +using SurrealDB.Ws; namespace SurrealDB.Shared.Tests; @@ -20,22 +23,28 @@ public static async Task> Create() { [DebuggerStepThrough] public static async Task WithDatabase(Func action) { + // enable console logging for events + using TestEventListener l = new(); + l.EnableEvents(WsReceiverInflaterEventSource.Log, EventLevel.LogAlways); + l.EnableEvents(WsReceiverDeflaterEventSource.Log, EventLevel.LogAlways); + // connect to the database using DbHandle db = await Create(); + // execute test methods await action(db.Database); } public T Database { get; } - ~DbHandle() { - Dispose(); - } - public void Dispose() { Process? p = _process; _process = null; if (p is not null) { - Database.Dispose(); - p.Kill(); + DisposeActual(p); } } + + private void DisposeActual(Process p) { + Database.Dispose(); + p.Kill(); + } } diff --git a/tests/Shared/TestEventListener.cs b/tests/Shared/TestEventListener.cs new file mode 100644 index 00000000..eaf5ab89 --- /dev/null +++ b/tests/Shared/TestEventListener.cs @@ -0,0 +1,51 @@ +using System.Diagnostics.Tracing; +using System.Runtime.ExceptionServices; +using System.Text; +using System.Text.RegularExpressions; + +namespace SurrealDB.Shared.Tests; + +public class TestEventListener : EventListener { + private StreamWriter? _writer = new(File.OpenWrite($"./{nameof(TestEventListener)}_{DateTimeOffset.UtcNow:s}.log")); + + public TestEventListener() { + AppDomain.CurrentDomain.FirstChanceException += FirstChanceException; + } + + private void FirstChanceException(object? sender, FirstChanceExceptionEventArgs e) { + ValueStringBuilder sb = new(stackalloc char[512]); + sb.Append(DateTime.UtcNow.ToString("T")); + sb.Append(" Error: "); + sb.Append(e.Exception.ToString()); + + WriteLine(sb.ToString()); + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) { + string message = $"{eventData.TimeStamp:T} {eventData.EventSource.Name}: {eventData.Message} (E:{eventData.EventName}, T:{eventData.OSThreadId})"; + if (eventData.Payload is null) { + WriteLine(message); + } else { + object?[] parameters = eventData.Payload.ToArray(); + WriteLine(message, parameters); + } + base.OnEventWritten(eventData); + } + + private void WriteLine(string message) { + _writer?.WriteLine(message); + } + + private void WriteLine(string format, object?[] args) { + _writer?.WriteLine(format, args); + } + + public override void Dispose() { + var writer = Interlocked.Exchange(ref _writer, null); + if (writer is not null) { + writer.Dispose(); + AppDomain.CurrentDomain.FirstChanceException -= FirstChanceException; + } + base.Dispose(); + } +}