@@ -7,11 +7,17 @@ namespace Marille;
77///
88/// </summary>
99public class Hub : IHub {
10+ readonly SemaphoreSlim semaphoreSlim = new ( 1 ) ;
1011 readonly Dictionary < string , Topic > topics = new ( ) ;
1112 readonly Dictionary < ( string Topic , Type Type ) , ( CancellationTokenSource CancellationToken , Task ? ConsumeTask ) > cancellationTokenSources = new ( ) ;
1213 readonly Dictionary < ( string Topic , Type type ) , List < object > > workers = new ( ) ;
13-
1414 public Channel < WorkerError > WorkersExceptions { get ; } = Channel . CreateUnbounded < WorkerError > ( ) ;
15+
16+ public Hub ( ) : this ( new SemaphoreSlim ( 1 ) ) { }
17+ internal Hub ( SemaphoreSlim semaphore )
18+ {
19+ semaphoreSlim = semaphore ;
20+ }
1521
1622 void DeliverAtLeastOnce < T > ( Channel < Message < T > > channel , IWorker < T > [ ] workersArray , Message < T > item , TimeSpan ? timeout )
1723 where T : struct
@@ -125,24 +131,27 @@ void StopConsuming<T> (string topicName) where T : struct
125131 if ( ! cancellationTokenSources . TryGetValue ( ( topicName , type ) , out var consumingInfo ) )
126132 return ;
127133
128- if ( ! TryGetChannel < T > ( topicName , out var ch ) )
134+ if ( ! TryGetChannel < T > ( topicName , out var topic , out _ ) )
129135 return ;
130136
131137 // complete the channels, this wont throw an cancellation exception, it will stop the channels from writing
132138 // and the consuming task will finish when it is done with the current message, therefore we can
133139 // use that to know when we are done
134- ch . Channel . Writer . Complete ( ) ;
140+ topic . CloseChannel < T > ( ) ;
135141 if ( consumingInfo . ConsumeTask is not null )
136142 await consumingInfo . ConsumeTask ;
137143
138144 // clean behind us
145+ topic . RemoveChannel < T > ( ) ;
146+ workers . Remove ( ( topicName , type ) ) ;
139147 cancellationTokenSources . Remove ( ( topicName , type ) ) ;
140148 }
141149
142- bool TryGetChannel < T > ( string topicName , [ NotNullWhen ( true ) ] out TopicInfo < T > ? ch ) where T : struct
150+ bool TryGetChannel < T > ( string topicName , [ NotNullWhen ( true ) ] out Topic ? topic , [ NotNullWhen ( true ) ] out TopicInfo < T > ? ch ) where T : struct
143151 {
152+ topic = null ;
144153 ch = null ;
145- if ( ! topics . TryGetValue ( topicName , out Topic ? topic ) ) {
154+ if ( ! topics . TryGetValue ( topicName , out topic ) ) {
146155 return false ;
147156 }
148157
@@ -173,21 +182,27 @@ public async Task<bool> CreateAsync<T> (string topicName, TopicConfiguration con
173182
174183 // the topic might already have the channel, in that case, do nothing
175184 Type type = typeof ( T ) ;
176- if ( ! topics . TryGetValue ( topicName , out Topic ? topic ) ) {
177- topic = new ( topicName ) ;
178- topics [ topicName ] = topic ;
179- }
185+ await semaphoreSlim . WaitAsync ( ) ;
186+ try {
187+ if ( ! topics . TryGetValue ( topicName , out Topic ? topic ) ) {
188+ topic = new ( topicName ) ;
189+ topics [ topicName ] = topic ;
190+ }
180191
181- if ( ! workers . ContainsKey ( ( topicName , type ) ) ) {
182- workers [ ( topicName , type ) ] = new ( initialWorkers ) ;
183- }
192+ if ( ! workers . ContainsKey ( ( topicName , type ) ) ) {
193+ workers [ ( topicName , type ) ] = new ( initialWorkers ) ;
194+ }
184195
185- if ( topic . TryGetChannel < T > ( out _ ) ) {
186- return false ;
196+ if ( topic . TryGetChannel < T > ( out _ ) ) {
197+ return false ;
198+ }
199+
200+ var ch = topic . CreateChannel < T > ( configuration ) ;
201+ await StartConsuming ( topicName , configuration , ch ) ;
202+ return true ;
203+ } finally {
204+ semaphoreSlim . Release ( ) ;
187205 }
188- var ch = topic . CreateChannel < T > ( configuration ) ;
189- await StartConsuming ( topicName , configuration , ch ) ;
190- return true ;
191206 }
192207
193208 /// <summary>
@@ -258,23 +273,28 @@ public Task<bool> CreateAsync<T> (string topicName, TopicConfiguration configura
258273 /// <remarks>Workers can be added to channels that are already being processed. The Hub will pause the consumtion
259274 /// of the messages while it adds the worker and will resume the processing after. Producer can be sending
260275 /// messages while this operation takes place because messages will be buffered by the channel.</remarks>
261- public Task < bool > RegisterAsync < T > ( string topicName , params IWorker < T > [ ] newWorkers ) where T : struct
276+ public async Task < bool > RegisterAsync < T > ( string topicName , params IWorker < T > [ ] newWorkers ) where T : struct
262277 {
263278 var type = typeof ( T ) ;
264- // we only allow the client to register to an existing topic
265- // in this API we will not create it, there are other APIs for that
266- if ( ! TryGetChannel < T > ( topicName , out var ch ) )
267- return Task . FromResult ( false ) ;
268-
269- // do not allow to add more than one worker ig we are in AtMostOnce mode.
270- if ( ch . Configuration . Mode == ChannelDeliveryMode . AtMostOnceAsync && workers [ ( topicName , type ) ] . Count >= 1 )
271- return Task . FromResult ( false ) ;
272-
273- // we will have to stop consuming while we add the new worker
274- // but we do not need to close the channel, the API will buffer
275- StopConsuming < T > ( topicName ) ;
276- workers [ ( topicName , type ) ] . AddRange ( newWorkers ) ;
277- return StartConsuming ( topicName , ch . Configuration , ch . Channel ) ;
279+ await semaphoreSlim . WaitAsync ( ) ;
280+ try {
281+ // we only allow the client to register to an existing topic
282+ // in this API we will not create it, there are other APIs for that
283+ if ( ! TryGetChannel < T > ( topicName , out _ , out var topicInfo ) )
284+ return false ;
285+
286+ // do not allow to add more than one worker ig we are in AtMostOnce mode.
287+ if ( topicInfo . Configuration . Mode == ChannelDeliveryMode . AtMostOnceAsync && workers [ ( topicName , type ) ] . Count >= 1 )
288+ return false ;
289+
290+ // we will have to stop consuming while we add the new worker
291+ // but we do not need to close the channel, the API will buffer
292+ StopConsuming < T > ( topicName ) ;
293+ workers [ ( topicName , type ) ] . AddRange ( newWorkers ) ;
294+ return await StartConsuming ( topicName , topicInfo . Configuration , topicInfo . Channel ) ;
295+ } finally {
296+ semaphoreSlim . Release ( ) ;
297+ }
278298 }
279299
280300 /// <summary>
@@ -302,11 +322,11 @@ public Task<bool> RegisterAsync<T> (string topicName, Func<T, CancellationToken,
302322 /// (topicName, messageType) combination.</exception>
303323 public ValueTask Publish < T > ( string topicName , T publishedEvent ) where T : struct
304324 {
305- if ( ! TryGetChannel < T > ( topicName , out var ch ) )
325+ if ( ! TryGetChannel < T > ( topicName , out _ , out var topicInfo ) )
306326 throw new InvalidOperationException (
307327 $ "Channel with topic { topicName } for event type { typeof ( T ) } not found") ;
308328 var message = new Message < T > ( MessageType . Data , publishedEvent ) ;
309- return ch . Channel . Writer . WriteAsync ( message ) ;
329+ return topicInfo . Channel . Writer . WriteAsync ( message ) ;
310330 }
311331
312332 /// <summary>
@@ -325,17 +345,21 @@ public async Task CloseAllAsync ()
325345 // `Task.WhenAll (consumingTasks!);`
326346 //
327347 // suppressing the warning is ugly when we do know how to help the compiler ;)
328-
329- var consumingTasks = from consumeInfo in cancellationTokenSources . Values
330- where consumeInfo . ConsumeTask is not null
331- select consumeInfo . ConsumeTask ;
332- // remove the need of a second loop by getting the cancellation tokens cancelled
333- var cancellationTasks = cancellationTokenSources . Values
334- . Select ( x => x . CancellationToken . CancelAsync ( ) ) ;
335-
336- // we could do a nested Task.WhenAll but we want to ensure that the cancellation tasks are done before
337- await Task . WhenAll ( cancellationTasks ) ;
338- await Task . WhenAll ( consumingTasks ) ;
348+ await semaphoreSlim . WaitAsync ( ) ;
349+ try {
350+ var consumingTasks = from consumeInfo in cancellationTokenSources . Values
351+ where consumeInfo . ConsumeTask is not null
352+ select consumeInfo . ConsumeTask ;
353+ // remove the need of a second loop by getting the cancellation tokens cancelled
354+ var cancellationTasks = cancellationTokenSources . Values
355+ . Select ( x => x . CancellationToken . CancelAsync ( ) ) ;
356+
357+ // we could do a nested Task.WhenAll but we want to ensure that the cancellation tasks are done before
358+ await Task . WhenAll ( cancellationTasks ) ;
359+ await Task . WhenAll ( consumingTasks ) ;
360+ } finally {
361+ semaphoreSlim . Release ( ) ;
362+ }
339363 }
340364
341365 /// <summary>
@@ -347,10 +371,15 @@ where consumeInfo.ConsumeTask is not null
347371 /// <returns>A task that will be completed once the channel has been flushed.</returns>
348372 public async Task < bool > CloseAsync < T > ( string topicName ) where T : struct
349373 {
350- // ensure that the channels does exist, if not, return false
351- if ( ! TryGetChannel < T > ( topicName , out _ ) )
352- return false ;
353- await StopConsumingAsync < T > ( topicName ) ;
354- return true ;
374+ await semaphoreSlim . WaitAsync ( ) ;
375+ try {
376+ // ensure that the channels does exist, if not, return false
377+ if ( ! TryGetChannel < T > ( topicName , out _ , out _ ) )
378+ return false ;
379+ await StopConsumingAsync < T > ( topicName ) ;
380+ return true ;
381+ } finally {
382+ semaphoreSlim . Release ( ) ;
383+ }
355384 }
356385}
0 commit comments