From 645a847615e73cfeb528dcea985c8645d4c99980 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Mon, 12 Jan 2026 22:15:02 +0000 Subject: [PATCH] fix: prevent interface implementation methods from being converted to async Methods that implement interface members should NOT have their signatures converted to async Task because this would break the interface implementation contract. The fix uses a two-pass approach: 1. Before syntax modifications, collect all methods that implement interface members using semantic analysis (via semantic model) 2. During async signature rewriting, skip methods identified as interface implementations Also adds ref/out/in parameter check to prevent async conversion of methods with those parameter types (which cannot be async). Fixes #4342 Co-Authored-By: Claude Opus 4.5 --- .../Base/AsyncMethodSignatureRewriter.cs | 106 ++++++++++++++++++ .../Base/BaseMigrationCodeFixProvider.cs | 8 +- .../NUnitMigrationAnalyzerTests.cs | 64 +++++++++++ 3 files changed, 177 insertions(+), 1 deletion(-) diff --git a/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs b/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs index 350227a03d..4e6de0c9cf 100644 --- a/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs +++ b/TUnit.Analyzers.CodeFixers/Base/AsyncMethodSignatureRewriter.cs @@ -10,6 +10,84 @@ namespace TUnit.Analyzers.CodeFixers.Base; /// public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter { + private readonly HashSet _interfaceImplementingMethods; + + public AsyncMethodSignatureRewriter() : this(new HashSet()) + { + } + + public AsyncMethodSignatureRewriter(HashSet interfaceImplementingMethods) + { + _interfaceImplementingMethods = interfaceImplementingMethods; + } + + /// + /// Collects method signatures that implement interface members. + /// This should be called BEFORE syntax modifications while the semantic model is still valid. + /// + public static HashSet CollectInterfaceImplementingMethods( + CompilationUnitSyntax compilationUnit, + SemanticModel semanticModel) + { + var methods = new HashSet(); + + foreach (var methodDecl in compilationUnit.DescendantNodes().OfType()) + { + // Check for explicit interface implementation syntax + if (methodDecl.ExplicitInterfaceSpecifier != null) + { + methods.Add(GetMethodKey(methodDecl)); + continue; + } + + var methodSymbol = semanticModel.GetDeclaredSymbol(methodDecl); + if (methodSymbol == null) + { + continue; + } + + // Check if this method explicitly implements an interface + if (methodSymbol.ExplicitInterfaceImplementations.Length > 0) + { + methods.Add(GetMethodKey(methodDecl)); + continue; + } + + // Check if this method implicitly implements an interface member + var containingType = methodSymbol.ContainingType; + if (containingType != null) + { + foreach (var iface in containingType.AllInterfaces) + { + foreach (var member in iface.GetMembers().OfType()) + { + var impl = containingType.FindImplementationForInterfaceMember(member); + if (SymbolEqualityComparer.Default.Equals(impl, methodSymbol)) + { + methods.Add(GetMethodKey(methodDecl)); + break; + } + } + } + } + } + + return methods; + } + + /// + /// Gets a unique key for a method declaration based on its signature. + /// This key is stable across syntax tree modifications. + /// + private static string GetMethodKey(MethodDeclarationSyntax node) + { + // Build a key from class name, method name, and parameter types + var className = node.Ancestors().OfType().FirstOrDefault()?.Identifier.Text ?? ""; + var methodName = node.Identifier.Text; + var parameters = string.Join(",", node.ParameterList.Parameters.Select(p => p.Type?.ToString() ?? "")); + return $"{className}.{methodName}({parameters})"; + } + public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node) { // First, visit children to ensure nested content is processed @@ -29,6 +107,21 @@ public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter return node; } + // Skip methods with ref/out/in parameters (they can't be async) + if (node.ParameterList.Parameters.Any(p => + p.Modifiers.Any(SyntaxKind.RefKeyword) || + p.Modifiers.Any(SyntaxKind.OutKeyword) || + p.Modifiers.Any(SyntaxKind.InKeyword))) + { + return node; + } + + // Skip if method implements an interface member (changing return type would break the implementation) + if (ImplementsInterfaceMember(node)) + { + return node; + } + // Convert the return type var newReturnType = ConvertReturnType(node.ReturnType); @@ -40,6 +133,19 @@ public class AsyncMethodSignatureRewriter : CSharpSyntaxRewriter .WithModifiers(newModifiers); } + private bool ImplementsInterfaceMember(MethodDeclarationSyntax node) + { + // Check for explicit interface implementation syntax (IFoo.Method) + if (node.ExplicitInterfaceSpecifier != null) + { + return true; + } + + // Check if this method was identified as an interface implementation + var key = GetMethodKey(node); + return _interfaceImplementingMethods.Contains(key); + } + private static TypeSyntax ConvertReturnType(TypeSyntax returnType) { // void -> Task diff --git a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs index c76aaede50..3f2c91f9a7 100644 --- a/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs +++ b/TUnit.Analyzers.CodeFixers/Base/BaseMigrationCodeFixProvider.cs @@ -50,6 +50,11 @@ protected async Task ConvertCodeAsync(Document document, SyntaxNode? r try { + // IMPORTANT: Collect interface-implementing methods BEFORE any syntax modifications + // while the semantic model is still valid for the original syntax tree + var interfaceImplementingMethods = AsyncMethodSignatureRewriter.CollectInterfaceImplementingMethods( + compilationUnit, semanticModel); + // Convert assertions FIRST (while semantic model still matches the syntax tree) var assertionRewriter = CreateAssertionRewriter(semanticModel, compilation); compilationUnit = (CompilationUnitSyntax)assertionRewriter.Visit(compilationUnit); @@ -58,7 +63,8 @@ protected async Task ConvertCodeAsync(Document document, SyntaxNode? r compilationUnit = ApplyFrameworkSpecificConversions(compilationUnit, semanticModel, compilation); // Fix method signatures that now contain await but aren't marked async - var asyncSignatureRewriter = new AsyncMethodSignatureRewriter(); + // Pass the collected interface methods to avoid converting interface implementations + var asyncSignatureRewriter = new AsyncMethodSignatureRewriter(interfaceImplementingMethods); compilationUnit = (CompilationUnitSyntax)asyncSignatureRewriter.Visit(compilationUnit); // Remove unnecessary base classes and interfaces diff --git a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs index 66274c147e..2da07ef128 100644 --- a/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/NUnitMigrationAnalyzerTests.cs @@ -2474,6 +2474,70 @@ public async Task TestMethod() ); } + [Test] + public async Task NUnit_InterfaceImplementation_NotConvertedToAsync() + { + // Methods that implement interface members should NOT be converted to async + // because that would break the interface implementation contract. + // The interface method contains no NUnit assertions, so no await is added. + // Only the test method (which doesn't implement an interface) gets converted to async. + await CodeFixer.VerifyCodeFixAsync( + """ + using NUnit.Framework; + using System.Threading.Tasks; + + public interface ITestRunner + { + void Run(); + } + + {|#0:public class MyClass|} : ITestRunner + { + [Test] + public void TestMethod() + { + Assert.That(true, Is.True); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + } + """, + Verifier.Diagnostic(Rules.NUnitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + using TUnit.Core; + using TUnit.Assertions; + using static TUnit.Assertions.Assert; + using TUnit.Assertions.Extensions; + + public interface ITestRunner + { + void Run(); + } + + public class MyClass : ITestRunner + { + [Test] + public async Task TestMethod() + { + await Assert.That(true).IsTrue(); + } + + public void Run() + { + // This implements ITestRunner.Run() and should stay void + var x = 1; + } + } + """, + ConfigureNUnitTest + ); + } + private static void ConfigureNUnitTest(Verifier.Test test) { test.TestState.AdditionalReferences.Add(typeof(NUnit.Framework.TestAttribute).Assembly);