diff --git a/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net8.cs b/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net8.cs index e733652..6796c75 100644 --- a/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net8.cs +++ b/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net8.cs @@ -19,7 +19,7 @@ static Offset2RowNumberConvertVisitor() var method = typeof(SelectExpression).GetMethod("GenerateOuterColumn", BindingFlags.NonPublic | BindingFlags.Instance); if (!typeof(ColumnExpression).IsAssignableFrom(method?.ReturnType)) { - throw new InvalidOperationException("SelectExpression.GenerateOuterColum() was not found"); + throw new InvalidOperationException("SelectExpression.GenerateOuterColumn() was not found"); } TableReferenceExpressionType = method.GetParameters().First().ParameterType; @@ -43,7 +43,7 @@ protected override Expression VisitExtension(Expression node) } if (node is SelectExpression se) { - return VisitSelect(se); + node = VisitSelect(se); } return base.VisitExtension(node); } diff --git a/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net9.cs b/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net9.cs index b5bda52..479b8a8 100644 --- a/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net9.cs +++ b/EntityFrameworkCore.UseRowNumberForPaging/Offset2RowNumberConvertVisitor.net9.cs @@ -1,5 +1,5 @@ #if NET9_0_OR_GREATER -using System.Collections.Generic; +using System; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -14,105 +14,88 @@ internal class Offset2RowNumberConvertVisitor( SqlAliasManager sqlAliasManager ) : ExpressionVisitor { + private static readonly MethodInfo GenerateOuterColumnAccessor; + + static Offset2RowNumberConvertVisitor() + { + var method = typeof(SelectExpression).GetMethod("GenerateOuterColumn", BindingFlags.NonPublic | BindingFlags.Instance); + if (!typeof(ColumnExpression).IsAssignableFrom(method?.ReturnType)) + { + throw new InvalidOperationException("SelectExpression.GenerateOuterColumn() was not found"); + } + GenerateOuterColumnAccessor = method; + } + + private readonly Expression root = root; private readonly ISqlExpressionFactory sqlExpressionFactory = sqlExpressionFactory; private readonly SqlAliasManager sqlAliasManager = sqlAliasManager; - protected override Expression VisitExtension(Expression node) => node switch - { - ShapedQueryExpression shapedQueryExpression => shapedQueryExpression.Update(Visit(shapedQueryExpression.QueryExpression), Visit(shapedQueryExpression.ShaperExpression)), - SelectExpression se => VisitSelect(se), - _ => base.VisitExtension(node), - }; - - private SelectExpression VisitSelect(SelectExpression selectExpression) + protected override Expression VisitExtension(Expression node) { - // if we have no offset, we do not need to use ROW_NUMBER for offset calculations - if (selectExpression.Offset == null) + if (node is ShapedQueryExpression shapedQueryExpression) { - return selectExpression; + return shapedQueryExpression.Update(Visit(shapedQueryExpression.QueryExpression), Visit(shapedQueryExpression.ShaperExpression)); } - var isRootQuery = selectExpression == root; + if (node is SelectExpression se) + { + node = VisitSelect(se); + } + return base.VisitExtension(node); + } - // store offset, limit and orderings + private Expression VisitSelect(SelectExpression selectExpression) + { var oldOffset = selectExpression.Offset; + if (oldOffset == null) + return selectExpression; var oldLimit = selectExpression.Limit; var oldOrderings = selectExpression.Orderings; + var newOrderings = oldOrderings.Count > 0 && (oldLimit != null || selectExpression == root) + ? oldOrderings.ToList() + : []; + // Change SelectExpression + selectExpression = selectExpression.Update(projections: selectExpression.Projection.ToList(), + tables: selectExpression.Tables.ToList(), + predicate: selectExpression.Predicate, + groupBy: selectExpression.GroupBy.ToList(), + having: selectExpression.Having, + orderings: newOrderings, + limit: null, + offset: null); + var rowOrderings = oldOrderings.Count != 0 ? oldOrderings + : [new OrderingExpression(new SqlFragmentExpression("(SELECT 1)"), true)]; - // remove offset and limit by creating new select expression from old one - // we can't use SelectExpression.Update because that breaks PushDownIntoSubquery - var enhancedSelect = new SelectExpression( - alias: null, - tables: new(selectExpression.Tables), - predicate: selectExpression.Predicate, - groupBy: new(selectExpression.GroupBy), - having: selectExpression.Having, - projections: new(selectExpression.Projection), - distinct: selectExpression.IsDistinct, - orderings: isRootQuery ? [] : new(selectExpression.Orderings), - offset: null, - limit: null, - tags: selectExpression.Tags, - annotations: null, - sqlAliasManager: sqlAliasManager, - isMutable: true - ); - // set up row_number expression - var rowNumber = new RowNumberExpression([], isRootQuery ? [ new(new SqlFragmentExpression("(SELECT 1)"), true) ] : oldOrderings, oldOffset.TypeMapping); - enhancedSelect.AddToProjection(rowNumber); - enhancedSelect.PushdownIntoSubquery(); + // restore sql alias manager in updated expression + typeof(SelectExpression) + .GetField("_sqlAliasManager", BindingFlags.Instance | BindingFlags.NonPublic) + .SetValue(selectExpression, sqlAliasManager); - // restore ordering to outer select after earlier removal - if (isRootQuery) - { - foreach (var orderingClause in oldOrderings) - { - selectExpression.AppendOrdering(orderingClause); - } - } + selectExpression.PushdownIntoSubquery(); - // generate subselect rownumber access expression - var innerTable = enhancedSelect.Tables[0]; - var rowNumberColname = enhancedSelect.Projection[enhancedSelect.Projection.Count - 1].Alias; - var rowNumberAlias = enhancedSelect.CreateColumnExpression(innerTable, rowNumberColname, typeof(int), null, false); + var subQuery = (SelectExpression)selectExpression.Tables[0]; + var projection = new RowNumberExpression([], rowOrderings, oldOffset.TypeMapping); + var left = GenerateOuterColumnAccessor.Invoke( + subQuery, + [ + subQuery.Alias, + projection, + sqlAliasManager.GenerateTableAlias("row"), + ]) as ColumnExpression; + selectExpression.ApplyPredicate(sqlExpressionFactory.GreaterThan(left!, oldOffset)); - // apply offset and limit - var rowNumberGtOffset = sqlExpressionFactory.GreaterThan(rowNumberAlias, oldOffset); - enhancedSelect.ApplyPredicate(rowNumberGtOffset); if (oldLimit != null) { if (oldOrderings.Count == 0) { - var rowNumberLimiting = sqlExpressionFactory.LessThanOrEqual(rowNumberAlias, sqlExpressionFactory.Add(oldOffset, oldLimit)); - enhancedSelect.ApplyPredicate(rowNumberLimiting); + selectExpression.ApplyPredicate(sqlExpressionFactory.LessThanOrEqual(left, sqlExpressionFactory.Add(oldOffset, oldLimit))); } else { - enhancedSelect.ApplyLimit(oldLimit); + selectExpression.ApplyLimit(oldLimit); } } - - enhancedSelect.ApplyProjection(); // to make immutable - var restoredProjections = enhancedSelect.Projection - .Where(p => p.Alias != rowNumberColname) - .ToList(); - var result = enhancedSelect.Update( - enhancedSelect.Tables, - enhancedSelect.Predicate, - enhancedSelect.GroupBy, - enhancedSelect.Having, - restoredProjections, - enhancedSelect.Orderings, - enhancedSelect.Offset, - enhancedSelect.Limit - ); - - // restore projection member binding lookup capabilities via reflection magic - var clientProjections = typeof(SelectExpression).GetField("_clientProjections", BindingFlags.NonPublic | BindingFlags.Instance); - clientProjections.SetValue(result, clientProjections.GetValue(selectExpression)); - var projectionMapping = typeof(SelectExpression).GetField("_projectionMapping", BindingFlags.NonPublic | BindingFlags.Instance); - projectionMapping.SetValue(result, projectionMapping.GetValue(selectExpression)); - return result; + return selectExpression; } } #endif \ No newline at end of file