Skip to content

Commit 739b53c

Browse files
authored
Feature/rework websocketbase (#190)
* Rework logic in live controller * Implement into base * Fix hub swapping and other edge cases * Remove unused class * No need to call abort on a websocket we know is null
1 parent 8b9cee8 commit 739b53c

8 files changed

Lines changed: 220 additions & 202 deletions

File tree

Common/Websocket/WebsockBaseController.cs

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,12 @@ public abstract class WebsocketBaseController<T> : OpenShockControllerBase, IAsy
2828
/// </summary>
2929
protected readonly ILogger<WebsocketBaseController<T>> Logger;
3030

31-
/// <summary>
32-
/// Close cancellation token to be called manually when termination of the current websocket is requested. Called on Dispose as well.
33-
/// </summary>
34-
protected readonly CancellationTokenSource Close = new();
35-
3631
/// <summary>
3732
/// When passing a cancellation token, pass this Linked token, it is a Link from ApplicationStopping and Close.
3833
/// </summary>
39-
protected readonly CancellationTokenSource LinkedSource;
34+
private CancellationTokenSource? _linkedSource;
4035

41-
protected readonly CancellationToken LinkedToken;
36+
protected CancellationToken LinkedToken;
4237

4338
/// <summary>
4439
/// Channel for multithreading thread safety of the websocket, MessageLoop is the only reader for this channel
@@ -53,12 +48,9 @@ public abstract class WebsocketBaseController<T> : OpenShockControllerBase, IAsy
5348
/// DI
5449
/// </summary>
5550
/// <param name="logger"></param>
56-
/// <param name="lifetime"></param>
57-
public WebsocketBaseController(ILogger<WebsocketBaseController<T>> logger, IHostApplicationLifetime lifetime)
51+
protected WebsocketBaseController(ILogger<WebsocketBaseController<T>> logger)
5852
{
5953
Logger = logger;
60-
LinkedSource = CancellationTokenSource.CreateLinkedTokenSource(Close.Token, lifetime.ApplicationStopping);
61-
LinkedToken = LinkedSource.Token;
6254
}
6355

6456

@@ -89,10 +81,9 @@ public virtual async ValueTask DisposeAsync()
8981
await UnregisterConnection();
9082

9183
Channel.Writer.TryComplete();
92-
await Close.CancelAsync();
9384

9485
WebSocket?.Dispose();
95-
LinkedSource.Dispose();
86+
_linkedSource?.Dispose();
9687

9788
GC.SuppressFinalize(this);
9889
Logger.LogTrace("Disposed websocket controller");
@@ -110,18 +101,25 @@ public virtual async ValueTask DisposeAsync()
110101
/// </summary>
111102
[ApiExplorerSettings(IgnoreApi = true)]
112103
[HttpGet]
113-
public async Task Get()
104+
public async Task Get([FromServices] IHostApplicationLifetime lifetime, CancellationToken cancellationToken)
114105
{
106+
#pragma warning disable IDISP003
107+
_linkedSource = CancellationTokenSource.CreateLinkedTokenSource(lifetime.ApplicationStopping, cancellationToken);
108+
#pragma warning restore IDISP003
109+
LinkedToken = _linkedSource.Token;
110+
115111
if (!HttpContext.WebSockets.IsWebSocketRequest)
116112
{
117113
var jsonOptions = HttpContext.RequestServices.GetRequiredService<IOptions<JsonOptions>>();
118114
HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest;
119115
var response = WebsocketError.NonWebsocketRequest;
120116
response.AddContext(HttpContext);
121117
// ReSharper disable once MethodSupportsCancellation
122-
await HttpContext.Response.WriteAsJsonAsync(response, jsonOptions.Value.SerializerOptions,
123-
contentType: MediaTypeNames.Application.ProblemJson);
124-
await Close.CancelAsync();
118+
await HttpContext.Response.WriteAsJsonAsync(
119+
response,
120+
jsonOptions.Value.SerializerOptions,
121+
contentType: MediaTypeNames.Application.ProblemJson,
122+
cancellationToken: cancellationToken);
125123
return;
126124
}
127125

@@ -133,16 +131,19 @@ await HttpContext.Response.WriteAsJsonAsync(response, jsonOptions.Value.Serializ
133131
HttpContext.Response.StatusCode = response.Status ?? StatusCodes.Status400BadRequest;
134132
response.AddContext(HttpContext);
135133
// ReSharper disable once MethodSupportsCancellation
136-
await HttpContext.Response.WriteAsJsonAsync(response, jsonOptions.Value.SerializerOptions,
137-
contentType: MediaTypeNames.Application.ProblemJson);
138-
139-
await Close.CancelAsync();
134+
await HttpContext.Response.WriteAsJsonAsync(
135+
response,
136+
jsonOptions.Value.SerializerOptions,
137+
contentType: MediaTypeNames.Application.ProblemJson,
138+
cancellationToken: cancellationToken);
140139
return;
141140
}
142141

143142
Logger.LogInformation("Opening websocket connection");
144-
WebSocket?.Dispose(); // This should never happen, suppresses warning
143+
144+
#pragma warning disable IDISP003
145145
WebSocket = await HttpContext.WebSockets.AcceptWebSocketAsync();
146+
#pragma warning restore IDISP003
146147

147148
#pragma warning disable CS4014
148149
OsTask.Run(MessageLoop);
@@ -154,8 +155,14 @@ await HttpContext.Response.WriteAsJsonAsync(response, jsonOptions.Value.Serializ
154155
// Logic ended
155156

156157
await UnregisterConnection();
157-
158-
await Close.CancelAsync();
158+
159+
// Only send close if the socket is still open, this allows us to close the websocket from inside the logic
160+
// We send close if the client sent a close message though
161+
if (WebSocket is { State: WebSocketState.Open or WebSocketState.CloseReceived })
162+
{
163+
await WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Normal closure",
164+
LinkedToken);
165+
}
159166
}
160167

161168
#region Send Loop
@@ -198,7 +205,54 @@ protected virtual Task SendWebSocketMessage(T message, WebSocket websocket, Canc
198205
/// </summary>
199206
/// <returns></returns>
200207
[NonAction]
201-
protected abstract Task Logic();
208+
private async Task Logic()
209+
{
210+
while (!LinkedToken.IsCancellationRequested)
211+
{
212+
try
213+
{
214+
if (WebSocket == null)
215+
{
216+
Logger.LogWarning("WebSocket is null, aborting");
217+
return;
218+
}
219+
220+
if (WebSocket.State is WebSocketState.CloseReceived or WebSocketState.CloseSent or WebSocketState.Closed)
221+
{
222+
// Client or we sent close message or both, we will close the connection after this
223+
return;
224+
}
225+
226+
if (WebSocket!.State != WebSocketState.Open)
227+
{
228+
Logger.LogWarning("WebSocket is not open [{State}], aborting", WebSocket.State);
229+
WebSocket?.Abort();
230+
return;
231+
}
232+
233+
await HandleReceive();
234+
235+
}
236+
catch (OperationCanceledException)
237+
{
238+
Logger.LogWarning("WebSocket connection terminated due to close or shutdown");
239+
return;
240+
}
241+
catch (Exception ex)
242+
{
243+
Logger.LogError(ex, "Exception while processing websocket request");
244+
WebSocket?.Abort();
245+
return;
246+
}
247+
}
248+
}
249+
250+
/// <summary>
251+
///
252+
/// </summary>
253+
/// <returns>True if you want to continue the receiver loop, false if you want to terminate</returns>
254+
[NonAction]
255+
protected abstract Task<bool> HandleReceive();
202256

203257
/// <summary>
204258
/// Send initial data to the client

LiveControlGateway/Controllers/HubControllerBase.cs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using FlatSharp;
1+
using System.Net.WebSockets;
2+
using FlatSharp;
23
using Microsoft.AspNetCore.Mvc;
34
using Microsoft.AspNetCore.Mvc.Filters;
45
using Microsoft.Extensions.Options;
@@ -78,30 +79,36 @@ public void OnActionExecuted(ActionExecutedContext context)
7879
/// <summary>
7980
/// Base for hub websocket controllers
8081
/// </summary>
81-
/// <param name="lifetime"></param>
8282
/// <param name="incomingSerializer"></param>
8383
/// <param name="outgoingSerializer"></param>
8484
/// <param name="hubLifetimeManager"></param>
8585
/// <param name="serviceProvider"></param>
8686
/// <param name="options"></param>
8787
/// <param name="logger"></param>
8888
protected HubControllerBase(
89-
IHostApplicationLifetime lifetime,
9089
ISerializer<TIn> incomingSerializer,
9190
ISerializer<TOut> outgoingSerializer,
9291
HubLifetimeManager hubLifetimeManager,
9392
IServiceProvider serviceProvider,
9493
IOptions<LcgOptions> options,
9594
ILogger<FlatbuffersWebsocketBaseController<TIn, TOut>> logger
96-
) : base(logger, lifetime, incomingSerializer, outgoingSerializer)
95+
) : base(logger, incomingSerializer, outgoingSerializer)
9796
{
9897
_hubLifetimeManager = hubLifetimeManager;
9998
ServiceProvider = serviceProvider;
10099
_options = options.Value;
101100
_keepAliveTimeoutTimer.Elapsed += async (_, _) =>
102101
{
103-
Logger.LogInformation("Keep alive timeout reached, closing websocket connection");
104-
await Close.CancelAsync();
102+
try
103+
{
104+
Logger.LogInformation("Keep alive timeout reached, closing websocket connection");
105+
await WebSocket!.CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Keep alive timeout reached",
106+
LinkedToken);
107+
}
108+
catch (Exception ex)
109+
{
110+
Logger.LogError(ex, "Error while closing websocket connection from keep alive timeout");
111+
}
105112
};
106113
_keepAliveTimeoutTimer.Start();
107114
}
@@ -167,7 +174,17 @@ protected override async Task UnregisterConnection()
167174

168175
/// <inheritdoc />
169176
public abstract ValueTask OtaInstall(SemVersion version);
170-
177+
178+
/// <inheritdoc />
179+
public async Task DisconnectOld()
180+
{
181+
if (WebSocket == null)
182+
return;
183+
184+
await WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Hub is connecting from a different location",
185+
LinkedToken);
186+
}
187+
171188
private static DateTimeOffset? GetBootedAtFromUptimeMs(ulong uptimeMs)
172189
{
173190
var uptime = TimeSpan.FromMilliseconds(uptimeMs);

LiveControlGateway/Controllers/HubV1Controller.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ public sealed class HubV1Controller : HubControllerBase<HubToGatewayMessage, Gat
3131
/// <summary>
3232
/// DI
3333
/// </summary>
34-
/// <param name="lifetime"></param>
3534
/// <param name="hubLifetimeManager"></param>
3635
/// <param name="userHubContext"></param>
3736
/// <param name="serviceProvider"></param>
3837
/// <param name="options"></param>
3938
/// <param name="logger"></param>
4039
public HubV1Controller(
41-
IHostApplicationLifetime lifetime,
4240
HubLifetimeManager hubLifetimeManager,
4341
IHubContext<UserHub, IUserHub> userHubContext,
4442
IServiceProvider serviceProvider,
4543
IOptions<LcgOptions> options,
4644
ILogger<HubV1Controller> logger
4745
)
48-
: base(lifetime, HubToGatewayMessage.Serializer, GatewayToHubMessage.Serializer, hubLifetimeManager, serviceProvider, options, logger)
46+
: base(HubToGatewayMessage.Serializer, GatewayToHubMessage.Serializer, hubLifetimeManager, serviceProvider, options, logger)
4947
{
5048
_userHubContext = userHubContext;
5149
}

LiveControlGateway/Controllers/HubV2Controller.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,19 @@ public sealed class HubV2Controller : HubControllerBase<HubToGatewayMessage, Gat
3636
/// <summary>
3737
/// DI
3838
/// </summary>
39-
/// <param name="lifetime"></param>
4039
/// <param name="hubLifetimeManager"></param>
4140
/// <param name="userHubContext"></param>
4241
/// <param name="serviceProvider"></param>
4342
/// <param name="options"></param>
4443
/// <param name="logger"></param>
4544
public HubV2Controller(
46-
IHostApplicationLifetime lifetime,
4745
HubLifetimeManager hubLifetimeManager,
4846
IHubContext<UserHub, IUserHub> userHubContext,
4947
IServiceProvider serviceProvider,
5048
IOptions<LcgOptions> options,
5149
ILogger<HubV2Controller> logger
5250
)
53-
: base(lifetime, HubToGatewayMessage.Serializer, GatewayToHubMessage.Serializer, hubLifetimeManager, serviceProvider, options, logger)
51+
: base(HubToGatewayMessage.Serializer, GatewayToHubMessage.Serializer, hubLifetimeManager, serviceProvider, options, logger)
5452
{
5553
_userHubContext = userHubContext;
5654
_pingTimer = new Timer(PingTimerElapsed, null, Duration.DevicePingInitialDelay, Duration.DevicePingPeriod);

LiveControlGateway/Controllers/IHubController.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,10 @@ public interface IHubController : IAsyncDisposable
3333
/// <param name="version"></param>
3434
/// <returns></returns>
3535
public ValueTask OtaInstall(SemVersion version);
36+
37+
/// <summary>
38+
/// Disconnect the old connection in favor of the new one
39+
/// </summary>
40+
/// <returns></returns>
41+
public Task DisconnectOld();
3642
}

0 commit comments

Comments
 (0)