Skip to content

[dotnet] [bidi] Revisit some core functionality to deserialize without intermediate JsonElement allocation #15575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 129 additions & 36 deletions dotnet/src/webdriver/BiDi/Communication/Broker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@

namespace OpenQA.Selenium.BiDi.Communication;

public class Broker : IAsyncDisposable
public sealed class Broker : IAsyncDisposable
{
private readonly ILogger _logger = Log.GetLogger<Broker>();

private readonly BiDi _bidi;
private readonly ITransport _transport;

private readonly ConcurrentDictionary<int, TaskCompletionSource<JsonElement>> _pendingCommands = new();
private readonly ConcurrentDictionary<long, CommandInfo> _pendingCommands = new();
private readonly BlockingCollection<MessageEvent> _pendingEvents = [];
private readonly Dictionary<string, Type> _eventTypesMap = [];

private readonly ConcurrentDictionary<string, List<EventHandler>> _eventHandlers = new();

private int _currentCommandId;
private long _currentCommandId;

private static readonly TaskFactory _myTaskFactory = new(CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, TaskScheduler.Default);

Expand Down Expand Up @@ -89,7 +90,6 @@ internal Broker(BiDi bidi, Uri url)
new JsonStringEnumConverter(JsonNamingPolicy.CamelCase),

// https://github.com/dotnet/runtime/issues/72604
new Json.Converters.Polymorphic.MessageConverter(),
new Json.Converters.Polymorphic.EvaluateResultConverter(),
new Json.Converters.Polymorphic.RemoteValueConverter(),
new Json.Converters.Polymorphic.RealmInfoConverter(),
Expand Down Expand Up @@ -122,23 +122,18 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
var data = await _transport.ReceiveAsync(cancellationToken).ConfigureAwait(false);

var message = JsonSerializer.Deserialize(new ReadOnlySpan<byte>(data), _jsonSerializerContext.Message);
try
{
var data = await _transport.ReceiveAsync(cancellationToken).ConfigureAwait(false);

switch (message)
ProcessReceivedMessage(data);
}
catch (Exception ex)
{
case MessageSuccess messageSuccess:
_pendingCommands[messageSuccess.Id].SetResult(messageSuccess.Result);
_pendingCommands.TryRemove(messageSuccess.Id, out _);
break;
case MessageEvent messageEvent:
_pendingEvents.Add(messageEvent);
break;
case MessageError mesageError:
_pendingCommands[mesageError.Id].SetException(new BiDiException($"{mesageError.Error}: {mesageError.Message}"));
_pendingCommands.TryRemove(mesageError.Id, out _);
break;
if (cancellationToken.IsCancellationRequested is not true && _logger.IsEnabled(LogEventLevel.Error))
{
_logger.Error($"Couldn't process received BiDi remote message: {ex}");
}
}
}
}
Expand All @@ -155,7 +150,7 @@ private async Task ProcessEventsAwaiterAsync()
{
foreach (var handler in eventHandlers.ToArray()) // copy handlers avoiding modified collection while iterating
{
var args = (EventArgs)result.Params.Deserialize(handler.EventArgsType, _jsonSerializerContext)!;
var args = result.Params;

args.BiDi = _bidi;

Expand All @@ -177,40 +172,41 @@ private async Task ProcessEventsAwaiterAsync()
{
if (_logger.IsEnabled(LogEventLevel.Error))
{
_logger.Error($"Unhandled error processing BiDi event: {ex}");
_logger.Error($"Unhandled error processing BiDi event handler: {ex}");
}
}
}
}

public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand command, CommandOptions? options)
public async Task ExecuteCommandAsync<TCommand>(TCommand command, CommandOptions? options)
where TCommand : Command
{
var jsonElement = await ExecuteCommandCoreAsync(command, options).ConfigureAwait(false);

return (TResult)jsonElement.Deserialize(typeof(TResult), _jsonSerializerContext)!;
await ExecuteCommandCoreAsync(command, options).ConfigureAwait(false);
}

public async Task ExecuteCommandAsync<TCommand>(TCommand command, CommandOptions? options)
public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand command, CommandOptions? options)
where TCommand : Command
where TResult : EmptyResult
{
await ExecuteCommandCoreAsync(command, options).ConfigureAwait(false);
var result = await ExecuteCommandCoreAsync(command, options).ConfigureAwait(false);

return (TResult)result;
}

private async Task<JsonElement> ExecuteCommandCoreAsync<TCommand>(TCommand command, CommandOptions? options)
private async Task<EmptyResult> ExecuteCommandCoreAsync<TCommand>(TCommand command, CommandOptions? options)
where TCommand : Command
{
command.Id = Interlocked.Increment(ref _currentCommandId);

var tcs = new TaskCompletionSource<JsonElement>(TaskCreationOptions.RunContinuationsAsynchronously);
var tcs = new TaskCompletionSource<EmptyResult>(TaskCreationOptions.RunContinuationsAsynchronously);

var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30);

using var cts = new CancellationTokenSource(timeout);

cts.Token.Register(() => tcs.TrySetCanceled(cts.Token));

_pendingCommands[command.Id] = tcs;
_pendingCommands[command.Id] = new(command.Id, command.ResultType, tcs);

var data = JsonSerializer.SerializeToUtf8Bytes(command, typeof(TCommand), _jsonSerializerContext);

Expand All @@ -222,6 +218,8 @@ private async Task<JsonElement> ExecuteCommandCoreAsync<TCommand>(TCommand comma
public async Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, Action<TEventArgs> action, SubscriptionOptions? options = null)
where TEventArgs : EventArgs
{
_eventTypesMap[eventName] = typeof(TEventArgs);

var handlers = _eventHandlers.GetOrAdd(eventName, (a) => []);

if (options is BrowsingContextsSubscriptionOptions browsingContextsOptions)
Expand Down Expand Up @@ -249,6 +247,8 @@ public async Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, Act
public async Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, Func<TEventArgs, Task> func, SubscriptionOptions? options = null)
where TEventArgs : EventArgs
{
_eventTypesMap[eventName] = typeof(TEventArgs);

var handlers = _eventHandlers.GetOrAdd(eventName, (a) => []);

if (options is BrowsingContextsSubscriptionOptions browsingContextsOptions)
Expand Down Expand Up @@ -303,12 +303,6 @@ public async Task UnsubscribeAsync(Modules.Session.Subscription subscription, Ev
}

public async ValueTask DisposeAsync()
{
await DisposeAsyncCore();
GC.SuppressFinalize(this);
}

protected virtual async ValueTask DisposeAsyncCore()
{
_pendingEvents.CompleteAdding();

Expand All @@ -320,5 +314,104 @@ protected virtual async ValueTask DisposeAsyncCore()
}

_transport.Dispose();

GC.SuppressFinalize(this);
}

private void ProcessReceivedMessage(byte[]? data)
{
long? id = default;
string? type = default;
string? method = default;
string? error = default;
string? message = default;
Utf8JsonReader resultReader = default;
Utf8JsonReader paramsReader = default;

Utf8JsonReader reader = new(new ReadOnlySpan<byte>(data));
reader.Read();

reader.Read(); // "{"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug.Assert(reader.TokenType == JsonTokenType.StartObject);, to protect against future refactorings?

Or better yet, maybe we should throw in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving fast for now, revisit later as not important for end users.


while (reader.TokenType == JsonTokenType.PropertyName)
{
string? propertyName = reader.GetString();
reader.Read();

switch (propertyName)
{
case "id":
id = reader.GetInt64();
break;

case "type":
type = reader.GetString();
break;

case "method":
method = reader.GetString();
break;

case "result":
resultReader = reader; // cloning reader with current position
break;

case "params":
paramsReader = reader; // cloning reader with current position
break;

case "error":
error = reader.GetString();
break;

case "message":
message = reader.GetString();
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default: throw new BiDiException($"Unexpected BiDi response: {Encoding.UTF8.GetString(data)}")?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we will introduce handling of "unknown polymorphic types"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please post an issue to not forget, for now just moving fast to introduce good pattern for BiDi namespace


reader.Skip();
reader.Read();
}

switch (type)
{
case "success":
if (id is null) throw new JsonException("The remote end responded with 'success' message type, but missed required 'id' property.");

var successCommand = _pendingCommands[id.Value];
var messageSuccess = JsonSerializer.Deserialize(ref resultReader, successCommand.ResultType, _jsonSerializerContext)!;
successCommand.TaskCompletionSource.SetResult((EmptyResult)messageSuccess);
_pendingCommands.TryRemove(id.Value, out _);
break;

case "event":
if (method is null) throw new JsonException("The remote end responded with 'event' message type, but missed required 'method' property.");

var eventType = _eventTypesMap[method];

var eventArgs = (EventArgs)JsonSerializer.Deserialize(ref paramsReader, eventType, _jsonSerializerContext)!;

var messageEvent = new MessageEvent(method, eventArgs);
_pendingEvents.Add(messageEvent);
break;

case "error":
if (id is null) throw new JsonException("The remote end responded with 'error' message type, but missed required 'id' property.");

var messageError = new MessageError(id.Value) { Error = error, Message = message };
var errorCommand = _pendingCommands[messageError.Id];
errorCommand.TaskCompletionSource.SetException(new BiDiException($"{messageError.Error}: {messageError.Message}"));
_pendingCommands.TryRemove(messageError.Id, out _);
break;
}
}

class CommandInfo(long id, Type resultType, TaskCompletionSource<EmptyResult> taskCompletionSource)
{
public long Id { get; } = id;

public Type ResultType { get; } = resultType;

public TaskCompletionSource<EmptyResult> TaskCompletionSource { get; } = taskCompletionSource;
};
}
14 changes: 11 additions & 3 deletions dotnet/src/webdriver/BiDi/Communication/Command.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,32 @@
// under the License.
// </copyright>

using System;
using System.Text.Json.Serialization;

namespace OpenQA.Selenium.BiDi.Communication;

public abstract class Command
{
protected Command(string method)
protected Command(string method, Type resultType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR: what do you think of renaming this type BiDiCommand? It is more clear and easier for namespacing (we already have a Command type)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't like any BiDi* type in BiDi namespace. Namespaces are especially created to resolve collisions.

{
Method = method;
ResultType = resultType;
}

[JsonPropertyOrder(1)]
public string Method { get; }

[JsonPropertyOrder(0)]
public int Id { get; internal set; }
public long Id { get; internal set; }

[JsonIgnore]
public Type ResultType { get; }
}

internal abstract class Command<TCommandParameters>(TCommandParameters @params, string method) : Command(method)
internal abstract class Command<TCommandParameters, TCommandResult>(TCommandParameters @params, string method) : Command(method, typeof(TCommandResult))
where TCommandParameters : CommandParameters
where TCommandResult : EmptyResult
{
[JsonPropertyOrder(2)]
public TCommandParameters Params { get; } = @params;
Expand All @@ -46,3 +52,5 @@ internal record CommandParameters
{
public static CommandParameters Empty { get; } = new CommandParameters();
}

public record EmptyResult;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe BaseResult or BiDiResult? Since it is derived by results with values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #15593

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered there, I prefer just Result

Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
namespace OpenQA.Selenium.BiDi.Communication.Json;

#region https://github.com/dotnet/runtime/issues/72604
[JsonSerializable(typeof(MessageSuccess))]
[JsonSerializable(typeof(MessageError))]
[JsonSerializable(typeof(MessageEvent))]

[JsonSerializable(typeof(Modules.Script.EvaluateResultSuccess))]
[JsonSerializable(typeof(Modules.Script.EvaluateResultException))]

Expand Down Expand Up @@ -71,7 +67,7 @@ namespace OpenQA.Selenium.BiDi.Communication.Json;
#endregion

[JsonSerializable(typeof(Command))]
[JsonSerializable(typeof(Message))]
[JsonSerializable(typeof(EmptyResult))]

[JsonSerializable(typeof(Modules.Session.StatusCommand))]
[JsonSerializable(typeof(Modules.Session.StatusResult))]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ public static string GetDiscriminator(this ref Utf8JsonReader reader, string nam
if (propertyName == name)
{
discriminator = readerClone.GetString();

break;
}

readerClone.Skip();
readerClone.Read();
}

return discriminator ?? throw new JsonException($"Couldn't determine '{name}' descriminator.");
return discriminator ?? throw new JsonException($"Couldn't determine '{name}' discriminator.");
}
}
Loading
Loading