Skip to content

Commit e9b52df

Browse files
committed
fix: invoke/stream cancellation races and add regression tests
1 parent 367765c commit e9b52df

5 files changed

Lines changed: 627 additions & 51 deletions

File tree

src/Eventa/EventInvoke.cs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,25 @@ void CompleteFaulted(Exception error)
6666
Cleanup();
6767
}
6868

69-
void CompleteCanceled()
69+
void FinishCanceled(bool emitAbort)
7070
{
7171
if (Interlocked.Exchange(ref finished, 1) != 0)
7272
{
7373
return;
7474
}
7575

76+
if (emitAbort)
77+
{
78+
context.Emit(sendAbortEvent, new AbortPayload(invokeId));
79+
}
80+
7681
completion.TrySetCanceled(cancellationToken);
7782
Cleanup();
7883
}
7984

8085
void AbortFromClient()
8186
{
82-
context.Emit(sendAbortEvent, new AbortPayload(invokeId));
83-
CompleteCanceled();
87+
FinishCanceled(emitAbort: true);
8488
}
8589

8690
disposables.Add(context.On(receiveEvent, envelope =>
@@ -118,14 +122,14 @@ void AbortFromClient()
118122

119123
if (cancellationToken.CanBeCanceled)
120124
{
121-
var registration = cancellationToken.Register(AbortFromClient);
122-
disposables.Add(new ActionDisposable(registration.Dispose));
123-
124125
if (cancellationToken.IsCancellationRequested)
125126
{
126127
AbortFromClient();
127128
return completion.Task;
128129
}
130+
131+
var registration = cancellationToken.Register(AbortFromClient);
132+
disposables.Add(new ActionDisposable(registration.Dispose));
129133
}
130134

131135
context.Emit(sendEvent, new SendPayload<TRequest>(invokeId, request));
@@ -143,17 +147,16 @@ public static IDisposable DefineInvokeHandler<TResponse, TRequest>(
143147
ArgumentNullException.ThrowIfNull(handler);
144148

145149
var registry = HandlerRegistries.GetValue(context, static _ => new InvokeHandlerRegistry());
146-
HandlerRegistration? registration;
147150

148151
lock (registry.SyncRoot)
149152
{
150153
if (!registry.Registrations.TryGetValue(eventDefinition.SendEventId, out var handlers))
151154
{
152-
handlers = new Dictionary<Delegate, HandlerRegistration>();
155+
handlers = [];
153156
registry.Registrations[eventDefinition.SendEventId] = handlers;
154157
}
155158

156-
if (!handlers.TryGetValue(handler, out registration))
159+
if (!handlers.TryGetValue(handler, out HandlerRegistration? registration))
157160
{
158161
registration = CreateUnaryHandlerRegistration(context, eventDefinition, handler);
159162
handlers[handler] = registration;
@@ -173,17 +176,16 @@ public static IDisposable DefineInvokeHandler<TResponse, TRequest>(
173176
ArgumentNullException.ThrowIfNull(handler);
174177

175178
var registry = HandlerRegistries.GetValue(context, static _ => new InvokeHandlerRegistry());
176-
HandlerRegistration? registration;
177179

178180
lock (registry.SyncRoot)
179181
{
180182
if (!registry.Registrations.TryGetValue(eventDefinition.SendEventId, out var handlers))
181183
{
182-
handlers = new Dictionary<Delegate, HandlerRegistration>();
184+
handlers = [];
183185
registry.Registrations[eventDefinition.SendEventId] = handlers;
184186
}
185187

186-
if (!handlers.TryGetValue(handler, out registration))
188+
if (!handlers.TryGetValue(handler, out HandlerRegistration? registration))
187189
{
188190
registration = CreateRequestStreamHandlerRegistration(context, eventDefinition, handler);
189191
handlers[handler] = registration;

src/Eventa/EventStream.cs

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public static IAsyncEnumerable<TResponse> DefineStreamInvoke<TResponse, TRequest
1515
context,
1616
eventDefinition,
1717
cancellationToken,
18-
invokeId =>
18+
(invokeId, _) =>
1919
{
2020
var sendEvent = new EventDefinition<SendPayload<TRequest>>(eventDefinition.SendEventId);
2121
context.Emit(sendEvent, new SendPayload<TRequest>(invokeId, request));
@@ -37,29 +37,29 @@ public static IAsyncEnumerable<TResponse> DefineStreamInvoke<TResponse, TRequest
3737
context,
3838
eventDefinition,
3939
cancellationToken,
40-
async invokeId =>
40+
async (invokeId, requestCancellationToken) =>
4141
{
4242
var sendEvent = new EventDefinition<SendPayload<TRequest>>(eventDefinition.SendEventId);
4343
var sendStreamEndEvent = new EventDefinition<StreamEndPayload>(eventDefinition.SendStreamEndId);
4444

4545
try
4646
{
47-
await foreach (var item in request.WithCancellation(cancellationToken).ConfigureAwait(false))
47+
await foreach (var item in request.WithCancellation(requestCancellationToken).ConfigureAwait(false))
4848
{
49-
if (cancellationToken.IsCancellationRequested)
49+
if (requestCancellationToken.IsCancellationRequested)
5050
{
5151
return;
5252
}
5353

5454
context.Emit(sendEvent, new SendPayload<TRequest>(invokeId, item));
5555
}
5656
}
57-
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
57+
catch (OperationCanceledException) when (requestCancellationToken.IsCancellationRequested)
5858
{
5959
return;
6060
}
6161

62-
if (!cancellationToken.IsCancellationRequested)
62+
if (!requestCancellationToken.IsCancellationRequested)
6363
{
6464
context.Emit(sendStreamEndEvent, new StreamEndPayload(invokeId));
6565
}
@@ -109,9 +109,7 @@ async Task HandleInvokeAsync(string invokeId, TRequest request)
109109
context.Emit(receiveStreamEndEvent, new StreamEndPayload(invokeId));
110110
}
111111
}
112-
catch (OperationCanceledException) when (cancellationSource.IsCancellationRequested)
113-
{
114-
}
112+
catch (OperationCanceledException) when (cancellationSource.IsCancellationRequested) { }
115113
catch (Exception error)
116114
{
117115
if (!cancellationSource.IsCancellationRequested)
@@ -220,9 +218,7 @@ async Task HandleInvokeAsync(RequestStreamInvocationState<TRequest> state)
220218
context.Emit(receiveStreamEndEvent, new StreamEndPayload(state.InvokeId));
221219
}
222220
}
223-
catch (OperationCanceledException) when (state.CancellationSource.IsCancellationRequested)
224-
{
225-
}
221+
catch (OperationCanceledException) when (state.CancellationSource.IsCancellationRequested) { }
226222
catch (Exception error)
227223
{
228224
if (!state.CancellationSource.IsCancellationRequested)
@@ -319,14 +315,18 @@ private static IAsyncEnumerable<TResponse> CreateStreamInvoke<TResponse, TReques
319315
IEventContext context,
320316
InvokeEventDefinition<TResponse, TRequest> eventDefinition,
321317
CancellationToken cancellationToken,
322-
Func<string, Task> sendRequest)
318+
Func<string, CancellationToken, Task> sendRequest)
323319
{
324320
var invokeId = IdGenerator.New();
325321
var sendAbortEvent = new EventDefinition<AbortPayload>(eventDefinition.SendAbortId);
326322
var receiveEvent = new EventDefinition<ReceivePayload<TResponse>>(eventDefinition.ReceiveEventId);
327323
var receiveErrorEvent = new EventDefinition<ReceiveErrorPayload>(eventDefinition.ReceiveErrorId);
328324
var receiveStreamEndEvent = new EventDefinition<StreamEndPayload>(eventDefinition.ReceiveStreamEndId);
329325
var responses = new AsyncSignalQueue<TResponse>();
326+
var requestCancellationSource = cancellationToken.CanBeCanceled
327+
? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)
328+
: new CancellationTokenSource();
329+
var requestCancellationToken = requestCancellationSource.Token;
330330
var subscriptions = new List<IDisposable>();
331331
var finished = 0;
332332

@@ -338,37 +338,20 @@ void Cleanup()
338338
}
339339
}
340340

341-
void Complete()
342-
{
343-
if (Interlocked.Exchange(ref finished, 1) != 0)
344-
{
345-
return;
346-
}
347-
348-
responses.Complete();
349-
Cleanup();
350-
}
351-
352-
void Fault(Exception error)
341+
void Finish(Exception? error, bool emitAbort)
353342
{
354343
if (Interlocked.Exchange(ref finished, 1) != 0)
355344
{
356345
return;
357346
}
358347

359-
responses.Fault(error);
360-
Cleanup();
361-
}
348+
requestCancellationSource.Cancel();
362349

363-
void AbortWithCompletion(Exception? error)
364-
{
365-
if (Interlocked.Exchange(ref finished, 1) != 0)
350+
if (emitAbort)
366351
{
367-
return;
352+
context.Emit(sendAbortEvent, new AbortPayload(invokeId));
368353
}
369354

370-
context.Emit(sendAbortEvent, new AbortPayload(invokeId));
371-
372355
if (error is null)
373356
{
374357
responses.Complete();
@@ -379,6 +362,22 @@ void AbortWithCompletion(Exception? error)
379362
}
380363

381364
Cleanup();
365+
requestCancellationSource.Dispose();
366+
}
367+
368+
void Complete()
369+
{
370+
Finish(error: null, emitAbort: false);
371+
}
372+
373+
void Fault(Exception error)
374+
{
375+
Finish(error, emitAbort: false);
376+
}
377+
378+
void AbortWithCompletion(Exception? error)
379+
{
380+
Finish(error, emitAbort: true);
382381
}
383382

384383
subscriptions.Add(context.On(receiveEvent, envelope =>
@@ -428,11 +427,9 @@ void AbortWithCompletion(Exception? error)
428427
{
429428
try
430429
{
431-
await sendRequest(invokeId).ConfigureAwait(false);
432-
}
433-
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
434-
{
430+
await sendRequest(invokeId, requestCancellationToken).ConfigureAwait(false);
435431
}
432+
catch (OperationCanceledException) when (requestCancellationToken.IsCancellationRequested) { }
436433
catch (Exception error)
437434
{
438435
Fault(error);

tests/Eventa.Tests/EventContextTests.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,55 @@ public void MatchExpressionSubscriptions_ReceiveOnlyMatchingPayloads()
156156
Assert.Equal(["match-first", "match-second"], matchedValues);
157157
}
158158

159+
[Fact]
160+
public void Emit_CallsAdapterOnSent_AfterLocalListeners()
161+
{
162+
var calls = new List<string>();
163+
using var adapter = new RecordingAdapter(calls);
164+
using var context = new EventContext(adapter);
165+
var definition = new EventDefinition<TestPayload>("test-event");
166+
var expression = new MatchExpression<TestPayload>("match-test-event", _ => true);
167+
168+
using var _ = context.On(definition, _ => calls.Add("listener"));
169+
using var __ = context.On(expression, _ => calls.Add("match"));
170+
171+
context.Emit(definition, new TestPayload("test"));
172+
173+
Assert.Equal(
174+
["listener", "received:test-event", "match", "received:match-test-event", "sent:test-event"],
175+
calls);
176+
}
177+
178+
[Fact]
179+
public void Emit_WhenListenerThrows_DoesNotCallAdapterOnSent()
180+
{
181+
var calls = new List<string>();
182+
using var adapter = new RecordingAdapter(calls);
183+
using var context = new EventContext(adapter);
184+
var definition = new EventDefinition<TestPayload>("test-event");
185+
186+
using var _ = context.On(definition, _ => throw new InvalidOperationException("boom"));
187+
188+
var error = Assert.Throws<InvalidOperationException>(() => context.Emit(definition, new TestPayload("test")));
189+
190+
Assert.Equal("boom", error.Message);
191+
Assert.Empty(calls);
192+
}
193+
159194
private sealed record TestPayload(string Value);
195+
196+
private sealed class RecordingAdapter(List<string> calls) : IEventaAdapter
197+
{
198+
public void OnSent(string eventId, object? _, object? __ = null)
199+
{
200+
calls.Add($"sent:{eventId}");
201+
}
202+
203+
public void OnReceived(string eventId, object? _)
204+
{
205+
calls.Add($"received:{eventId}");
206+
}
207+
208+
public void Dispose() { }
209+
}
160210
}

0 commit comments

Comments
 (0)