From 48e1da90617391521679a695036ad2917a633299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Luthi?= Date: Fri, 28 Nov 2025 22:28:29 +0100 Subject: [PATCH] Fix rewriter when the sync method is in the same extensions class --- .../AsyncToSyncRewriter.cs | 13 ++++++++++++- .../EntityFrameworkQueryableExtensions.cs | 17 +++++++++++------ tests/Generator.Tests/ExtensionMethodTests.cs | 11 ++++++++--- ...nsions.QueryableExtensionAsync.g.verified.cs | 10 ++++++++-- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index cb5cba5..9ab918b 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -1847,7 +1847,18 @@ private InvocationExpressionSyntax UnwrapExtension(InvocationExpressionSyntax ie var newName = reducedFrom.Name; newName = changeMemoryToSpan ? GetNewName(reducedFrom) : RemoveAsync(newName); - var fullyQualifiedName = $"{MakeType(reducedFrom.ContainingType)}.{newName}"; + var newNameExistsInContainingType = semanticModel.Compilation.References + .Select(semanticModel.Compilation.GetAssemblyOrModuleSymbol) + .Append(semanticModel.Compilation.Assembly) + .OfType() + .Select(assemblySymbol => assemblySymbol.GetTypeByMetadataName(reducedFrom.ContainingType.ToString())) + .OfType() + .SelectMany(symbol => symbol.GetMembers(newName)) + .Any(); + + var fullyQualifiedName = newNameExistsInContainingType + ? $"{reducedFrom.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}.{newName}" + : $"{MakeType(reducedFrom.ContainingType)}.{newName}"; var es = (ies.Expression switch { diff --git a/tests/GenerationSandbox.Tests/EntityFrameworkQueryableExtensions.cs b/tests/GenerationSandbox.Tests/EntityFrameworkQueryableExtensions.cs index 57495f1..15bc36f 100644 --- a/tests/GenerationSandbox.Tests/EntityFrameworkQueryableExtensions.cs +++ b/tests/GenerationSandbox.Tests/EntityFrameworkQueryableExtensions.cs @@ -1,8 +1,7 @@ -using System.Linq; -using System.Threading; +using System.Threading; using System.Threading.Tasks; -namespace Zomp.SyncMethodGenerator.IntegrationTests +namespace GenerationSandbox.Tests { using Microsoft.EntityFrameworkCore; @@ -14,13 +13,19 @@ public partial class EntityFrameworkQueryableExtensions /// /// Test method. /// - /// The source. + /// The db context. /// The cancellation token. /// The result. [Zomp.SyncMethodGenerator.CreateSyncVersion] - public async Task QueryableExtensionAsync(IQueryable source, CancellationToken cancellationToken) + public async Task QueryableExtensionAsync(DbContext dbContext, CancellationToken cancellationToken) { - return await source.AnyAsync(cancellationToken); + var dbSet = dbContext.Set(); + if (await dbSet.AnyAsync(cancellationToken)) + { + return await dbSet.ExecuteDeleteAsync(cancellationToken); + } + + return 0; } } } diff --git a/tests/Generator.Tests/ExtensionMethodTests.cs b/tests/Generator.Tests/ExtensionMethodTests.cs index 1ec4c22..c7b216b 100644 --- a/tests/Generator.Tests/ExtensionMethodTests.cs +++ b/tests/Generator.Tests/ExtensionMethodTests.cs @@ -121,7 +121,6 @@ public async IAsyncEnumerable WhereLessThan(T threshold) [Fact] public Task EntityFrameworkQueryableExtensions() => """ -using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -132,9 +131,15 @@ namespace Zomp.SyncMethodGenerator.IntegrationTests public partial class EntityFrameworkQueryableExtensions { [Zomp.SyncMethodGenerator.CreateSyncVersion] - public async Task QueryableExtensionAsync(IQueryable source, CancellationToken cancellationToken) + public async Task QueryableExtensionAsync(DbContext dbContext, CancellationToken cancellationToken) { - return await source.AnyAsync(cancellationToken); + var dbSet = dbContext.Set(); + if (await dbSet.AnyAsync(cancellationToken)) + { + return await dbSet.ExecuteDeleteAsync(cancellationToken); + } + + return 0; } } } diff --git a/tests/Generator.Tests/Snapshots/ExtensionMethodTests.EntityFrameworkQueryableExtensions#Zomp.SyncMethodGenerator.IntegrationTests.EntityFrameworkQueryableExtensions.QueryableExtensionAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/ExtensionMethodTests.EntityFrameworkQueryableExtensions#Zomp.SyncMethodGenerator.IntegrationTests.EntityFrameworkQueryableExtensions.QueryableExtensionAsync.g.verified.cs index b79f037..0e1bb7c 100644 --- a/tests/Generator.Tests/Snapshots/ExtensionMethodTests.EntityFrameworkQueryableExtensions#Zomp.SyncMethodGenerator.IntegrationTests.EntityFrameworkQueryableExtensions.QueryableExtensionAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/ExtensionMethodTests.EntityFrameworkQueryableExtensions#Zomp.SyncMethodGenerator.IntegrationTests.EntityFrameworkQueryableExtensions.QueryableExtensionAsync.g.verified.cs @@ -5,9 +5,15 @@ namespace Zomp.SyncMethodGenerator.IntegrationTests { public partial class EntityFrameworkQueryableExtensions { - public bool QueryableExtension(global::System.Linq.IQueryable source) + public int QueryableExtension(global::Microsoft.EntityFrameworkCore.DbContext dbContext) { - return global::System.Linq.Queryable.Any(source); + var dbSet = dbContext.Set(); + if (global::System.Linq.Queryable.Any(dbSet)) + { + return global::Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions.ExecuteDelete(dbSet); + } + + return 0; } } }