Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System;
using System.Collections.Generic;
using System.Net.Http;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand All @@ -20,6 +20,8 @@ public BinaryPayloadContentBuilder(IReadOnlyList<IHubProtocol> hubProtocols)
_hubProtocols = hubProtocols;
}

public ObjectSerializer? ObjectSerializer => null;
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated

public HttpContent? Build(HubMessage? payload, Type? typeHint)
{
return payload == null ? null : (HttpContent)new BinaryPayloadMessageContent(payload, _hubProtocols);
Comment thread
ZSWY666 marked this conversation as resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using System;
using System.Net.Http;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand All @@ -13,4 +13,6 @@ namespace Microsoft.Azure.SignalR.Common;
internal interface IPayloadContentBuilder
{
HttpContent? Build(HubMessage? payload, Type? typeHint);

ObjectSerializer? ObjectSerializer { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer;
}
8 changes: 6 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ internal class RestClient

private readonly IPayloadContentBuilder _payloadContentBuilder;

private static readonly ObjectSerializer DefaultObjectSerializer = new JsonObjectSerializer();

public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder)
{
_httpClientFactory = httpClientFactory;
Expand Down Expand Up @@ -81,10 +83,10 @@ public Task SendMessageWithRetryAsync(
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
Expand All @@ -99,6 +101,8 @@ public Task SendStreamMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken);
}

public ObjectSerializer ObjectSerializer => _payloadContentBuilder.ObjectSerializer ?? DefaultObjectSerializer;
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
{
if (query == null || query.Count == 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Text.Json.Serialization;

namespace Microsoft.Azure.SignalR.Management.ClientInvocation;

sealed class InvocationResponse<T>
{
[JsonPropertyName("result")]
public T Result { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var protocolResolver = _serviceProvider.GetRequiredService<IHubProtocolResolver>();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint GetClientInvocationEndpoint(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
105 changes: 104 additions & 1 deletion src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using System.IO;
using Microsoft.AspNetCore.SignalR.Protocol;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand All @@ -31,13 +35,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly IHubProtocolResolver _protocolResolver;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_protocolResolver = protocolResolver;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -353,6 +359,103 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

var api = _restApiProvider.GetClientInvocationEndpoint(_appName, _hubName, connectionId);
string? errorContent = null;
var isSuccess = false;
T? resultValue = default;

await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
await using var contentStream = await response.Content.ReadAsStreamAsync(cancellationToken);

// Deserialize whole response into JsonDocument (InvocationResponse "shell")
using var doc = await JsonDocument.ParseAsync(contentStream, cancellationToken: cancellationToken);
var root = doc.RootElement;

if (!root.TryGetProperty("result", out var resultProperty))
Comment thread
ZSWY666 marked this conversation as resolved.
Outdated
{
throw new HubException("Response cannot be null or empty.");
}

using var resultStream = new MemoryStream();
await using (var utf8Writer = new Utf8JsonWriter(resultStream))
{
resultProperty.WriteTo(utf8Writer);
await utf8Writer.FlushAsync(cancellationToken);
}

resultStream.Position = 0;

var deserialized = await _restClient.ObjectSerializer.DeserializeAsync(
resultStream,
typeof(T),
cancellationToken);
resultValue = (T)deserialized!;
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.InternalServerError;
},
cancellationToken);

if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}

return resultValue!;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
//case "messagepack":
return true;
default:
return false;
}
}

#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Loading
Loading