diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index a33754220..936b73640 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -140,6 +140,8 @@ public static class Headers public const string AsrsMessageTracingId = AsrsInternalHeaderPrefix + "Message-Tracing-Id"; public const string MicrosoftErrorCode = "x-ms-error-code"; + + public const string AsrsManagementSDKClientInvocationProtocol = AsrsInternalHeaderPrefix + "Protocol"; } public static class ErrorCodes diff --git a/src/Microsoft.Azure.SignalR.Common/Exceptions/AzureSignalRInaccessibleEndpointException.cs b/src/Microsoft.Azure.SignalR.Common/Exceptions/AzureSignalRInaccessibleEndpointException.cs index f5f66bb47..e351b7994 100644 --- a/src/Microsoft.Azure.SignalR.Common/Exceptions/AzureSignalRInaccessibleEndpointException.cs +++ b/src/Microsoft.Azure.SignalR.Common/Exceptions/AzureSignalRInaccessibleEndpointException.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -12,7 +12,6 @@ public class AzureSignalRInaccessibleEndpointException : AzureSignalRException private const string ErrorPhenomenon = "Unable to access SignalR service."; private const string SuggestAction = "Please make sure the endpoint or DNS setting is correct."; - public AzureSignalRInaccessibleEndpointException(string requestUri, Exception innerException) : base(string.IsNullOrEmpty(requestUri) ? $"{ErrorPhenomenon} {innerException.Message} {SuggestAction}" : $"{ErrorPhenomenon} {innerException.Message} {SuggestAction} Request Uri: {requestUri}", innerException) { } @@ -24,4 +23,4 @@ protected AzureSignalRInaccessibleEndpointException(SerializationInfo info, Stre { } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/BinaryPayloadContentBuilder.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/BinaryPayloadContentBuilder.cs index 644364e90..f9fa21d3c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/BinaryPayloadContentBuilder.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/BinaryPayloadContentBuilder.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Net.Http; - using Microsoft.AspNetCore.SignalR.Protocol; #nullable enable diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/IPayloadContentBuilder.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/IPayloadContentBuilder.cs index ec8b7c22e..5f0118df5 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/IPayloadContentBuilder.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/IPayloadContentBuilder.cs @@ -3,7 +3,6 @@ using System; using System.Net.Http; - using Microsoft.AspNetCore.SignalR.Protocol; #nullable enable diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/JsonPayloadContentBuilder.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/JsonPayloadContentBuilder.cs index de1a33585..08e658c30 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/JsonPayloadContentBuilder.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/JsonPayloadContentBuilder.cs @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer) { return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint); } + + public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer; } diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs new file mode 100644 index 000000000..ca4532233 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR; + +internal sealed class SimpleInvocationBinder : IInvocationBinder +{ + private readonly Type _returnType; + + public SimpleInvocationBinder(Type returnType) + { + _returnType = returnType ?? throw new ArgumentNullException(nameof(returnType)); + } + + public Type GetReturnType(string invocationId) + { + return _returnType; + } + + public IReadOnlyList GetParameterTypes(string methodName) + { + throw new NotImplementedException(); + } + + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } +} diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs index 866aa34a4..639137569 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs @@ -55,7 +55,7 @@ public Task SendAsync( Func>? handleExpectedResponseAsync, CancellationToken cancellationToken = default) { - return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken); + return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken); } public Task SendWithRetryAsync( @@ -73,7 +73,7 @@ public Task SendWithRetryAsync( Func>? handleExpectedResponseAsync = null, CancellationToken cancellationToken = default) { - return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken); + return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken); } public Task SendMessageWithRetryAsync( @@ -81,10 +81,22 @@ public Task SendMessageWithRetryAsync( HttpMethod httpMethod, string methodName, object?[] args, - Func? handleExpectedResponse = null, + Func>? handleExpectedResponse = null, CancellationToken cancellationToken = default) { - return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken); + return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, null, cancellationToken); + } + + public Task SendMessageWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string methodName, + object?[] args, + Func>? handleExpectedResponse = null, + string? accepts = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, accepts, cancellationToken); } public Task SendStreamMessageWithRetryAsync( @@ -96,7 +108,7 @@ public Task SendStreamMessageWithRetryAsync( Func? handleExpectedResponse = null, CancellationToken cancellationToken = default) { - return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken); + return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), null, cancellationToken); } private static Uri GetUri(string url, IDictionary? query) @@ -164,11 +176,15 @@ private async Task SendAsyncCore( HubMessage? body, Type? typeHint, Func>? handleExpectedResponseAsync = null, + string? accepts = null, CancellationToken cancellationToken = default) { using var httpClient = _httpClientFactory.CreateClient(httpClientName); using var request = BuildRequest(api, httpMethod, body, typeHint); - + if (accepts != null) + { + request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue(accepts)); + } try { using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); diff --git a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs index b019698da..a3e9c9d9f 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager Create(string hubName) where THub var httpClientFactory = _serviceProvider.GetRequiredService(); var serviceEndpoint = _serviceProvider.GetRequiredService().Endpoints.First().Key; var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder()); - return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ApplicationName, restClient); + var protocolResolver = _serviceProvider.GetRequiredService(); + return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver); } default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType)); } diff --git a/src/Microsoft.Azure.SignalR.Management/Microsoft.Azure.SignalR.Management.csproj b/src/Microsoft.Azure.SignalR.Management/Microsoft.Azure.SignalR.Management.csproj index 2ff8be7a9..e966ad949 100644 --- a/src/Microsoft.Azure.SignalR.Management/Microsoft.Azure.SignalR.Management.csproj +++ b/src/Microsoft.Azure.SignalR.Management/Microsoft.Azure.SignalR.Management.csproj @@ -44,6 +44,7 @@ + diff --git a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs index f424fd1be..bdcec3341 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections"); } + public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId) + { + return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke"); + } + private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary queries = null) { var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}"; diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs index 68f24204f..85bdd72eb 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs @@ -14,6 +14,11 @@ using Azure; using Microsoft.AspNetCore.SignalR; +#if NET7_0_OR_GREATER +using System.Buffers; +using System.IO; +using Microsoft.AspNetCore.SignalR.Protocol; +#endif using Microsoft.Extensions.Primitives; using static Microsoft.Azure.SignalR.Constants; @@ -31,13 +36,15 @@ internal class RestHubLifetimeManager : HubLifetimeManager, IService private readonly RestApiProvider _restApiProvider; private readonly string _hubName; private readonly string _appName; + private readonly IHubProtocolResolver _protocolResolver; - public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient) + public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver) { _restApiProvider = new RestApiProvider(endpoint); _appName = appName; _hubName = hubName; _restClient = restClient; + _protocolResolver = protocolResolver; } public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) @@ -353,6 +360,144 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken); } +#if NET7_0_OR_GREATER + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(methodName)) + { + throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); + } + if (string.IsNullOrEmpty(connectionId)) + { + throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); + } + if (!_protocolResolver.AllProtocols.All(IsInvocationSupported)) + { + throw new NotSupportedException("Non supported protocol for client invocation."); + } + + var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId); + string? errorContent = null; + var isSuccess = false; + CompletionMessage? responseMessage = null; + + await _restClient.SendMessageWithRetryAsync( + api, + HttpMethod.Post, + methodName, + args, + async response => + { + isSuccess = response.IsSuccessStatusCode; + + if (isSuccess) + { + // 1. Get protocol from header (e.g. "json" or "messagepack") + if (!response.Headers.TryGetValues(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, out var protocolHeaderValues)) + { + throw new HubException("Response is missing protocol header."); + } + var protocolName = protocolHeaderValues.FirstOrDefault(); + if (string.IsNullOrEmpty(protocolName)) + { + throw new HubException("Response protocol header is empty."); + } + + // 2. Pick the hub protocol that matches X-Protocol + var protocol = _protocolResolver.AllProtocols + .FirstOrDefault(p => string.Equals(p.Name, protocolName, StringComparison.OrdinalIgnoreCase)); + + if (protocol == null) + { + if (string.Equals(protocolName, "messagepack", StringComparison.OrdinalIgnoreCase) && + _protocolResolver.AllProtocols.Count == 1 && + _protocolResolver.AllProtocols[0] is JsonObjectSerializerHubProtocol jsonObjectSerializerHubProtocol) + { + // The hub protocol is the default one. Service will convert it to MessagePack and keep backward compatibility as users may depend on this feature for MessagePack client support. + protocol = new MessagePackHubProtocol(); + } + else + { + throw new NotSupportedException($"The protocol '{protocolName}' is not supported."); + } + } + + // 3. Read raw completion payload from response body + + byte[] buffer; + await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken) + .ConfigureAwait(false); + using var ms = new MemoryStream(); + await stream.CopyToAsync(ms, cancellationToken).ConfigureAwait(false); + + if (ms.Length == 0) + { + throw new HubException("Response payload is empty."); + } + + buffer = ms.ToArray(); + + // 4. Use SimpleInvocationBinder with typeof(T) + var binder = new SimpleInvocationBinder(typeof(T)); + + // 5. Parse the payload bytes into CompletionMessage + var sequence = new ReadOnlySequence(buffer); + var local = sequence; + if (!protocol.TryParseMessage(ref local, binder, out var hubMessage)) + { + throw new HubException("Failed to parse invocation response."); + } + + responseMessage = (CompletionMessage)hubMessage!; + } + else + { + errorContent = await response.Content.ReadAsStringAsync(cancellationToken); + } + + return isSuccess || response.StatusCode == HttpStatusCode.BadRequest; + }, + "application/octet-stream", + cancellationToken); + + if (!isSuccess) + { + throw new HubException(errorContent ?? "Unknown error in response"); + } + if (responseMessage == null) + { + throw new HubException("Response message is null."); + } + if (responseMessage.Error != null) + { + throw new HubException(responseMessage.Error); + } + + return (T)responseMessage!.Result!; + } + + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + // This method won't get trigger because in transient we will wait for the returned completion message. + // this is to honor the interface + throw new NotImplementedException(); + } + + private static bool IsInvocationSupported(IHubProtocol protocol) + { + // Use protocol.Name to check for supported protocols + switch (protocol.Name) + { + case "json": + case "messagepack": + return true; + default: + return false; + } + } + +#endif + private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) => response.IsSuccessStatusCode || (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase)); diff --git a/src/Microsoft.Azure.SignalR.Management/Serialization/JsonObjectSerializerHubProtocol.cs b/src/Microsoft.Azure.SignalR.Management/Serialization/JsonObjectSerializerHubProtocol.cs index c35ad5196..6baeb40c8 100644 --- a/src/Microsoft.Azure.SignalR.Management/Serialization/JsonObjectSerializerHubProtocol.cs +++ b/src/Microsoft.Azure.SignalR.Management/Serialization/JsonObjectSerializerHubProtocol.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. #nullable enable @@ -11,8 +11,9 @@ using System.Text.Json; #if NET5_0_OR_GREATER +using System.IO; using System.Text.Json.Serialization; - +using System.Diagnostics.CodeAnalysis; #endif using Azure.Core.Serialization; @@ -25,14 +26,13 @@ namespace Microsoft.Azure.SignalR.Management /// /// Implements the SignalR Hub Protocol using . /// Modified from https://github.com/dotnet/aspnetcore/blob/d9660d157627af710b71c636fa8cb139616cadba/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs - /// /// /// Changes compared to original version: /// - /// Change to unsupported as we don't need it. Related codes removed. - /// Use instead of in the serialization. + /// Change to a seperate version for Net7.0. /// /// + /// internal sealed class JsonObjectSerializerHubProtocol : IHubProtocol { private const string ResultPropertyName = "result"; @@ -86,13 +86,28 @@ public bool IsVersionSupported(int version) return version == Version; } +#if NET7_0_OR_GREATER + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, [NotNullWhen(true)] out HubMessage? message) + { + + if (!TextMessageParser.TryParseMessage(ref input, out var payload)) + { + message = null!; + return false; + } + + message = ParseMessage(payload, binder); + + return message != null; + } +#else public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) { //We don't need reading message with this protocol. throw new NotSupportedException(); } +#endif - /// public void WriteMessage(HubMessage message, IBufferWriter output) { WriteMessageCore(message, output); @@ -105,6 +120,141 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) return HubProtocolExtensions.GetMessageBytes(this, message); } +#if NET7_0_OR_GREATER + private HubMessage ParseMessage(ReadOnlySequence input, IInvocationBinder binder) + { + try + { + using var doc = JsonDocument.Parse(input); + var root = doc.RootElement; + if (root.ValueKind != JsonValueKind.Object) + { + throw new InvalidDataException("Expected JSON object for hub message."); + } + + int? type = null; + string? invocationId = null; + string? error = null; + var hasResult = false; + object? result = null; + JsonElement? resultElement = null; + + // type + if (root.TryGetProperty(TypePropertyName, out var typeProp)) + { + if (typeProp.ValueKind != JsonValueKind.Number || !typeProp.TryGetInt32(out var messageType)) + { + throw new InvalidDataException($"Expected '{TypePropertyName}' to be of type {JsonTokenType.Number}."); + } + + type = messageType; + } + + // invocationId + if (root.TryGetProperty(InvocationIdPropertyName, out var invocationIdProp) && + invocationIdProp.ValueKind == JsonValueKind.String) + { + invocationId = invocationIdProp.GetString(); + } + + // error + if (root.TryGetProperty(ErrorPropertyName, out var errorProp) && + errorProp.ValueKind == JsonValueKind.String) + { + error = errorProp.GetString(); + } + + // result + if (root.TryGetProperty(ResultPropertyName, out var resultProp)) + { + hasResult = true; + resultElement = resultProp; + } + + HubMessage message; + + switch (type) + { + case HubProtocolConstants.CompletionMessageType: + if (invocationId is null) + { + throw new InvalidDataException($"Missing required property '{InvocationIdPropertyName}'."); + } + + if (hasResult && resultElement.HasValue) + { + var returnType = binder.GetReturnType(invocationId); + if (returnType is null) + { + result = null; + } + else + { + try + { + result = BindTypeFromElement(resultElement.Value, returnType); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {returnType.Name}. {ex.Message}"; + hasResult = false; + } + } + } + + message = BindCompletionMessage(invocationId, error, result, hasResult); + break; + case HubProtocolConstants.InvocationMessageType: + case HubProtocolConstants.StreamInvocationMessageType: + case HubProtocolConstants.StreamItemMessageType: + case HubProtocolConstants.CancelInvocationMessageType: + case HubProtocolConstants.PingMessageType: + case HubProtocolConstants.CloseMessageType: + throw new NotSupportedException($"Not supported message type: {type}."); + case null: + throw new InvalidDataException($"Missing required property '{TypePropertyName}'."); + default: + return null!; + } + return message; + } + catch (JsonException jrex) + { + throw new InvalidDataException("Error reading JSON.", jrex); + } + } + + private object? BindTypeFromElement(JsonElement element, Type type) + { + // For normal types, deserialize using ObjectSerializer from the element's raw JSON + var raw = element.GetRawText(); + var stream = new MemoryStream(Encoding.UTF8.GetBytes(raw)); + return BindType(ref stream, type); + } + + private object? BindType(ref MemoryStream reader, Type type) => ObjectSerializer.Deserialize(reader, type, default); + + private static HubMessage BindCompletionMessage(string invocationId, string? error, object? result, bool hasResult) + { + if (string.IsNullOrEmpty(invocationId)) + { + throw new InvalidDataException($"Missing required property '{InvocationIdPropertyName}'."); + } + + if (error != null && hasResult) + { + throw new InvalidDataException("The 'error' and 'result' properties are mutually exclusive."); + } + + if (hasResult) + { + return new CompletionMessage(invocationId, error, result, hasResult: true); + } + + return new CompletionMessage(invocationId, error, result: null, hasResult: false); + } +#endif + private void WriteMessageCore(HubMessage message, IBufferWriter stream) { var reusableWriter = ReusableUtf8JsonWriter.Get(stream); @@ -319,6 +469,6 @@ internal static JsonSerializerOptions CreateDefaultSerializerSettings() DefaultBufferSize = 16 * 1024, Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, }; - } + } } } diff --git a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs index b0241653f..856bc31f9 100644 --- a/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/WebsocketsHubLifetimeManager.cs @@ -2,6 +2,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. #nullable enable using System; +#if NET7_0_OR_GREATER +using System.Linq; +#endif using System.Threading; using System.Threading.Tasks; @@ -269,6 +272,11 @@ public override async Task InvokeConnectionAsync(string connectionId, stri throw new ArgumentNullException(nameof(methodName)); } + if (!ProtocolResolver.AllProtocols.All(IsInvocationSupported)) + { + throw new NotSupportedException("Non supported protocol for client invocation."); + } + // cancellationToken is required to be cancellable. using var cts = new CancellationTokenSource(DefaultInvocationTimeoutTimespan); @@ -297,6 +305,19 @@ public override Task SetConnectionResultAsync(string connectionId, CompletionMes // this is to honor the interface throw new NotImplementedException(); } + + private static bool IsInvocationSupported(IHubProtocol protocol) + { + // Use protocol.Name to check for supported protocols + switch (protocol.Name) + { + case "json": + case "messagepack": + return true; + default: + return false; + } + } #endif protected override T AppendMessageTracingId(T message) diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs index 0255c9bd2..d1b2a8654 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs @@ -22,12 +22,14 @@ internal abstract class ServiceLifetimeManagerBase : HubLifetimeManager serviceConnectionManager, IHubProtocolResolver protocolResolver, IOptions globalHubOptions, IOptions> hubOptions, ILogger logger) { Logger = logger ?? throw new ArgumentNullException(nameof(logger)); ServiceConnectionContainer = serviceConnectionManager; _messageSerializer = new DefaultHubMessageSerializer(protocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols); + _protocolResolver = protocolResolver; } public override Task OnConnectedAsync(HubConnectionContext connection) @@ -326,6 +328,8 @@ protected virtual T AppendMessageTracingId(T message) where T : ServiceMessag return message.WithTracingId(); } + protected IHubProtocolResolver ProtocolResolver => _protocolResolver; + private async Task WriteCoreAsync(T message, Func task) where T : ServiceMessage, IMessageWithTracingId { try diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs index de1ef2974..2cbd6e5d5 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs +++ b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs @@ -4,18 +4,24 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; - +using Azure.Core.Serialization; +using MessagePack; +using MessagePack.Formatters; +using MessagePack.Resolvers; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Azure.SignalR.Tests; using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; - using Xunit; using Xunit.Abstractions; @@ -971,6 +977,219 @@ public async Task ListConnectionsInGroupTest(ServiceTransportType serviceTranspo } } + public static IEnumerable ClientInvocationTestData => from serviceTransportType in ServiceTransportType + from intention in (ClientInvocationTestIntentions[])Enum.GetValues(typeof(ClientInvocationTestIntentions)) + select new object[] { serviceTransportType, intention }; + + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [MemberData(nameof(ClientInvocationTestData))] + public async Task ClientInvocationTest(ServiceTransportType serviceTransportType, ClientInvocationTestIntentions intention) + { + bool testJson = true; + bool testMessagePack = true; + bool hasCustomisedSerializer = false; + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocationTest)); + var serviceManagerBuilder = new ServiceManagerBuilder().WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }); + + switch (intention) + { + case ClientInvocationTestIntentions.Default: + { + if (serviceTransportType == Management.ServiceTransportType.Persistent) + { + // In persistent mode, service will not convert the json message to messagepack automatically + testMessagePack = false; + } + break; + } + case ClientInvocationTestIntentions.Json: + { + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(new JsonHubProtocol()); + testMessagePack = false; + break; + } + case ClientInvocationTestIntentions.MessagePack: + { + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(new MessagePackHubProtocol()); + testJson = false; + break; + } + case ClientInvocationTestIntentions.MultipleProtocols: + { + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(new JsonHubProtocol(), new MessagePackHubProtocol()); + break; + } + case ClientInvocationTestIntentions.MessagePackWithNewtonSoft: + { + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(new MessagePackHubProtocol()) + .WithNewtonsoftJson(); + break; + } + case ClientInvocationTestIntentions.JsonWithCustomisedSerializer: + { + var options = JsonObjectSerializerHubProtocol.CreateDefaultSerializerSettings(); + options.Converters.Add(new TestEnumJsonConverter()); + var objectSerializer = new JsonObjectSerializer(options); + var protocol = new JsonObjectSerializerHubProtocol(objectSerializer); + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(protocol); + testMessagePack = false; + hasCustomisedSerializer = true; + break; + } + case ClientInvocationTestIntentions.MessagePackWithCustomisedSerializer: + { + var protocol = GetMessagePackHubProtocolWithCustomisedSerializer(); + serviceManagerBuilder = serviceManagerBuilder.WithHubProtocols(protocol); + testJson = false; + hasCustomisedSerializer = true; + break; + } + } + + var serviceManager = serviceManagerBuilder.WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + if (testJson) + { + await TestClientInvocationAsync(hubContext, "json", hasCustomisedSerializer); + } + if (testMessagePack) + { + await TestClientInvocationAsync(hubContext, "messagepack", hasCustomisedSerializer); + } + } + + private static async Task CreateAndStartClientConnectionWithProtocolAsync(string endpoint, string accessToken, string protocol = "json", bool hasCustomisedSerializer = false) + { + var connectionBuilder = new HubConnectionBuilder() + .WithUrl(endpoint, option => + { + option.AccessTokenProvider = () => + { + return Task.FromResult(accessToken); + }; + }) + .WithAutomaticReconnect(); + + var messagePackOptions = MessagePackSerializerOptions.Standard + .WithResolver( + CompositeResolver.Create( + // formatters (our enum formatter first) + new IMessagePackFormatter[] + { + new TestEnumFormatter(), + new TestObjectFormatter(), + }, + // resolvers + new IFormatterResolver[] + { + StandardResolver.Instance, + })); + + switch (protocol.ToLower()) + { + case "json": + if (hasCustomisedSerializer) + { + connectionBuilder.AddJsonProtocol(options => + { + options.PayloadSerializerOptions.Converters.Add(new TestEnumJsonConverter()); + }); + } + else + { + connectionBuilder.AddJsonProtocol(); + } + break; + case "messagepack": + if (hasCustomisedSerializer) + { + connectionBuilder.AddMessagePackProtocol( + options => + { + options.SerializerOptions = messagePackOptions; + }); + } + else + { + connectionBuilder.AddMessagePackProtocol(); + } + break; + default: + throw new ArgumentException($"The protocol '{protocol}' is not supported."); + } + var connection = connectionBuilder.Build(); + + await connection.StartAsync(); + + connection.On("InvokeString", (Func>)(args => + { + return Task.FromResult("Method Invoked"); + })); + connection.On("InvokeObject", (Func>)(args => + { + return Task.FromResult(new testObject { Name = "Method Invoked", EnumValue = TestEnum.MethodInvoked }); + })); + connection.On("InvokeEnum", (Func>)(args => + { + return Task.FromResult(TestEnum.MethodInvoked); + })); + connection.On("InvokeNull", (Func>)(args => + { + return Task.FromResult(null); + })); + connection.On("InvokeException", (Func>)(args => + { + throw new InvalidOperationException("Test exception"); + })); + + return connection; + } + + private static async Task TestClientInvocationAsync(ServiceHubContext context, string protocol, bool hasCustomisedSerializer = false) + { + var negotationResponse = await context.NegotiateAsync(); + + var expectedStringMessage = "Method Invoked"; + var expectedEnumString = "MethodInvoked"; + var expectedCustomisedEnumString = "aaamytest"; + + var clientConnection = await CreateAndStartClientConnectionWithProtocolAsync(negotationResponse.Url, negotationResponse.AccessToken, protocol, hasCustomisedSerializer); + + var response_string = await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeString", "", default).OrTimeout(); + Assert.Equal(expectedStringMessage, response_string.ToString()); + + var response_object = await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeObject", "", default).OrTimeout(); + Assert.Equal(expectedStringMessage, response_object.Name); + Assert.Equal(hasCustomisedSerializer ? expectedCustomisedEnumString : expectedEnumString, response_object.EnumValue.ToString()); + + var response_null = await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeNull", "", default).OrTimeout(); + Assert.Null(response_null); + + var ex = await Assert.ThrowsAsync(async () => + await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeException", "", default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + + if (hasCustomisedSerializer) + { + var response_enum = await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeEnum", "", default); + Assert.Equal(expectedCustomisedEnumString, response_enum.ToString()); + } + else + { + var response_enum = await context.Clients.Client(clientConnection.ConnectionId).InvokeAsync("InvokeEnum", "", default); + Assert.Equal(expectedEnumString, response_enum.ToString()); + } + + await clientConnection.StopAsync(); + } + private static IDictionary> GenerateUserGroupDict(IList userNames, IList groupNames) { return (from i in Enumerable.Range(0, userNames.Count) @@ -1054,12 +1273,6 @@ await Task.WhenAll(from connection in connections } } - private static string[] GetTestStringList(string prefix, int count) - { - return (from i in Enumerable.Range(0, count) - select $"{prefix}{i}").ToArray(); - } - private async Task<(string ClientEndpoint, IEnumerable ClientAccessTokens, IServiceHubContext ServiceHubContext)> InitAsync(ServiceTransportType serviceTransportType, string appName, IEnumerable userNames) { var serviceManager = GenerateServiceManager(TestConfiguration.Instance.ConnectionString, serviceTransportType, appName); @@ -1179,4 +1392,200 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except } } } + + private sealed class TestEnumFormatter : IMessagePackFormatter + { + public void Serialize(ref MessagePackWriter writer, TestEnum value, MessagePackSerializerOptions options) + { + if (value == TestEnum.MethodInvoked) + { + writer.Write("aaamytest"); + } + } + + public TestEnum Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + var name = reader.ReadString(); + return name == "aaamytest" + ? TestEnum.aaamytest + : TestEnum.None; + } + } + + private sealed class TestObjectFormatter : IMessagePackFormatter + { + public void Serialize(ref MessagePackWriter writer, testObject value, MessagePackSerializerOptions options) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + writer.WriteMapHeader(2); + + writer.Write("Name"); + writer.Write(value.Name); + + writer.Write("EnumValue"); + var resolver = options.Resolver; + var enumFormatter = resolver.GetFormatterWithVerify(); + enumFormatter.Serialize(ref writer, value.EnumValue, options); + } + + public testObject Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + if (reader.TryReadNil()) + { + return null; + } + + var count = reader.ReadMapHeader(); + + var result = new testObject(); + var resolver = options.Resolver; + var enumFormatter = resolver.GetFormatterWithVerify(); + + for (var i = 0; i < count; i++) + { + var propertyName = reader.ReadString(); + + switch (propertyName) + { + case "Name": + result.Name = reader.ReadString(); + break; + case "EnumValue": + result.EnumValue = enumFormatter.Deserialize(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + return result; + } + } + + public sealed class testObject + { + public string Name { get; set; } + public TestEnum EnumValue { get; set; } + } + + public enum TestEnum + { + None, + MethodInvoked, + aaamytest + } + + public enum ClientInvocationTestIntentions + { + Default, + Json, + MessagePack, + MultipleProtocols, + MessagePackWithNewtonSoft, + JsonWithCustomisedSerializer, + MessagePackWithCustomisedSerializer + } + + private static MessagePackHubProtocol GetMessagePackHubProtocolWithCustomisedSerializer() + { + var messagePackOptions = MessagePackSerializerOptions.Standard + .WithResolver( + CompositeResolver.Create( + new IMessagePackFormatter[] + { + new TestEnumFormatter(), + new TestObjectFormatter(), + }, + new IFormatterResolver[] + { + StandardResolver.Instance, + })); + + var protocolOptions = new MessagePackHubProtocolOptions + { + SerializerOptions = messagePackOptions, + }; + + return new MessagePackHubProtocol( + Extensions.Options.Options.Create(protocolOptions) + ); + } + + private sealed class TestEnumJsonConverter : JsonConverter + { + public override TestEnum Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return TestEnum.None; + } + + var name = reader.GetString(); + return name == "aaamytest" + ? TestEnum.aaamytest + : TestEnum.None; + } + + public override void Write(Utf8JsonWriter writer, TestEnum value, JsonSerializerOptions options) + { + if (value == TestEnum.MethodInvoked) + { + writer.WriteStringValue("aaamytest"); + } + else + { + writer.WriteNullValue(); + } + } + } + + private sealed class MessagePackObjectSerializer : ObjectSerializer + { + private readonly MessagePackSerializerOptions _options; + + public MessagePackObjectSerializer(MessagePackSerializerOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public override void Serialize(Stream stream, object value, Type type, CancellationToken cancellationToken) + { + // MessagePack is sync; we honor the token only for consistency. + MessagePackSerializer.Serialize(type, stream, value, _options, cancellationToken: cancellationToken); + stream.Flush(); + } + + public override async ValueTask SerializeAsync(Stream stream, object value, Type type, CancellationToken cancellationToken) + { +#if NETSTANDARD2_0 + // Async overloads may not be available; fall back to sync and wrap in Task. + Serialize(stream, value, type, cancellationToken); + await Task.CompletedTask; +#else + await MessagePackSerializer.SerializeAsync(type, stream, value, _options, cancellationToken); + await stream.FlushAsync(cancellationToken); +#endif + } + + public override object Deserialize(Stream stream, Type returnType, CancellationToken cancellationToken) + { + return MessagePackSerializer.Deserialize(returnType, stream, _options, cancellationToken: cancellationToken); + } + + public override async ValueTask DeserializeAsync(Stream stream, Type returnType, CancellationToken cancellationToken) + { +#if NETSTANDARD2_0 + // Async overloads may not be available; fall back to sync. + return Deserialize(stream, returnType, cancellationToken); +#else + return await MessagePackSerializer.DeserializeAsync(returnType, stream, _options, cancellationToken); +#endif + } + } } diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj b/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj index 0930c7007..ae8ed178f 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj +++ b/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj @@ -7,6 +7,7 @@ + diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/Microsoft.Azure.SignalR.Management.Tests.csproj b/test/Microsoft.Azure.SignalR.Management.Tests/Microsoft.Azure.SignalR.Management.Tests.csproj index 8fc914aa1..79c0ce4f4 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/Microsoft.Azure.SignalR.Management.Tests.csproj +++ b/test/Microsoft.Azure.SignalR.Management.Tests/Microsoft.Azure.SignalR.Management.Tests.csproj @@ -8,7 +8,7 @@ - + diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs new file mode 100644 index 000000000..2a93983ae --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs @@ -0,0 +1,337 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Common; +using Microsoft.Azure.SignalR.Tests.Common; +using Moq; +using Moq.Protected; +using Xunit; + +#nullable enable + +namespace Microsoft.Azure.SignalR.Management.Tests +{ + public class RestHubLifetimeManagerFacts + { +#if NET7_0_OR_GREATER + private readonly Mock _httpClientFactoryMock; + private readonly HttpClient _httpClient; + private readonly string _hubName = "TestHub"; + private readonly string _appName = "TestApp"; + private readonly RestHubLifetimeManager _manager; + + private readonly Mock _httpMessageHandlerMock; + + public RestHubLifetimeManagerFacts() + { + _httpMessageHandlerMock = new Mock(); + + _httpClient = new HttpClient(_httpMessageHandlerMock.Object); + + _httpClientFactoryMock = new Mock(); + _httpClientFactoryMock + .Setup(f => f.CreateClient(It.IsAny())) + .Returns(_httpClient); + + var restClient = new RestClient(_httpClientFactoryMock.Object); + + _manager = new RestHubLifetimeManager( + _hubName, + new(FakeEndpointUtils.GetFakeConnectionString(1).First()), + _appName, + restClient, + new DefaultHubProtocolResolver() + ); + } + + [Fact] + public async Task InvokeConnectionAsync_NullMethodName_ThrowsArgumentException() + { + string? methodName = null; + var connectionId = "connection1"; + var args = Array.Empty(); + + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName!, args)); + + Assert.Equal("methodName", exception.ParamName); + + methodName = ""; + exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal("methodName", exception.ParamName); + } + + [Fact] + public async Task InvokeConnectionAsync_NullConnectionId_ThrowsArgumentException() + { + var methodName = "testMethod"; + string? connectionId = null; + var args = Array.Empty(); + + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId!, methodName, args)); + + Assert.Equal("connectionId", exception.ParamName); + + connectionId = ""; + exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal("connectionId", exception.ParamName); + } + + [Fact] + public async Task InvokeConnectionAsync_WithNotFoundResponse_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getError"; + var args = Array.Empty(); + var errorMessage = "Connection does not exist."; + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny() + ) + .ReturnsAsync(() => new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent(errorMessage) + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + } + + [Fact] + public async Task InvokeConnectionAsync_WithBadRequestResponse_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getError"; + var args = Array.Empty(); + var errorMessage = "This is a Bad Request."; + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny() + ) + .ReturnsAsync(() => new HttpResponseMessage(HttpStatusCode.BadRequest) + { + Content = new StringContent(errorMessage) + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal(errorMessage, exception.Message); + } + + [Fact] + public async Task InvokeConnectionAsync_WithStringResult_ReturnsDeserializedValue() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getUsername"; + var args = new object?[] { 42, "test-param", true }; + var expectedResult = "John Doe"; + + // Build a CompletionMessage carrying the string result + var completion = new CompletionMessage( + invocationId: "1234", + error: null, + result: expectedResult, + hasResult: true); + + // Serialize to SignalR JSON frame (with record separator) + var protocol = new JsonHubProtocol(); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }; + + // Protocol header expected by InvokeConnectionAsync + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, protocol.Name); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + + return response; + }); + + // Act + var result = await _manager.InvokeConnectionAsync(connectionId, methodName, args); + + // Assert + Assert.Equal(expectedResult, result); + } + + [Fact] + public async Task InvokeConnectionAsync_WithComplexObjectResult_ReturnsDeserializedObject() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getUserProfile"; + var args = new object?[] { "userId123", new { filter = "personal" } }; + + var expectedProfile = new UserProfile + { + id = 123, + name = "Jane Doe", + active = true, + roles = new[] { "user", "admin" }, + }; + + var completion = new CompletionMessage( + invocationId: "1234", + error: null, + result: expectedProfile, + hasResult: true); + + var protocol = new JsonHubProtocol(); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }; + + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, protocol.Name); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + + return response; + }); + + // Act + var result = await _manager.InvokeConnectionAsync(connectionId, methodName, args); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedProfile.id, result.id); + Assert.Equal(expectedProfile.name, result.name); + Assert.Equal(expectedProfile.active, result.active); + Assert.Equal(expectedProfile.roles.Length, result.roles.Length); + Assert.Contains("admin", result.roles); + } + + [Fact] + public async Task InvokeConnectionAsync_WithMissingProtocolHeader_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getData"; + var args = Array.Empty(); + + var protocol = new JsonHubProtocol(); + var completion = new CompletionMessage("1234", null, "value", hasResult: true); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }); + + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + Assert.Equal("Response is missing protocol header.", ex.Message); + } + + [Fact] + public async Task InvokeConnectionAsync_WithEmptyPayload_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getData"; + var args = Array.Empty(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(Array.Empty()), + }; + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, "json"); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + return response; + }); + + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + Assert.Equal("Response payload is empty.", ex.Message); + } +#endif + + public class TestHub : Hub { } + + public class UserProfile + { + public int id { get; set; } + public string name { get; set; } = string.Empty; + public bool active { get; set; } + public string[] roles { get; set; } = Array.Empty(); + } + + private sealed class DefaultHubProtocolResolver : IHubProtocolResolver + { + public IReadOnlyList AllProtocols => new List + { + new JsonHubProtocol(), + new MessagePackHubProtocol() + }; + + public IHubProtocol? GetProtocol(string protocolName, IReadOnlyList? supportedProtocols) + { + throw new NotImplementedException(); + } + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/WebsocketsHubLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/WebsocketsHubLifetimeManagerFacts.cs index f3d2dfebb..7fb907e23 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/WebsocketsHubLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/WebsocketsHubLifetimeManagerFacts.cs @@ -1,22 +1,24 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; -using System.Threading.Tasks; +using System.Linq; using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Moq; using Xunit; -using System; namespace Microsoft.Azure.SignalR.Management.Tests { public class WebsocketsHubLifetimeManagerFacts { private readonly Mock> _serviceConnectionManagerMock; - private readonly Mock _protocolResolverMock; + private readonly DefaultHubProtocolResolver _protocolResolver; private readonly Mock> _globalHubOptionsMock; private readonly Mock>> _hubOptionsMock; private readonly Mock _loggerFactoryMock; @@ -28,7 +30,7 @@ public class WebsocketsHubLifetimeManagerFacts public WebsocketsHubLifetimeManagerFacts() { _serviceConnectionManagerMock = new Mock>(); - _protocolResolverMock = new Mock(); + _protocolResolver = new DefaultHubProtocolResolver([new JsonHubProtocol(), new MessagePackHubProtocol()]); _globalHubOptionsMock = new Mock>(); _hubOptionsMock = new Mock>>(); _loggerFactoryMock = new Mock(); @@ -49,7 +51,7 @@ public void Constructor_ShouldInitialize_WhenDependenciesAreValid() // Act var manager = new WebSocketsHubLifetimeManager( _serviceConnectionManagerMock.Object, - _protocolResolverMock.Object, + _protocolResolver, _globalHubOptionsMock.Object, _hubOptionsMock.Object, _loggerFactoryMock.Object, @@ -71,7 +73,7 @@ public async Task InvokeConnectionAsync_ShouldInvokeMethod_WhenArgumentsAreValid // Arrange var manager = new WebSocketsHubLifetimeManager( _serviceConnectionManagerMock.Object, - _protocolResolverMock.Object, + _protocolResolver, _globalHubOptionsMock.Object, _hubOptionsMock.Object, _loggerFactoryMock.Object, @@ -108,7 +110,7 @@ public async Task InvokeConnectionAsync_ShouldThrowArgumentNullException_WhenCon // Arrange var manager = new WebSocketsHubLifetimeManager( _serviceConnectionManagerMock.Object, - _protocolResolverMock.Object, + _protocolResolver, _globalHubOptionsMock.Object, _hubOptionsMock.Object, _loggerFactoryMock.Object, @@ -133,7 +135,7 @@ public async Task InvokeConnectionAsync_ShouldThrowArgumentNullException_WhenMet // Arrange var manager = new WebSocketsHubLifetimeManager( _serviceConnectionManagerMock.Object, - _protocolResolverMock.Object, + _protocolResolver, _globalHubOptionsMock.Object, _hubOptionsMock.Object, _loggerFactoryMock.Object, @@ -151,6 +153,41 @@ public async Task InvokeConnectionAsync_ShouldThrowArgumentNullException_WhenMet // Act & Assert await Assert.ThrowsAsync(() => manager.InvokeConnectionAsync(connectionId, invalidMethodName, args, cancellationToken)); } + + private sealed class DefaultHubProtocolResolver : IHubProtocolResolver + { + + private readonly List _hubProtocols; + private readonly Dictionary _availableProtocols; + + public IReadOnlyList AllProtocols => _hubProtocols; + + public DefaultHubProtocolResolver(IEnumerable availableProtocols) + { + _availableProtocols = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // We might get duplicates in _hubProtocols, but we're going to check it and overwrite in just a sec. + _hubProtocols = availableProtocols.ToList(); + foreach (var protocol in _hubProtocols) + { + _availableProtocols[protocol.Name] = protocol; + } + } + + public IHubProtocol GetProtocol(string protocolName, IReadOnlyList supportedProtocols) + { + protocolName = protocolName ?? throw new ArgumentNullException(nameof(protocolName)); + + if (_availableProtocols.TryGetValue(protocolName, out var protocol) && (supportedProtocols == null || supportedProtocols.Contains(protocolName, StringComparer.OrdinalIgnoreCase))) + { + return protocol; + } + + // null result indicates protocol is not supported + // result will be validated by the caller + return null; + } + } #endif } diff --git a/test/appsettings.Test.json b/test/appsettings.Test.json index c9408e0b8..93efd92c0 100644 --- a/test/appsettings.Test.json +++ b/test/appsettings.Test.json @@ -1,4 +1,4 @@ -{ +{ "ReadMe": "Recommend external contributors fill in `ConnectionString`, change service mode to `Classic` and then run E2E test before create a pull request.", "Azure": { "SignalR": {