diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index a33754220..936b73640 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -140,6 +140,8 @@ public static class Headers public const string AsrsMessageTracingId = AsrsInternalHeaderPrefix + "Message-Tracing-Id"; public const string MicrosoftErrorCode = "x-ms-error-code"; + + public const string AsrsManagementSDKClientInvocationProtocol = AsrsInternalHeaderPrefix + "Protocol"; } public static class ErrorCodes diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs new file mode 100644 index 000000000..ca4532233 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/Rest/SimpleInvocationBinder.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR; + +internal sealed class SimpleInvocationBinder : IInvocationBinder +{ + private readonly Type _returnType; + + public SimpleInvocationBinder(Type returnType) + { + _returnType = returnType ?? throw new ArgumentNullException(nameof(returnType)); + } + + public Type GetReturnType(string invocationId) + { + return _returnType; + } + + public IReadOnlyList GetParameterTypes(string methodName) + { + throw new NotImplementedException(); + } + + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } +} diff --git a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs index f424fd1be..bdcec3341 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections"); } + public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId) + { + return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke"); + } + private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary queries = null) { var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}"; diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs index 4ca0560a6..5de66f0a8 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Net; using System.Net.Http; + using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -15,6 +16,8 @@ using Microsoft.AspNetCore.SignalR; #if NET7_0_OR_GREATER +using System.Buffers; +using System.Net.Http.Headers; using Microsoft.AspNetCore.SignalR.Protocol; #endif using Microsoft.Extensions.Primitives; @@ -359,10 +362,115 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId } #if NET7_0_OR_GREATER + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(methodName)) + { + throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); + } + if (string.IsNullOrEmpty(connectionId)) + { + throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); + } + if (!_protocolResolver.AllProtocols.All(IsInvocationSupported)) + { + throw new NotSupportedException("Non supported protocol for client invocation."); + } + + var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId); + string? errorContent = null; + var isSuccess = false; + CompletionMessage? responseMessage = null; + + await _restClient.SendMessageWithRetryAsync( + api, + HttpMethod.Post, + methodName, + args, + async response => + { + isSuccess = response.IsSuccessStatusCode; + + if (isSuccess) + { + // 1. Get protocol from header (e.g. "json" or "messagepack") + if (!response.Headers.TryGetValues(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, out var protocolHeaderValues)) + { + throw new HubException("Response is missing protocol header."); + } + var protocolName = protocolHeaderValues.FirstOrDefault(); + if (string.IsNullOrEmpty(protocolName)) + { + throw new HubException("Response protocol header is empty."); + } + + // 2. Pick the hub protocol that matches X-Protocol + var protocol = _protocolResolver.GetProtocol(protocolName, supportedProtocols: null); + + if (protocol == null) + { + throw new InvalidOperationException( + $"The protocol '{protocolName}' is not configured. " + + $"Add the missing protocol using ServiceManagerBuilder.AddHubProtocol() or ServiceManagerBuilder.WithHubProtocols()."); + } + + // 3. Read raw completion payload from response body + + var buffer = await response.Content.ReadAsByteArrayAsync(cancellationToken) + .ConfigureAwait(false); + + if (buffer.Length == 0) + { + throw new HubException("Response payload is empty."); + } + + // 4. Use SimpleInvocationBinder with typeof(T) + var binder = new SimpleInvocationBinder(typeof(T)); + + // 5. Parse the payload bytes into CompletionMessage + var sequence = new ReadOnlySequence(buffer); + var local = sequence; + if (!protocol.TryParseMessage(ref local, binder, out var hubMessage)) + { + throw new HubException("Failed to parse invocation response."); + } + + responseMessage = (CompletionMessage)hubMessage!; + } + else + { + errorContent = await response.Content.ReadAsStringAsync(cancellationToken); + } + + return isSuccess || response.StatusCode == HttpStatusCode.BadRequest; + }, + new MediaTypeWithQualityHeaderValue("application/octet-stream"), + cancellationToken); + + if (!isSuccess) + { + throw new HubException(errorContent ?? "Unknown error in response"); + } + if (responseMessage == null) + { + throw new HubException("Response message is null."); + } + if (responseMessage.Error != null) + { + throw new HubException(responseMessage.Error); + } + + return (T)responseMessage!.Result!; + } + + public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + // This method won't get trigger because in transient we will wait for the returned completion message. + // this is to honor the interface + throw new NotImplementedException(); + } -#pragma warning disable IDE0051 // Will be used in the future updates private static bool IsInvocationSupported(IHubProtocol protocol) -#pragma warning restore IDE0051 // Will be used in the future updates { // Use protocol.Name to check for supported protocols switch (protocol.Name) diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs index de1ef2974..fd71a9ec9 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs +++ b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs @@ -4,18 +4,24 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; - +using Azure.Core.Serialization; +using MessagePack; +using MessagePack.Formatters; +using MessagePack.Resolvers; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Azure.SignalR.Tests; using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; - using Xunit; using Xunit.Abstractions; @@ -971,6 +977,529 @@ public async Task ListConnectionsInGroupTest(ServiceTransportType serviceTranspo } } + #region ClientInvocation Tests + + /// + /// Tests client invocation with default protocol configuration using JSON client. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithDefaultProtocol_JsonClient(ServiceTransportType serviceTransportType) + { + // Arrange: Create service manager with default protocol (no explicit hub protocol configured) + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithDefaultProtocol_JsonClient)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create JSON client connection + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateJsonClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with standard serialization + Assert.Equal(TestEnum.MethodInvoked, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + /// + /// Tests client invocation with explicit JSON protocol configured on ServiceManager. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithExplicitJsonProtocol(ServiceTransportType serviceTransportType) + { + // Arrange: Create service manager with explicit JsonHubProtocol + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithExplicitJsonProtocol)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(new JsonHubProtocol()) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create JSON client connection + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateJsonClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with standard serialization + Assert.Equal(TestEnum.MethodInvoked, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + /// + /// Tests client invocation with explicit MessagePack protocol configured on ServiceManager. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithExplicitMessagePackProtocol(ServiceTransportType serviceTransportType) + { + // Arrange: Create service manager with explicit MessagePackHubProtocol + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithExplicitMessagePackProtocol)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(new MessagePackHubProtocol()) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create MessagePack client connection + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateMessagePackClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with standard serialization + Assert.Equal(TestEnum.MethodInvoked, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + /// + /// Tests client invocation with both JSON and MessagePack protocols configured. + /// Verifies that clients using either protocol can successfully invoke methods. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithMultipleProtocols(ServiceTransportType serviceTransportType) + { + // Arrange: Create service manager with both JSON and MessagePack protocols + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithMultipleProtocols)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(new JsonHubProtocol(), new MessagePackHubProtocol()) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create both JSON and MessagePack client connections + var negotiationResponse = await hubContext.NegotiateAsync(); + var jsonClient = await CreateJsonClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + var messagePackClient = await CreateMessagePackClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // JSON Client - Act: Invoke method that returns all test values + var jsonResult = await hubContext.Clients.Client(jsonClient.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify all values for JSON client + Assert.Equal("Method Invoked", jsonResult.StringValue); + Assert.Equal(TestEnum.MethodInvoked, jsonResult.EnumValue); + Assert.Null(jsonResult.NullValue); + Assert.Equal(TestDateTime, jsonResult.DateTimeValue); + + // Act & Assert: Invoke method that throws exception (JSON) + var ex_json = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(jsonClient.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex_json.Message); + + // MessagePack Client - Act: Invoke method that returns all test values + var msgPackResult = await hubContext.Clients.Client(messagePackClient.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify all values for MessagePack client + Assert.Equal("Method Invoked", msgPackResult.StringValue); + Assert.Equal(TestEnum.MethodInvoked, msgPackResult.EnumValue); + Assert.Null(msgPackResult.NullValue); + Assert.Equal(TestDateTime, msgPackResult.DateTimeValue); + + // Act & Assert: Invoke method that throws exception (MessagePack) + var ex_messagePack = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(messagePackClient.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex_messagePack.Message); + } + finally + { + await jsonClient.StopAsync(); + await messagePackClient.StopAsync(); + } + } + + /// + /// Tests client invocation with MessagePack protocol and Newtonsoft.Json serializer for REST payloads. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithMessagePackAndNewtonsoftJson(ServiceTransportType serviceTransportType) + { + // Arrange: Create service manager with MessagePack protocol and Newtonsoft.Json for REST + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithMessagePackAndNewtonsoftJson)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(new MessagePackHubProtocol()) + .WithNewtonsoftJson() + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create MessagePack client connection + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateMessagePackClientConnectionAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with standard serialization + Assert.Equal(TestEnum.MethodInvoked, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + /// + /// Tests client invocation with custom JSON serializer. + /// Custom serializer converts TestEnum.MethodInvoked to "aaamytest" string. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithCustomJsonSerializer(ServiceTransportType serviceTransportType) + { + // Arrange: Create custom JSON serializer with TestEnumJsonConverter + var jsonOptions = JsonObjectSerializerHubProtocol.CreateDefaultSerializerSettings(); + jsonOptions.Converters.Add(new TestEnumJsonConverter()); + var customProtocol = new JsonObjectSerializerHubProtocol(new JsonObjectSerializer(jsonOptions)); + + // Arrange: Create service manager with custom JSON protocol + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithCustomJsonSerializer)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(customProtocol) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create JSON client with matching custom serializer + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateJsonClientWithCustomSerializerAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with customised serialization (MethodInvoked -> aaamytest) + Assert.Equal(TestEnum.aaamytest, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + /// + /// Tests client invocation with custom MessagePack serializer. + /// Custom serializer converts TestEnum.MethodInvoked to "aaamytest" string. + /// + [ConditionalTheory] + [SkipIfConnectionStringNotPresent] + [InlineData(Management.ServiceTransportType.Transient)] + [InlineData(Management.ServiceTransportType.Persistent)] + public async Task ClientInvocation_WithCustomMessagePackSerializer(ServiceTransportType serviceTransportType) + { + // Arrange: Create custom MessagePack protocol with TestEnumFormatter and TestInvocationResultFormatter + var customProtocol = CreateMessagePackProtocolWithCustomSerializer(); + + // Arrange: Create service manager with custom MessagePack protocol + using var logger = StartLog(out var loggerFactory, nameof(ClientInvocation_WithCustomMessagePackSerializer)); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = TestConfiguration.Instance.ConnectionString; + o.ServiceTransportType = serviceTransportType; + }) + .WithHubProtocols(customProtocol) + .WithLoggerFactory(loggerFactory) + .BuildServiceManager(); + using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default); + + // Arrange: Create MessagePack client with matching custom serializer + var negotiationResponse = await hubContext.NegotiateAsync(); + var clientConnection = await CreateMessagePackClientWithCustomSerializerAsync(negotiationResponse.Url, negotiationResponse.AccessToken); + + try + { + // Act: Invoke method that returns all test values in a single call + var result = await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeAll", TestInput, default).OrTimeout(); + + // Assert: Verify string value + Assert.Equal("Method Invoked", result.StringValue); + + // Assert: Verify enum value with customised serialization (MethodInvoked -> aaamytest) + Assert.Equal(TestEnum.aaamytest, result.EnumValue); + + // Assert: Verify null value + Assert.Null(result.NullValue); + + // Assert: Verify datetime value + Assert.Equal(TestDateTime, result.DateTimeValue); + + // Act & Assert: Invoke method that throws exception + var ex = await Assert.ThrowsAsync(async () => + await hubContext.Clients.Client(clientConnection.ConnectionId) + .InvokeAsync("InvokeException", TestInput, default).OrTimeout()); + Assert.Contains("Test exception", ex.Message); + } + finally + { + await clientConnection.StopAsync(); + } + } + + #endregion + + #region Client Connection Helpers (setup/teardown plumbing) + + private static async Task CreateJsonClientConnectionAsync(string endpoint, string accessToken) + { + var connection = new HubConnectionBuilder() + .WithUrl(endpoint, option => option.AccessTokenProvider = () => Task.FromResult(accessToken)) + .WithAutomaticReconnect() + .AddJsonProtocol() + .Build(); + + await connection.StartAsync(); + RegisterClientInvocationHandlers(connection); + return connection; + } + + private static async Task CreateMessagePackClientConnectionAsync(string endpoint, string accessToken) + { + var connection = new HubConnectionBuilder() + .WithUrl(endpoint, option => option.AccessTokenProvider = () => Task.FromResult(accessToken)) + .WithAutomaticReconnect() + .AddMessagePackProtocol() + .Build(); + + await connection.StartAsync(); + RegisterClientInvocationHandlers(connection); + return connection; + } + + private static async Task CreateJsonClientWithCustomSerializerAsync(string endpoint, string accessToken) + { + var connection = new HubConnectionBuilder() + .WithUrl(endpoint, option => option.AccessTokenProvider = () => Task.FromResult(accessToken)) + .WithAutomaticReconnect() + .AddJsonProtocol(options => options.PayloadSerializerOptions.Converters.Add(new TestEnumJsonConverter())) + .Build(); + + await connection.StartAsync(); + RegisterClientInvocationHandlers(connection); + return connection; + } + + private static async Task CreateMessagePackClientWithCustomSerializerAsync(string endpoint, string accessToken) + { + var messagePackOptions = MessagePackSerializerOptions.Standard.WithResolver( + CompositeResolver.Create( + new IMessagePackFormatter[] { new TestEnumFormatter(), new TestInvocationResultFormatter(), new TestInvocationInputFormatter() }, + new IFormatterResolver[] { StandardResolver.Instance })); + + var connection = new HubConnectionBuilder() + .WithUrl(endpoint, option => option.AccessTokenProvider = () => Task.FromResult(accessToken)) + .WithAutomaticReconnect() + .AddMessagePackProtocol(options => options.SerializerOptions = messagePackOptions) + .Build(); + + await connection.StartAsync(); + RegisterClientInvocationHandlers(connection); + return connection; + } + + private static readonly DateTime TestDateTime = new DateTime(2024, 6, 15, 10, 30, 0, DateTimeKind.Utc); + + private static readonly TestInvocationInput TestInput = new TestInvocationInput + { + StringValue = "Test Input String", + DateTimeValue = TestDateTime, + IntValue = 42 + }; + + private static void RegisterClientInvocationHandlers(HubConnection connection) + { + connection.On("InvokeAll", (Func>)(input => + { + // Validate input was correctly deserialized + if (input.StringValue != TestInput.StringValue || + input.DateTimeValue != TestInput.DateTimeValue || + input.IntValue != TestInput.IntValue) + { + throw new InvalidOperationException($"Input validation failed. Expected: {TestInput.StringValue}, {TestInput.DateTimeValue}, {TestInput.IntValue}. Actual: {input.StringValue}, {input.DateTimeValue}, {input.IntValue}"); + } + + return Task.FromResult(new TestInvocationResult + { + StringValue = "Method Invoked", + EnumValue = TestEnum.MethodInvoked, + NullValue = null, + DateTimeValue = TestDateTime + }); + })); + connection.On("InvokeException", (Func>)(_ => throw new InvalidOperationException("Test exception"))); + } + + private static MessagePackHubProtocol CreateMessagePackProtocolWithCustomSerializer() + { + var messagePackOptions = MessagePackSerializerOptions.Standard.WithResolver( + CompositeResolver.Create( + new IMessagePackFormatter[] { new TestEnumFormatter(), new TestInvocationResultFormatter(), new TestInvocationInputFormatter() }, + new IFormatterResolver[] { StandardResolver.Instance })); + + return new MessagePackHubProtocol( + Extensions.Options.Options.Create(new MessagePackHubProtocolOptions { SerializerOptions = messagePackOptions })); + } + + #endregion + private static IDictionary> GenerateUserGroupDict(IList userNames, IList groupNames) { return (from i in Enumerable.Range(0, userNames.Count) @@ -1054,12 +1583,6 @@ await Task.WhenAll(from connection in connections } } - private static string[] GetTestStringList(string prefix, int count) - { - return (from i in Enumerable.Range(0, count) - select $"{prefix}{i}").ToArray(); - } - private async Task<(string ClientEndpoint, IEnumerable ClientAccessTokens, IServiceHubContext ServiceHubContext)> InitAsync(ServiceTransportType serviceTransportType, string appName, IEnumerable userNames) { var serviceManager = GenerateServiceManager(TestConfiguration.Instance.ConnectionString, serviceTransportType, appName); @@ -1179,4 +1702,244 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except } } } + + private sealed class TestEnumFormatter : IMessagePackFormatter + { + public void Serialize(ref MessagePackWriter writer, TestEnum value, MessagePackSerializerOptions options) + { + if (value == TestEnum.MethodInvoked) + { + writer.Write("aaamytest"); + } + } + + public TestEnum Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + var name = reader.ReadString(); + return name == "aaamytest" + ? TestEnum.aaamytest + : TestEnum.None; + } + } + + private sealed class TestInvocationResultFormatter : IMessagePackFormatter + { + public void Serialize(ref MessagePackWriter writer, TestInvocationResult value, MessagePackSerializerOptions options) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + writer.WriteMapHeader(4); + + writer.Write("StringValue"); + writer.Write(value.StringValue); + + writer.Write("EnumValue"); + var resolver = options.Resolver; + var enumFormatter = resolver.GetFormatterWithVerify(); + enumFormatter.Serialize(ref writer, value.EnumValue, options); + + writer.Write("NullValue"); + writer.WriteNil(); + + writer.Write("DateTimeValue"); + writer.Write(value.DateTimeValue.Ticks); + } + + public TestInvocationResult Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + if (reader.TryReadNil()) + { + return null; + } + + var count = reader.ReadMapHeader(); + + var result = new TestInvocationResult(); + var resolver = options.Resolver; + var enumFormatter = resolver.GetFormatterWithVerify(); + + for (var i = 0; i < count; i++) + { + var propertyName = reader.ReadString(); + + switch (propertyName) + { + case "StringValue": + result.StringValue = reader.ReadString(); + break; + case "EnumValue": + result.EnumValue = enumFormatter.Deserialize(ref reader, options); + break; + case "NullValue": + reader.TryReadNil(); + result.NullValue = null; + break; + case "DateTimeValue": + result.DateTimeValue = new DateTime(reader.ReadInt64(), DateTimeKind.Utc); + break; + default: + reader.Skip(); + break; + } + } + + return result; + } + } + + private sealed class TestInvocationInputFormatter : IMessagePackFormatter + { + public void Serialize(ref MessagePackWriter writer, TestInvocationInput value, MessagePackSerializerOptions options) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + writer.WriteMapHeader(3); + + writer.Write("StringValue"); + writer.Write(value.StringValue); + + writer.Write("DateTimeValue"); + writer.Write(value.DateTimeValue.Ticks); + + writer.Write("IntValue"); + writer.Write(value.IntValue); + } + + public TestInvocationInput Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + if (reader.TryReadNil()) + { + return null; + } + + var count = reader.ReadMapHeader(); + + var result = new TestInvocationInput(); + + for (var i = 0; i < count; i++) + { + var propertyName = reader.ReadString(); + + switch (propertyName) + { + case "StringValue": + result.StringValue = reader.ReadString(); + break; + case "DateTimeValue": + result.DateTimeValue = new DateTime(reader.ReadInt64(), DateTimeKind.Utc); + break; + case "IntValue": + result.IntValue = reader.ReadInt32(); + break; + default: + reader.Skip(); + break; + } + } + + return result; + } + } + + public sealed class TestInvocationInput + { + public string StringValue { get; set; } + public DateTime DateTimeValue { get; set; } + public int IntValue { get; set; } + } + + public sealed class TestInvocationResult + { + public string StringValue { get; set; } + public TestEnum EnumValue { get; set; } + public object NullValue { get; set; } + public DateTime DateTimeValue { get; set; } + } + + public enum TestEnum + { + None, + MethodInvoked, + aaamytest + } + + private sealed class TestEnumJsonConverter : JsonConverter + { + public override TestEnum Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return TestEnum.None; + } + + var name = reader.GetString(); + return name == "aaamytest" + ? TestEnum.aaamytest + : TestEnum.None; + } + + public override void Write(Utf8JsonWriter writer, TestEnum value, JsonSerializerOptions options) + { + if (value == TestEnum.MethodInvoked) + { + writer.WriteStringValue("aaamytest"); + } + else + { + writer.WriteNullValue(); + } + } + } + + private sealed class MessagePackObjectSerializer : ObjectSerializer + { + private readonly MessagePackSerializerOptions _options; + + public MessagePackObjectSerializer(MessagePackSerializerOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public override void Serialize(Stream stream, object value, Type type, CancellationToken cancellationToken) + { + // MessagePack is sync; we honor the token only for consistency. + MessagePackSerializer.Serialize(type, stream, value, _options, cancellationToken: cancellationToken); + stream.Flush(); + } + + public override async ValueTask SerializeAsync(Stream stream, object value, Type type, CancellationToken cancellationToken) + { +#if NETSTANDARD2_0 + // Async overloads may not be available; fall back to sync and wrap in Task. + Serialize(stream, value, type, cancellationToken); + await Task.CompletedTask; +#else + await MessagePackSerializer.SerializeAsync(type, stream, value, _options, cancellationToken); + await stream.FlushAsync(cancellationToken); +#endif + } + + public override object Deserialize(Stream stream, Type returnType, CancellationToken cancellationToken) + { + return MessagePackSerializer.Deserialize(returnType, stream, _options, cancellationToken: cancellationToken); + } + + public override async ValueTask DeserializeAsync(Stream stream, Type returnType, CancellationToken cancellationToken) + { +#if NETSTANDARD2_0 + // Async overloads may not be available; fall back to sync. + return Deserialize(stream, returnType, cancellationToken); +#else + return await MessagePackSerializer.DeserializeAsync(returnType, stream, _options, cancellationToken); +#endif + } + } } diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj b/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj index 0930c7007..ae8ed178f 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj +++ b/test/Microsoft.Azure.SignalR.E2ETests/Microsoft.Azure.SignalR.E2ETests.csproj @@ -7,6 +7,7 @@ + diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs new file mode 100644 index 000000000..7f8a62f93 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Management.Tests/RestHubLifetimeManagerFacts.cs @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Common; +using Microsoft.Azure.SignalR.Tests.Common; +using Moq; +using Moq.Protected; +using Xunit; + +#nullable enable + +namespace Microsoft.Azure.SignalR.Management.Tests +{ + public class RestHubLifetimeManagerFacts + { +#if NET7_0_OR_GREATER + private readonly Mock _httpClientFactoryMock; + private readonly HttpClient _httpClient; + private readonly string _hubName = "TestHub"; + private readonly string _appName = "TestApp"; + private readonly RestHubLifetimeManager _manager; + + private readonly Mock _httpMessageHandlerMock; + + public RestHubLifetimeManagerFacts() + { + _httpMessageHandlerMock = new Mock(); + + _httpClient = new HttpClient(_httpMessageHandlerMock.Object); + + _httpClientFactoryMock = new Mock(); + _httpClientFactoryMock + .Setup(f => f.CreateClient(It.IsAny())) + .Returns(_httpClient); + + var restClient = new RestClient(_httpClientFactoryMock.Object); + + _manager = new RestHubLifetimeManager( + _hubName, + new(FakeEndpointUtils.GetFakeConnectionString(1).First()), + _appName, + restClient, + new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(), + new MessagePackHubProtocol() + }) + ); + } + + [Fact] + public async Task InvokeConnectionAsync_NullMethodName_ThrowsArgumentException() + { + string? methodName = null; + var connectionId = "connection1"; + var args = Array.Empty(); + + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName!, args)); + + Assert.Equal("methodName", exception.ParamName); + + methodName = ""; + exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal("methodName", exception.ParamName); + } + + [Fact] + public async Task InvokeConnectionAsync_NullConnectionId_ThrowsArgumentException() + { + var methodName = "testMethod"; + string? connectionId = null; + var args = Array.Empty(); + + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId!, methodName, args)); + + Assert.Equal("connectionId", exception.ParamName); + + connectionId = ""; + exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal("connectionId", exception.ParamName); + } + + [Fact] + public async Task InvokeConnectionAsync_WithNotFoundResponse_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getError"; + var args = Array.Empty(); + var errorMessage = "Connection does not exist."; + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny() + ) + .ReturnsAsync(() => new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent(errorMessage) + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + } + + [Fact] + public async Task InvokeConnectionAsync_WithBadRequestResponse_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getError"; + var args = Array.Empty(); + var errorMessage = "This is a Bad Request."; + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny() + ) + .ReturnsAsync(() => new HttpResponseMessage(HttpStatusCode.BadRequest) + { + Content = new StringContent(errorMessage) + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await _manager.InvokeConnectionAsync(connectionId, methodName, args)); + Assert.Equal(errorMessage, exception.Message); + } + + [Fact] + public async Task InvokeConnectionAsync_WithStringResult_ReturnsDeserializedValue() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getUsername"; + var args = new object?[] { 42, "test-param", true }; + var expectedResult = "John Doe"; + + // Build a CompletionMessage carrying the string result + var completion = new CompletionMessage( + invocationId: "1234", + error: null, + result: expectedResult, + hasResult: true); + + // Serialize to SignalR JSON frame (with record separator) + var protocol = new JsonHubProtocol(); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }; + + // Protocol header expected by InvokeConnectionAsync + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, protocol.Name); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + + return response; + }); + + // Act + var result = await _manager.InvokeConnectionAsync(connectionId, methodName, args); + + // Assert + Assert.Equal(expectedResult, result); + } + + [Fact] + public async Task InvokeConnectionAsync_WithComplexObjectResult_ReturnsDeserializedObject() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getUserProfile"; + var args = new object?[] { "userId123", new { filter = "personal" } }; + + var expectedProfile = new UserProfile + { + id = 123, + name = "Jane Doe", + active = true, + roles = new[] { "user", "admin" }, + }; + + var completion = new CompletionMessage( + invocationId: "1234", + error: null, + result: expectedProfile, + hasResult: true); + + var protocol = new JsonHubProtocol(); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }; + + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, protocol.Name); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + + return response; + }); + + // Act + var result = await _manager.InvokeConnectionAsync(connectionId, methodName, args); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedProfile.id, result.id); + Assert.Equal(expectedProfile.name, result.name); + Assert.Equal(expectedProfile.active, result.active); + Assert.Equal(expectedProfile.roles.Length, result.roles.Length); + Assert.Contains("admin", result.roles); + } + + [Fact] + public async Task InvokeConnectionAsync_WithMissingProtocolHeader_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getData"; + var args = Array.Empty(); + + var protocol = new JsonHubProtocol(); + var completion = new CompletionMessage("1234", null, "value", hasResult: true); + var payloadBytes = protocol.GetMessageBytes(completion).ToArray(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(payloadBytes), + }); + + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + Assert.Equal("Response is missing protocol header.", ex.Message); + } + + [Fact] + public async Task InvokeConnectionAsync_WithEmptyPayload_ThrowsHubException() + { + // Arrange + var connectionId = "connection1"; + var methodName = "getData"; + var args = Array.Empty(); + + _httpMessageHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(Array.Empty()), + }; + response.Headers.Add(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, "json"); + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/octet-stream"); + return response; + }); + + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => _manager.InvokeConnectionAsync(connectionId, methodName, args)); + + Assert.Equal("Response payload is empty.", ex.Message); + } +#endif + + public class TestHub : Hub { } + + public class UserProfile + { + public int id { get; set; } + public string name { get; set; } = string.Empty; + public bool active { get; set; } + public string[] roles { get; set; } = Array.Empty(); + } + + private sealed class DefaultHubProtocolResolver : IHubProtocolResolver + { + + private readonly List _hubProtocols; + private readonly Dictionary _availableProtocols; + + public IReadOnlyList AllProtocols => _hubProtocols; + + public DefaultHubProtocolResolver(IEnumerable availableProtocols) + { + _availableProtocols = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // We might get duplicates in _hubProtocols, but we're going to check it and overwrite in just a sec. + _hubProtocols = availableProtocols.ToList(); + foreach (var protocol in _hubProtocols) + { + _availableProtocols[protocol.Name] = protocol; + } + } + + public IHubProtocol? GetProtocol(string protocolName, IReadOnlyList? supportedProtocols) + { + protocolName = protocolName ?? throw new ArgumentNullException(nameof(protocolName)); + + if (_availableProtocols.TryGetValue(protocolName, out var protocol) && (supportedProtocols == null || supportedProtocols.Contains(protocolName, StringComparer.OrdinalIgnoreCase))) + { + return protocol; + } + + // null result indicates protocol is not supported + // result will be validated by the caller + return null; + } + } + } +}