Skip to content

Commit

Permalink
Optimize Generic Handler Registration in MediatR
Browse files Browse the repository at this point in the history
Summary
This PR introduces an optimized mechanism for registering generic handlers in MediatR. The current implementation scans all assemblies passed to MediatR during startup to find every possible concrete implementation and service type that satisfies all generic handler constraints. This PR modifies the behavior to scan and register generic handler services on-demand, triggered only when a specific service is requested.

This feature remains opt-in and can be enabled by setting the RegisterGenericHandlers configuration flag to true.

Changes Made
Optimized Generic Handler Registration:

The registration process now scans assemblies only when a specific service is requested, rather than eagerly scanning all possible types during startup.
Once a service is resolved, the registration is cached for future requests.
Dynamic Service Provider Integration:

Introduced dynamic resolution and caching for generic handlers, minimizing the startup overhead.
Backward Compatibility:

The feature remains optional and is controlled via the RegisterGenericHandlers flag in the MediatR configuration.
Code Refactor:

Extracted and modularized key logic for resolving generic handler types.
Improved readability and maintainability by reducing redundant logic and clarifying workflows.
  • Loading branch information
zachpainter77 committed Jan 16, 2025
1 parent db235f8 commit 761256f
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 383 deletions.
1 change: 1 addition & 0 deletions src/MediatR/MediatR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
</PackageReference>
<PackageReference Include="MediatR.Contracts" Version="[2.0.1, 3.0.0)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" Condition="'$(TargetFramework)' == 'netstandard2.0'" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="8.0.0" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.3.0" PrivateAssets="All" />
Expand Down
22 changes: 1 addition & 21 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,7 @@ public class MediatRServiceConfiguration
/// Automatically register processors during assembly scanning
/// </summary>
public bool AutoRegisterRequestProcessors { get; set; }

/// <summary>
/// Configure the maximum number of type parameters that a generic request handler can have. To Disable this constraint, set the value to 0.
/// </summary>
public int MaxGenericTypeParameters { get; set; } = 10;

/// <summary>
/// Configure the maximum number of types that can close a generic request type parameter constraint. To Disable this constraint, set the value to 0.
/// </summary>
public int MaxTypesClosing { get; set; } = 100;

/// <summary>
/// Configure the Maximum Amount of Generic RequestHandler Types MediatR will try to register. To Disable this constraint, set the value to 0.
/// </summary>
public int MaxGenericTypeRegistrations { get; set; } = 125000;

/// <summary>
/// Configure the Timeout in Milliseconds that the GenericHandler Registration Process will exit with error. To Disable this constraint, set the value to 0.
/// </summary>
public int RegistrationTimeout { get; set; } = 15000;


/// <summary>
/// Flag that controlls whether MediatR will attempt to register handlers that containg generic type parameters.
/// </summary>
Expand Down
11 changes: 7 additions & 4 deletions src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}

ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration);
ServiceRegistrar.AddMediatRClasses(services, configuration);

ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration);

ServiceRegistrar.AddRequiredServices(services, configuration);
ServiceRegistrar.AddRequiredServices(services, configuration);

if (configuration.RegisterGenericHandlers)
{
ServiceRegistrar.AddDynamicServiceProvider(services, configuration);
}

return services;
}
Expand Down
210 changes: 210 additions & 0 deletions src/MediatR/Registration/DynamicServiceProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;

namespace MediatR.Registration
{
public class DynamicServiceProvider : IServiceProvider, IDisposable
{
private readonly IServiceProvider _rootProvider;
private ServiceProvider _currentProvider;
private readonly IServiceCollection _services;
private readonly Type[] RequestHandlerTypes = new Type[] { typeof(IRequestHandler<>), typeof(IRequestHandler<,>) };

public DynamicServiceProvider(IServiceProvider rootProvider)
{
_rootProvider = rootProvider;
_services = new ServiceCollection();
_currentProvider = _services.BuildServiceProvider();
}

public IEnumerable<ServiceDescriptor> GetAllServiceDescriptors()
{
return _services;
}

public void AddService(Type serviceType, Type implementationType, ServiceLifetime lifetime = ServiceLifetime.Transient)
{
var constructor = implementationType.GetConstructors().OrderByDescending(c => c.GetParameters().Length).FirstOrDefault();
if (constructor != null)
{
var parameters = constructor.GetParameters();

foreach (var parameter in parameters)
{
// Check if the dependency is already registered
if (_currentProvider.GetService(parameter.ParameterType) != null)
continue;

// Attempt to resolve from the root provider
var dependency = _rootProvider.GetService(parameter.ParameterType);
if (dependency != null)
{
// Dynamically register the dependency in the dynamic registry
_services.Add(new ServiceDescriptor(parameter.ParameterType, _ => dependency, lifetime));
RebuildProvider(); // Rebuild the internal provider to include the new service
}
else
{
throw new InvalidOperationException(
$"Unable to resolve dependency {parameter.ParameterType.FullName} for {serviceType.FullName}");
}
}
}
_services.Add(new ServiceDescriptor(serviceType, implementationType, lifetime));
RebuildProvider();
}

public void AddService(Type serviceType, Func<IServiceProvider, object> implementationFactory, ServiceLifetime lifetime = ServiceLifetime.Transient)
{
if (serviceType == null) throw new ArgumentNullException(nameof(serviceType));
if (implementationFactory == null) throw new ArgumentNullException(nameof(implementationFactory));

// Add the service descriptor with the factory
_services.Add(new ServiceDescriptor(serviceType, implementationFactory, lifetime));
RebuildProvider();
}

//public IServiceProvider RootProvider { get { return _rootProvider; } }
public object? GetService(Type serviceType)
{
// Handle requests for IEnumerable<T>
if (serviceType.IsGenericType && serviceType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
// typeof(T) for IEnumerable<T>
var elementType = serviceType.GenericTypeArguments[0];
return CastToEnumerableOfType(GetServices(elementType), elementType);
}

// Try resolving from the current provider
var service = _currentProvider.GetService(serviceType);
if (service != null) return service;

//fall back to root provider
service = _rootProvider.GetService(serviceType);
if(service != null) return service;

//if not found in current or root then try to find the implementation and register it.
if (serviceType.IsGenericType)
{
var genericArguments = serviceType.GetGenericArguments();
var hasResponseType = genericArguments.Length > 1;
var openInterface = hasResponseType ? typeof(IRequestHandler<,>) : typeof(IRequestHandler<>);

if(openInterface != null)
{
var requestType = genericArguments[0];
Type? responseType = hasResponseType ? genericArguments[1] : null;

var implementationType = FindOpenGenericHandlerType(requestType, responseType);
if(implementationType == null)
throw new InvalidOperationException($"No implementation found for {openInterface.FullName}");

AddService(serviceType, implementationType);
}
}

//find the newly registered service
service = _currentProvider.GetService(serviceType);
if (service != null) return service;

// Fallback to the root provider as a last resort
return _rootProvider.GetService(serviceType);
}

public IEnumerable<object> GetServices(Type serviceType)
{
// Collect services from the dynamic provider
var dynamicServices = _services
.Where(d => d.ServiceType == serviceType)
.Select(d => _currentProvider.GetService(d.ServiceType))
.Where(s => s != null);

// Collect services from the root provider
var rootServices = _rootProvider
.GetServices(serviceType)
.Cast<object>();

// Combine results and remove duplicates
return dynamicServices.Concat(rootServices).Distinct()!;
}

private object CastToEnumerableOfType(IEnumerable<object> services, Type elementType)
{
var castMethod = typeof(Enumerable)
.GetMethod(nameof(Enumerable.Cast))
?.MakeGenericMethod(elementType);

var toListMethod = typeof(Enumerable)
.GetMethod(nameof(Enumerable.ToList))
?.MakeGenericMethod(elementType);

if (castMethod == null || toListMethod == null)
throw new InvalidOperationException("Unable to cast services to the specified enumerable type.");

var castedServices = castMethod.Invoke(null, new object[] { services });
return toListMethod.Invoke(null, new[] { castedServices })!;
}

public Type? FindOpenGenericHandlerType(Type requestType, Type? responseType = null)
{
if (!requestType.IsGenericType)
return null;

// Define the target generic handler type
var openHandlerType = responseType == null ? typeof(IRequestHandler<>) : typeof(IRequestHandler<,>);
var genericArguments = responseType == null ? new Type[] { requestType } : new Type[] { requestType, responseType };
var closedHandlerType = openHandlerType.MakeGenericType(genericArguments);

// Get the current assembly
var currentAssembly = Assembly.GetExecutingAssembly();

// Get assemblies that reference the current assembly
var consumingAssemblies = AppDomain.CurrentDomain.GetAssemblies()
.Where(assembly => assembly.GetReferencedAssemblies()
.Any(reference => reference.FullName == currentAssembly.FullName));

// Search for matching types
var types = consumingAssemblies.SelectMany(x => x.GetTypes())
.Where(t => t.IsClass && !t.IsAbstract && t.IsGenericTypeDefinition)
.ToList();

foreach (var type in types)
{
var interfaces = type.GetInterfaces();
if (interfaces.Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == openHandlerType))
{
// Check generic constraints
var concreteHandlerGenericArgs = type.GetGenericArguments();
var concreteHandlerConstraints = concreteHandlerGenericArgs.Select(x => x.GetGenericParameterConstraints());
var concreteRequestTypeGenericArgs = requestType.GetGenericArguments();
//var secondArgConstrants = genericArguments[1];

// Ensure the constraints are compatible
if (concreteHandlerConstraints
.Select((list, i) => new { List = list, Index = i })
.All(x => x.List.All(c => c.IsAssignableFrom(concreteRequestTypeGenericArgs[x.Index]))))
{
return type.MakeGenericType(concreteRequestTypeGenericArgs);
}
}
}

return null; // No matching type found
}

private void RebuildProvider()
{
_currentProvider.Dispose();
_currentProvider = _services.BuildServiceProvider();
}

public void Dispose()
{
_currentProvider.Dispose();
}
}
}
Loading

0 comments on commit 761256f

Please sign in to comment.