Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ public override void Initialize(AnalysisContext context)
return;
}

context.Compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemLinqExpressionsExpression1, out INamedTypeSymbol? linqExpressionType);

context.RegisterOperationAction(
context => AnalyzeInvocation(context, cancellationTokenSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol),
context => AnalyzeInvocation(context, cancellationTokenSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, linqExpressionType),
OperationKind.Invocation);
});
}
Expand All @@ -69,7 +71,8 @@ private static void AnalyzeInvocation(
INamedTypeSymbol testContextSymbol,
INamedTypeSymbol classCleanupAttributeSymbol,
INamedTypeSymbol assemblyCleanupAttributeSymbol,
INamedTypeSymbol testMethodAttributeSymbol)
INamedTypeSymbol testMethodAttributeSymbol,
INamedTypeSymbol? linqExpressionType)
{
var invocationOperation = (IInvocationOperation)context.Operation;
IMethodSymbol method = invocationOperation.TargetMethod;
Expand Down Expand Up @@ -102,6 +105,12 @@ private static void AnalyzeInvocation(
cancellationTokenParameterName = cancellationTokenParameter.Name;
}

// Skip diagnostics inside expression trees where the code fix cannot be applied.
if (IsInsideExpressionTree(invocationOperation, linqExpressionType))
{
return;
}

context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState, cancellationTokenParameterName)));
return;
}
Expand All @@ -119,6 +128,12 @@ private static void AnalyzeInvocation(
cancellationTokenParameterName = cancellationTokenParameterFromDifferentOverload.Name;
}

// Skip diagnostics inside expression trees where the code fix cannot be applied.
if (IsInsideExpressionTree(invocationOperation, linqExpressionType))
{
return;
}

context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState, cancellationTokenParameterName)));
}

Expand All @@ -138,6 +153,28 @@ private static void AnalyzeInvocation(
}
}

private static bool IsInsideExpressionTree(IOperation operation, INamedTypeSymbol? linqExpressionType)
{
if (linqExpressionType is null)
{
return false;
}

IOperation? current = operation.Parent;
while (current is not null)
{
if (current is IAnonymousFunctionOperation or ILocalFunctionOperation)
{
return SymbolEqualityComparer.Default.Equals(
current.Parent?.Type?.OriginalDefinition, linqExpressionType);
}

current = current.Parent;
}

return false;
}

private static IParameterSymbol? GetCancellationTokenParameterOfOverloadWithCancellationToken(IMethodSymbol method, INamedTypeSymbol cancellationTokenSymbol)
{
// Look for overloads of the same method that accept CancellationToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ internal static class WellKnownTypeNames
public const string SystemIAsyncDisposable = "System.IAsyncDisposable";
public const string SystemIDisposable = "System.IDisposable";
public const string SystemLinqEnumerable = "System.Linq.Enumerable";
public const string SystemLinqExpressionsExpression1 = "System.Linq.Expressions.Expression`1";
public const string SystemOperatingSystem = "System.OperatingSystem";
public const string SystemReflectionMethodInfo = "System.Reflection.MethodInfo";
public const string SystemRuntimeCompilerServicesCallerFilePathAttribute = "System.Runtime.CompilerServices.CallerFilePathAttribute";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,108 @@ await Task.Delay(

await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[TestMethod]
public async Task WhenInsideExpressionTree_NoDiagnostic()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

[TestClass]
public class MyTestClass
{
public TestContext TestContext { get; set; }

[TestMethod]
public void MyTestMethod()
{
Expression<Func<Task>> expr = () => Task.Delay(1000);
}
}
""";

await VerifyCS.VerifyCodeFixAsync(code, code);
}

[TestMethod]
public async Task WhenInsideExpressionTreeWithOverloadHavingCancellationToken_NoDiagnostic()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

public interface IMyService
{
Task DoWorkAsync(string input);
Task DoWorkAsync(string input, CancellationToken cancellationToken);
}

[TestClass]
public class MyTestClass
{
public TestContext TestContext { get; set; }

[TestMethod]
public void MyTestMethod()
{
Expression<Func<IMyService, Task>> expr = svc => svc.DoWorkAsync("test");
}
}
""";

await VerifyCS.VerifyCodeFixAsync(code, code);
}

[TestMethod]
public async Task WhenInsideLambdaButNotExpressionTree_Diagnostic()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Threading;
using System.Threading.Tasks;

[TestClass]
public class MyTestClass
{
public TestContext TestContext { get; set; }

[TestMethod]
public async Task MyTestMethod()
{
Func<Task> action = () => [|Task.Delay(1000)|];
await action();
}
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Threading;
using System.Threading.Tasks;

[TestClass]
public class MyTestClass
{
public TestContext TestContext { get; set; }

[TestMethod]
public async Task MyTestMethod()
{
Func<Task> action = () => Task.Delay(1000, TestContext.CancellationToken);
await action();
}
}
""";

await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}
}
Loading