Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ public static class Headers
public const string AsrsMessageTracingId = AsrsInternalHeaderPrefix + "Message-Tracing-Id";

public const string MicrosoftErrorCode = "x-ms-error-code";

public const string AsrsManagementSDKClientInvocationProtocol = AsrsInternalHeaderPrefix + "Protocol";
}

public static class ErrorCodes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;

internal sealed class SimpleInvocationBinder : IInvocationBinder
{
private readonly Type _returnType;

public SimpleInvocationBinder(Type returnType)
{
_returnType = returnType ?? throw new ArgumentNullException(nameof(returnType));
}

public Type GetReturnType(string invocationId)
{
return _returnType;
}

public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
throw new NotImplementedException();
}

public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
112 changes: 110 additions & 2 deletions src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net;
using System.Net.Http;

using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -15,6 +16,8 @@

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using System.Buffers;
using System.Net.Http.Headers;
using Microsoft.AspNetCore.SignalR.Protocol;
#endif
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -359,10 +362,115 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
CompletionMessage? responseMessage = null;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
// 1. Get protocol from header (e.g. "json" or "messagepack")
if (!response.Headers.TryGetValues(Constants.Headers.AsrsManagementSDKClientInvocationProtocol, out var protocolHeaderValues))
{
throw new HubException("Response is missing protocol header.");
}
var protocolName = protocolHeaderValues.FirstOrDefault();
if (string.IsNullOrEmpty(protocolName))
{
throw new HubException("Response protocol header is empty.");
}

// 2. Pick the hub protocol that matches X-Protocol
var protocol = _protocolResolver.GetProtocol(protocolName, supportedProtocols: null);

if (protocol == null)
{
throw new InvalidOperationException(
$"The protocol '{protocolName}' is not configured. " +
$"Add the missing protocol using ServiceManagerBuilder.AddHubProtocol() or ServiceManagerBuilder.WithHubProtocols().");
}

// 3. Read raw completion payload from response body

var buffer = await response.Content.ReadAsByteArrayAsync(cancellationToken)
.ConfigureAwait(false);

if (buffer.Length == 0)
{
throw new HubException("Response payload is empty.");
}

// 4. Use SimpleInvocationBinder with typeof(T)
var binder = new SimpleInvocationBinder(typeof(T));

// 5. Parse the payload bytes into CompletionMessage
var sequence = new ReadOnlySequence<byte>(buffer);
var local = sequence;
if (!protocol.TryParseMessage(ref local, binder, out var hubMessage))
{
throw new HubException("Failed to parse invocation response.");
}

responseMessage = (CompletionMessage)hubMessage!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.BadRequest;
},
new MediaTypeWithQualityHeaderValue("application/octet-stream"),
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}
if (responseMessage == null)
{
throw new HubException("Response message is null.");
}
if (responseMessage.Error != null)
{
throw new HubException(responseMessage.Error);
}

return (T)responseMessage!.Result!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

#pragma warning disable IDE0051 // Will be used in the future updates
private static bool IsInvocationSupported(IHubProtocol protocol)
#pragma warning restore IDE0051 // Will be used in the future updates
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
Expand Down
Loading
Loading