Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer ObjectSerializer => _jsonObjectSerializer;
}
17 changes: 17 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ public Task SendMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponseAsync, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
Expand All @@ -99,6 +110,12 @@ public Task SendStreamMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken);
}

public ObjectSerializer ObjectSerializer => _payloadContentBuilder switch
{
JsonPayloadContentBuilder jsonBuilder => jsonBuilder.ObjectSerializer,
_ => throw new NotSupportedException("Only JsonPayloadContentBuilder is supported to get the ObjectSerializer.")
};

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
{
if (query == null || query.Count == 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Text.Json.Serialization;

namespace Microsoft.Azure.SignalR.Management.ClientInvocation;

sealed class InvocationResponse<T>
{
[JsonPropertyName("result")]
public T Result { get; set; }
}
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint SendClientInvocation(string appName, string hubName, string connectionId)
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
71 changes: 71 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Management.ClientInvocation;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand Down Expand Up @@ -353,6 +358,72 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
// Validate input parameters
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}

// Get API endpoint and prepare for the request
var api = _restApiProvider.SendClientInvocation(_appName, _hubName, connectionId);
InvocationResponse<T>? wrapper = null;
string? errorContent = null;
bool isSuccess = false;
// Send request and capture the response
await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
await using var contentStream = await response.Content.ReadAsStreamAsync(cancellationToken);

var deserialized = await _restClient.ObjectSerializer.DeserializeAsync(
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
contentStream,
typeof(InvocationResponse<T>),
cancellationToken);

wrapper = deserialized as InvocationResponse<T>
?? throw new AzureSignalRException("Failed to deserialize response");
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.BadRequest;
Comment thread
ZSWY666 marked this conversation as resolved.
},
cancellationToken);

// Ensure we have a response
if (!isSuccess)
{
throw new AzureSignalRException(errorContent ?? "Unknown error in response");
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
}

return wrapper!.Result;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}
#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,51 @@ public async Task ListConnectionsInGroupTest(ServiceTransportType serviceTranspo
}
}

[ConditionalTheory]
[SkipIfConnectionStringNotPresent]
[MemberData(nameof(TestData))]
public async Task ClientInvocationTest(ServiceTransportType serviceTransportType, string appName)
{
using var logger = StartLog(out var loggerFactory, nameof(ClientInvocationTest));
using var serviceManager = new ServiceManagerBuilder().WithOptions(o =>
{
o.ConnectionString = TestConfiguration.Instance.ConnectionString;
o.ServiceTransportType = serviceTransportType;
o.ApplicationName = appName;
})
.WithLoggerFactory(loggerFactory)
.BuildServiceManager();
using var hubContext = await serviceManager.CreateHubContextAsync(HubName, default);
var negotationResponse = await hubContext.NegotiateAsync();
var clientConnections = await CreateAndStartClientConnections(negotationResponse.Url, Enumerable.Repeat(negotationResponse.AccessToken, 1));

string expectedStringMessage = "Method Invoked";

clientConnections[0].On("Invoke", (string message) =>
{
if (message == "String Response")
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
{
return expectedStringMessage;
}
if (message == "Null Response")
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
{
return null;
}
return "";
});

var response_string = await hubContext.Clients.Client(clientConnections[0].ConnectionId).InvokeAsync<object>("Invoke", "String Response", default);
var response_null = await hubContext.Clients.Client(clientConnections[0].ConnectionId).InvokeAsync<object>("Invoke", "Null Response", default);

Assert.Equal(response_string.ToString(), expectedStringMessage);
Assert.Null(response_null);

foreach (var connection in clientConnections)
{
await connection.StopAsync();
}
}

private static IDictionary<string, List<string>> GenerateUserGroupDict(IList<string> userNames, IList<string> groupNames)
{
return (from i in Enumerable.Range(0, userNames.Count)
Expand Down
Loading