Skip to content
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Net.Http;

using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;

internal sealed class SimpleInvocationBinder : IInvocationBinder
{
private readonly Type _returnType;

public SimpleInvocationBinder(Type returnType)
{
_returnType = returnType ?? throw new ArgumentNullException(nameof(returnType));
}

public Type GetReturnType(string invocationId)
{
return _returnType;
}

public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
throw new NotImplementedException();
}

public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
5 changes: 3 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ public Task SendMessageWithRetryAsync(
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
Expand Down Expand Up @@ -198,6 +198,7 @@ private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMeth
private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues>? query, HttpMethod httpMethod, HubMessage? body, Type? typeHint)
{
var request = new HttpRequestMessage(httpMethod, GetUri(url, query));
request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/octet-stream"));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right place. BinaryPayloadContent already contains it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the BinaryPayloadContent, we just add a contentType instead of Accept, or maybe I lost something?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I got the two mixed up. But it's still an appropriate place to put this in BinaryPayloadContent. RestClient is a general entry point, and we don't want JsonPayloadContent to receive binary response payload, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit tricky here as this header is used to identify that "this request is come from ManagementSDK", so runtime will directly return raw bytes and let SDK to deserialize the result from client.
In this case, JsonPayloadContent will also receive the binary response payload

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does JsonPayloadContent also receive the binary response payload?

request.Content = _payloadContentBuilder.Build(body, typeHint);
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var protocolResolver = _serviceProvider.GetRequiredService<IHubProtocolResolver>();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

<ItemGroup Condition=" '$(TargetFrameworkIdentifier)' != '.NETStandard' ">
<FrameworkReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.SignalR.Protocols.MessagePack" Version="8.0.11" />
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
137 changes: 136 additions & 1 deletion src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.SignalR.Protocol;
using System.Buffers;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand All @@ -31,13 +35,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly IHubProtocolResolver _protocolResolver;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_protocolResolver = protocolResolver;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -353,6 +359,135 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
CompletionMessage? responseMessage = null;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
// 1. Get protocol from header (e.g. "json" or "messagepack")
if (!response.Headers.TryGetValues("X-Protocol", out var protocolHeaderValues))
{
throw new HubException("Response is missing protocol header.");
}
var protocolName = protocolHeaderValues.FirstOrDefault();
if (string.IsNullOrEmpty(protocolName))
{
throw new HubException("Response protocol header is empty.");
}

// 2. Pick the hub protocol that matches X-Protocol
var protocol = _protocolResolver.AllProtocols
.FirstOrDefault(p => string.Equals(p.Name, protocolName, StringComparison.OrdinalIgnoreCase));

if (protocol == null)
{
if (string.Equals(protocolName, "messagepack", StringComparison.OrdinalIgnoreCase) &&
_protocolResolver.AllProtocols.Count == 1 &&
_protocolResolver.AllProtocols[0] is JsonObjectSerializerHubProtocol jsonObjectSerializerHubProtocol)
{
// The hub protocol is the default one. Service will convert it to MessagePack and keep backward compatibility as users may depend on this feature for MessagePack client support.
protocol = new MessagePackHubProtocol();
}
else
{
throw new NotSupportedException($"The protocol '{protocolName}' is not supported.");
}
}

// 3. Read raw completion payload from response body
var payloadBytes = await response.Content.ReadAsByteArrayAsync(cancellationToken);
if (payloadBytes == null || payloadBytes.Length == 0)
{
throw new HubException("Response payload is empty.");
}

// 4. Use SimpleInvocationBinder with typeof(T)
var binder = new SimpleInvocationBinder(typeof(T));

// 5. Parse the payload bytes into CompletionMessage
var sequence = new ReadOnlySequence<byte>(payloadBytes);
var local = sequence;
if (!protocol.TryParseMessage(ref local, binder, out var hubMessage))
{
throw new HubException("Failed to parse invocation response.");
}

responseMessage = (CompletionMessage)hubMessage!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.InternalServerError;
},
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}
if ( responseMessage == null)
{
throw new HubException("Response message is null.");
}
if (responseMessage.Error != null)
{
throw new HubException(responseMessage.Error);
}

return (T)responseMessage!.Result!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
case "messagepack":
return true;
default:
return false;
}
}

#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Loading
Loading