Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ public static class Headers
public const string AsrsMessageTracingId = AsrsInternalHeaderPrefix + "Message-Tracing-Id";

public const string MicrosoftErrorCode = "x-ms-error-code";

public const string AsrsManagementSDKClientInvocationProtocol = AsrsInternalHeaderPrefix + "Protocol";
}

public static class ErrorCodes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand All @@ -12,7 +12,6 @@ public class AzureSignalRInaccessibleEndpointException : AzureSignalRException
private const string ErrorPhenomenon = "Unable to access SignalR service.";
private const string SuggestAction = "Please make sure the endpoint or DNS setting is correct.";


public AzureSignalRInaccessibleEndpointException(string requestUri, Exception innerException) : base(string.IsNullOrEmpty(requestUri) ? $"{ErrorPhenomenon} {innerException.Message} {SuggestAction}" : $"{ErrorPhenomenon} {innerException.Message} {SuggestAction} Request Uri: {requestUri}", innerException)
{
}
Expand All @@ -24,4 +23,4 @@ protected AzureSignalRInaccessibleEndpointException(SerializationInfo info, Stre
{
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;

internal sealed class SimpleInvocationBinder : IInvocationBinder
{
private readonly Type _returnType;

public SimpleInvocationBinder(Type returnType)
{
_returnType = returnType ?? throw new ArgumentNullException(nameof(returnType));
}

public Type GetReturnType(string invocationId)
{
return _returnType;
}

public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
throw new NotImplementedException();
}

public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
28 changes: 22 additions & 6 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public Task SendAsync(
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken);
}

public Task SendWithRetryAsync(
Expand All @@ -73,18 +73,30 @@ public Task SendWithRetryAsync(
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, null, handleExpectedResponseAsync, null, cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, null, cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
string? accepts = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, accepts, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
Expand All @@ -96,7 +108,7 @@ public Task SendStreamMessageWithRetryAsync(
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), null, cancellationToken);
}

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
Expand Down Expand Up @@ -164,11 +176,15 @@ private async Task SendAsyncCore(
HubMessage? body,
Type? typeHint,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
string? accepts = null,
CancellationToken cancellationToken = default)
{
using var httpClient = _httpClientFactory.CreateClient(httpClientName);
using var request = BuildRequest(api, httpMethod, body, typeHint);

if (accepts != null)
{
request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue(accepts));
}
try
{
using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var protocolResolver = _serviceProvider.GetRequiredService<IHubProtocolResolver>();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

<ItemGroup Condition=" '$(TargetFrameworkIdentifier)' != '.NETStandard' ">
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.11" />
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
147 changes: 146 additions & 1 deletion src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using System.Buffers;
using System.IO;
using Microsoft.AspNetCore.SignalR.Protocol;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand All @@ -31,13 +36,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly IHubProtocolResolver _protocolResolver;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_protocolResolver = protocolResolver;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -353,6 +360,144 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
CompletionMessage? responseMessage = null;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
// 1. Get protocol from header (e.g. "json" or "messagepack")
if (!response.Headers.TryGetValues(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, out var protocolHeaderValues))
{
throw new HubException("Response is missing protocol header.");
}
var protocolName = protocolHeaderValues.FirstOrDefault();
if (string.IsNullOrEmpty(protocolName))
{
throw new HubException("Response protocol header is empty.");
}

// 2. Pick the hub protocol that matches X-Protocol
var protocol = _protocolResolver.AllProtocols
.FirstOrDefault(p => string.Equals(p.Name, protocolName, StringComparison.OrdinalIgnoreCase));

if (protocol == null)
{
if (string.Equals(protocolName, "messagepack", StringComparison.OrdinalIgnoreCase) &&
_protocolResolver.AllProtocols.Count == 1 &&
_protocolResolver.AllProtocols[0] is JsonObjectSerializerHubProtocol jsonObjectSerializerHubProtocol)
{
// The hub protocol is the default one. Service will convert it to MessagePack and keep backward compatibility as users may depend on this feature for MessagePack client support.
protocol = new MessagePackHubProtocol();
}
else
{
throw new NotSupportedException($"The protocol '{protocolName}' is not supported.");
}
}

// 3. Read raw completion payload from response body

byte[] buffer;
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken)
.ConfigureAwait(false);
using var ms = new MemoryStream();
await stream.CopyToAsync(ms, cancellationToken).ConfigureAwait(false);

if (ms.Length == 0)
{
throw new HubException("Response payload is empty.");
}

buffer = ms.ToArray();

// 4. Use SimpleInvocationBinder with typeof(T)
var binder = new SimpleInvocationBinder(typeof(T));

// 5. Parse the payload bytes into CompletionMessage
var sequence = new ReadOnlySequence<byte>(buffer);
var local = sequence;
if (!protocol.TryParseMessage(ref local, binder, out var hubMessage))
{
throw new HubException("Failed to parse invocation response.");
}

responseMessage = (CompletionMessage)hubMessage!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.BadRequest;
},
"application/octet-stream",
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}
if (responseMessage == null)
{
throw new HubException("Response message is null.");
}
if (responseMessage.Error != null)
{
throw new HubException(responseMessage.Error);
}

return (T)responseMessage!.Result!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
case "messagepack":
return true;
default:
return false;
}
}

#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Loading
Loading