Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;

internal sealed class SimpleInvocationBinder : IInvocationBinder
{
private readonly Type _returnType;

public SimpleInvocationBinder(Type returnType)
{
_returnType = returnType ?? throw new ArgumentNullException(nameof(returnType));
}

public Type GetReturnType(string invocationId)
{
return _returnType;
}

public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
// No parameters for responses in this scenario
return Array.Empty<Type>();
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
}

#pragma warning disable IDE0060 // Remove unused parameter
public Type GetStreamItemType(string streamId)
#pragma warning restore IDE0060 // Remove unused parameter
{
// No streaming in this scenario
return typeof(object);
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
}
}
35 changes: 30 additions & 5 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public Task SendAsync(
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken);
}

public Task SendWithRetryAsync(
Expand All @@ -73,18 +73,38 @@ public Task SendWithRetryAsync(
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, null, cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
Action<HttpRequestMessage>? preProcessRequest = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);

return SendAsyncCore(Constants.HttpClientNames.MessageResilient,
api,
httpMethod,
new InvocationMessage(methodName, args),
null,
handleExpectedResponse,
preProcessRequest,
cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
Expand All @@ -96,7 +116,7 @@ public Task SendStreamMessageWithRetryAsync(
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), null, cancellationToken);
}

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
Expand Down Expand Up @@ -164,11 +184,16 @@ private async Task SendAsyncCore(
HubMessage? body,
Type? typeHint,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
Action<HttpRequestMessage>? preProcessRequest = null,
CancellationToken cancellationToken = default)
{
using var httpClient = _httpClientFactory.CreateClient(httpClientName);
using var request = BuildRequest(api, httpMethod, body, typeHint);

// preprocess the request
// used for client invocation to add extra headers today.
preProcessRequest?.Invoke(request);

try
{
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Text.Json.Serialization;

namespace Microsoft.Azure.SignalR.Management.ClientInvocation;

sealed class InvocationResponse
{
[JsonPropertyName("result")]
public string Result { get; set; }

[JsonPropertyName("protocol")]
public string protocol { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var protocolResolver = _serviceProvider.GetRequiredService<IHubProtocolResolver>();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

<ItemGroup Condition=" '$(TargetFrameworkIdentifier)' != '.NETStandard' ">
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.11" />
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
134 changes: 133 additions & 1 deletion src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.Azure.SignalR.Management.ClientInvocation;
using Microsoft.AspNetCore.SignalR.Protocol;
using System.Buffers;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand All @@ -31,13 +36,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly IHubProtocolResolver _protocolResolver;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_protocolResolver = protocolResolver;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -353,6 +360,131 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
CompletionMessage? responseMessage = null;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
// 1. Read the outer InvocationResponse (JSON)
await using var contentStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var InvocationResponse = await JsonSerializer.DeserializeAsync<InvocationResponse>(contentStream, cancellationToken: cancellationToken).ConfigureAwait(false);
if (InvocationResponse == null || InvocationResponse.protocol == null || InvocationResponse.Result == null)
{
throw new HubException("Response is null or incomplete.");
}

// 2. Pick the hub protocol that matches clientResponse.Protocol
var protocol = _protocolResolver.AllProtocols
.FirstOrDefault(p => string.Equals(p.Name, InvocationResponse.protocol, StringComparison.OrdinalIgnoreCase));

if (protocol == null)
{
if (string.Equals(InvocationResponse.protocol, "messagepack", StringComparison.OrdinalIgnoreCase) &&
_protocolResolver.AllProtocols.Count == 1 &&
_protocolResolver.AllProtocols[0] is JsonObjectSerializerHubProtocol jsonObjectSerializerHubProtocol)
{
// The hub protocol is the default one. Service will convert it to MessagePack and keep backward compatibility as users may depend on this feature for MessagePack client support.
protocol = new MessagePackHubProtocol();
}
else
{
throw new NotSupportedException($"The protocol '{InvocationResponse.protocol}' is not supported.");
}
}

// 3. Use SimpleInvocationBinder with typeof(T)
var binder = new SimpleInvocationBinder(typeof(T));

// 4. Parse the payload bytes into CompletionMessage
var messageBytes = Convert.FromBase64String(InvocationResponse.Result);
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
var sequence = new ReadOnlySequence<byte>(messageBytes);
var local = sequence;
if (!protocol.TryParseMessage(ref local, binder, out var hubMessage))
{
throw new HubException("Failed to parse invocation response.");
}

responseMessage = (CompletionMessage)hubMessage!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.InternalServerError;
},
(request) =>
{
// Add Accept header for SignalR client to recognize the request from management SDK
request.Headers.Accept!.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/octet-stream"));
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
},
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}
if ( responseMessage == null)
{
throw new HubException("Response message is null.");
}
if (responseMessage.Error != null)
{
throw new HubException(responseMessage.Error);
}

return (T)responseMessage!.Result!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
case "messagepack":
return true;
default:
return false;
}
}

#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Loading
Loading