Skip to content

Commit f407703

Browse files
authored
Update RabbitMQ Pipeline to asynchronous programming model (#1003)
## Motivation and Context (Why the change? What's the scenario?) This PR updates the RabbitMQ Pipeline implementation with new asynchronous programming model of RabbitMQ.Client v7.0.0 library. See also rabbitmq/rabbitmq-dotnet-client#1720 for further details about the required changes. Closes #995
1 parent 0961ec7 commit f407703

File tree

3 files changed

+89
-64
lines changed

3 files changed

+89
-64
lines changed

Directory.Packages.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
<PackageVersion Include="OllamaSharp" Version="5.0.6" />
3535
<PackageVersion Include="PdfPig" Version="0.1.9" />
3636
<PackageVersion Include="Polly.Core" Version="8.5.2" />
37-
<PackageVersion Include="RabbitMQ.Client" Version="6.8.1" />
37+
<PackageVersion Include="RabbitMQ.Client" Version="7.0.0" />
3838
<PackageVersion Include="ReadLine" Version="2.0.1" />
3939
<PackageVersion Include="Swashbuckle.AspNetCore" Version="7.2.0" />
4040
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />

extensions/RabbitMQ/RabbitMQ.TestApplication/Program.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public static async Task Main()
4444

4545
await pipeline.ConnectToQueueAsync(QueueName, QueueOptions.PubSub);
4646

47-
ListenToDeadLetterQueue(rabbitMQConfig);
47+
await ListenToDeadLetterQueueAsync(rabbitMQConfig);
4848

4949
// Change ConcurrentThreads and PrefetchCount to 1 to see
5050
// how they affect total execution time
@@ -59,7 +59,7 @@ public static async Task Main()
5959
}
6060
}
6161

62-
private static void ListenToDeadLetterQueue(RabbitMQConfig config)
62+
private static async Task ListenToDeadLetterQueueAsync(RabbitMQConfig config)
6363
{
6464
var factory = new ConnectionFactory
6565
{
@@ -68,19 +68,18 @@ private static void ListenToDeadLetterQueue(RabbitMQConfig config)
6868
UserName = config.Username,
6969
Password = config.Password,
7070
VirtualHost = !string.IsNullOrWhiteSpace(config.VirtualHost) ? config.VirtualHost : "/",
71-
DispatchConsumersAsync = true,
7271
Ssl = new SslOption
7372
{
7473
Enabled = config.SslEnabled,
7574
ServerName = config.Host,
7675
}
7776
};
7877

79-
var connection = factory.CreateConnection();
80-
var channel = connection.CreateModel();
78+
var connection = await factory.CreateConnectionAsync();
79+
var channel = await connection.CreateChannelAsync();
8180
var consumer = new AsyncEventingBasicConsumer(channel);
8281

83-
consumer.Received += async (object sender, BasicDeliverEventArgs args) =>
82+
consumer.ReceivedAsync += async (object _, BasicDeliverEventArgs args) =>
8483
{
8584
byte[] body = args.Body.ToArray();
8685
string message = Encoding.UTF8.GetString(body);
@@ -89,7 +88,7 @@ private static void ListenToDeadLetterQueue(RabbitMQConfig config)
8988
await Task.Delay(0);
9089
};
9190

92-
channel.BasicConsume(queue: $"{QueueName}{config.PoisonQueueSuffix}",
91+
await channel.BasicConsumeAsync(queue: $"{QueueName}{config.PoisonQueueSuffix}",
9392
autoAck: true,
9493
consumer: consumer);
9594
}

extensions/RabbitMQ/RabbitMQ/RabbitMQPipeline.cs

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,20 @@
1717
namespace Microsoft.KernelMemory.Orchestration.RabbitMQ;
1818

1919
[Experimental("KMEXP04")]
20-
public sealed class RabbitMQPipeline : IQueue
20+
public sealed class RabbitMQPipeline : IQueue, IAsyncDisposable
2121
{
2222
private readonly ILogger<RabbitMQPipeline> _log;
23-
private readonly IConnection _connection;
24-
private readonly IModel _channel;
25-
private readonly AsyncEventingBasicConsumer _consumer;
23+
24+
private readonly ConnectionFactory _factory;
2625
private readonly RabbitMQConfig _config;
26+
27+
private IConnection? _connection;
28+
private IChannel? _channel;
29+
private AsyncEventingBasicConsumer? _consumer;
30+
31+
// The action that will be executed when a new message is received.
32+
private Func<string, Task<ReturnType>>? _processMessageAction;
33+
2734
private readonly int _messageTTLMsecs;
2835
private readonly int _delayBeforeRetryingMsecs;
2936
private readonly int _maxAttempts;
@@ -40,16 +47,14 @@ public RabbitMQPipeline(RabbitMQConfig config, ILoggerFactory? loggerFactory = n
4047
this._config = config;
4148
this._config.Validate(this._log);
4249

43-
// see https://www.rabbitmq.com/dotnet-api-guide.html#consuming-async
44-
var factory = new ConnectionFactory
50+
this._factory = new ConnectionFactory
4551
{
4652
ClientProvidedName = "KernelMemory",
4753
HostName = config.Host,
4854
Port = config.Port,
4955
UserName = config.Username,
5056
Password = config.Password,
5157
VirtualHost = !string.IsNullOrWhiteSpace(config.VirtualHost) ? config.VirtualHost : "/",
52-
DispatchConsumersAsync = true,
5358
ConsumerDispatchConcurrency = config.ConcurrentThreads,
5459
Ssl = new SslOption
5560
{
@@ -59,22 +64,20 @@ public RabbitMQPipeline(RabbitMQConfig config, ILoggerFactory? loggerFactory = n
5964
};
6065

6166
this._messageTTLMsecs = config.MessageTTLSecs * 1000;
62-
this._connection = factory.CreateConnection();
63-
this._channel = this._connection.CreateModel();
64-
this._channel.BasicQos(prefetchSize: 0, prefetchCount: config.PrefetchCount, global: false);
65-
this._consumer = new AsyncEventingBasicConsumer(this._channel);
6667

6768
this._delayBeforeRetryingMsecs = Math.Max(0, this._config.DelayBeforeRetryingMsecs);
6869
this._maxAttempts = Math.Max(0, this._config.MaxRetriesBeforePoisonQueue) + 1;
6970
}
7071

7172
/// <inheritdoc />
7273
/// About poison queue and dead letters, see https://www.rabbitmq.com/docs/dlx
73-
public Task<IQueue> ConnectToQueueAsync(string queueName, QueueOptions options = default, CancellationToken cancellationToken = default)
74+
public async Task<IQueue> ConnectToQueueAsync(string queueName, QueueOptions options = default, CancellationToken cancellationToken = default)
7475
{
7576
ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(queueName, nameof(queueName), "The queue name is empty");
7677
ArgumentExceptionEx.ThrowIf(queueName.StartsWith("amq.", StringComparison.OrdinalIgnoreCase), nameof(queueName), "The queue name cannot start with 'amq.'");
7778

79+
await this.InitializeAsync().ConfigureAwait(false);
80+
7881
var poisonExchangeName = $"{queueName}.dlx";
7982
var poisonQueueName = $"{queueName}{this._config.PoisonQueueSuffix}";
8083

@@ -94,17 +97,13 @@ public Task<IQueue> ConnectToQueueAsync(string queueName, QueueOptions options =
9497
this._queueName = queueName;
9598
try
9699
{
97-
this._channel.QueueDeclare(
98-
queue: this._queueName,
99-
durable: true,
100-
exclusive: false,
101-
autoDelete: false,
102-
arguments: new Dictionary<string, object>
103-
{
104-
["x-queue-type"] = "quorum",
105-
["x-delivery-limit"] = this._config.MaxRetriesBeforePoisonQueue,
106-
["x-dead-letter-exchange"] = poisonExchangeName
107-
});
100+
await this._channel!.QueueDeclareAsync(queue: this._queueName, durable: true, exclusive: false, autoDelete: false, arguments: new Dictionary<string, object?>
101+
{
102+
["x-queue-type"] = "quorum",
103+
["x-delivery-limit"] = this._config.MaxRetriesBeforePoisonQueue,
104+
["x-dead-letter-exchange"] = poisonExchangeName
105+
}, cancellationToken: cancellationToken).ConfigureAwait(false);
106+
108107
this._log.LogTrace("Queue name: {0}", this._queueName);
109108
}
110109
#pragma warning disable CA2254
@@ -129,60 +128,94 @@ public Task<IQueue> ConnectToQueueAsync(string queueName, QueueOptions options =
129128

130129
// Define poison queue where failed messages are stored
131130
this._poisonQueueName = poisonQueueName;
132-
this._channel.QueueDeclare(
131+
await this._channel.QueueDeclareAsync(
133132
queue: this._poisonQueueName,
134133
durable: true,
135134
exclusive: false,
136135
autoDelete: false,
137-
arguments: null);
136+
arguments: null,
137+
cancellationToken: cancellationToken).ConfigureAwait(false);
138138

139139
// Define exchange to route failed messages to poison queue
140-
this._channel.ExchangeDeclare(poisonExchangeName, "fanout", durable: true, autoDelete: false);
141-
this._channel.QueueBind(this._poisonQueueName, poisonExchangeName, routingKey: string.Empty, arguments: null);
140+
await this._channel.ExchangeDeclareAsync(poisonExchangeName, "fanout", durable: true, autoDelete: false, cancellationToken: cancellationToken).ConfigureAwait(false);
141+
await this._channel.QueueBindAsync(this._poisonQueueName, poisonExchangeName, routingKey: string.Empty, arguments: null, cancellationToken: cancellationToken).ConfigureAwait(false);
142142
this._log.LogTrace("Poison queue name '{0}' bound to exchange '{1}' for queue '{2}'", this._poisonQueueName, poisonExchangeName, this._queueName);
143143

144144
// Activate consumer
145145
if (options.DequeueEnabled)
146146
{
147-
this._channel.BasicConsume(queue: this._queueName, autoAck: false, consumer: this._consumer);
147+
await this._channel.BasicConsumeAsync(queue: this._queueName, autoAck: false, consumer: this._consumer!, cancellationToken: cancellationToken).ConfigureAwait(false);
148148
this._log.LogTrace("Enabling dequeue on queue `{0}`", this._queueName);
149149
}
150150

151-
return Task.FromResult<IQueue>(this);
151+
return this;
152152
}
153153

154154
/// <inheritdoc />
155-
public Task EnqueueAsync(string message, CancellationToken cancellationToken = default)
155+
public async Task EnqueueAsync(string message, CancellationToken cancellationToken = default)
156156
{
157157
if (cancellationToken.IsCancellationRequested)
158158
{
159-
return Task.FromCanceled(cancellationToken);
159+
return;
160160
}
161161

162162
if (string.IsNullOrEmpty(this._queueName))
163163
{
164164
throw new InvalidOperationException("The client must be connected to a queue first");
165165
}
166166

167-
this.PublishMessage(
167+
await this.PublishMessageAsync(
168168
queueName: this._queueName,
169169
body: Encoding.UTF8.GetBytes(message),
170170
messageId: Guid.NewGuid().ToString("N"),
171-
expirationMsecs: this._messageTTLMsecs);
172-
173-
return Task.CompletedTask;
171+
expirationMsecs: this._messageTTLMsecs).ConfigureAwait(false);
174172
}
175173

176174
/// <inheritdoc />
177175
public void OnDequeue(Func<string, Task<ReturnType>> processMessageAction)
178176
{
179-
this._consumer.Received += async (object sender, BasicDeliverEventArgs args) =>
177+
// We just store the action to be executed when a message is received.
178+
// The actual message processing is registered only when the consumer is created.
179+
this._processMessageAction = processMessageAction;
180+
}
181+
182+
public void Dispose()
183+
{
184+
// Note: Start from v7.0, Synchronous Close methods are not available anymore in the library, so we just call Dispose.
185+
((IDisposable)this._channel!).Dispose();
186+
((IDisposable)this._connection!).Dispose();
187+
}
188+
189+
public async ValueTask DisposeAsync()
190+
{
191+
await this._channel!.CloseAsync().ConfigureAwait(false);
192+
await this._connection!.CloseAsync().ConfigureAwait(false);
193+
194+
await this._channel!.DisposeAsync().ConfigureAwait(false);
195+
await this._connection!.DisposeAsync().ConfigureAwait(false);
196+
}
197+
198+
private async Task InitializeAsync()
199+
{
200+
if (this._connection is not null)
201+
{
202+
// The client is already connected.
203+
return;
204+
}
205+
206+
this._connection = await this._factory.CreateConnectionAsync().ConfigureAwait(false);
207+
208+
this._channel = await this._connection.CreateChannelAsync().ConfigureAwait(false);
209+
await this._channel.BasicQosAsync(prefetchSize: 0, prefetchCount: this._config.PrefetchCount, global: false).ConfigureAwait(false);
210+
211+
this._consumer = new AsyncEventingBasicConsumer(this._channel);
212+
this._consumer.ReceivedAsync += async (object _, BasicDeliverEventArgs args) =>
180213
{
181214
// Just for logging, extract the attempt number from the message headers
182215
var attemptNumber = 1;
183216
if (args.BasicProperties?.Headers != null && args.BasicProperties.Headers.TryGetValue("x-delivery-count", out object? value))
184217
{
185-
attemptNumber = int.TryParse(value.ToString(), out var parsedResult) ? ++parsedResult : -1;
218+
attemptNumber = int.TryParse(value!.ToString(), out var parsedResult) ? ++parsedResult : -1;
186219
}
187220

188221
try
@@ -193,12 +226,13 @@ public void OnDequeue(Func<string, Task<ReturnType>> processMessageAction)
193226
byte[] body = args.Body.ToArray();
194227
string message = Encoding.UTF8.GetString(body);
195228

196-
var returnType = await processMessageAction.Invoke(message).ConfigureAwait(false);
229+
// Invokes the action that has been stored in the OnDequeue method.
230+
var returnType = await this._processMessageAction!.Invoke(message).ConfigureAwait(false);
197231
switch (returnType)
198232
{
199233
case ReturnType.Success:
200234
this._log.LogTrace("Message '{0}' successfully processed, deleting message", args.BasicProperties?.MessageId);
201-
this._channel.BasicAck(args.DeliveryTag, multiple: false);
235+
await this._channel.BasicAckAsync(args.DeliveryTag, multiple: false, cancellationToken: args.CancellationToken).ConfigureAwait(false);
202236
break;
203237

204238
case ReturnType.TransientError:
@@ -217,12 +251,12 @@ public void OnDequeue(Func<string, Task<ReturnType>> processMessageAction)
217251
args.BasicProperties?.MessageId, attemptNumber, this._maxAttempts);
218252
}
219253

220-
this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true);
254+
await this._channel.BasicNackAsync(args.DeliveryTag, multiple: false, requeue: true, cancellationToken: args.CancellationToken).ConfigureAwait(false);
221255
break;
222256

223257
case ReturnType.FatalError:
224258
this._log.LogError("Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId);
225-
this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false);
259+
await this._channel.BasicNackAsync(args.DeliveryTag, multiple: false, requeue: false, cancellationToken: args.CancellationToken).ConfigureAwait(false);
226260
break;
227261

228262
default:
@@ -232,7 +266,7 @@ public void OnDequeue(Func<string, Task<ReturnType>> processMessageAction)
232266
catch (KernelMemoryException e) when (e.IsTransient.HasValue && !e.IsTransient.Value)
233267
{
234268
this._log.LogError(e, "Message '{0}' failed to process due to a non-recoverable error, moving to poison queue", args.BasicProperties?.MessageId);
235-
this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: false);
269+
await this._channel.BasicNackAsync(args.DeliveryTag, multiple: false, requeue: false, cancellationToken: args.CancellationToken).ConfigureAwait(false);
236270
}
237271
#pragma warning disable CA1031 // Must catch all to handle queue properly
238272
catch (Exception e)
@@ -258,28 +292,19 @@ public void OnDequeue(Func<string, Task<ReturnType>> processMessageAction)
258292
}
259293

260294
// TODO: verify and document what happens if this fails. RabbitMQ should automatically unlock messages.
261-
this._channel.BasicNack(args.DeliveryTag, multiple: false, requeue: true);
295+
await this._channel.BasicNackAsync(args.DeliveryTag, multiple: false, requeue: true, cancellationToken: args.CancellationToken).ConfigureAwait(false);
262296
}
263297
#pragma warning restore CA1031
264298
};
265299
}
266300

267-
public void Dispose()
268-
{
269-
this._channel.Close();
270-
this._connection.Close();
271-
272-
this._channel.Dispose();
273-
this._connection.Dispose();
274-
}
275-
276-
private void PublishMessage(
301+
private async Task PublishMessageAsync(
277302
string queueName,
278303
ReadOnlyMemory<byte> body,
279304
string messageId,
280305
int? expirationMsecs)
281306
{
282-
var properties = this._channel.CreateBasicProperties();
307+
var properties = new BasicProperties();
283308
properties.Persistent = true;
284309
properties.MessageId = messageId;
285310

@@ -291,11 +316,12 @@ private void PublishMessage(
291316
this._log.LogDebug("Sending message to {0}: {1} (TTL: {2} secs)...",
292317
queueName, properties.MessageId, expirationMsecs.HasValue ? expirationMsecs / 1000 : "infinite");
293318

294-
this._channel.BasicPublish(
319+
await this._channel!.BasicPublishAsync(
295320
routingKey: queueName,
296321
body: body,
297322
exchange: string.Empty,
298-
basicProperties: properties);
323+
basicProperties: properties,
324+
mandatory: true).ConfigureAwait(false);
299325

300326
this._log.LogDebug("Message sent: {0} (TTL: {1} secs)", properties.MessageId, expirationMsecs.HasValue ? expirationMsecs / 1000 : "infinite");
301327
}

0 commit comments

Comments
 (0)