Skip to content

Commit bc00bea

Browse files
authored
Merge pull request #767 from FDUdannychen/invoke-once
Each matching exception handler/action is invoked at most once
2 parents 03f73ae + adc374c commit bc00bea

File tree

5 files changed

+156
-47
lines changed

5 files changed

+156
-47
lines changed

src/MediatR.Contracts/MediatR.Contracts.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<Authors>Jimmy Bogard</Authors>
66
<Description>Contracts package for requests, responses, and notifications</Description>
77
<Copyright>Copyright Jimmy Bogard</Copyright>
8-
<TargetFrameworks>netstandard2.0;net461;</TargetFrameworks>
8+
<TargetFramework>netstandard2.0</TargetFramework>
99
<Nullable>enable</Nullable>
1010
<Features>strict</Features>
1111
<PackageTags>mediator;request;response;queries;commands;notifications</PackageTags>

src/MediatR/Pipeline/RequestExceptionActionProcessorBehavior.cs

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
namespace MediatR.Pipeline;
22

3-
using MediatR.Internal;
3+
using Internal;
44
using System;
55
using System.Collections.Generic;
66
using System.Linq;
@@ -31,38 +31,62 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
3131
}
3232
catch (Exception exception)
3333
{
34-
for (Type exceptionType = exception.GetType(); exceptionType != typeof(object); exceptionType = exceptionType.BaseType)
35-
{
36-
var actionsForException = GetActionsForException(exceptionType, request, out MethodInfo actionMethod);
34+
var exceptionTypes = GetExceptionTypes(exception.GetType());
35+
36+
var actionsForException = exceptionTypes
37+
.SelectMany(exceptionType => GetActionsForException(exceptionType, request))
38+
.GroupBy(actionForException => actionForException.Action.GetType())
39+
.Select(actionForException => actionForException.First())
40+
.Select(actionForException => (MethodInfo: GetMethodInfoForAction(actionForException.ExceptionType), actionForException.Action))
41+
.ToList();
3742

38-
foreach (var actionForException in actionsForException)
43+
foreach (var actionForException in actionsForException)
44+
{
45+
try
46+
{
47+
await ((Task)(actionForException.MethodInfo.Invoke(actionForException.Action, new object[] { request, exception, cancellationToken })
48+
?? throw new InvalidOperationException($"Could not create task for action method {actionForException.MethodInfo}."))).ConfigureAwait(false);
49+
}
50+
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
3951
{
40-
try
41-
{
42-
await ((Task)(actionMethod.Invoke(actionForException, new object[] { request, exception, cancellationToken })
43-
?? throw new InvalidOperationException($"Could not create task for action method {actionMethod}."))).ConfigureAwait(false);
44-
}
45-
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
46-
{
47-
// Unwrap invocation exception to throw the actual error
48-
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
49-
}
52+
// Unwrap invocation exception to throw the actual error
53+
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
5054
}
5155
}
5256

5357
throw;
5458
}
5559
}
5660

57-
private IList<object> GetActionsForException(Type exceptionType, TRequest request, out MethodInfo actionMethodInfo)
61+
private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
62+
{
63+
while (exceptionType != null && exceptionType != typeof(object))
64+
{
65+
yield return exceptionType;
66+
exceptionType = exceptionType.BaseType;
67+
}
68+
}
69+
70+
private IEnumerable<(Type ExceptionType, object Action)> GetActionsForException(Type exceptionType, TRequest request)
5871
{
5972
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);
6073
var enumerableExceptionActionInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionActionInterfaceType);
61-
actionMethodInfo = exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
62-
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");
6374

6475
var actionsForException = (IEnumerable<object>)_serviceFactory(enumerableExceptionActionInterfaceType);
6576

66-
return HandlersOrderer.Prioritize(actionsForException.ToList(), request);
77+
return HandlersOrderer.Prioritize(actionsForException.ToList(), request)
78+
.Select(action => (exceptionType, action));
79+
}
80+
81+
private static MethodInfo GetMethodInfoForAction(Type exceptionType)
82+
{
83+
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);
84+
85+
var actionMethodInfo =
86+
exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
87+
?? throw new InvalidOperationException(
88+
$"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");
89+
90+
return actionMethodInfo;
6791
}
6892
}

src/MediatR/Pipeline/RequestExceptionProcessorBehavior.cs

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,32 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
3232
catch (Exception exception)
3333
{
3434
var state = new RequestExceptionHandlerState<TResponse>();
35-
Type? exceptionType = null;
3635

37-
while (!state.Handled && exceptionType != typeof(Exception))
36+
var exceptionTypes = GetExceptionTypes(exception.GetType());
37+
38+
var handlersForException = exceptionTypes
39+
.SelectMany(exceptionType => GetHandlersForException(exceptionType, request))
40+
.GroupBy(handlerForException => handlerForException.Handler.GetType())
41+
.Select(handlerForException => handlerForException.First())
42+
.Select(handlerForException => (MethodInfo: GetMethodInfoForHandler(handlerForException.ExceptionType), handlerForException.Handler))
43+
.ToList();
44+
45+
foreach (var handlerForException in handlersForException)
3846
{
39-
exceptionType = exceptionType == null ? exception.GetType() : exceptionType.BaseType
40-
?? throw new InvalidOperationException("Could not determine exception base type.");
41-
var exceptionHandlers = GetExceptionHandlers(request, exceptionType, out MethodInfo handleMethod);
47+
try
48+
{
49+
await ((Task) (handlerForException.MethodInfo.Invoke(handlerForException.Handler, new object[] { request, exception, state, cancellationToken })
50+
?? throw new InvalidOperationException("Did not return a Task from the exception handler."))).ConfigureAwait(false);
51+
}
52+
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
53+
{
54+
// Unwrap invocation exception to throw the actual error
55+
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
56+
}
4257

43-
foreach (var exceptionHandler in exceptionHandlers)
58+
if (state.Handled)
4459
{
45-
try
46-
{
47-
await ((Task)(handleMethod.Invoke(exceptionHandler, new object[] { request, exception, state, cancellationToken })
48-
?? throw new InvalidOperationException("Did not return a Task from the exception handler."))).ConfigureAwait(false);
49-
}
50-
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
51-
{
52-
// Unwrap invocation exception to throw the actual error
53-
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
54-
}
55-
56-
if (state.Handled)
57-
{
58-
break;
59-
}
60+
break;
6061
}
6162
}
6263

@@ -73,16 +74,33 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
7374
return state.Response; //cannot be null if Handled
7475
}
7576
}
77+
private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
78+
{
79+
while (exceptionType != null && exceptionType != typeof(object))
80+
{
81+
yield return exceptionType;
82+
exceptionType = exceptionType.BaseType;
83+
}
84+
}
7685

77-
private IList<object> GetExceptionHandlers(TRequest request, Type exceptionType, out MethodInfo handleMethodInfo)
86+
private IEnumerable<(Type ExceptionType, object Handler)> GetHandlersForException(Type exceptionType, TRequest request)
7887
{
7988
var exceptionHandlerInterfaceType = typeof(IRequestExceptionHandler<,,>).MakeGenericType(typeof(TRequest), typeof(TResponse), exceptionType);
8089
var enumerableExceptionHandlerInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionHandlerInterfaceType);
81-
handleMethodInfo = exceptionHandlerInterfaceType.GetMethod(nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle))
82-
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle)} on type {exceptionHandlerInterfaceType}");
8390

84-
var exceptionHandlers = (IEnumerable<object>)_serviceFactory.Invoke(enumerableExceptionHandlerInterfaceType);
91+
var exceptionHandlers = (IEnumerable<object>) _serviceFactory(enumerableExceptionHandlerInterfaceType);
92+
93+
return HandlersOrderer.Prioritize(exceptionHandlers.ToList(), request)
94+
.Select(handler => (exceptionType, action: handler));
95+
}
96+
97+
private static MethodInfo GetMethodInfoForHandler(Type exceptionType)
98+
{
99+
var exceptionHandlerInterfaceType = typeof(IRequestExceptionHandler<,,>).MakeGenericType(typeof(TRequest), typeof(TResponse), exceptionType);
100+
101+
var handleMethodInfo = exceptionHandlerInterfaceType.GetMethod(nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle))
102+
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle)} on type {exceptionHandlerInterfaceType}");
85103

86-
return HandlersOrderer.Prioritize(exceptionHandlers.ToList(), request);
104+
return handleMethodInfo;
87105
}
88106
}

test/MediatR.Tests/Pipeline/RequestExceptionActionTests.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ public Task<Pong> Handle(Ping request, CancellationToken cancellationToken)
4949
}
5050
}
5151

52+
public class GenericExceptionAction<TRequest> : IRequestExceptionAction<TRequest>
53+
{
54+
public int ExecutionCount { get; private set; }
55+
56+
public Task Execute(TRequest request, Exception exception, CancellationToken cancellationToken)
57+
{
58+
ExecutionCount++;
59+
return Task.CompletedTask;
60+
}
61+
}
62+
5263
public class PingPongExceptionAction<TRequest> : IRequestExceptionAction<TRequest, PingPongException>
5364
{
5465
public bool Executed { get; private set; }
@@ -83,7 +94,7 @@ public Task Execute(Ping request, PongException exception, CancellationToken can
8394
}
8495

8596
[Fact]
86-
public async Task Should_run_all_exception_handlers_that_match_base_type()
97+
public async Task Should_run_all_exception_actions_that_match_base_type()
8798
{
8899
var pingExceptionAction = new PingExceptionAction();
89100
var pongExceptionAction = new PongExceptionAction();
@@ -108,4 +119,25 @@ public async Task Should_run_all_exception_handlers_that_match_base_type()
108119
pingPongExceptionAction.Executed.ShouldBeTrue();
109120
pongExceptionAction.Executed.ShouldBeFalse();
110121
}
122+
123+
[Fact]
124+
public async Task Should_run_matching_exception_actions_only_once()
125+
{
126+
var genericExceptionAction = new GenericExceptionAction<Ping>();
127+
var container = new Container(cfg =>
128+
{
129+
cfg.For<IRequestHandler<Ping, Pong>>().Use<PingHandler>();
130+
cfg.For<IRequestExceptionAction<Ping>>().Use(_ => genericExceptionAction);
131+
cfg.For(typeof(IPipelineBehavior<,>)).Add(typeof(RequestExceptionActionProcessorBehavior<,>));
132+
cfg.For<ServiceFactory>().Use<ServiceFactory>(ctx => t => ctx.GetInstance(t));
133+
cfg.For<IMediator>().Use<Mediator>();
134+
});
135+
136+
var mediator = container.GetInstance<IMediator>();
137+
138+
var request = new Ping { Message = "Ping!" };
139+
await Assert.ThrowsAsync<PingException>(() => mediator.Send(request));
140+
141+
genericExceptionAction.ExecutionCount.ShouldBe(1);
142+
}
111143
}

test/MediatR.Tests/Pipeline/RequestExceptionHandlerTests.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ public Task<Pong> Handle(Ping request, CancellationToken cancellationToken)
3535
}
3636
}
3737

38+
public class GenericPingExceptionHandler : IRequestExceptionHandler<Ping, Pong>
39+
{
40+
public int ExecutionCount { get; private set; }
41+
42+
public Task Handle(Ping request, Exception exception, RequestExceptionHandlerState<Pong> state, CancellationToken cancellationToken)
43+
{
44+
ExecutionCount++;
45+
return Task.CompletedTask;
46+
}
47+
}
48+
3849
public class PingPongExceptionHandlerForType : IRequestExceptionHandler<Ping, Pong, PingException>
3950
{
4051
public Task Handle(Ping request, PingException exception, RequestExceptionHandlerState<Pong> state, CancellationToken cancellationToken)
@@ -133,4 +144,28 @@ await Should.ThrowAsync<ApplicationException>(async () =>
133144
});
134145
}
135146

147+
[Fact]
148+
public async Task Should_run_matching_exception_handlers_only_once()
149+
{
150+
var genericPingExceptionHandler = new GenericPingExceptionHandler();
151+
var container = new Container(cfg =>
152+
{
153+
cfg.For<IRequestHandler<Ping, Pong>>().Use<PingHandler>();
154+
cfg.For<IRequestExceptionHandler<Ping, Pong>>().Use(genericPingExceptionHandler);
155+
cfg.For(typeof(IPipelineBehavior<,>)).Add(typeof(RequestExceptionProcessorBehavior<,>));
156+
cfg.For<ServiceFactory>().Use<ServiceFactory>(ctx => t => ctx.GetInstance(t));
157+
cfg.For<IMediator>().Use<Mediator>();
158+
});
159+
160+
var mediator = container.GetInstance<IMediator>();
161+
162+
var request = new Ping { Message = "Ping" };
163+
await Should.ThrowAsync<PingException>(async () =>
164+
{
165+
await mediator.Send(request);
166+
});
167+
168+
genericPingExceptionHandler.ExecutionCount.ShouldBe(1);
169+
}
170+
136171
}

0 commit comments

Comments
 (0)