Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ public static class Headers
public const string AsrsMessageTracingId = AsrsInternalHeaderPrefix + "Message-Tracing-Id";

public const string MicrosoftErrorCode = "x-ms-error-code";

public const string AsrsManagementSDKClientInvocationProtocol = AsrsInternalHeaderPrefix + "Protocol";
}

public static class ErrorCodes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;

internal sealed class SimpleInvocationBinder : IInvocationBinder
{
private readonly Type _returnType;

public SimpleInvocationBinder(Type returnType)
{
_returnType = returnType ?? throw new ArgumentNullException(nameof(returnType));
}

public Type GetReturnType(string invocationId)
{
return _returnType;
}

public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
throw new NotImplementedException();
}

public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
112 changes: 110 additions & 2 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net;
using System.Net.Http;

using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -15,6 +16,8 @@

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using System.Buffers;
using System.Net.Http.Headers;
using Microsoft.AspNetCore.SignalR.Protocol;
#endif
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -359,10 +362,115 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
CompletionMessage? responseMessage = null;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
// 1. Get protocol from header (e.g. "json" or "messagepack")
if (!response.Headers.TryGetValues(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, out var protocolHeaderValues))
{
throw new HubException("Response is missing protocol header.");
}
var protocolName = protocolHeaderValues.FirstOrDefault();
if (string.IsNullOrEmpty(protocolName))
{
throw new HubException("Response protocol header is empty.");
}

// 2. Pick the hub protocol that matches X-Protocol
var protocol = _protocolResolver.GetProtocol(protocolName, supportedProtocols: null);

if (protocol == null)
{
throw new InvalidOperationException(
$"The protocol '{protocolName}' is not configured. " +
$"Add the missing protocol using ServiceManagerBuilder.AddHubProtocol() or ServiceManagerBuilder.WithHubProtocols().");
}

// 3. Read raw completion payload from response body

var buffer = await response.Content.ReadAsByteArrayAsync(cancellationToken)
.ConfigureAwait(false);

if (buffer.Length == 0)
{
throw new HubException("Response payload is empty.");
}

// 4. Use SimpleInvocationBinder with typeof(T)
var binder = new SimpleInvocationBinder(typeof(T));

// 5. Parse the payload bytes into CompletionMessage
var sequence = new ReadOnlySequence<byte>(buffer);
var local = sequence;
if (!protocol.TryParseMessage(ref local, binder, out var hubMessage))
{
throw new HubException("Failed to parse invocation response.");
}

responseMessage = (CompletionMessage)hubMessage!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.BadRequest;
},
new MediaTypeWithQualityHeaderValue("application/octet-stream"),
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}
if (responseMessage == null)
{
throw new HubException("Response message is null.");
}
if (responseMessage.Error != null)
{
throw new HubException(responseMessage.Error);
}

return (T)responseMessage!.Result!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

#pragma warning disable IDE0051 // Will be used in the future updates
private static bool IsInvocationSupported(IHubProtocol protocol)
#pragma warning restore IDE0051 // Will be used in the future updates
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
Expand Down
Loading