diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/ChannelAsyncOperation.cs b/Stack/Opc.Ua.Core/Stack/Tcp/ChannelAsyncOperation.cs index feba487b4d..687cbbf3b2 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/ChannelAsyncOperation.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/ChannelAsyncOperation.cs @@ -31,6 +31,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Sources; using Microsoft.Extensions.Logging; namespace Opc.Ua.Bindings @@ -39,7 +40,7 @@ namespace Opc.Ua.Bindings /// Stores the results of an asynchronous operation. /// /// - public class ChannelAsyncOperation : IAsyncResult, IDisposable + public class ChannelAsyncOperation : IAsyncResult, IDisposable, IValueTaskSource, IValueTaskSource { /// /// Initializes the object with a callback @@ -51,10 +52,13 @@ public ChannelAsyncOperation(int timeout, AsyncCallback? callback, object? async m_synchronous = false; m_completed = false; m_logger = logger; + m_asyncWaitSource.RunContinuationsAsynchronously = true; if (timeout is > 0 and not int.MaxValue) { - m_timer = new Timer(new TimerCallback(OnTimeout), null, timeout, Timeout.Infinite); + m_timeoutCancellationTokenSource = new CancellationTokenSource(timeout); + m_timeoutCancellationRegistration = + m_timeoutCancellationTokenSource.Token.Register(OnTimeout); } } @@ -76,8 +80,9 @@ protected virtual void Dispose(bool disposing) { lock (m_lock) { - m_timer?.Dispose(); - m_timer = null; + m_timeoutCancellationRegistration.Dispose(); + m_timeoutCancellationTokenSource?.Dispose(); + m_timeoutCancellationTokenSource = null; if (m_event != null) { @@ -86,13 +91,11 @@ protected virtual void Dispose(bool disposing) m_event = null; } - if (m_tcs != null) + if (m_asyncWaitPending) { - if (!m_tcs.Task.IsCompleted) - { - m_tcs.TrySetCanceled(); - } - m_tcs = null; + m_asyncWaitSource.SetException( + new TaskCanceledException("ChannelAsyncOperation was disposed while an async wait was pending.")); + m_asyncWaitPending = false; } } } @@ -223,13 +226,25 @@ public T End(int timeout, bool throwOnError = true) /// The awaitable response returned from the server. /// /// - public async Task EndAsync( + public Task EndAsync( + int timeout, + bool throwOnError = true, + CancellationToken ct = default) + { + return EndValueTaskAsync(timeout, throwOnError, ct).AsTask(); + } + + /// + /// The low-allocation awaitable response returned from the server. + /// + internal async ValueTask EndValueTaskAsync( int timeout, bool throwOnError = true, CancellationToken ct = default) { // check if the request has already completed. bool mustWait = false; + ValueTask waitTask = default; lock (m_lock) { @@ -237,8 +252,8 @@ public async Task EndAsync( if (mustWait) { - m_tcs = new TaskCompletionSource( - TaskCreationOptions.RunContinuationsAsynchronously); + m_asyncWaitPending = true; + waitTask = new ValueTask(this, m_asyncWaitSource.Version); } } @@ -248,19 +263,13 @@ public async Task EndAsync( bool badRequestInterrupted = false; try { - Task awaitableTask = m_tcs!.Task; - if (timeout != int.MaxValue) + if (timeout != int.MaxValue || ct.CanBeCanceled) { - awaitableTask = m_tcs.Task - .WaitAsync(TimeSpan.FromMilliseconds(timeout), ct); + await WaitAsync(waitTask, timeout, ct).ConfigureAwait(false); } - else if (ct != default) + else { - awaitableTask = m_tcs.Task.WaitAsync(ct); - } - if (!await awaitableTask.ConfigureAwait(false)) - { - badRequestInterrupted = true; + _ = await waitTask.ConfigureAwait(false); } } catch (TimeoutException) @@ -271,12 +280,9 @@ public async Task EndAsync( { badRequestInterrupted = true; } - finally + catch (OperationCanceledException) { - lock (m_lock) - { - m_tcs = null; - } + badRequestInterrupted = true; } if (badRequestInterrupted && throwOnError) @@ -368,17 +374,6 @@ public bool IsCompleted } } - /// - /// Called when the operation times out. - /// - private void OnTimeout(object? state) - { - if (m_timer != null) - { - InternalComplete(false, new ServiceResult(StatusCodes.BadRequestTimeout)); - } - } - /// /// Called when an asynchronous operation completes. /// @@ -403,12 +398,17 @@ protected virtual bool InternalComplete(bool doNotBlock, object? result) m_completed = true; - m_timer?.Dispose(); - m_timer = null; + m_timeoutCancellationRegistration.Dispose(); + m_timeoutCancellationTokenSource?.Dispose(); + m_timeoutCancellationTokenSource = null; m_event?.Set(); - m_tcs?.TrySetResult(true); + if (m_asyncWaitPending) + { + m_asyncWaitSource.SetResult(true); + m_asyncWaitPending = false; + } } AsyncCallback? callback = m_callback; @@ -436,17 +436,102 @@ protected virtual bool InternalComplete(bool doNotBlock, object? result) return true; } + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) + { + return m_asyncWaitSource.GetStatus(token); + } + + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) + { + return m_asyncWaitSource.GetStatus(token); + } + + bool IValueTaskSource.GetResult(short token) + { + return m_asyncWaitSource.GetResult(token); + } + + void IValueTaskSource.GetResult(short token) + { + ((IValueTaskSource)m_asyncWaitSource).GetResult(token); + } + + void IValueTaskSource.OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + { + m_asyncWaitSource.OnCompleted(continuation, state, token, flags); + } + + void IValueTaskSource.OnCompleted( + Action continuation, + object? state, + short token, + ValueTaskSourceOnCompletedFlags flags) + { + m_asyncWaitSource.OnCompleted(continuation, state, token, flags); + } + + private void OnTimeout() + { + if (m_timeoutCancellationTokenSource != null) + { + InternalComplete(false, new ServiceResult(StatusCodes.BadRequestTimeout)); + } + } + + private static async Task WaitAsync(ValueTask waitTask, int timeout, CancellationToken ct) + { + Task task = waitTask.AsTask(); + using CancellationTokenSource? timeoutCancellationTokenSource = + timeout != int.MaxValue ? new CancellationTokenSource(timeout) : null; + using CancellationTokenSource? linkedCancellationTokenSource = + timeoutCancellationTokenSource != null && ct.CanBeCanceled ? + CancellationTokenSource.CreateLinkedTokenSource(ct, timeoutCancellationTokenSource.Token) : + null; + + CancellationToken effectiveCancellationToken = + linkedCancellationTokenSource?.Token ?? + timeoutCancellationTokenSource?.Token ?? + ct; + + if (!effectiveCancellationToken.CanBeCanceled) + { + _ = await task.ConfigureAwait(false); + return; + } + +#if NET6_0_OR_GREATER + _ = await task.WaitAsync(effectiveCancellationToken).ConfigureAwait(false); +#else + Task completedTask = await Task.WhenAny( + task, + Task.Delay(Timeout.Infinite, effectiveCancellationToken)).ConfigureAwait(false); + + if (!ReferenceEquals(completedTask, task)) + { + effectiveCancellationToken.ThrowIfCancellationRequested(); + } + + _ = await task.ConfigureAwait(false); +#endif + } + private readonly Lock m_lock = new(); private readonly AsyncCallback? m_callback; private readonly object? m_asyncState; private readonly bool m_synchronous; private readonly ILogger m_logger; + private readonly ManualResetValueTaskSource m_asyncWaitSource = new(); private bool m_completed; + private bool m_asyncWaitPending; private ManualResetEvent? m_event; - private TaskCompletionSource? m_tcs; private T? m_response; private ServiceResult? m_error; - private Timer? m_timer; + private CancellationTokenSource? m_timeoutCancellationTokenSource; + private CancellationTokenRegistration m_timeoutCancellationRegistration; private Dictionary? m_properties; } } diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/TcpListenerChannel.cs b/Stack/Opc.Ua.Core/Stack/Tcp/TcpListenerChannel.cs index c17dc5719c..275f0dee8a 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/TcpListenerChannel.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/TcpListenerChannel.cs @@ -161,7 +161,7 @@ public void Attach(uint channelId, Socket socket) Socket.Handle, ChannelId); - Socket.ReadNextMessage(); + Socket.ReadNextMessageAsync(); } } diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/TcpMessageSocket.cs b/Stack/Opc.Ua.Core/Stack/Tcp/TcpMessageSocket.cs index 2d6948968b..8a50963d82 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/TcpMessageSocket.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/TcpMessageSocket.cs @@ -28,7 +28,7 @@ * ======================================================================*/ using System; -using System.Diagnostics; +using System.Collections.Generic; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; @@ -224,9 +224,6 @@ public TcpMessageSocket( m_bufferManager = bufferManager ?? throw new ArgumentNullException(nameof(bufferManager)); m_receiveBufferSize = receiveBufferSize; - m_incomingMessageSize = -1; - m_readComplete = OnReadComplete; - m_readState = ReadState.Ready; } /// @@ -245,8 +242,6 @@ public TcpMessageSocket( m_bufferManager = bufferManager ?? throw new ArgumentNullException(nameof(bufferManager)); m_receiveBufferSize = receiveBufferSize; - m_incomingMessageSize = -1; - m_readComplete = OnReadComplete; } /// @@ -380,326 +375,388 @@ public void Close() /// /// Starts reading messages from the socket. /// - public void ReadNextMessage() + public async Task ReadNextMessageAsync(CancellationToken ct = default) { - lock (m_readLock) + byte[]? receiveBuffer = null; + try { - do + while (true) { - // allocate a buffer large enough to a message chunk. - m_receiveBuffer ??= m_bufferManager.TakeBuffer( + receiveBuffer = m_bufferManager.TakeBuffer( m_receiveBufferSize, - "ReadNextMessage"); + "ReadNextMessageAsync"); + + // Read the fixed-size message header (type + size = 8 bytes). + int headerRead = await ReceiveExactAsync( + receiveBuffer, + 0, + TcpMessageLimits.MessageTypeAndSize, + ct).ConfigureAwait(false); + + if (headerRead == 0) + { + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + receiveBuffer = null; + m_sink?.OnReceiveError(this, + ServiceResult.Create( + StatusCodes.BadConnectionClosed, + "Remote side closed connection.")); + return; + } + + // Validate the message type. + uint messageType = BitConverter.ToUInt32(receiveBuffer, 0); + if (!TcpMessageType.IsValid(messageType)) + { + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + receiveBuffer = null; + m_sink?.OnReceiveError(this, + ServiceResult.Create( + StatusCodes.BadTcpMessageTypeInvalid, + "Message type {0:X8} is invalid.", + messageType)); + return; + } + + // Validate the declared message size. + int messageSize = BitConverter.ToInt32(receiveBuffer, 4); + if (messageSize <= 0 || messageSize > m_receiveBufferSize) + { + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + receiveBuffer = null; + m_sink?.OnReceiveError(this, + ServiceResult.Create( + StatusCodes.BadTcpMessageTooLarge, + "Messages size {0} bytes is too large for buffer of size {1}.", + messageSize, + m_receiveBufferSize)); + return; + } + + // Read the remainder of the message body. + int remaining = messageSize - TcpMessageLimits.MessageTypeAndSize; + if (remaining > 0) + { + int bodyRead = await ReceiveExactAsync( + receiveBuffer, + TcpMessageLimits.MessageTypeAndSize, + remaining, + ct).ConfigureAwait(false); - // read the first 8 bytes of the message which contains the message size. - m_bytesReceived = 0; - m_bytesToReceive = TcpMessageLimits.MessageTypeAndSize; - m_incomingMessageSize = -1; + if (bodyRead == 0) + { + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + receiveBuffer = null; + m_sink?.OnReceiveError(this, + ServiceResult.Create( + StatusCodes.BadConnectionClosed, + "Remote side closed connection.")); + return; + } + } - do + // Deliver the complete message chunk to the sink. + IMessageSink? sink = m_sink; + if (sink != null) + { + var messageChunk = new ArraySegment(receiveBuffer, 0, messageSize); + receiveBuffer = null; // sink now owns the buffer + try + { + sink.OnMessageReceived(this, messageChunk); + } + catch (Exception ex) + { + m_logger.LogError( + ex, + "Unexpected error invoking OnMessageReceived callback."); + } + } + else { - ReadNextBlock(); - } while (m_readState == ReadState.ReadNextBlock); - } while (m_readState == ReadState.ReadNextMessage); + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + receiveBuffer = null; + } + } + } + catch (OperationCanceledException) when (ct.IsCancellationRequested) + { + // Normal cancellation - loop exits silently. + } + catch (Exception ex) + { + m_sink?.OnReceiveError(this, + ServiceResult.Create( + ex, + StatusCodes.BadTcpInternalError, + "Unexpected error receiving data.")); + } + finally + { + if (receiveBuffer != null) + { + m_bufferManager.ReturnBuffer(receiveBuffer, "ReadNextMessageAsync"); + } } } /// - /// Changes the sink used to report reads. + /// Reads exactly bytes starting at + /// in . + /// Returns 0 if the remote side closed the connection before any bytes were read. /// - public void ChangeSink(IMessageSink sink) + private async Task ReceiveExactAsync( + byte[] buffer, + int offset, + int count, + CancellationToken ct) { - lock (m_readLock) + int totalReceived = 0; + while (totalReceived < count) { - m_sink = sink; + int received = await ReceiveAsync( + buffer, + offset + totalReceived, + count - totalReceived, + ct).ConfigureAwait(false); + + if (received == 0) + { + // Connection closed without sending all bytes. + return 0; + } + + totalReceived += received; } + + return totalReceived; } +#if NETFRAMEWORK /// - /// Handles a read complete event. + /// Single async receive call wrapped in a Task (legacy .NET Framework path). /// - private void OnReadComplete(object? sender, SocketAsyncEventArgs e) + private Task ReceiveAsync(byte[] buffer, int offset, int count, CancellationToken ct) { - lock (m_readLock) + Socket? socket; + lock (m_socketLock) { - ServiceResult? error = null; + socket = m_socket; + } - try - { - bool innerCall = m_readState == ReadState.ReadComplete; - error = DoReadComplete(e); - // to avoid recursion, inner calls of OnReadComplete return - // after processing the ReadComplete and let the outer call handle it - if (!innerCall && !ServiceResult.IsBad(error)) - { - while (ReadNext()) - { - } - } - } - catch (Exception ex) + if (socket == null || !socket.Connected) + { + return Task.FromResult(0); + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var args = new SocketAsyncEventArgs(); + args.SetBuffer(buffer, offset, count); + CancellationTokenRegistration registration = ct.Register( + static state => ((TaskCompletionSource)state!).TrySetCanceled(), + tcs); + + args.Completed += (_, e) => + { + registration.Dispose(); + if (e.SocketError != SocketError.Success) { - m_logger.LogError(ex, "Unexpected error during OnReadComplete,"); - error = ServiceResult.Create(ex, StatusCodes.BadTcpInternalError, ex.Message); + tcs.TrySetException(new SocketException((int)e.SocketError)); } - finally + else { - e?.Dispose(); + tcs.TrySetResult(e.BytesTransferred); } - if (m_readState == ReadState.NotConnected && ServiceResult.IsGood(error)) - { - error = ServiceResult.Create( - StatusCodes.BadConnectionClosed, - "Remote side closed connection."); - } + e.Dispose(); + }; - if (ServiceResult.IsBad(error)) + try + { + if (!socket.ReceiveAsync(args)) { - if (m_receiveBuffer != null) + // Completed synchronously. + registration.Dispose(); + SocketError socketError = args.SocketError; + int bytesTransferred = args.BytesTransferred; + args.Dispose(); + + if (socketError != SocketError.Success) { - m_bufferManager.ReturnBuffer(m_receiveBuffer, "OnReadComplete"); - m_receiveBuffer = null; + return Task.FromException(new SocketException((int)socketError)); } - m_sink?.OnReceiveError(this, error); + return Task.FromResult(bytesTransferred); } } - } + catch (Exception ex) + { + registration.Dispose(); + args.Dispose(); + return Task.FromException(ex); + } + return tcs.Task; + } +#else /// - /// Handles a read complete event. + /// Single async receive call using the modern API. /// - private ServiceResult DoReadComplete(SocketAsyncEventArgs e) + private ValueTask ReceiveAsync(byte[] buffer, int offset, int count, CancellationToken ct) { - // complete operation. - int bytesRead = e.BytesTransferred; - m_readState = ReadState.Ready; - + Socket? socket; lock (m_socketLock) { - if (m_receiveBuffer != null) - { - BufferManager.UnlockBuffer(m_receiveBuffer); - } + socket = m_socket; } - if (bytesRead == 0) + if (socket == null || !socket.Connected) { - // Remote end has closed the connection + return new ValueTask(0); + } - // free the empty receive buffer. - if (m_receiveBuffer != null) - { - m_bufferManager.ReturnBuffer(m_receiveBuffer, "DoReadComplete"); - m_receiveBuffer = null; - } + return socket.ReceiveAsync(buffer.AsMemory(offset, count), SocketFlags.None, ct); + } +#endif + + /// + /// Changes the sink used to report reads. + /// + public void ChangeSink(IMessageSink sink) + { + m_sink = sink; + } - m_readState = ReadState.Error; - return ServiceResult.Create( - StatusCodes.BadConnectionClosed, - "Remote side closed connection"); + /// + public ValueTask SendAsync(ReadOnlyMemory buffer, CancellationToken ct = default) + { + if (m_socket == null) + { + throw new InvalidOperationException("The socket is not connected."); } - m_bytesReceived += bytesRead; + return SendAllAsync(buffer, ct); + } - // check if more data left to read. - if (m_bytesReceived < m_bytesToReceive) + /// + public ValueTask SendAsync(IList> buffers, CancellationToken ct = default) + { + if (m_socket == null) { - m_readState = ReadState.ReadNextBlock; - return ServiceResult.Good; + throw new InvalidOperationException("The socket is not connected."); } - // start reading the message body. - if (m_receiveBuffer != null) + return SendBufferListAsync(buffers, ct); + } + + private async ValueTask SendAllAsync(ReadOnlyMemory data, CancellationToken ct) + { + while (data.Length > 0) { - if (m_incomingMessageSize < 0) + int sent = await SendOnceAsync(data, ct).ConfigureAwait(false); + if (sent == 0) { - uint messageType = BitConverter.ToUInt32(m_receiveBuffer, 0); - if (!TcpMessageType.IsValid(messageType)) - { - m_readState = ReadState.Error; - - return ServiceResult.Create( - StatusCodes.BadTcpMessageTypeInvalid, - "Message type {0:X8} is invalid.", - messageType); - } - - m_incomingMessageSize = BitConverter.ToInt32(m_receiveBuffer, 4); - if (m_incomingMessageSize <= 0 || m_incomingMessageSize > m_receiveBufferSize) - { - m_readState = ReadState.Error; - - return ServiceResult.Create( - StatusCodes.BadTcpMessageTooLarge, - "Messages size {0} bytes is too large for buffer of size {1}.", - m_incomingMessageSize, - m_receiveBufferSize); - } - - // set up buffer for reading the message body. - m_bytesToReceive = m_incomingMessageSize; - - m_readState = ReadState.ReadNextBlock; - - return ServiceResult.Good; + throw new SocketException((int)SocketError.ConnectionReset); } - // notify the sink. - IMessageSink sink = m_sink; - if (sink != null) - { - try - { - // send notification (implementor responsible for freeing buffer) on success. - var messageChunk = new ArraySegment( - m_receiveBuffer, - 0, - m_incomingMessageSize); - - // must allocate a new buffer for the next message. - m_receiveBuffer = null; - - sink.OnMessageReceived(this, messageChunk); - } - catch (Exception ex) - { - m_logger.LogError(ex, "Unexpected error invoking OnMessageReceived callback."); - } - } + data = data.Slice(sent); } + } - // free the receive buffer. - if (m_receiveBuffer != null) + private async ValueTask SendBufferListAsync(IList> buffers, CancellationToken ct) + { + foreach (ArraySegment segment in buffers) { - m_bufferManager.ReturnBuffer(m_receiveBuffer, "DoReadComplete"); - m_receiveBuffer = null; + await SendAllAsync( + new ReadOnlyMemory(segment.Array, segment.Offset, segment.Count), + ct).ConfigureAwait(false); } - - // start receiving next message. - m_readState = ReadState.ReadNextMessage; - - return ServiceResult.Good; } - /// - /// Reads the next block of data from the socket. - /// - private void ReadNextBlock() +#if NETFRAMEWORK + private Task SendOnceAsync(ReadOnlyMemory data, CancellationToken ct) { Socket? socket; - - // check if already closed. lock (m_socketLock) { socket = m_socket; + } + + if (socket == null) + { + return Task.FromResult(0); + } + + if (!MemoryMarshal.TryGetArray(data, out ArraySegment segment)) + { + segment = new ArraySegment(data.ToArray()); + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var args = new SocketAsyncEventArgs(); + args.SetBuffer(segment.Array, segment.Offset, segment.Count); + CancellationTokenRegistration registration = ct.Register( + static state => ((TaskCompletionSource)state!).TrySetCanceled(), + tcs); - if (socket == null || !socket.Connected) + args.Completed += (_, e) => + { + registration.Dispose(); + if (e.SocketError != SocketError.Success) { - // buffer is returned in calling code - m_readState = ReadState.NotConnected; - return; + tcs.TrySetException(new SocketException((int)e.SocketError)); + } + else + { + tcs.TrySetResult(e.BytesTransferred); } - } - BufferManager.LockBuffer(m_receiveBuffer!); + e.Dispose(); + }; - SocketAsyncEventArgs? args = null; try { - args = new SocketAsyncEventArgs(); - m_readState = ReadState.Receive; - args.SetBuffer( - m_receiveBuffer, - m_bytesReceived, - m_bytesToReceive - m_bytesReceived); - args.Completed += m_readComplete; - if (!socket.ReceiveAsync(args)) + if (!socket.SendAsync(args)) { - // I/O completed synchronously - if (args.SocketError != SocketError.Success) + registration.Dispose(); + SocketError socketError = args.SocketError; + int bytesTransferred = args.BytesTransferred; + args.Dispose(); + + if (socketError != SocketError.Success) { - throw ServiceResultException.Create( - StatusCodes.BadTcpInternalError, - args.SocketError.ToString()); + return Task.FromException(new SocketException((int)socketError)); } - // set state to inner complete - m_readState = ReadState.ReadComplete; - m_readComplete(null, args); + + return Task.FromResult(bytesTransferred); } - args = null; // ownership transferred - } - catch (ServiceResultException) - { - BufferManager.UnlockBuffer(m_receiveBuffer!); - throw; } catch (Exception ex) { - BufferManager.UnlockBuffer(m_receiveBuffer!); - throw ServiceResultException.Create( - StatusCodes.BadTcpInternalError, - ex, - "BeginReceive failed."); + registration.Dispose(); + args.Dispose(); + return Task.FromException(ex); } - finally - { - args?.Dispose(); - } - } - /// - /// Helper to read next block or message based on current state. - /// - private bool ReadNext() - { - switch (m_readState) - { - case ReadState.ReadNextBlock: - ReadNextBlock(); - return true; - case ReadState.ReadNextMessage: - ReadNextMessage(); - return true; - case ReadState.Ready: - case ReadState.Receive: - case ReadState.ReadComplete: - case ReadState.NotConnected: - case ReadState.Error: - return false; - default: - Debug.Fail("Unexpected read state."); - return false; - } + return tcs.Task; } - - /// - /// Sends a buffer. - /// - /// - /// - public bool Send(IMessageSocketAsyncEventArgs args) +#else + private ValueTask SendOnceAsync(ReadOnlyMemory data, CancellationToken ct) { - if (args is not TcpMessageSocketAsyncEventArgs eventArgs) + Socket? socket; + lock (m_socketLock) { - throw new ArgumentNullException(nameof(args)); + socket = m_socket; } - if (m_socket == null) + + if (socket == null) { - throw new InvalidOperationException("The socket is not connected."); + return new ValueTask(0); } - eventArgs.Args.SocketError = SocketError.NotConnected; - return m_socket.SendAsync(eventArgs.Args); - } - /// - /// Create event args for TcpMessageSocket. - /// - public IMessageSocketAsyncEventArgs MessageSocketEventArgs() - { - return new TcpMessageSocketAsyncEventArgs(); + return socket.SendAsync(data, SocketFlags.None, ct); } +#endif private void ShutdownAndDispose(Socket socket) { @@ -725,31 +782,9 @@ private void ShutdownAndDispose(Socket socket) private readonly ILogger m_logger; private readonly BufferManager m_bufferManager; private readonly int m_receiveBufferSize; - private readonly EventHandler m_readComplete; private readonly Lock m_socketLock = new(); private Socket? m_socket; private bool m_closed; - - /// - /// States for the nested read handler. - /// - private enum ReadState - { - Ready = 0, - ReadNextMessage = 1, - ReadNextBlock = 2, - Receive = 3, - ReadComplete = 4, - NotConnected = 5, - Error = 0xff - } - - private readonly Lock m_readLock = new(); - private byte[]? m_receiveBuffer; - private int m_bytesReceived; - private int m_bytesToReceive; - private int m_incomingMessageSize; - private ReadState m_readState; } } diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/TcpServerChannel.cs b/Stack/Opc.Ua.Core/Stack/Tcp/TcpServerChannel.cs index 3d76874ad7..ad7fc46c98 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/TcpServerChannel.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/TcpServerChannel.cs @@ -192,7 +192,7 @@ private void OnReverseConnectComplete(object? sender, IMessageSocketAsyncEventAr try { // start reading messages. - ar.Socket!.ReadNextMessage(); + ar.Socket!.ReadNextMessageAsync(); // send reverse hello message. using var encoder = new BinaryEncoder( diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryChannel.cs b/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryChannel.cs index 31dc52dfc5..3f25020f13 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryChannel.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryChannel.cs @@ -541,124 +541,105 @@ protected virtual void HandleSocketError(ServiceResult result) } /// - /// Handles a write complete event. + /// Queues a write request for a single contiguous buffer. + /// The buffer is returned to after the send completes. /// - protected virtual void OnWriteComplete(object? sender, IMessageSocketAsyncEventArgs e) + /// + protected void BeginWriteMessage(ArraySegment buffer, object? state) { - ServiceResult error = ServiceResult.Good; + IMessageSocket socket = + Socket ?? throw ServiceResultException.Create( + StatusCodes.BadConnectionClosed, + "The socket was closed by the remote application."); + + Interlocked.Increment(ref m_activeWriteRequests); + byte[] bufferArray = buffer.GetArray(); + var data = new ReadOnlyMemory(bufferArray, buffer.Offset, buffer.Count); + + ValueTask task; try { - if (e.BytesTransferred == 0) - { - error = ServiceResult.Create( - StatusCodes.BadConnectionClosed, - "The socket was closed by the remote application."); - } - if (e.Buffer != null) - { - BufferManager.ReturnBuffer(e.Buffer, "OnWriteComplete"); - } - HandleWriteComplete(e.BufferList, e.UserToken, e.BytesTransferred, error); + task = socket.SendAsync(data); } catch (Exception ex) { - if (ex is InvalidOperationException) - { - // suppress chained exception in HandleWriteComplete/ReturnBuffer - e.BufferList = null; - } - error = ServiceResult.Create( - ex, - StatusCodes.BadTcpInternalError, - "Unexpected error during write operation."); - HandleWriteComplete(e.BufferList, e.UserToken, e.BytesTransferred, error); + BufferManager.ReturnBuffer(bufferArray, "BeginWriteMessage"); + HandleWriteComplete( + null, + state, + 0, + ServiceResult.Create( + ex, + StatusCodes.BadTcpInternalError, + "Unexpected error during write operation.")); + return; } - e.Dispose(); + if (task.IsCompletedSuccessfully) + { + BufferManager.ReturnBuffer(bufferArray, "BeginWriteMessage"); + HandleWriteComplete(null, state, buffer.Count, ServiceResult.Good); + return; + } + + _ = CompleteWriteAsync(task, bufferArray, null, state, buffer.Count); } /// - /// Queues a write request. + /// Queues a write request for a collection of buffers. + /// The buffers are released via after the send completes. /// - /// - protected void BeginWriteMessage(ArraySegment buffer, object? state) + protected void BeginWriteMessage(BufferCollection buffers, object? state) { - ServiceResult error = ServiceResult.Good; - IMessageSocketAsyncEventArgs args = - (Socket?.MessageSocketEventArgs()) - ?? throw ServiceResultException.Create( + IMessageSocket socket = + Socket ?? throw ServiceResultException.Create( StatusCodes.BadConnectionClosed, "The socket was closed by the remote application."); + Interlocked.Increment(ref m_activeWriteRequests); + + ValueTask task; try { - Interlocked.Increment(ref m_activeWriteRequests); - args.SetBuffer(buffer.GetArray(), buffer.Offset, buffer.Count); - args.Completed += OnWriteComplete; - args.UserToken = state; - if (!Socket.Send(args)) - { - // I/O completed synchronously - if (args.IsSocketError || (args.BytesTransferred < buffer.Count)) - { - error = ServiceResult.Create( - StatusCodes.BadConnectionClosed, - args.SocketErrorString); - HandleWriteComplete(null, state, args.BytesTransferred, error); - args.Dispose(); - } - else - { - // success, call Complete - OnWriteComplete(null, args); - } - } + task = socket.SendAsync(buffers); } catch (Exception ex) { - error = ServiceResult.Create( - ex, - StatusCodes.BadTcpInternalError, - "Unexpected error during write operation."); + HandleWriteComplete( + buffers, + state, + 0, + ServiceResult.Create( + ex, + StatusCodes.BadTcpInternalError, + "Unexpected error during write operation.")); + return; + } - HandleWriteComplete(null, state, args.BytesTransferred, error); - args.Dispose(); + if (task.IsCompletedSuccessfully) + { + HandleWriteComplete(buffers, state, buffers.TotalSize, ServiceResult.Good); + return; } + + _ = CompleteWriteAsync(task, null, buffers, state, buffers.TotalSize); } /// - /// Queues a write request. + /// Awaits a pending write and then calls + /// with the outcome. /// - protected void BeginWriteMessage(BufferCollection buffers, object? state) + private async Task CompleteWriteAsync( + ValueTask task, + byte[]? bufferToReturn, + BufferCollection? buffers, + object? state, + int expectedBytes) { ServiceResult error = ServiceResult.Good; - IMessageSocketAsyncEventArgs args = Socket!.MessageSocketEventArgs(); - try { - // m_logger.LogWarning("OUT:{Id}", TcpMessageType.GetTypeAndSize(buffers[0])); - - Interlocked.Increment(ref m_activeWriteRequests); - args.BufferList = buffers; - args.Completed += OnWriteComplete; - args.UserToken = state; - IMessageSocket? socket = Socket; - if (socket == null || !socket.Send(args)) - { - // I/O completed synchronously - if (args.IsSocketError || (args.BytesTransferred < buffers.TotalSize)) - { - error = ServiceResult.Create( - StatusCodes.BadConnectionClosed, - args.SocketErrorString); - HandleWriteComplete(buffers, state, args.BytesTransferred, error); - args.Dispose(); - } - else - { - OnWriteComplete(null, args); - } - } + await task.ConfigureAwait(false); } catch (Exception ex) { @@ -666,9 +647,14 @@ protected void BeginWriteMessage(BufferCollection buffers, object? state) ex, StatusCodes.BadTcpInternalError, "Unexpected error during write operation."); - HandleWriteComplete(buffers, state, args.BytesTransferred, error); - args.Dispose(); } + + if (bufferToReturn != null) + { + BufferManager.ReturnBuffer(bufferToReturn, "CompleteWriteAsync"); + } + + HandleWriteComplete(buffers, state, ServiceResult.IsGood(error) ? expectedBytes : 0, error); } /// diff --git a/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryClientChannel.cs b/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryClientChannel.cs index d538aff9b0..a73647d6be 100644 --- a/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryClientChannel.cs +++ b/Stack/Opc.Ua.Core/Stack/Tcp/UaSCBinaryClientChannel.cs @@ -225,7 +225,7 @@ public async ValueTask ConnectAsync(Uri url, int timeout, CancellationToken ct) CompleteConnect(operation); } - await operation.EndAsync(int.MaxValue, ct: ct).ConfigureAwait(false); + await operation.EndValueTaskAsync(int.MaxValue, ct: ct).ConfigureAwait(false); SendQueuedOperations(); } @@ -264,7 +264,7 @@ public async Task CloseAsync(int timeout, CancellationToken ct = default) { try { - _ = await operation.EndAsync(timeout, false, ct).ConfigureAwait(false); + _ = await operation.EndValueTaskAsync(timeout, false, ct).ConfigureAwait(false); ValidateChannelCloseError(operation.Error); } catch (Exception e) @@ -334,7 +334,7 @@ public async ValueTask SendRequestAsync( } try { - await operation.EndAsync(int.MaxValue, true, ct).ConfigureAwait(false); + await operation.EndValueTaskAsync(int.MaxValue, true, ct).ConfigureAwait(false); } finally { @@ -948,7 +948,7 @@ private void CompleteConnect(WriteOperation operation) } // start reading messages. - Socket.ReadNextMessage(); + Socket.ReadNextMessageAsync(); // send the hello message. SendHelloMessage(operation); @@ -1086,7 +1086,7 @@ private async void OnScheduledHandshakeAsync(object? state) CompleteConnect(operation); // Complete handshake - await operation.EndAsync(int.MaxValue).ConfigureAwait(false); + await operation.EndValueTaskAsync(int.MaxValue).ConfigureAwait(false); SendQueuedOperations(); } diff --git a/Stack/Opc.Ua.Core/Stack/Transport/IMessageSocket.cs b/Stack/Opc.Ua.Core/Stack/Transport/IMessageSocket.cs index 256f88b3a0..9b30174964 100644 --- a/Stack/Opc.Ua.Core/Stack/Transport/IMessageSocket.cs +++ b/Stack/Opc.Ua.Core/Stack/Transport/IMessageSocket.cs @@ -28,6 +28,7 @@ * ======================================================================*/ using System; +using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Sockets; @@ -191,9 +192,13 @@ public interface IMessageSocket : IDisposable void Close(); /// - /// Starts reading messages from the socket. + /// Starts the async read loop that delivers complete message chunks via + /// until the connection is + /// closed or an error occurs. The returned completes + /// when the loop terminates; errors are reported through + /// before the task completes. /// - void ReadNextMessage(); + Task ReadNextMessageAsync(CancellationToken ct = default); /// /// Changes the sink used to report reads. @@ -201,14 +206,16 @@ public interface IMessageSocket : IDisposable void ChangeSink(IMessageSink sink); /// - /// Sends a buffer. + /// Sends a single contiguous buffer asynchronously. + /// The caller retains ownership of the underlying array; the implementation + /// must not access it after the returned completes. /// - bool Send(IMessageSocketAsyncEventArgs args); + ValueTask SendAsync(ReadOnlyMemory buffer, CancellationToken ct = default); /// - /// Get the message socket event args. + /// Sends a list of buffers as a logical gather-write asynchronously. /// - IMessageSocketAsyncEventArgs MessageSocketEventArgs(); + ValueTask SendAsync(IList> buffers, CancellationToken ct = default); } /// diff --git a/Tests/Opc.Ua.Core.Tests/Stack/Transport/ChannelAsyncOperationTests.cs b/Tests/Opc.Ua.Core.Tests/Stack/Transport/ChannelAsyncOperationTests.cs new file mode 100644 index 0000000000..c3615e0ec8 --- /dev/null +++ b/Tests/Opc.Ua.Core.Tests/Stack/Transport/ChannelAsyncOperationTests.cs @@ -0,0 +1,102 @@ +/* ======================================================================== + * Copyright (c) 2005-2025 The OPC Foundation, Inc. All rights reserved. + * + * OPC Foundation MIT License 1.00 + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + * + * The complete license agreement can be found here: + * http://opcfoundation.org/License/MIT/1.00/ + * ======================================================================*/ + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; +using Opc.Ua.Bindings; + +namespace Opc.Ua.Core.Tests.Stack.Transport +{ + [TestFixture] + [Category("TransportChannelTests")] + [SetCulture("en-us")] + [SetUICulture("en-us")] + [Parallelizable] + public sealed class ChannelAsyncOperationTests + { + [Test] + public async Task EndAsyncReturnsCompletedResponse() + { + using var operation = new ChannelAsyncOperation( + int.MaxValue, + null, + null, + NullLogger.Instance); + + Task completeTask = Task.Run(async () => + { + await Task.Delay(50).ConfigureAwait(false); + operation.Complete(123); + }); + + int result = await operation.EndAsync(int.MaxValue).ConfigureAwait(false); + await completeTask.ConfigureAwait(false); + + Assert.That(result, Is.EqualTo(123)); + Assert.That(operation.Error.StatusCode, Is.EqualTo(StatusCodes.Good)); + } + + [Test] + public void EndAsyncThrowsBadRequestInterruptedWhenCanceled() + { + using var operation = new ChannelAsyncOperation( + int.MaxValue, + null, + null, + NullLogger.Instance); + using var cancellationTokenSource = new CancellationTokenSource(); + + Task task = operation.EndAsync(int.MaxValue, ct: cancellationTokenSource.Token); + cancellationTokenSource.Cancel(); + + ServiceResultException exception = + Assert.ThrowsAsync(() => task); + + Assert.That(exception.StatusCode, Is.EqualTo(StatusCodes.BadRequestInterrupted)); + } + + [Test] + public void EndAsyncThrowsBadRequestTimeoutWhenOperationTimesOut() + { + using var operation = new ChannelAsyncOperation( + 50, + null, + null, + NullLogger.Instance); + + ServiceResultException exception = + Assert.ThrowsAsync(() => operation.EndAsync(int.MaxValue)); + + Assert.That(exception.StatusCode, Is.EqualTo(StatusCodes.BadRequestTimeout)); + } + } +}