Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Zomp.SyncMethodGenerator/AnalyzerReleases.Unshipped.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ Rule ID | Category | Severity | Notes
ZSMGEN001 | Preprocessor | Error | DiagnosticMessages
ZSMGEN002 | Preprocessor | Error | DiagnosticMessages
ZSMGEN003 | Preprocessor | Error | DiagnosticMessages
ZSMGEN004 | Preprocessor | Error | DiagnosticMessages
ZSMGEN005 | Preprocessor | Error | DiagnosticMessages
37 changes: 27 additions & 10 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ namespace Zomp.SyncMethodGenerator;
/// <param name="semanticModel">The semantic model.</param>
/// <param name="disableNullable">Instructs the source generator that nullable context should be disabled.</param>
/// <param name="preserveProgress">Instructs the source generator to preserve <see cref="IProgress"/> parameters.</param>
internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disableNullable, bool preserveProgress) : CSharpSyntaxRewriter
/// <param name="userMappings">User defined mappings for custom sync methods.</param>
internal sealed class AsyncToSyncRewriter(
SemanticModel semanticModel,
bool disableNullable,
bool preserveProgress,
UserMappings userMappings) : CSharpSyntaxRewriter
{
public const string SyncOnly = "SYNC_ONLY";

Expand Down Expand Up @@ -70,6 +75,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disa
private readonly SemanticModel semanticModel = semanticModel;
private readonly bool disableNullable = disableNullable;
private readonly bool preserveProgress = preserveProgress;
private readonly UserMappings userMappings = userMappings;
private readonly HashSet<IParameterSymbol> removedParameters = [];

/// <summary>
Expand Down Expand Up @@ -552,6 +558,15 @@ bool InitializedToMemory(SyntaxNode node)

droppingAsync = prevDroppingAsync;

if (userMappings.TryGetValue(methodSymbol, out var result))
{
var args = methodSymbol is { IsExtensionMethod: true, ReducedFrom: not null } && @base.Expression is MemberAccessExpressionSyntax member
? ArgumentList(SeparatedList([Argument(member.Expression.WithoutTrivia()), .. @base.ArgumentList.Arguments]))
: @base.ArgumentList;

return InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(Global(result.Namespace)), IdentifierName(result.Method)), args).WithTriviaFrom(@base);
}

// Assumption here is that if there's a method like GetMemory(), there is also method called GetSpan(). Revisit if this isn't the case.
var endsWithMemory = symbol.Name.EndsWith(Memory, StringComparison.Ordinal);

Expand Down Expand Up @@ -1623,15 +1638,6 @@ private static SyntaxTokenList StripAsyncModifier(SyntaxTokenList list)
private static string RemoveAsync(string original)
=> Regex.Replace(original, "Async", string.Empty);

private static bool HasSyncMethod(IMethodSymbol ms)
=> ms.Name.EndsWith("Async", StringComparison.Ordinal)
&& ms.ContainingType is { } type
&& type.GetMembers(RemoveAsync(ms.Name))
.OfType<IMethodSymbol>()
.Any(m => m.Parameters.Length == ms.Parameters.Length
&& m.Parameters.Zip(ms.Parameters, (p1, p2) => SymbolEqualityComparer.Default.Equals(p1, p2))
.All(z => z));

private static bool CanDropIf(IfStatementSyntax ifStatement)
=> ifStatement.Statement is BlockSyntax { Statements.Count: 0 } or null
&& (ifStatement.Else is null || CanDropElse(ifStatement.Else))
Expand Down Expand Up @@ -2019,6 +2025,17 @@ private bool ShouldRemoveType(ITypeSymbol symbol)
return (IsIProgress(namedSymbol) && !preserveProgress) || IsCancellationToken(namedSymbol);
}

private bool HasSyncMethod(IMethodSymbol ms)
=> userMappings.TryGetValue(ms, out _)
|| (
ms.Name.EndsWith("Async", StringComparison.Ordinal)
&& ms.ContainingType is { } type
&& type.GetMembers(RemoveAsync(ms.Name))
.OfType<IMethodSymbol>()
.Any(m => m.Parameters.Length == ms.Parameters.Length
&& m.Parameters.Zip(ms.Parameters, (p1, p2) => SymbolEqualityComparer.Default.Equals(p1, p2))
.All(z => z)));

private bool ShouldRemoveArgument(ISymbol symbol, bool isNegated = false) => symbol switch
{
IPropertySymbol
Expand Down
16 changes: 16 additions & 0 deletions src/Zomp.SyncMethodGenerator/DiagnosticMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,21 @@ internal static class DiagnosticMessages
DiagnosticSeverity.Error,
isEnabledByDefault: true);

internal static readonly DiagnosticDescriptor DuplicateUserMapping = new(
id: "ZSMGEN004",
title: "Duplicate user mapping",
messageFormat: "User mapping '{0}' is already defined",
category: Preprocessor,
DiagnosticSeverity.Error,
isEnabledByDefault: true);

internal static readonly DiagnosticDescriptor AttributeAndUserMappingConflict = new(
id: "ZSMGEN005",
title: "Attribute and user mapping conflict",
messageFormat: "Method '{0}' has both an attribute and a user mapping defined. The user mapping will be used.",
category: Preprocessor,
DiagnosticSeverity.Error,
isEnabledByDefault: true);

private const string Preprocessor = "Preprocessor";
}
112 changes: 108 additions & 4 deletions src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class SyncMethodSourceGenerator : IIncrementalGenerator

internal const string OmitNullableDirective = "OmitNullableDirective";
internal const string PreserveProgress = "PreserveProgress";
internal static readonly Regex LineRegex = new(@"([^\r\n]+)(\r\n|\r|\n)?", RegexOptions.Compiled);

/// <inheritdoc/>
public void Initialize(IncrementalGeneratorInitializationContext context)
Expand All @@ -38,6 +39,25 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
$"{SkipSyncVersionAttribute}.g.cs", SourceText.From(SourceGenerationHelper.SkipSyncVersionAttributeSource, Encoding.UTF8)));

var userMappings = context.AdditionalTextsProvider
.Where(a => a.Path.Equals("SyncMethods.txt", StringComparison.OrdinalIgnoreCase) ||
a.Path.EndsWith("/SyncMethods.txt", StringComparison.OrdinalIgnoreCase) ||
a.Path.EndsWith(@"\SyncMethods.txt", StringComparison.OrdinalIgnoreCase))
.Select((text, cancellationToken) => (text.Path, text.GetText(cancellationToken)?.ToString() ?? string.Empty))
.Collect()
.Select(GetUserMappings)
.WithTrackingName("GetUserMappings");

context.RegisterSourceOutput(
userMappings,
static (spc, source) =>
{
foreach (var diagnostic in source.Diagnostics)
{
spc.ReportDiagnostic(diagnostic);
}
});

var disableNullable =
context.CompilationProvider.Select((c, _) =>
{
Expand All @@ -52,8 +72,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
predicate: static (s, _) => IsSyntaxTargetForGeneration(s),
transform: static (ctx, ct) => TransformForGeneration(ctx, ct))
.SelectMany((list, ct) => list)
.Combine(userMappings)
.Combine(disableNullable)
.Select((data, ct) => GetMethodToGenerate(data.Left.Context, data.Left.Syntax, data.Right, ct)!)
.Select((data, ct) =>
{
var ((result, userMappingsValue), disableNullableValue) = data;
return GetMethodToGenerate(result.Context, result.Syntax, disableNullableValue, userMappingsValue, ct)!;
})
.WithTrackingName("GetMethodToGenerate")
.Where(static s => s is not null);

Expand All @@ -77,6 +102,71 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static UserMappings GetUserMappings(ImmutableArray<(string Path, string Content)> array, CancellationToken token)
{
var mappings = ImmutableArray.CreateBuilder<(string Key, string Namespace, string Method)>();
var diagnostics = ImmutableArray.CreateBuilder<ReportedDiagnostic>();
var keys = new HashSet<string>(StringComparer.Ordinal);
var lineIndex = -1;
var index = 0;

foreach (var (path, content) in array)
{
foreach (Match lineMatch in LineRegex.Matches(content))
{
lineIndex++;

var line = lineMatch.Groups[1].Value;
var newLineLength = lineMatch.Groups[2] is { Success: true, Value: { } val } ? val.Length : 0;

var startIndex = index;
index += line.Length + newLineLength;

token.ThrowIfCancellationRequested();

var separatorIndex = line.IndexOf('=');

if (separatorIndex < 0)
{
// Invalid line, skip it
continue;
}

var key = line[..separatorIndex].Trim();
var value = line[(separatorIndex + 1)..].Trim();

if (string.IsNullOrWhiteSpace(key) || string.IsNullOrWhiteSpace(value))
{
// Invalid key or value, skip it
continue;
}

if (!keys.Add(key))
{
diagnostics.Add(new ReportedDiagnostic(
DuplicateUserMapping,
path,
new TextSpan(startIndex, line.Length),
new LinePositionSpan(
new LinePosition(lineIndex, 0),
new LinePosition(lineIndex, line.Length)),
key));

continue;
}

var methodIndex = value.LastIndexOf('.');
var @namespace = methodIndex < 0 ? string.Empty : value[..methodIndex].Trim();
var methodName = methodIndex < 0 ? value.Trim() : value[(methodIndex + 1)..].Trim();

mappings.Add(("global::" + key, @namespace, methodName));
}
}

var result = new UserMappings(mappings.ToImmutable(), diagnostics.ToImmutable());
return result;
}

private static bool IsSyntaxTargetForGeneration(SyntaxNode node) => node switch
{
MethodDeclarationSyntax { AttributeLists.Count: > 0 } => true,
Expand Down Expand Up @@ -119,7 +209,7 @@ static string BuildClassName(MethodParentDeclaration c)
return (m, sourcePath, source);
}

private static MethodToGenerate? GetMethodToGenerate(GeneratorAttributeSyntaxContext context, MethodDeclarationSyntax methodDeclarationSyntax, bool disableNullable, CancellationToken ct)
private static MethodToGenerate? GetMethodToGenerate(GeneratorAttributeSyntaxContext context, MethodDeclarationSyntax methodDeclarationSyntax, bool disableNullable, UserMappings userMappings, CancellationToken ct)
{
// stop if we're asked to
ct.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -229,12 +319,26 @@ static string BuildClassName(MethodParentDeclaration c)

var preserveProgress = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == PreserveProgress) is { Value.Value: true };

var rewriter = new AsyncToSyncRewriter(context.SemanticModel, disableNullable, preserveProgress);
var rewriter = new AsyncToSyncRewriter(context.SemanticModel, disableNullable, preserveProgress, userMappings);
var sn = rewriter.Visit(methodDeclarationSyntax);
var content = sn.ToFullString();

var diagnostics = rewriter.Diagnostics;

if (userMappings.TryGetValue(methodSymbol, out _))
{
var fullName = methodSymbol.ContainingType is not null
? $"{methodSymbol.ContainingType.ToDisplayString()}.{methodSymbol.Name}"
: methodSymbol.Name;

diagnostics = [
..diagnostics,
ReportedDiagnostic.Create(
AttributeAndUserMappingConflict,
methodDeclarationSyntax.Identifier.GetLocation(),
fullName)];
}

var hasErrors = false;
foreach (var diagnostic in diagnostics)
{
Expand Down Expand Up @@ -265,7 +369,7 @@ static string BuildClassName(MethodParentDeclaration c)
}
}

var result = new MethodToGenerate(index, namespaces.ToImmutable(), isNamespaceFileScoped, classes.ToImmutable(), methodDeclarationSyntax.Identifier.ValueText, content, disableNullable, rewriter.Diagnostics, hasErrors);
var result = new MethodToGenerate(index, namespaces.ToImmutable(), isNamespaceFileScoped, classes.ToImmutable(), methodDeclarationSyntax.Identifier.ValueText, content, disableNullable, diagnostics, hasErrors);

return result;
}
Expand Down
38 changes: 38 additions & 0 deletions src/Zomp.SyncMethodGenerator/UserMappings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
namespace Zomp.SyncMethodGenerator;

internal sealed class UserMappings(
EquatableArray<(string Key, string Namespace, string Method)> mappings,
EquatableArray<ReportedDiagnostic> diagnostics) : IEquatable<UserMappings>
{
[field: MaybeNull]
public IReadOnlyDictionary<string, (string Namespace, string Method)> Mappings
=> field ??= mappings.ToDictionary(kv => kv.Key, kv => (kv.Namespace, kv.Method));

public EquatableArray<ReportedDiagnostic> Diagnostics { get; } = diagnostics;

public bool Equals(UserMappings? other)
{
return other is not null &&
Mappings.Equals(other.Mappings) &&
Diagnostics.Equals(other.Diagnostics);
}

public override bool Equals(object? obj)
{
return obj is not null && (ReferenceEquals(this, obj) || (obj is UserMappings other && Equals(other)));
}

public override int GetHashCode() => HashCode.Combine(Mappings);

public bool TryGetValue(IMethodSymbol ms, out (string Namespace, string Method) value)
{
if (ms.ContainingType is { } containingType)
{
return Mappings.TryGetValue(
containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + "." + ms.Name,
out value);
}

return Mappings.TryGetValue(ms.Name, out value);
}
}
19 changes: 19 additions & 0 deletions tests/GenerationSandbox.Tests/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace GenerationSandbox.Tests;

internal static partial class EnumerableExtensions
{
public static Task<List<TSource>> ToListAsync<TSource>(this IEnumerable<TSource> source)
{
return Task.FromResult(source.ToList());
}

[Zomp.SyncMethodGenerator.CreateSyncVersion]
public static Task<List<TSource>> ReturnListAsync<TSource>(IEnumerable<TSource> source)
{
return ToListAsync(source);
}
}
4 changes: 4 additions & 0 deletions tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Zomp.SyncMethodGenerator\Zomp.SyncMethodGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
<ItemGroup>
<None Remove="SyncMethods.txt" />
<AdditionalFiles Include="SyncMethods.txt" />
</ItemGroup>

</Project>
1 change: 1 addition & 0 deletions tests/GenerationSandbox.Tests/SyncMethods.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GenerationSandbox.Tests.EnumerableExtensions.ToListAsync=System.Linq.Enumerable.ToList
9 changes: 9 additions & 0 deletions tests/GenerationSandbox.Tests/SyncTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ public void TestGeneratedAverageRoutine()
public void TestStaticAsyncWithIProgress()
=> AsyncWithIProgress.CallWithIProgress();

[Fact]
public void TestReturnListAsync()
{
var expected = new int[] { 1, 2, 3, 4, 5 };
var myNumbers = new int[] { 1, 2, 3, 4, 5 };
var result = EnumerableExtensions.ReturnList(myNumbers);
Assert.Equal(expected, result);
}

#if NET8_0_OR_GREATER
[Fact]
public void TestIndexOfMaxSoFar()
Expand Down
15 changes: 11 additions & 4 deletions tests/Generator.Tests/IncrementalGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,17 @@ public void CheckGeneratorIsIncremental(
driver = driver.RunGenerators(compilation);

var result = driver.GetRunResult().Results.Single();
var sourceOutputs =
result.TrackedOutputSteps.SelectMany(outputStep => outputStep.Value).SelectMany(output => output.Outputs);
var (value, reason) = Assert.Single(sourceOutputs);
Assert.Equal(sourceStepReason, reason);
var trackedOutput = Assert.Single(result.TrackedOutputSteps);

Assert.Equal(2, trackedOutput.Value.Length);

// User mappings are not changed in this test case, so they should always return 'Cached'
var userMappingsOutput = trackedOutput.Value[0].Outputs;
Assert.Equal(IncrementalStepRunReason.Cached, Assert.Single(userMappingsOutput).Reason);
Assert.Equal(IncrementalStepRunReason.Cached, result.TrackedSteps["GetUserMappings"].Single().Outputs[0].Reason);

var sourceOutputs = trackedOutput.Value[1].Outputs;
Assert.Equal(sourceStepReason, Assert.Single(sourceOutputs).Reason);
Assert.Equal(executeStepReason, result.TrackedSteps["GetMethodToGenerate"].Single().Outputs[0].Reason);
Assert.Equal(combineStepReason, result.TrackedSteps["GenerateSource"].Single().Outputs[0].Reason);
}
Expand Down
13 changes: 13 additions & 0 deletions tests/Generator.Tests/MemoryAdditionalText.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Microsoft.CodeAnalysis.Text;

namespace Generator.Tests;

public class MemoryAdditionalText(string path, string text) : AdditionalText
{
public override string Path { get; } = path;

public override SourceText GetText(CancellationToken cancellationToken = default)
{
return SourceText.From(text);
}
}
Loading
Loading