Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 129 additions & 44 deletions Stack/Opc.Ua.Core/Stack/Tcp/ChannelAsyncOperation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
using Microsoft.Extensions.Logging;

namespace Opc.Ua.Bindings
Expand All @@ -39,7 +40,7 @@ namespace Opc.Ua.Bindings
/// Stores the results of an asynchronous operation.
/// </summary>
/// <typeparam name="T"></typeparam>
public class ChannelAsyncOperation<T> : IAsyncResult, IDisposable
public class ChannelAsyncOperation<T> : IAsyncResult, IDisposable, IValueTaskSource<bool>, IValueTaskSource
{
/// <summary>
/// Initializes the object with a callback
Expand All @@ -51,10 +52,13 @@ public ChannelAsyncOperation(int timeout, AsyncCallback? callback, object? async
m_synchronous = false;
m_completed = false;
m_logger = logger;
m_asyncWaitSource.RunContinuationsAsynchronously = true;

if (timeout is > 0 and not int.MaxValue)
{
m_timer = new Timer(new TimerCallback(OnTimeout), null, timeout, Timeout.Infinite);
m_timeoutCancellationTokenSource = new CancellationTokenSource(timeout);
m_timeoutCancellationRegistration =
m_timeoutCancellationTokenSource.Token.Register(OnTimeout);
}
}

Expand All @@ -76,8 +80,9 @@ protected virtual void Dispose(bool disposing)
{
lock (m_lock)
{
m_timer?.Dispose();
m_timer = null;
m_timeoutCancellationRegistration.Dispose();
m_timeoutCancellationTokenSource?.Dispose();
m_timeoutCancellationTokenSource = null;

if (m_event != null)
{
Expand All @@ -86,13 +91,11 @@ protected virtual void Dispose(bool disposing)
m_event = null;
}

if (m_tcs != null)
if (m_asyncWaitPending)
{
if (!m_tcs.Task.IsCompleted)
{
m_tcs.TrySetCanceled();
}
m_tcs = null;
m_asyncWaitSource.SetException(
new TaskCanceledException("ChannelAsyncOperation was disposed while an async wait was pending."));
m_asyncWaitPending = false;
}
}
}
Expand Down Expand Up @@ -223,22 +226,34 @@ public T End(int timeout, bool throwOnError = true)
/// The awaitable response returned from the server.
/// </summary>
/// <exception cref="ServiceResultException"></exception>
public async Task<T> EndAsync(
public Task<T> EndAsync(
int timeout,
bool throwOnError = true,
CancellationToken ct = default)
{
return EndValueTaskAsync(timeout, throwOnError, ct).AsTask();
}

/// <summary>
/// The low-allocation awaitable response returned from the server.
/// </summary>
internal async ValueTask<T> EndValueTaskAsync(
int timeout,
bool throwOnError = true,
CancellationToken ct = default)
{
// check if the request has already completed.
bool mustWait = false;
ValueTask<bool> waitTask = default;

lock (m_lock)
{
mustWait = !m_completed;

if (mustWait)
{
m_tcs = new TaskCompletionSource<bool>(
TaskCreationOptions.RunContinuationsAsynchronously);
m_asyncWaitPending = true;
waitTask = new ValueTask<bool>(this, m_asyncWaitSource.Version);
}
}

Expand All @@ -248,19 +263,13 @@ public async Task<T> EndAsync(
bool badRequestInterrupted = false;
try
{
Task<bool> awaitableTask = m_tcs!.Task;
if (timeout != int.MaxValue)
if (timeout != int.MaxValue || ct.CanBeCanceled)
{
awaitableTask = m_tcs.Task
.WaitAsync(TimeSpan.FromMilliseconds(timeout), ct);
await WaitAsync(waitTask, timeout, ct).ConfigureAwait(false);
}
else if (ct != default)
else
{
awaitableTask = m_tcs.Task.WaitAsync(ct);
}
if (!await awaitableTask.ConfigureAwait(false))
{
badRequestInterrupted = true;
_ = await waitTask.ConfigureAwait(false);
}
}
catch (TimeoutException)
Expand All @@ -271,12 +280,9 @@ public async Task<T> EndAsync(
{
badRequestInterrupted = true;
}
finally
catch (OperationCanceledException)
{
lock (m_lock)
{
m_tcs = null;
}
badRequestInterrupted = true;
}

if (badRequestInterrupted && throwOnError)
Expand Down Expand Up @@ -368,17 +374,6 @@ public bool IsCompleted
}
}

/// <summary>
/// Called when the operation times out.
/// </summary>
private void OnTimeout(object? state)
{
if (m_timer != null)
{
InternalComplete(false, new ServiceResult(StatusCodes.BadRequestTimeout));
}
}

/// <summary>
/// Called when an asynchronous operation completes.
/// </summary>
Expand All @@ -403,12 +398,17 @@ protected virtual bool InternalComplete(bool doNotBlock, object? result)

m_completed = true;

m_timer?.Dispose();
m_timer = null;
m_timeoutCancellationRegistration.Dispose();
m_timeoutCancellationTokenSource?.Dispose();
m_timeoutCancellationTokenSource = null;

m_event?.Set();

m_tcs?.TrySetResult(true);
if (m_asyncWaitPending)
{
m_asyncWaitSource.SetResult(true);
m_asyncWaitPending = false;
}
}

AsyncCallback? callback = m_callback;
Expand Down Expand Up @@ -436,17 +436,102 @@ protected virtual bool InternalComplete(bool doNotBlock, object? result)
return true;
}

ValueTaskSourceStatus IValueTaskSource<bool>.GetStatus(short token)
{
return m_asyncWaitSource.GetStatus(token);
}

ValueTaskSourceStatus IValueTaskSource.GetStatus(short token)
{
return m_asyncWaitSource.GetStatus(token);
}

bool IValueTaskSource<bool>.GetResult(short token)
{
return m_asyncWaitSource.GetResult(token);
}

void IValueTaskSource.GetResult(short token)
{
((IValueTaskSource)m_asyncWaitSource).GetResult(token);
}

void IValueTaskSource<bool>.OnCompleted(
Action<object?> continuation,
object? state,
short token,
ValueTaskSourceOnCompletedFlags flags)
{
m_asyncWaitSource.OnCompleted(continuation, state, token, flags);
}

void IValueTaskSource.OnCompleted(
Action<object?> continuation,
object? state,
short token,
ValueTaskSourceOnCompletedFlags flags)
{
m_asyncWaitSource.OnCompleted(continuation, state, token, flags);
}

private void OnTimeout()
{
if (m_timeoutCancellationTokenSource != null)
{
InternalComplete(false, new ServiceResult(StatusCodes.BadRequestTimeout));
}
}

private static async Task WaitAsync(ValueTask<bool> waitTask, int timeout, CancellationToken ct)
{
Task<bool> task = waitTask.AsTask();
using CancellationTokenSource? timeoutCancellationTokenSource =
timeout != int.MaxValue ? new CancellationTokenSource(timeout) : null;
using CancellationTokenSource? linkedCancellationTokenSource =
timeoutCancellationTokenSource != null && ct.CanBeCanceled ?
CancellationTokenSource.CreateLinkedTokenSource(ct, timeoutCancellationTokenSource.Token) :
null;

CancellationToken effectiveCancellationToken =
linkedCancellationTokenSource?.Token ??
timeoutCancellationTokenSource?.Token ??
ct;

if (!effectiveCancellationToken.CanBeCanceled)
{
_ = await task.ConfigureAwait(false);
return;
}

#if NET6_0_OR_GREATER
_ = await task.WaitAsync(effectiveCancellationToken).ConfigureAwait(false);
#else
Task completedTask = await Task.WhenAny(
task,
Task.Delay(Timeout.Infinite, effectiveCancellationToken)).ConfigureAwait(false);

if (!ReferenceEquals(completedTask, task))
{
effectiveCancellationToken.ThrowIfCancellationRequested();
}

_ = await task.ConfigureAwait(false);
#endif
}

private readonly Lock m_lock = new();
private readonly AsyncCallback? m_callback;
private readonly object? m_asyncState;
private readonly bool m_synchronous;
private readonly ILogger m_logger;
private readonly ManualResetValueTaskSource<bool> m_asyncWaitSource = new();
private bool m_completed;
private bool m_asyncWaitPending;
private ManualResetEvent? m_event;
private TaskCompletionSource<bool>? m_tcs;
private T? m_response;
private ServiceResult? m_error;
private Timer? m_timer;
private CancellationTokenSource? m_timeoutCancellationTokenSource;
private CancellationTokenRegistration m_timeoutCancellationRegistration;
private Dictionary<string, object>? m_properties;
}
}
2 changes: 1 addition & 1 deletion Stack/Opc.Ua.Core/Stack/Tcp/TcpListenerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public void Attach(uint channelId, Socket socket)
Socket.Handle,
ChannelId);

Socket.ReadNextMessage();
Socket.ReadNextMessageAsync();
}
}

Expand Down
Loading
Loading