diff --git a/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs index 388b893439..582ff7cceb 100644 --- a/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs @@ -57,8 +57,10 @@ public override void Initialize(AnalysisContext context) return; } + context.Compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemLinqExpressionsExpression1, out INamedTypeSymbol? linqExpressionType); + context.RegisterOperationAction( - context => AnalyzeInvocation(context, cancellationTokenSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol), + context => AnalyzeInvocation(context, cancellationTokenSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, linqExpressionType), OperationKind.Invocation); }); } @@ -69,7 +71,8 @@ private static void AnalyzeInvocation( INamedTypeSymbol testContextSymbol, INamedTypeSymbol classCleanupAttributeSymbol, INamedTypeSymbol assemblyCleanupAttributeSymbol, - INamedTypeSymbol testMethodAttributeSymbol) + INamedTypeSymbol testMethodAttributeSymbol, + INamedTypeSymbol? linqExpressionType) { var invocationOperation = (IInvocationOperation)context.Operation; IMethodSymbol method = invocationOperation.TargetMethod; @@ -102,6 +105,12 @@ private static void AnalyzeInvocation( cancellationTokenParameterName = cancellationTokenParameter.Name; } + // Skip diagnostics inside expression trees where the code fix cannot be applied. + if (IsInsideExpressionTree(invocationOperation, linqExpressionType)) + { + return; + } + context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState, cancellationTokenParameterName))); return; } @@ -119,6 +128,12 @@ private static void AnalyzeInvocation( cancellationTokenParameterName = cancellationTokenParameterFromDifferentOverload.Name; } + // Skip diagnostics inside expression trees where the code fix cannot be applied. + if (IsInsideExpressionTree(invocationOperation, linqExpressionType)) + { + return; + } + context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState, cancellationTokenParameterName))); } @@ -138,6 +153,28 @@ private static void AnalyzeInvocation( } } + private static bool IsInsideExpressionTree(IOperation operation, INamedTypeSymbol? linqExpressionType) + { + if (linqExpressionType is null) + { + return false; + } + + IOperation? current = operation.Parent; + while (current is not null) + { + if (current is IAnonymousFunctionOperation or ILocalFunctionOperation) + { + return SymbolEqualityComparer.Default.Equals( + current.Parent?.Type?.OriginalDefinition, linqExpressionType); + } + + current = current.Parent; + } + + return false; + } + private static IParameterSymbol? GetCancellationTokenParameterOfOverloadWithCancellationToken(IMethodSymbol method, INamedTypeSymbol cancellationTokenSymbol) { // Look for overloads of the same method that accept CancellationToken diff --git a/src/Analyzers/MSTest.Analyzers/Helpers/WellKnownTypeNames.cs b/src/Analyzers/MSTest.Analyzers/Helpers/WellKnownTypeNames.cs index c04213e414..93500ce963 100644 --- a/src/Analyzers/MSTest.Analyzers/Helpers/WellKnownTypeNames.cs +++ b/src/Analyzers/MSTest.Analyzers/Helpers/WellKnownTypeNames.cs @@ -50,6 +50,7 @@ internal static class WellKnownTypeNames public const string SystemIAsyncDisposable = "System.IAsyncDisposable"; public const string SystemIDisposable = "System.IDisposable"; public const string SystemLinqEnumerable = "System.Linq.Enumerable"; + public const string SystemLinqExpressionsExpression1 = "System.Linq.Expressions.Expression`1"; public const string SystemOperatingSystem = "System.OperatingSystem"; public const string SystemReflectionMethodInfo = "System.Reflection.MethodInfo"; public const string SystemRuntimeCompilerServicesCallerFilePathAttribute = "System.Runtime.CompilerServices.CallerFilePathAttribute"; diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs index a639858e30..a2fcda75d3 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/FlowTestContextCancellationTokenAnalyzerTests.cs @@ -830,4 +830,108 @@ await Task.Delay( await VerifyCS.VerifyCodeFixAsync(code, fixedCode); } + + [TestMethod] + public async Task WhenInsideExpressionTree_NoDiagnostic() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + using System.Linq.Expressions; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + public TestContext TestContext { get; set; } + + [TestMethod] + public void MyTestMethod() + { + Expression> expr = () => Task.Delay(1000); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, code); + } + + [TestMethod] + public async Task WhenInsideExpressionTreeWithOverloadHavingCancellationToken_NoDiagnostic() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + using System.Linq.Expressions; + using System.Threading; + using System.Threading.Tasks; + + public interface IMyService + { + Task DoWorkAsync(string input); + Task DoWorkAsync(string input, CancellationToken cancellationToken); + } + + [TestClass] + public class MyTestClass + { + public TestContext TestContext { get; set; } + + [TestMethod] + public void MyTestMethod() + { + Expression> expr = svc => svc.DoWorkAsync("test"); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, code); + } + + [TestMethod] + public async Task WhenInsideLambdaButNotExpressionTree_Diagnostic() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + public TestContext TestContext { get; set; } + + [TestMethod] + public async Task MyTestMethod() + { + Func action = () => [|Task.Delay(1000)|]; + await action(); + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + using System.Threading; + using System.Threading.Tasks; + + [TestClass] + public class MyTestClass + { + public TestContext TestContext { get; set; } + + [TestMethod] + public async Task MyTestMethod() + { + Func action = () => Task.Delay(1000, TestContext.CancellationToken); + await action(); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } }