Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions samples/SimpleConsole/SimpleConsole.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
<TreatWarningsAsErrors>false</TreatWarningsAsErrors>
<ImplicitUsings>disable</ImplicitUsings>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>Generated</CompilerGeneratedFilesOutputPath>
</PropertyGroup>

<ItemGroup>
<Compile Remove="$(CompilerGeneratedFilesOutputPath)/**/*.cs" />
<None Include="$(CompilerGeneratedFilesOutputPath)/**/*.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(DotNetVersion)" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ CompilationAnalyzer analyzer
}

public string RequestFullName => Symbol.GetTypeSymbolFullName();
public bool ResponseIsValueType => ResponseSymbol!.IsValueType;
public string ResponseFullName => ResponseSymbol!.GetTypeSymbolFullName();
public bool ResponseIsValueType => ResponseSymbol.IsValueType;
public string ResponseFullName => ResponseSymbol.GetTypeSymbolFullName();
public string ResponseFullNameWithoutReferenceNullability =>
ResponseSymbol.GetTypeSymbolFullName(includeReferenceNullability: false);

public void SetHandler(RequestMessageHandler handler) => Handler = handler;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ public static class RoslynExtensions
public static string GetTypeSymbolFullName(
this ITypeSymbol symbol,
bool withGlobalPrefix = true,
bool includeTypeParameters = true
bool includeTypeParameters = true,
bool includeReferenceNullability = true
)
{
var miscOptions = SymbolDisplayMiscellaneousOptions.ExpandNullable;
if (includeReferenceNullability)
miscOptions |= SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier;

return symbol.ToDisplayString(
new SymbolDisplayFormat(
withGlobalPrefix
Expand All @@ -17,7 +22,7 @@ public static string GetTypeSymbolFullName(
includeTypeParameters
? SymbolDisplayGenericsOptions.IncludeTypeParameters
: SymbolDisplayGenericsOptions.None,
miscellaneousOptions: SymbolDisplayMiscellaneousOptions.ExpandNullable
miscellaneousOptions: miscOptions
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#pragma warning disable CS8321 // Unused local function
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

#nullable enable

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using System.Linq;
Expand Down Expand Up @@ -35,7 +37,7 @@ namespace Microsoft.Extensions.DependencyInjection
/// <summary>
/// Adds the Mediator implementation and handlers of your application, with specified options.
/// </summary>
public static IServiceCollection AddMediator(this IServiceCollection services, global::System.Action<global::Mediator.MediatorOptions> options)
public static IServiceCollection AddMediator(this IServiceCollection services, global::System.Action<global::Mediator.MediatorOptions>? options)
{
var opts = new global::Mediator.MediatorOptions();
if (options != null)
Expand Down Expand Up @@ -401,7 +403,7 @@ namespace {{ MediatorNamespace }}
{{~ for message in IRequestMessages ~}}
case {{ message.RequestFullName }} r:
{
if(typeof(TResponse) == typeof({{ message.ResponseFullName }}))
if (typeof(TResponse) == typeof({{ message.ResponseFullNameWithoutReferenceNullability }}))
{
var task = Send({{- message.ParameterModifier -}}r, cancellationToken);
return global::System.Runtime.CompilerServices.Unsafe.As<{{ message.AsyncReturnType }}, global::System.Threading.Tasks.ValueTask<TResponse>>(ref task);
Expand Down Expand Up @@ -515,7 +517,7 @@ namespace {{ MediatorNamespace }}
{{~ for message in ICommandMessages ~}}
case {{ message.RequestFullName }} r:
{
if(typeof(TResponse) == typeof({{ message.ResponseFullName }}))
if (typeof(TResponse) == typeof({{ message.ResponseFullNameWithoutReferenceNullability }}))
{
var task = Send({{- message.ParameterModifier -}}r, cancellationToken);
return global::System.Runtime.CompilerServices.Unsafe.As<{{ message.AsyncReturnType }}, global::System.Threading.Tasks.ValueTask<TResponse>>(ref task);
Expand Down Expand Up @@ -629,7 +631,7 @@ namespace {{ MediatorNamespace }}
{{~ for message in IQueryMessages ~}}
case {{ message.RequestFullName }} r:
{
if(typeof(TResponse) == typeof({{ message.ResponseFullName }}))
if (typeof(TResponse) == typeof({{ message.ResponseFullNameWithoutReferenceNullability }}))
{
var task = Send({{- message.ParameterModifier -}}r, cancellationToken);
return global::System.Runtime.CompilerServices.Unsafe.As<{{ message.AsyncReturnType }}, global::System.Threading.Tasks.ValueTask<TResponse>>(ref task);
Expand Down Expand Up @@ -732,7 +734,7 @@ namespace {{ MediatorNamespace }}
/// <param name="message">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Awaitable task</returns>
public{{ if HasAnyRequest; " async "; else; " "; end; }}global::System.Threading.Tasks.ValueTask<object> Send(
public{{ if HasAnyRequest; " async "; else; " "; end; }}global::System.Threading.Tasks.ValueTask<object?> Send(
object message,
global::System.Threading.CancellationToken cancellationToken = default
)
Expand Down Expand Up @@ -764,7 +766,7 @@ namespace {{ MediatorNamespace }}
/// <param name="message">Incoming message</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Async enumerable</returns>
public global::System.Collections.Generic.IAsyncEnumerable<object> CreateStream(
public global::System.Collections.Generic.IAsyncEnumerable<object?> CreateStream(
object message,
global::System.Threading.CancellationToken cancellationToken = default
)
Expand All @@ -791,7 +793,7 @@ namespace {{ MediatorNamespace }}
}

{{- if HasAnyValueTypeStreamResponse }}
static async global::System.Collections.Generic.IAsyncEnumerable<object> AsyncWrapper<T>(global::System.Collections.Generic.IAsyncEnumerable<T> wrapped, [global::System.Runtime.CompilerServices.EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default) where T : struct
static async global::System.Collections.Generic.IAsyncEnumerable<object?> AsyncWrapper<T>(global::System.Collections.Generic.IAsyncEnumerable<T> wrapped, [global::System.Runtime.CompilerServices.EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default) where T : struct
{
await foreach (var value in global::System.Threading.Tasks.TaskAsyncEnumerableExtensions.WithCancellation(wrapped, cancellationToken))
{
Expand Down Expand Up @@ -880,7 +882,7 @@ namespace {{ MediatorNamespace }}
async global::System.Threading.Tasks.ValueTask Publish({{ message.FullName }} notification, global::Mediator.INotificationHandler<{{ message.FullName }}>[] handlers, global::System.Threading.CancellationToken cancellationToken)
{
// We don't allocate the list if no task throws
global::System.Collections.Generic.List<global::System.Exception> exceptions = null;
global::System.Collections.Generic.List<global::System.Exception>? exceptions = null;

for (int i = 0; i < handlers.Length; i++)
{
Expand Down Expand Up @@ -929,7 +931,7 @@ namespace {{ MediatorNamespace }}
async global::System.Threading.Tasks.ValueTask Publish({{ message.FullName }} notification, global::System.Collections.Generic.IEnumerable<global::Mediator.INotificationHandler<{{ message.FullName }}>> handlers, global::System.Threading.CancellationToken cancellationToken)
{
// We don't allocate the list if no task throws
global::System.Collections.Generic.List<global::System.Exception> exceptions = null;
global::System.Collections.Generic.List<global::System.Exception>? exceptions = null;

foreach (var handler in handlers)
{
Expand All @@ -949,7 +951,7 @@ namespace {{ MediatorNamespace }}
async global::System.Threading.Tasks.ValueTask PublishArr({{ message.FullName }} notification, global::Mediator.INotificationHandler<{{ message.FullName }}>[] handlers, global::System.Threading.CancellationToken cancellationToken)
{
// We don't allocate the list if no task throws
global::System.Collections.Generic.List<global::System.Exception> exceptions = null;
global::System.Collections.Generic.List<global::System.Exception>? exceptions = null;

for (int i = 0; i < handlers.Length; i++)
{
Expand Down Expand Up @@ -1008,7 +1010,7 @@ namespace {{ MediatorNamespace }}
throw new global::Mediator.MissingMessageHandlerException(msg);

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidMessage<T>(T msg, string paramName = null)
private static void ThrowInvalidMessage<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1019,7 +1021,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidRequest<T>(T msg, string paramName = null)
private static void ThrowInvalidRequest<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1030,7 +1032,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidCommand<T>(T msg, string paramName = null)
private static void ThrowInvalidCommand<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1041,7 +1043,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidQuery<T>(T msg, string paramName = null)
private static void ThrowInvalidQuery<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1052,7 +1054,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidStreamMessage<T>(T msg, string paramName = null)
private static void ThrowInvalidStreamMessage<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1063,7 +1065,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidStreamRequest<T>(T msg, string paramName = null)
private static void ThrowInvalidStreamRequest<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1074,7 +1076,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidStreamCommand<T>(T msg, string paramName = null)
private static void ThrowInvalidStreamCommand<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1085,7 +1087,7 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidStreamQuery<T>(T msg, string paramName = null)
private static void ThrowInvalidStreamQuery<T>(T? msg, string? paramName = null)
{
if (msg == null)
ThrowArgumentNull(paramName);
Expand All @@ -1096,20 +1098,20 @@ namespace {{ MediatorNamespace }}
}

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowArgumentNull(string paramName) =>
private static void ThrowArgumentNull(string? paramName) =>
throw new global::System.ArgumentNullException(paramName);

[global::System.Diagnostics.CodeAnalysis.DoesNotReturn]
private static void ThrowInvalidMessage<T>(T msg) =>
throw new global::Mediator.InvalidMessageException(msg);

private static void ThrowIfNull<T>(T argument, string paramName)
private static void ThrowIfNull<T>(T? argument, string paramName)
{
if (argument == null)
ThrowArgumentNull(paramName);
}

private static void ThrowInvalidNotification<T>(T argument, string paramName)
private static void ThrowInvalidNotification<T>(T? argument, string paramName)
{
if (argument == null)
ThrowArgumentNull(paramName);
Expand All @@ -1121,7 +1123,7 @@ namespace {{ MediatorNamespace }}
private static void ThrowAggregateException(global::System.Collections.Generic.List<global::System.Exception> exceptions) =>
throw new global::System.AggregateException(exceptions);

private static void MaybeThrowAggregateException(global::System.Collections.Generic.List<global::System.Exception> exceptions)
private static void MaybeThrowAggregateException(global::System.Collections.Generic.List<global::System.Exception>? exceptions)
{
if (exceptions != null)
{
Expand Down
32 changes: 32 additions & 0 deletions test/Mediator.Tests/NullableResponseTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Mediator.Tests.TestTypes;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Mediator.Tests;

public sealed record RequestWithNullableResponse(Guid Id) : IRequest<SomeResponse?>;

public sealed class RequestWithNullableResponseHandler : IRequestHandler<RequestWithNullableResponse, SomeResponse?>
{
public ValueTask<SomeResponse?> Handle(RequestWithNullableResponse request, CancellationToken cancellationToken) =>
new ValueTask<SomeResponse?>(default(SomeResponse?));
}

public class NullableResponseTests
{
[Fact]
public async Task Test_Request_With_Nullable_Response()
{
var (_, mediator) = Fixture.GetMediator();
var concrete = (Mediator)mediator;

var id = Guid.NewGuid();

var response = await mediator.Send(new RequestWithNullableResponse(id));
Assert.Null(response);

var response2 = await concrete.Send(new RequestWithNullableResponse(id));
Assert.Null(response2);
}
}