Skip to content

Commit ba7e8de

Browse files
[API] Allow to multihtread register and close. (#26)
1 parent 5948943 commit ba7e8de

File tree

5 files changed

+192
-56
lines changed

5 files changed

+192
-56
lines changed

Marille.Tests/CancellationTests.cs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ namespace Marille.Tests;
33
public class CancellationTests {
44

55
Hub _hub;
6+
SemaphoreSlim _semaphoreSlim;
67
TopicConfiguration configuration;
78

89
public CancellationTests ()
910
{
10-
_hub = new ();
11+
_semaphoreSlim = new (1);
12+
_hub = new (_semaphoreSlim);
1113
configuration = new();
1214
}
1315
[Fact]
@@ -100,4 +102,54 @@ public async Task CloseAllWorkersNoEvents ()
100102
Assert.Equal (0, worker1.ConsumedCount);
101103
Assert.Equal (0, worker2.ConsumedCount);
102104
}
105+
106+
[Fact]
107+
public async Task MultithreadedClose ()
108+
{
109+
var threadCount = 100;
110+
var results = new List<Task<bool>> (100);
111+
112+
Random random = new Random ();
113+
114+
// create the topic and then try to close if from several threads ensuring that only one of them
115+
// closes the channel.
116+
configuration.Mode = ChannelDeliveryMode.AtLeastOnce;
117+
var topic = nameof (MultithreadedClose);
118+
await _hub.CreateAsync<WorkQueuesEvent> (topic, configuration);
119+
120+
// block the closing until we have created all the needed threads
121+
await _semaphoreSlim.WaitAsync ();
122+
123+
for (var index = 0; index < threadCount; index++) {
124+
var tcs = new TaskCompletionSource<bool> ();
125+
results.Add (tcs.Task);
126+
// try to register from diff threads and ensure there are no unexpected issues
127+
// this means that we DO NOT have two true values
128+
// DO NOT AWAIT THE TASKS OR ELSE YOU WILL DEADLOCK
129+
Task.Run (async () => {
130+
// random sleep to ensure that the other thread is also trying to create
131+
var sleep = random.Next (1000);
132+
await Task.Delay (TimeSpan.FromMilliseconds (sleep));
133+
var closed = await _hub.CloseAsync <WorkQueuesEvent> (topic);
134+
tcs.TrySetResult (closed);
135+
});
136+
}
137+
138+
_semaphoreSlim.Release ();
139+
var closed = await Task.WhenAll (results);
140+
bool? positive = null;
141+
var finalResult = true;
142+
// ensure that we have a true and a false, that means that an && should be false
143+
for (var index = 0; index < threadCount; index++) {
144+
finalResult &= closed[index];
145+
if (closed[index] && positive is null) {
146+
positive = true;
147+
continue;
148+
}
149+
if (closed[index] && positive is true) {
150+
Assert.Fail ("More than one close happened.");
151+
}
152+
}
153+
Assert.False (finalResult);
154+
}
103155
}

Marille.Tests/RegistrationTests.cs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ namespace Marille.Tests;
33
public class RegistrationTests {
44

55
Hub _hub;
6+
SemaphoreSlim _semaphoreSlim;
67

78
public RegistrationTests ()
89
{
9-
_hub = new ();
10+
_semaphoreSlim = new(1);
11+
_hub = new (_semaphoreSlim);
1012
}
1113

1214
[Fact]
@@ -70,4 +72,50 @@ public async Task MultipleOneToOneRegistrationWithLambda ()
7072
Assert.True (await _hub.RegisterAsync (topic, worker1));
7173
Assert.False(await _hub.RegisterAsync (topic, action));
7274
}
75+
76+
[Fact]
77+
public async Task MutithreadCreate ()
78+
{
79+
var threadCount = 100;
80+
var results = new List<Task<bool>> (100);
81+
82+
Random random = new Random ();
83+
var topic = nameof (MutithreadCreate );
84+
85+
TopicConfiguration configuration = new() { Mode = ChannelDeliveryMode.AtMostOnceAsync };
86+
87+
await _semaphoreSlim.WaitAsync ();
88+
for (var index = 0; index < threadCount; index++) {
89+
var tcs = new TaskCompletionSource<bool> ();
90+
results.Add (tcs.Task);
91+
// try to register from diff threads and ensure there are no unexpected issues
92+
// this means that we DO NOT have two true values
93+
// DO NOT AWAIT THE TASKS OR ELSE YOU WILL DEADLOCK
94+
Task.Run (async () => {
95+
// random sleep to ensure that the other thread is also trying to create
96+
var sleep = random.Next (1000);
97+
await Task.Delay (TimeSpan.FromMilliseconds (sleep));
98+
var created = await _hub.CreateAsync<WorkQueuesEvent> (topic, configuration);
99+
tcs.TrySetResult (created);
100+
});
101+
}
102+
103+
// release the semaphore so that we can move on
104+
_semaphoreSlim.Release ();
105+
var added = await Task.WhenAll (results);
106+
bool? positive = null;
107+
var finalResult = true;
108+
// ensure that we have a true and a false, that means that an && should be false
109+
for (var index = 0; index < threadCount; index++) {
110+
finalResult &= added[index];
111+
if (added[index] && positive is null) {
112+
positive = true;
113+
continue;
114+
}
115+
if (added[index] && positive is true) {
116+
Assert.Fail ("More than one addition happened.");
117+
}
118+
}
119+
Assert.False (finalResult);
120+
}
73121
}

Marille/Hub.cs

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@ namespace Marille;
77
///
88
/// </summary>
99
public class Hub : IHub {
10+
readonly SemaphoreSlim semaphoreSlim = new(1);
1011
readonly Dictionary<string, Topic> topics = new();
1112
readonly Dictionary<(string Topic, Type Type), (CancellationTokenSource CancellationToken, Task? ConsumeTask)> cancellationTokenSources = new();
1213
readonly Dictionary<(string Topic, Type type), List<object>> workers = new();
13-
1414
public Channel<WorkerError> WorkersExceptions { get; } = Channel.CreateUnbounded<WorkerError> ();
15+
16+
public Hub () : this (new SemaphoreSlim (1)) { }
17+
internal Hub (SemaphoreSlim semaphore)
18+
{
19+
semaphoreSlim = semaphore;
20+
}
1521

1622
void DeliverAtLeastOnce<T> (Channel<Message<T>> channel, IWorker<T> [] workersArray, Message<T> item, TimeSpan? timeout)
1723
where T : struct
@@ -125,24 +131,27 @@ void StopConsuming<T> (string topicName) where T : struct
125131
if (!cancellationTokenSources.TryGetValue ((topicName, type), out var consumingInfo))
126132
return;
127133

128-
if (!TryGetChannel<T> (topicName, out var ch))
134+
if (!TryGetChannel<T> (topicName, out var topic, out _))
129135
return;
130136

131137
// complete the channels, this wont throw an cancellation exception, it will stop the channels from writing
132138
// and the consuming task will finish when it is done with the current message, therefore we can
133139
// use that to know when we are done
134-
ch.Channel.Writer.Complete ();
140+
topic.CloseChannel<T> ();
135141
if (consumingInfo.ConsumeTask is not null)
136142
await consumingInfo.ConsumeTask;
137143

138144
// clean behind us
145+
topic.RemoveChannel<T> ();
146+
workers.Remove ((topicName, type));
139147
cancellationTokenSources.Remove ((topicName, type));
140148
}
141149

142-
bool TryGetChannel<T> (string topicName, [NotNullWhen(true)] out TopicInfo<T>? ch) where T : struct
150+
bool TryGetChannel<T> (string topicName, [NotNullWhen(true)] out Topic? topic, [NotNullWhen(true)] out TopicInfo<T>? ch) where T : struct
143151
{
152+
topic = null;
144153
ch = null;
145-
if (!topics.TryGetValue (topicName, out Topic? topic)) {
154+
if (!topics.TryGetValue (topicName, out topic)) {
146155
return false;
147156
}
148157

@@ -173,21 +182,27 @@ public async Task<bool> CreateAsync<T> (string topicName, TopicConfiguration con
173182

174183
// the topic might already have the channel, in that case, do nothing
175184
Type type = typeof (T);
176-
if (!topics.TryGetValue (topicName, out Topic? topic)) {
177-
topic = new(topicName);
178-
topics [topicName] = topic;
179-
}
185+
await semaphoreSlim.WaitAsync ();
186+
try {
187+
if (!topics.TryGetValue (topicName, out Topic? topic)) {
188+
topic = new(topicName);
189+
topics [topicName] = topic;
190+
}
180191

181-
if (!workers.ContainsKey ((topicName, type))) {
182-
workers [(topicName, type)] = new(initialWorkers);
183-
}
192+
if (!workers.ContainsKey ((topicName, type))) {
193+
workers [(topicName, type)] = new(initialWorkers);
194+
}
184195

185-
if (topic.TryGetChannel<T> (out _)) {
186-
return false;
196+
if (topic.TryGetChannel<T> (out _)) {
197+
return false;
198+
}
199+
200+
var ch = topic.CreateChannel<T> (configuration);
201+
await StartConsuming (topicName, configuration, ch);
202+
return true;
203+
} finally {
204+
semaphoreSlim.Release ( );
187205
}
188-
var ch = topic.CreateChannel<T> (configuration);
189-
await StartConsuming (topicName, configuration, ch);
190-
return true;
191206
}
192207

193208
/// <summary>
@@ -258,23 +273,28 @@ public Task<bool> CreateAsync<T> (string topicName, TopicConfiguration configura
258273
/// <remarks>Workers can be added to channels that are already being processed. The Hub will pause the consumtion
259274
/// of the messages while it adds the worker and will resume the processing after. Producer can be sending
260275
/// messages while this operation takes place because messages will be buffered by the channel.</remarks>
261-
public Task<bool> RegisterAsync<T> (string topicName, params IWorker<T>[] newWorkers) where T : struct
276+
public async Task<bool> RegisterAsync<T> (string topicName, params IWorker<T>[] newWorkers) where T : struct
262277
{
263278
var type = typeof (T);
264-
// we only allow the client to register to an existing topic
265-
// in this API we will not create it, there are other APIs for that
266-
if (!TryGetChannel<T> (topicName, out var ch))
267-
return Task.FromResult(false);
268-
269-
// do not allow to add more than one worker ig we are in AtMostOnce mode.
270-
if (ch.Configuration.Mode == ChannelDeliveryMode.AtMostOnceAsync && workers [(topicName, type)].Count >= 1)
271-
return Task.FromResult (false);
272-
273-
// we will have to stop consuming while we add the new worker
274-
// but we do not need to close the channel, the API will buffer
275-
StopConsuming<T> (topicName);
276-
workers [(topicName, type)].AddRange (newWorkers);
277-
return StartConsuming (topicName, ch.Configuration, ch.Channel);
279+
await semaphoreSlim.WaitAsync ();
280+
try {
281+
// we only allow the client to register to an existing topic
282+
// in this API we will not create it, there are other APIs for that
283+
if (!TryGetChannel<T> (topicName, out _, out var topicInfo))
284+
return false;
285+
286+
// do not allow to add more than one worker ig we are in AtMostOnce mode.
287+
if (topicInfo.Configuration.Mode == ChannelDeliveryMode.AtMostOnceAsync && workers [(topicName, type)].Count >= 1)
288+
return false;
289+
290+
// we will have to stop consuming while we add the new worker
291+
// but we do not need to close the channel, the API will buffer
292+
StopConsuming<T> (topicName);
293+
workers [(topicName, type)].AddRange (newWorkers);
294+
return await StartConsuming (topicName, topicInfo.Configuration, topicInfo.Channel);
295+
} finally {
296+
semaphoreSlim.Release ();
297+
}
278298
}
279299

280300
/// <summary>
@@ -302,11 +322,11 @@ public Task<bool> RegisterAsync<T> (string topicName, Func<T, CancellationToken,
302322
/// (topicName, messageType) combination.</exception>
303323
public ValueTask Publish<T> (string topicName, T publishedEvent) where T : struct
304324
{
305-
if (!TryGetChannel<T> (topicName, out var ch))
325+
if (!TryGetChannel<T> (topicName, out _, out var topicInfo))
306326
throw new InvalidOperationException (
307327
$"Channel with topic {topicName} for event type {typeof(T)} not found");
308328
var message = new Message<T> (MessageType.Data, publishedEvent);
309-
return ch.Channel.Writer.WriteAsync (message);
329+
return topicInfo.Channel.Writer.WriteAsync (message);
310330
}
311331

312332
/// <summary>
@@ -325,17 +345,21 @@ public async Task CloseAllAsync ()
325345
// `Task.WhenAll (consumingTasks!);`
326346
//
327347
// suppressing the warning is ugly when we do know how to help the compiler ;)
328-
329-
var consumingTasks = from consumeInfo in cancellationTokenSources.Values
330-
where consumeInfo.ConsumeTask is not null
331-
select consumeInfo.ConsumeTask;
332-
// remove the need of a second loop by getting the cancellation tokens cancelled
333-
var cancellationTasks = cancellationTokenSources.Values
334-
.Select (x => x.CancellationToken.CancelAsync ());
335-
336-
// we could do a nested Task.WhenAll but we want to ensure that the cancellation tasks are done before
337-
await Task.WhenAll (cancellationTasks);
338-
await Task.WhenAll (consumingTasks);
348+
await semaphoreSlim.WaitAsync ();
349+
try {
350+
var consumingTasks = from consumeInfo in cancellationTokenSources.Values
351+
where consumeInfo.ConsumeTask is not null
352+
select consumeInfo.ConsumeTask;
353+
// remove the need of a second loop by getting the cancellation tokens cancelled
354+
var cancellationTasks = cancellationTokenSources.Values
355+
.Select (x => x.CancellationToken.CancelAsync ());
356+
357+
// we could do a nested Task.WhenAll but we want to ensure that the cancellation tasks are done before
358+
await Task.WhenAll (cancellationTasks);
359+
await Task.WhenAll (consumingTasks);
360+
} finally {
361+
semaphoreSlim.Release ();
362+
}
339363
}
340364

341365
/// <summary>
@@ -347,10 +371,15 @@ where consumeInfo.ConsumeTask is not null
347371
/// <returns>A task that will be completed once the channel has been flushed.</returns>
348372
public async Task<bool> CloseAsync<T> (string topicName) where T : struct
349373
{
350-
// ensure that the channels does exist, if not, return false
351-
if (!TryGetChannel<T> (topicName, out _))
352-
return false;
353-
await StopConsumingAsync<T> (topicName);
354-
return true;
374+
await semaphoreSlim.WaitAsync ();
375+
try {
376+
// ensure that the channels does exist, if not, return false
377+
if (!TryGetChannel<T> (topicName, out _, out _))
378+
return false;
379+
await StopConsumingAsync<T> (topicName);
380+
return true;
381+
} finally {
382+
semaphoreSlim.Release ();
383+
}
355384
}
356385
}

Marille/Marille.csproj

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99
<NoWarn>$(NoWarn);CS1591</NoWarn>
1010
</PropertyGroup>
1111

12+
<ItemGroup>
13+
<AssemblyAttribute Include="System.Runtime.CompilerServices.InternalsVisibleToAttribute">
14+
<_Parameter1>Marille.Tests</_Parameter1>
15+
</AssemblyAttribute>
16+
</ItemGroup>
17+
1218
<PropertyGroup>
1319
<Title>Marille</Title>
1420
<PackageId>Marille</PackageId>
15-
<Version>0.4.1</Version>
21+
<Version>0.4.2</Version>
1622
<Authors>Manuel de la Peña Saenz</Authors>
1723
<Owners>Manuel de la Peña Saenz</Owners>
1824
<Copyright>Manuel de la Peña Saenz</Copyright>

Marille/Topic.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ public void CloseChannel<T> () where T : struct
4040
}
4141

4242
public bool ContainsChannel<T> ()
43-
{
44-
return channels.ContainsKey (typeof (T));
45-
}
43+
=> channels.ContainsKey (typeof (T));
44+
45+
public bool RemoveChannel<T> ()
46+
=> channels.Remove (typeof (T));
4647
}

0 commit comments

Comments
 (0)