Skip to content

Add keyed registration and resolvation of handlers #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions Rebus.ServiceProvider.Tests/Examples/KeyedServices.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using NUnit.Framework;
using Rebus.Config;
using Rebus.Handlers;
using Rebus.Tests.Contracts;
using Rebus.Transport.InMem;
#pragma warning disable CS1998

namespace Rebus.ServiceProvider.Tests.Examples;

#if NET8_0_OR_GREATER
[TestFixture]
[Description("Demonstrates how a service key can be used to resolve Rebus handlers")]
public class KeyedServices : FixtureBase
{
[Test]
public async Task ItWorks()
{
var services = new ServiceCollection();

var bus1Key = "bus1";
var bus2Key = "bus2";

var log = new List<string>();

services.AddRebusHandler<Handler1>(bus1Key, (serviceProvider, serviceKey) => new((string)serviceKey, log));
services.AddRebusHandler<Handler2>(bus2Key, (serviceProvider, serviceKey) => new((string)serviceKey, log));

services.AddRebus(
configure => configure.Transport(t => t.UseInMemoryTransport(new InMemNetwork(), $"queue-{Guid.NewGuid():N}")),
isDefaultBus: true,
key: bus1Key,
serviceKey: bus1Key);

services.AddRebus(
configure => configure.Transport(t => t.UseInMemoryTransport(new InMemNetwork(), $"queue-{Guid.NewGuid():N}")),
isDefaultBus: false,
key: bus2Key,
serviceKey: bus2Key);

await using var provider = services.BuildServiceProvider();
provider.StartRebus();

var bus1 = provider.GetRequiredService<IBusRegistry>().GetBus(bus1Key);
var bus2 = provider.GetRequiredService<IBusRegistry>().GetBus(bus2Key);

await bus1.SendLocal("Hej!");
await bus2.SendLocal("Hej!");

await Task.Delay(TimeSpan.FromSeconds(3));

Assert.That(log, Is.EquivalentTo(new[] { bus1Key, bus2Key }));
}

class Handler1(string serviceKey, List<string> log) : IHandleMessages<string>
{
public async Task Handle(string _) => log.Add(serviceKey);
}

class Handler2(string serviceKey, List<string> log) : IHandleMessages<string>
{
public async Task Handle(string _) => log.Add(serviceKey);
}
}
#endif
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net48;net8.0</TargetFrameworks>
<LangVersion>11</LangVersion>
<LangVersion>latest</LangVersion>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Rebus.ServiceProvider\Rebus.ServiceProvider.csproj" />
Expand Down
128 changes: 122 additions & 6 deletions Rebus.ServiceProvider/Config/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ public static class ServiceCollectionExtensions
/// </param>
public static IServiceCollection AddRebus(this IServiceCollection services,
Func<RebusConfigurer, RebusConfigurer> configure, bool isDefaultBus = true, Func<IBus, Task> onCreated = null,
string key = null, bool startAutomatically = true)
string key = null, bool startAutomatically = true
#if NET8_0_OR_GREATER
, object serviceKey = null
#endif
)
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (configure == null) throw new ArgumentNullException(nameof(configure));
Expand All @@ -62,6 +66,9 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
onCreated: (bus, _) => onCreated?.Invoke(bus) ?? Task.CompletedTask,
key: key,
startAutomatically: startAutomatically
#if NET8_0_OR_GREATER
, serviceKey: serviceKey
#endif
);
}

Expand Down Expand Up @@ -93,12 +100,26 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
/// Configures whether this bus should be started automatically (i.e. whether message consumption should begin) when the host starts up (or when StartRebus() is called on the service provider).
/// Setting this to false should be combined with providing a <paramref name="key"/>, because the bus can then be started by resolving <see cref="IBusRegistry"/> and calling <see cref="IBusRegistry.StartBus"/> on it.
/// </param>
public static IServiceCollection AddRebus(this IServiceCollection services, Func<RebusConfigurer, RebusConfigurer> configure, Func<IBus, IServiceProvider, Task> onCreated, bool isDefaultBus = true, string key = null, bool startAutomatically = true)
public static IServiceCollection AddRebus(this IServiceCollection services, Func<RebusConfigurer, RebusConfigurer> configure, Func<IBus, IServiceProvider, Task> onCreated, bool isDefaultBus = true, string key = null, bool startAutomatically = true
#if NET8_0_OR_GREATER
, object serviceKey = null
#endif
)
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (configure == null) throw new ArgumentNullException(nameof(configure));

return AddRebus(services, (configurer, _) => configure(configurer), isDefaultBus: isDefaultBus, onCreated: onCreated, key: key, startAutomatically: startAutomatically);
return AddRebus(
services,
(configurer, _) => configure(configurer),
isDefaultBus: isDefaultBus,
onCreated: onCreated,
key: key,
startAutomatically: startAutomatically
#if NET8_0_OR_GREATER
, serviceKey: serviceKey
#endif
);
}

/// <summary>
Expand Down Expand Up @@ -130,7 +151,11 @@ public static IServiceCollection AddRebus(this IServiceCollection services, Func
/// </param>
public static IServiceCollection AddRebus(this IServiceCollection services,
Func<RebusConfigurer, IServiceProvider, RebusConfigurer> configure, bool isDefaultBus = true,
Func<IBus, Task> onCreated = null, string key = null, bool startAutomatically = true)
Func<IBus, Task> onCreated = null, string key = null, bool startAutomatically = true
#if NET8_0_OR_GREATER
, object serviceKey = null
#endif
)
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (configure == null) throw new ArgumentNullException(nameof(configure));
Expand All @@ -142,6 +167,9 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
onCreated: (bus, _) => onCreated?.Invoke(bus) ?? Task.CompletedTask,
key: key,
startAutomatically: startAutomatically
#if NET8_0_OR_GREATER
, serviceKey: serviceKey
#endif
);
}

Expand Down Expand Up @@ -173,8 +201,12 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
/// Setting this to false should be combined with providing a <paramref name="key"/>, because the bus can then be started by resolving <see cref="IBusRegistry"/> and calling <see cref="IBusRegistry.StartBus"/> on it.
/// </param>
public static IServiceCollection AddRebus(this IServiceCollection services,
Func<RebusConfigurer, IServiceProvider, RebusConfigurer> configure, Func<IBus, IServiceProvider, Task> onCreated,
bool isDefaultBus = true, string key = null, bool startAutomatically = true)
Func<RebusConfigurer, IServiceProvider, RebusConfigurer> configure, Func<IBus, IServiceProvider, Task> onCreated,
bool isDefaultBus = true, string key = null, bool startAutomatically = true
#if NET8_0_OR_GREATER
, object serviceKey = null
#endif
)
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (configure == null) throw new ArgumentNullException(nameof(configure));
Expand Down Expand Up @@ -208,6 +240,9 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
serviceProvider: p,
isDefaultBus: true,
lifetime: p.GetService<IHostApplicationLifetime>()
#if NET8_0_OR_GREATER
, serviceKey: serviceKey
#endif
));

services.AddSingleton(p =>
Expand Down Expand Up @@ -236,6 +271,9 @@ public static IServiceCollection AddRebus(this IServiceCollection services,
serviceProvider: p,
isDefaultBus: false,
lifetime: p.GetService<IHostApplicationLifetime>()
#if NET8_0_OR_GREATER
, serviceKey: serviceKey
#endif
);

return new RebusBackgroundService(rebusInitializer);
Expand Down Expand Up @@ -270,6 +308,16 @@ public static IServiceCollection AddRebusHandler<THandler>(this IServiceCollecti
return AddRebusHandler(services, typeof(THandler));
}

#if NET8_0_OR_GREATER
/// <summary>
/// Registers the given <typeparamref name="THandler"/> with a transient lifestyle using given <paramref name="serviceKey"/>
/// </summary>
public static IServiceCollection AddRebusHandler<THandler>(this IServiceCollection services, object serviceKey) where THandler : IHandleMessages
{
return AddRebusHandler(services, typeof(THandler), serviceKey);
}
#endif

/// <summary>
/// Register the given <paramref name="typeToRegister"/> with a transient lifestyle
/// </summary>
Expand All @@ -283,18 +331,51 @@ public static IServiceCollection AddRebusHandler(this IServiceCollection service
return services;
}

#if NET8_0_OR_GREATER
/// <summary>
/// Register the given <paramref name="typeToRegister"/> with a transient lifestyle using given <paramref name="serviceKey"/>
/// </summary>
public static IServiceCollection AddRebusHandler(this IServiceCollection services, Type typeToRegister, object serviceKey)
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (typeToRegister == null) throw new ArgumentNullException(nameof(typeToRegister));
if (serviceKey == null) throw new ArgumentNullException(nameof(serviceKey));

RegisterType(services, typeToRegister, serviceKey);

return services;
}
#endif

/// <summary>
/// Registers the given <typeparamref name="THandler"/> with a transient lifestyle
/// </summary>
public static IServiceCollection AddRebusHandler<THandler>(this IServiceCollection services, Func<IServiceProvider, THandler> factory) where THandler : IHandleMessages
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (factory == null) throw new ArgumentNullException(nameof(factory));

RegisterFactory(services, typeof(THandler), provider => factory(provider));

return services;
}

#if NET8_0_OR_GREATER
/// <summary>
/// Registers the given <typeparamref name="THandler"/> with a transient lifestyle using given <paramref name="serviceKey"/>
/// </summary>
public static IServiceCollection AddRebusHandler<THandler>(this IServiceCollection services, object serviceKey, Func<IServiceProvider, object, THandler> factory) where THandler : IHandleMessages
{
if (services == null) throw new ArgumentNullException(nameof(services));
if (serviceKey == null) throw new ArgumentNullException(nameof(serviceKey));
if (factory == null) throw new ArgumentNullException(nameof(factory));

RegisterFactory(services, typeof(THandler), serviceKey, (provider, serviceKey) => factory(provider, serviceKey));

return services;
}
#endif

/// <summary>
/// Automatically picks up all handler types from the assembly containing <typeparamref name="THandler"/> and registers them in the container
/// </summary>
Expand Down Expand Up @@ -379,7 +460,11 @@ static IEnumerable<Type> GetImplementedHandlerInterfaces(Type type) =>
type.GetInterfaces()
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandleMessages<>));

#if NET8_0_OR_GREATER
static void RegisterAssembly(IServiceCollection services, Assembly assemblyToRegister, string namespaceFilter = null, Func<Type, bool> predicate = null, object serviceKey = null)
#else
static void RegisterAssembly(IServiceCollection services, Assembly assemblyToRegister, string namespaceFilter = null, Func<Type, bool> predicate = null)
#endif
{
var typesToAutoRegister = assemblyToRegister.GetTypes()
.Where(IsClass)
Expand All @@ -403,7 +488,14 @@ static void RegisterAssembly(IServiceCollection services, Assembly assemblyToReg

foreach (var type in typesToAutoRegister)
{
#if NET8_0_OR_GREATER
if (serviceKey != null)
RegisterType(services, type.Type, serviceKey);
else
RegisterType(services, type.Type);
#else
RegisterType(services, type.Type);
#endif
}
}

Expand All @@ -419,6 +511,18 @@ static void RegisterFactory(IServiceCollection services, Type typeToRegister, Fu
}
}

#if NET8_0_OR_GREATER
static void RegisterFactory(IServiceCollection services, Type typeToRegister, object serviceKey, Func<IServiceProvider, object, object> factory)
{
var implementedHandlerInterfaces = GetImplementedHandlerInterfaces(typeToRegister).ToArray();

foreach (var handlerInterface in implementedHandlerInterfaces)
{
services.AddKeyedTransient(handlerInterface, serviceKey, factory);
}
}
#endif

static void RegisterType(IServiceCollection services, Type typeToRegister)
{
var implementedHandlerInterfaces = GetImplementedHandlerInterfaces(typeToRegister).ToArray();
Expand All @@ -428,4 +532,16 @@ static void RegisterType(IServiceCollection services, Type typeToRegister)
services.AddTransient(handlerInterface, typeToRegister);
}
}

#if NET8_0_OR_GREATER
static void RegisterType(IServiceCollection services, Type typeToRegister, object serviceKey)
{
var implementedHandlerInterfaces = GetImplementedHandlerInterfaces(typeToRegister).ToArray();

foreach (var handlerInterface in implementedHandlerInterfaces)
{
services.AddKeyedTransient(handlerInterface, serviceKey, typeToRegister);
}
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ public class DependencyInjectionHandlerActivator : IHandlerActivator
{
readonly ConcurrentDictionary<Type, Type[]> _typesToResolveByMessage = new();
readonly IServiceProvider _provider;
readonly object _serviceKey;

/// <summary>
/// Initializes a new instance of the <see cref="DependencyInjectionHandlerActivator"/> class.
/// </summary>
/// <param name="provider">The service provider used to yield handler instances.</param>
public DependencyInjectionHandlerActivator(IServiceProvider provider) => _provider = provider ?? throw new ArgumentNullException(nameof(provider));
public DependencyInjectionHandlerActivator(IServiceProvider provider, object serviceKey = null)
{
_provider = provider ?? throw new ArgumentNullException(nameof(provider));
_serviceKey = serviceKey;
}

/// <summary>
/// Resolves all handlers for the given <typeparamref name="TMessage"/> message type
Expand Down Expand Up @@ -81,7 +86,11 @@ IReadOnlyList<IHandleMessages<TMessage>> GetMessageHandlersForMessage<TMessage>(
var typesToResolve = _typesToResolveByMessage.GetOrAdd(typeof(TMessage), FigureOutTypesToResolve);

return typesToResolve
#if NET8_0_OR_GREATER
.SelectMany(type => (_serviceKey == null ? serviceProvider.GetServices(type) : serviceProvider.GetKeyedServices(type, _serviceKey)).Cast<IHandleMessages>())
#else
.SelectMany(type => serviceProvider.GetServices(type).Cast<IHandleMessages>())
#endif
.Distinct(new TypeEqualityComparer())
.Cast<IHandleMessages<TMessage>>()
.ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RebusInitializer
readonly IServiceProvider _serviceProvider;
readonly bool _isDefaultBus;
readonly CancellationToken? _cancellationToken;
readonly object _serviceKey;

public RebusInitializer(
bool startAutomatically,
Expand All @@ -30,7 +31,11 @@ public RebusInitializer(
Func<IBus, IServiceProvider, Task> onCreated,
IServiceProvider serviceProvider,
bool isDefaultBus,
IHostApplicationLifetime lifetime)
IHostApplicationLifetime lifetime
#if NET8_0_OR_GREATER
, object serviceKey = null
#endif
)
{
_startAutomatically = startAutomatically;
_key = key;
Expand All @@ -39,6 +44,9 @@ public RebusInitializer(
_serviceProvider = serviceProvider;
_isDefaultBus = isDefaultBus;
_cancellationToken = lifetime?.ApplicationStopping;
#if NET8_0_OR_GREATER
_serviceKey = serviceKey;
#endif

_busAndEvents = GetLazyInitializer();
}
Expand All @@ -54,7 +62,7 @@ public RebusInitializer(
BusLifetimeEvents busLifetimeEventsHack = null;

var rebusConfigurer = Configure
.With(new DependencyInjectionHandlerActivator(_serviceProvider))
.With(new DependencyInjectionHandlerActivator(_serviceProvider, _serviceKey))
.Options(o => o.Decorate(c =>
{
// snatch events here
Expand Down