From a02f364edc2ed1eae022f88b0e0214842cc77c4a Mon Sep 17 00:00:00 2001 From: ouyuanning <52741323@qq.com> Date: Thu, 11 Jun 2026 12:03:00 +0800 Subject: [PATCH 1/2] fix(frontend): improve role rule merge and set role behavior --- pkg/frontend/authenticate.go | 32 +- pkg/frontend/authenticate_test.go | 56 +++ pkg/frontend/rewrite_rule.go | 361 +++++++--------- pkg/frontend/rewrite_rule_test.go | 395 +++++------------- .../cases/tenant/privilege/owner1.result | 4 +- .../cases/tenant/privilege/owner1.sql | 3 +- .../cases/zz_accesscontrol/role_rule.result | 28 +- .../cases/zz_accesscontrol/role_rule.sql | 28 +- .../cases/zz_accesscontrol/use_role.result | 26 +- .../cases/zz_accesscontrol/use_role.sql | 18 +- 10 files changed, 390 insertions(+), 561 deletions(-) diff --git a/pkg/frontend/authenticate.go b/pkg/frontend/authenticate.go index 5444cd80b4daf..cc2e3fed2a533 100644 --- a/pkg/frontend/authenticate.go +++ b/pkg/frontend/authenticate.go @@ -3498,22 +3498,16 @@ func doAlterAccount(ctx context.Context, ses *Session, aa *alterAccount) (err er return err } -// doSetSecondaryRoleAll set the session role of the user with smallness role_id +// doSetSecondaryRoleAll validates user role metadata before enabling all secondary roles. +// The current primary role must not change; SET SECONDARY ROLE ALL only affects secondary roles. func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) { var sql string var userId uint32 - var erArray []ExecResult - var roleId int64 - var roleName string account := ses.GetTenantInfo() // get current user_id userId = account.GetUserID() - // init role_id and role_name - roleId = publicRoleID - roleName = publicRoleName - // step1:get all roles expect public bh := ses.GetBackgroundExec(ctx) defer bh.Close() @@ -3533,26 +3527,7 @@ func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) { return err } - erArray, err = getResultSet(ctx, bh) - if err != nil { - return err - } - if execResultArrayHasData(erArray) { - roleId, err = erArray[0].GetInt64(ctx, 0, 0) - if err != nil { - return err - } - - roleName, err = erArray[0].GetString(ctx, 0, 1) - if err != nil { - return err - } - } - - // step2 : switch the default role and role id; - account.SetDefaultRoleID(uint32(roleId)) - account.SetDefaultRole(roleName) - + _, err = getResultSet(ctx, bh) return err } @@ -3649,6 +3624,7 @@ func doSwitchRole(ctx context.Context, ses *Session, sr *tree.SetRole) (err erro account.SetDefaultRole(sr.Role.UserName) // then, reset secondary role to none account.SetUseSecondaryRole(false) + ses.InvalidatePrivilegeCache() return err } diff --git a/pkg/frontend/authenticate_test.go b/pkg/frontend/authenticate_test.go index 0ff6484679a35..4bf2c738d32c2 100644 --- a/pkg/frontend/authenticate_test.go +++ b/pkg/frontend/authenticate_test.go @@ -8185,6 +8185,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) { err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses) convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") }) convey.Convey("do set secondary role succ", t, func() { @@ -8224,6 +8226,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) { err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses) convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") }) } @@ -8263,6 +8267,8 @@ func TestDoSwitchRoleSecondaryRoleAllInvalidatesRuleCache(t *testing.T) { }) convey.So(err, convey.ShouldBeNil) convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeTrue) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") ses.ruleCacheMu.RLock() cacheIsNil := ses.ruleCache == nil @@ -8337,6 +8343,56 @@ func TestDoSwitchRoleSecondaryRoleNoneInvalidatesRuleCache(t *testing.T) { }) } +func TestDoSwitchRolePrimaryRoleInvalidatesRuleCache(t *testing.T) { + convey.Convey("set role switches current role and invalidates rule cache", t, func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + bh := &backgroundExecTest{} + bh.init() + + bhStub := gostub.StubFunc(&NewBackgroundExec, bh) + defer bhStub.Reset() + + ses := newSes(&privilege{}, ctrl) + tenant := &TenantInfo{ + Tenant: "test_account", + User: "test_user", + DefaultRole: "role1", + TenantID: 3001, + UserID: 3, + DefaultRoleID: 5, + } + tenant.SetUseSecondaryRole(true) + ses.SetTenantInfo(tenant) + ses.ruleCache = map[string]string{"db1.t1": "select a from db1.t1"} + + bh.sql2result["begin;"] = nil + bh.sql2result["commit;"] = nil + bh.sql2result["rollback;"] = nil + + roleSQL, err := getSqlForRoleIdOfRole(context.Background(), "role2") + convey.So(err, convey.ShouldBeNil) + bh.sql2result[roleSQL] = newMrsForRoleIdOfRole([][]interface{}{{int64(6)}}) + bh.sql2result[getSqlForCheckUserGrant(6, int64(tenant.UserID))] = newMrsForCheckUserGrant([][]interface{}{ + {int64(6), int64(tenant.UserID), false}, + }) + + err = doSwitchRole(ses.GetTxnHandler().GetTxnCtx(), ses, &tree.SetRole{ + Role: &tree.Role{UserName: "role2"}, + }) + convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(6)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role2") + convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeFalse) + + ses.ruleCacheMu.RLock() + cacheIsNil := ses.ruleCache == nil + ses.ruleCacheMu.RUnlock() + convey.So(cacheIsNil, convey.ShouldBeTrue) + }) +} + func TestGetSessionSysVar(t *testing.T) { convey.Convey("get session system variable succ", t, func() { ctrl := gomock.NewController(t) diff --git a/pkg/frontend/rewrite_rule.go b/pkg/frontend/rewrite_rule.go index b396c33867b2f..a311d4a98414e 100644 --- a/pkg/frontend/rewrite_rule.go +++ b/pkg/frontend/rewrite_rule.go @@ -30,7 +30,6 @@ import ( "github.com/matrixorigin/matrixone/pkg/sql/parsers" "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" - "github.com/matrixorigin/matrixone/pkg/sql/plan/function" ) const ( @@ -307,30 +306,24 @@ func getSqlForRoleRulesOfRoleIDs(roleIDs []int64) string { } func mergeRewriteRules(ctx context.Context, leftRule, rightRule string) (string, error) { - leftRule = trimRewriteRuleForUnion(leftRule) - rightRule = trimRewriteRuleForUnion(rightRule) + leftRule = trimRewriteRuleForMerge(leftRule) + rightRule = trimRewriteRuleForMerge(rightRule) - leftColumns, ok, err := rewriteRuleOutputColumns(ctx, leftRule) - if err != nil { - return "", err - } - if !ok { - return rightRule, nil + if leftRule == rightRule { + return leftRule, nil } - rightColumns, ok, err := rewriteRuleOutputColumns(ctx, rightRule) + + mergedRule, ok, err := mergeRewriteRulesSafely(ctx, leftRule, rightRule) if err != nil { return "", err } if !ok { return rightRule, nil } - if !sameRewriteOutputColumns(leftColumns, rightColumns) { - return rightRule, nil - } - return fmt.Sprintf("(%s) union distinct (%s)", leftRule, rightRule), nil + return mergedRule, nil } -func trimRewriteRuleForUnion(rule string) string { +func trimRewriteRuleForMerge(rule string) string { rule = strings.TrimSpace(rule) for strings.HasSuffix(rule, ";") { rule = strings.TrimSpace(strings.TrimSuffix(rule, ";")) @@ -338,265 +331,197 @@ func trimRewriteRuleForUnion(rule string) string { return rule } -type rewriteRuleOutputColumn struct { - name string - expr string +type rewriteRuleMergeShape struct { + stmt *tree.Select + clause *tree.SelectClause + selectList string + table string } -func rewriteRuleOutputColumns(ctx context.Context, rule string) ([]rewriteRuleOutputColumn, bool, error) { - stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) +// Rewrite rules are merged by OR-ing filters only for simple row-preserving +// single-table SELECTs with exactly matching select lists and source tables. +func mergeRewriteRulesSafely(ctx context.Context, leftRule, rightRule string) (string, bool, error) { + leftShape, ok, err := rewriteRuleMergeShapeForRule(ctx, leftRule) if err != nil { - return nil, false, moerr.NewInternalErrorf(ctx, "failed to parse rewrite rule %q while merging rewrite rules: %v", rule, err) + return "", false, err } - columns, ok := outputColumnsFromRewriteStatement(stmt) - return columns, ok, nil -} - -func validateRewriteRuleSQL(ctx context.Context, rule string) error { - if strings.TrimSpace(rule) == "" { - return moerr.NewInvalidInput(ctx, "rewrite rule SQL is empty") + if !ok { + return "", false, nil } - stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + rightShape, ok, err := rewriteRuleMergeShapeForRule(ctx, rightRule) if err != nil { - return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: %v", rule, err) + return "", false, err + } + if !ok { + return "", false, nil } - switch stmt.(type) { - case *tree.Select, *tree.ParenSelect: - return nil - default: - return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: only accept SELECT-like statements as rewrites", rule) + if leftShape.selectList != rightShape.selectList || leftShape.table != rightShape.table { + return "", false, nil + } + + merged := *leftShape.stmt + mergedClause := *leftShape.clause + mergedClause.Where = mergeRewriteRuleWhere(leftShape.clause.Where, rightShape.clause.Where) + merged.Select = &mergedClause + + return tree.String(&merged, dialect.MYSQL), true, nil +} + +func rewriteRuleMergeShapeForRule(ctx context.Context, rule string) (*rewriteRuleMergeShape, bool, error) { + stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + if err != nil { + return nil, false, moerr.NewInternalErrorf(ctx, "failed to parse rewrite rule %q while merging rewrite rules: %v", rule, err) } + shape, ok := rewriteRuleMergeShapeForStatement(stmt) + return shape, ok, nil } -func outputColumnsFromRewriteStatement(stmt tree.Statement) ([]rewriteRuleOutputColumn, bool) { +func rewriteRuleMergeShapeForStatement(stmt tree.Statement) (*rewriteRuleMergeShape, bool) { switch s := stmt.(type) { case *tree.Select: - if len(s.OrderBy) > 0 || s.Limit != nil { - return nil, false - } - return outputColumnsFromRewriteSelectStatement(s.Select) + return rewriteRuleMergeShapeForSelect(s) case *tree.ParenSelect: if s.Select == nil { return nil, false } - return outputColumnsFromRewriteStatement(s.Select) + return rewriteRuleMergeShapeForStatement(s.Select) default: return nil, false } } -func outputColumnsFromRewriteSelectStatement(stmt tree.SelectStatement) ([]rewriteRuleOutputColumn, bool) { - switch s := stmt.(type) { - case *tree.SelectClause: - if !mergeableRewriteSelectClause(s) { - return nil, false - } - columns := make([]rewriteRuleOutputColumn, 0, len(s.Exprs)) - for _, expr := range s.Exprs { - column, ok := outputColumnFromRewriteSelectExpr(expr) - if !ok { - return nil, false - } - columns = append(columns, column) - } - return columns, true - case *tree.UnionClause: - return outputColumnsFromRewriteSelectStatement(s.Left) - case *tree.ParenSelect: - if s.Select == nil { - return nil, false - } - return outputColumnsFromRewriteStatement(s.Select) - case *tree.Select: - if len(s.OrderBy) > 0 || s.Limit != nil { - return nil, false - } - return outputColumnsFromRewriteSelectStatement(s.Select) - default: +func rewriteRuleMergeShapeForSelect(stmt *tree.Select) (*rewriteRuleMergeShape, bool) { + if stmt == nil || len(stmt.OrderBy) > 0 || stmt.Limit != nil || stmt.With != nil || + stmt.TimeWindow != nil || stmt.RankOption != nil || stmt.Ep != nil || stmt.SelectLockInfo != nil { return nil, false } -} -func mergeableRewriteSelectClause(stmt *tree.SelectClause) bool { - if stmt.Distinct || stmt.Option&(tree.QuerySpecOptionDistinct|tree.QuerySpecOptionDistinctRow) != 0 { - return false + clause, ok := stmt.Select.(*tree.SelectClause) + if !ok || !rewriteRuleSelectClauseIsMergeable(clause) { + return nil, false } - if stmt.GroupBy != nil || stmt.Having != nil { - return false + + table, ok := rewriteRuleSingleTableSource(clause.From) + if !ok { + return nil, false } - for _, expr := range stmt.Exprs { - if !rewriteExprIsMergeSafe(expr.Expr) { + + return &rewriteRuleMergeShape{ + stmt: stmt, + clause: clause, + selectList: normalizeRewriteSQL(tree.String(&clause.Exprs, dialect.MYSQL)), + table: table, + }, true +} + +func rewriteRuleSelectClauseIsMergeable(clause *tree.SelectClause) bool { + return clause != nil && + !clause.Distinct && + clause.Option == 0 && + clause.GroupBy == nil && + clause.Having == nil && + rewriteRuleSelectExprsAreMergeable(clause.Exprs) +} + +func rewriteRuleSelectExprsAreMergeable(exprs tree.SelectExprs) bool { + for _, expr := range exprs { + switch e := expr.Expr.(type) { + case tree.UnqualifiedStar: + if expr.As != nil && !expr.As.Empty() { + return false + } + continue + case *tree.UnresolvedName: + if e.Star && expr.As != nil && !expr.As.Empty() { + return false + } + if e.Star || e.NumParts > 0 { + continue + } + return false + default: return false } } return true } -func outputColumnFromRewriteSelectExpr(expr tree.SelectExpr) (rewriteRuleOutputColumn, bool) { - if expr.As != nil && !expr.As.Empty() { - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(expr.As.Compare()), - expr: normalizeRewriteOutputExpr(expr.Expr), - }, true +func rewriteRuleSingleTableSource(from *tree.From) (string, bool) { + if from == nil || len(from.Tables) != 1 { + return "", false } - switch e := expr.Expr.(type) { - case tree.UnqualifiedStar: - return rewriteRuleOutputColumn{}, false - case *tree.UnresolvedName: - if e.Star { - return rewriteRuleOutputColumn{}, false - } - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(e.ColName()), - expr: normalizeRewriteOutputColumn(tree.String(expr.Expr, dialect.MYSQL)), - }, true - default: - exprText := normalizeRewriteOutputExpr(expr.Expr) - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(tree.String(expr.Expr, dialect.MYSQL)), - expr: exprText, - }, true + tableExpr, ok := rewriteRuleSingleTableExpr(from.Tables[0]) + if !ok { + return "", false } -} -func normalizeRewriteOutputColumn(column string) string { - return strings.ToLower(strings.TrimSpace(column)) + return normalizeRewriteSQL(tree.String(tableExpr, dialect.MYSQL)), true } -func normalizeRewriteOutputExpr(expr tree.Expr) string { - if _, ok := expr.(*tree.UnresolvedName); ok { - return normalizeRewriteOutputColumn(tree.String(expr, dialect.MYSQL)) +func rewriteRuleSingleTableExpr(expr tree.TableExpr) (*tree.AliasedTableExpr, bool) { + if joinExpr, ok := expr.(*tree.JoinTableExpr); ok { + // The MySQL parser wraps a lone table reference as a CROSS JoinTableExpr. + if joinExpr.Right != nil || joinExpr.Cond != nil || joinExpr.Option != "" { + return nil, false + } + expr = joinExpr.Left } - return strings.TrimSpace(tree.String(expr, dialect.MYSQL)) -} -func sameRewriteOutputColumns(leftColumns, rightColumns []rewriteRuleOutputColumn) bool { - if len(leftColumns) != len(rightColumns) { - return false + tableExpr, ok := expr.(*tree.AliasedTableExpr) + if !ok || tableExpr.As.Cols != nil || len(tableExpr.IndexHints) > 0 { + return nil, false } - for i := range leftColumns { - if leftColumns[i] != rightColumns[i] { - return false - } + + tableName, ok := tableExpr.Expr.(*tree.TableName) + if !ok || tableName.AtTsExpr != nil { + return nil, false } - return true + + return tableExpr, true } -func rewriteExprIsMergeSafe(expr tree.Expr) bool { - if expr == nil { - return true - } - - switch e := expr.(type) { - case *tree.UnresolvedName, tree.UnqualifiedStar, *tree.NumVal, *tree.StrVal: - return true - case *tree.BinaryExpr: - return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) - case *tree.UnaryExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.ComparisonExpr: - return rewriteExprIsMergeSafe(e.Left) && - rewriteExprIsMergeSafe(e.Right) && - rewriteExprIsMergeSafe(e.Escape) - case *tree.AndExpr: - return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) - case *tree.XorExpr: - return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) - case *tree.OrExpr: - return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) - case *tree.NotExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsNullExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsNotNullExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsUnknownExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsNotUnknownExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsTrueExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsNotTrueExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsFalseExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.IsNotFalseExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.ParenExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.FuncExpr: - name := rewriteFuncExprName(e) - if name == "" || e.Type == tree.FUNC_TYPE_TABLE || - e.WindowSpec != nil || - function.GetFunctionIsAggregateByName(name) || - function.GetFunctionIsWinFunByName(name) { - return false - } - return rewriteExprsAreMergeSafe(e.Exprs) && rewriteOrderByIsMergeSafe(e.OrderBy) - case *tree.SerialExtractExpr: - return rewriteExprIsMergeSafe(e.SerialExpr) && rewriteExprIsMergeSafe(e.IndexExpr) - case *tree.CastExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.BitCastExpr: - return rewriteExprIsMergeSafe(e.Expr) - case *tree.Tuple: - return rewriteExprsAreMergeSafe(e.Exprs) - case *tree.RangeCond: - return rewriteExprIsMergeSafe(e.Left) && - rewriteExprIsMergeSafe(e.From) && - rewriteExprIsMergeSafe(e.To) - case *tree.CaseExpr: - if !rewriteExprIsMergeSafe(e.Expr) || !rewriteExprIsMergeSafe(e.Else) { - return false - } - for _, when := range e.Whens { - if when == nil { - continue - } - if !rewriteExprIsMergeSafe(when.Cond) || !rewriteExprIsMergeSafe(when.Val) { - return false - } - } - return true - case *tree.IntervalExpr: - return rewriteExprIsMergeSafe(e.Expr) +func mergeRewriteRuleWhere(left, right *tree.Where) *tree.Where { + // Mergeability is decided from the top-level SELECT shape only. Predicate + // internals, including subqueries, stay opaque and are only OR-ed together. + switch { + case left == nil || left.Expr == nil: + return nil + case right == nil || right.Expr == nil: + return nil default: - return false + return &tree.Where{ + Type: tree.AstWhere, + Expr: tree.NewOrExpr( + tree.NewParentExpr(left.Expr), + tree.NewParentExpr(right.Expr), + ), + } } } -func rewriteExprsAreMergeSafe(exprs tree.Exprs) bool { - for _, expr := range exprs { - if !rewriteExprIsMergeSafe(expr) { - return false - } - } - return true +func normalizeRewriteSQL(sql string) string { + return strings.ToLower(strings.TrimSpace(sql)) } -func rewriteOrderByIsMergeSafe(orderBy tree.OrderBy) bool { - for _, order := range orderBy { - if order == nil { - continue - } - if !rewriteExprIsMergeSafe(order.Expr) { - return false - } +func validateRewriteRuleSQL(ctx context.Context, rule string) error { + if strings.TrimSpace(rule) == "" { + return moerr.NewInvalidInput(ctx, "rewrite rule SQL is empty") } - return true -} -func rewriteFuncExprName(fn *tree.FuncExpr) string { - if fn.FuncName != nil { - return strings.ToLower(fn.FuncName.Origin()) + stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + if err != nil { + return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: %v", rule, err) } - if name, ok := fn.Func.FunctionReference.(*tree.UnresolvedName); ok { - return strings.ToLower(name.ColName()) + + switch stmt.(type) { + case *tree.Select, *tree.ParenSelect: + return nil + default: + return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: only accept SELECT-like statements as rewrites", rule) } - return "" } // escapeSQLString escapes a string for safe use in SQL literals using writeEscapedSQLString. diff --git a/pkg/frontend/rewrite_rule_test.go b/pkg/frontend/rewrite_rule_test.go index b22b33088551d..3607ea00097b0 100644 --- a/pkg/frontend/rewrite_rule_test.go +++ b/pkg/frontend/rewrite_rule_test.go @@ -18,7 +18,6 @@ import ( "context" "math/rand" "reflect" - "strconv" "strings" "testing" "testing/quick" @@ -29,8 +28,6 @@ import ( "github.com/stretchr/testify/require" "github.com/matrixorigin/matrixone/pkg/defines" - "github.com/matrixorigin/matrixone/pkg/sql/parsers" - "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" ) @@ -475,7 +472,7 @@ func TestLoadRuleCacheIncludesSecondaryRoles(t *testing.T) { rules, err := loadRuleCache(context.Background(), ses) require.NoError(t, err) require.Equal(t, map[string]string{ - "db1.t1": "(select a, age from db1.t1 where age < 3) union distinct (select A, Age from db1.t1 where age > 28)", + "db1.t1": "select a, age from db1.t1 where (age < 3) or (age > 28)", "db2.t2": "select a from db2.t2 where a = 20", }, rules) } @@ -687,11 +684,11 @@ func TestMergeRewriteRules(t *testing.T) { right := "select a from db1.t1 where a = 2" merged, err := mergeRewriteRules(ctx, left, right) require.NoError(t, err) - require.Equal(t, "(select a from db1.t1 where a = 1) union distinct (select a from db1.t1 where a = 2)", merged) + require.Equal(t, "select a from db1.t1 where (a = 1) or (a = 2)", merged) merged, err = mergeRewriteRules(ctx, merged, "select a from db1.t1 where a = 3") require.NoError(t, err) - require.Equal(t, "((select a from db1.t1 where a = 1) union distinct (select a from db1.t1 where a = 2)) union distinct (select a from db1.t1 where a = 3)", merged) + require.Equal(t, "select a from db1.t1 where ((a = 1) or (a = 2)) or (a = 3)", merged) merged, err = mergeRewriteRules(ctx, "select a, age from db1.t1 where age > 28", "select a from db1.t1 where a = 2") require.NoError(t, err) @@ -703,14 +700,14 @@ func TestMergeRewriteRules(t *testing.T) { merged, err = mergeRewriteRules(ctx, "select * from db1.t1 where age > 28", "select * from db1.t1 where age < 3") require.NoError(t, err) require.Equal(t, - "select * from db1.t1 where age < 3", + "select * from db1.t1 where (age > 28) or (age < 3)", merged, ) merged, err = mergeRewriteRules(ctx, "select t.* from db1.t1 as t where age > 28", "select t.* from db1.t1 as t where age < 3") require.NoError(t, err) require.Equal(t, - "select t.* from db1.t1 as t where age < 3", + "select t.* from db1.t1 as t where (age > 28) or (age < 3)", merged, ) @@ -775,11 +772,11 @@ func TestMergeRewriteRules(t *testing.T) { merged, err = mergeRewriteRules(ctx, "select a as x from db1.t1 where age > 28", "select a as x from db1.t1 where age < 3") require.NoError(t, err) - require.Equal(t, "(select a as x from db1.t1 where age > 28) union distinct (select a as x from db1.t1 where age < 3)", merged) + require.Equal(t, "select a as x from db1.t1 where (age > 28) or (age < 3)", merged) merged, err = mergeRewriteRules(ctx, "select a + 1 as b from db1.t1 where age > 28", "select a + 1 as b from db1.t1 where age < 3") require.NoError(t, err) - require.Equal(t, "(select a + 1 as b from db1.t1 where age > 28) union distinct (select a + 1 as b from db1.t1 where age < 3)", merged) + require.Equal(t, "select a + 1 as b from db1.t1 where age < 3", merged) _, err = mergeRewriteRules(ctx, "select a from", "select a from db1.t1") require.Error(t, err) @@ -829,303 +826,131 @@ func TestMergeRewriteRulesFallbackWhenEitherSideIsUnmergeable(t *testing.T) { } } -func TestRewriteRuleOutputColumns(t *testing.T) { - ctx := context.Background() - - columns, ok, err := rewriteRuleOutputColumns(ctx, "select A as x, a + 1 as b from db1.t1") - require.NoError(t, err) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{ - {name: "x", expr: "a"}, - {name: "b", expr: "a + 1"}, - }, columns) - - cases := []struct { - name string - rule string - }{ - {name: "top-level order by", rule: "select a from db1.t1 order by a"}, - {name: "top-level limit", rule: "select a from db1.t1 limit 1"}, - {name: "union order by", rule: "select a from db1.t1 union select a from db1.t2 order by a"}, - {name: "distinct", rule: "select distinct a from db1.t1"}, - {name: "group by", rule: "select a from db1.t1 group by a"}, - {name: "having", rule: "select a from db1.t1 having a > 1"}, - {name: "star", rule: "select * from db1.t1"}, - {name: "qualified star", rule: "select t.* from db1.t1 as t"}, - {name: "aggregate projection", rule: "select sum(a) as s from db1.t1"}, - {name: "window projection", rule: "select row_number() over () as rn from db1.t1"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - columns, ok, err := rewriteRuleOutputColumns(ctx, tc.rule) - require.NoError(t, err) - require.False(t, ok) - require.Nil(t, columns) - }) - } -} - -func TestTrimRewriteRuleForUnion(t *testing.T) { - require.Equal(t, "select a from db1.t1", trimRewriteRuleForUnion(" select a from db1.t1 ;; ; ")) - require.Equal(t, "select a from db1.t1", trimRewriteRuleForUnion("select a from db1.t1")) - require.Equal(t, "", trimRewriteRuleForUnion(" ; ; ")) -} - -func TestOutputColumnsFromRewriteStatementASTBranches(t *testing.T) { - selectClause := &tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - } - - columns, ok := outputColumnsFromRewriteStatement(&tree.ParenSelect{ - Select: &tree.Select{Select: selectClause}, - }) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteStatement(&tree.ParenSelect{}) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteStatement(&tree.Delete{}) - require.False(t, ok) - require.Nil(t, columns) -} - -func TestOutputColumnsFromRewriteSelectStatementASTBranches(t *testing.T) { - selectClause := &tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - } - - columns, ok := outputColumnsFromRewriteSelectStatement(&tree.UnionClause{ - Left: selectClause, - Right: &tree.SelectClause{Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("b")}}}, - }) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.ParenSelect{}) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.Select{ - Select: selectClause, - OrderBy: tree.OrderBy{&tree.Order{Expr: tree.NewUnresolvedColName("a")}}, - }) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.Select{Select: selectClause}) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.ValuesStatement{}) - require.False(t, ok) - require.Nil(t, columns) -} - -func TestOutputColumnFromRewriteSelectExprBranches(t *testing.T) { - column, ok := outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewBinaryExpr(tree.PLUS, tree.NewUnresolvedColName("a"), rewriteRuleNumVal(1)), - }) - require.True(t, ok) - require.Equal(t, rewriteRuleOutputColumn{name: "a + 1", expr: "a + 1"}, column) - - column, ok = outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewUnresolvedColName("A"), - As: tree.NewCStr("AliasA", 1), - }) - require.True(t, ok) - require.Equal(t, rewriteRuleOutputColumn{name: "aliasa", expr: "a"}, column) - - column, ok = outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewUnresolvedNameWithStar(tree.NewCStr("t", 1)), - }) - require.False(t, ok) - require.Equal(t, rewriteRuleOutputColumn{}, column) -} - -func TestMergeableRewriteSelectClauseRejectsDistinctOptions(t *testing.T) { - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Distinct: true})) - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Option: tree.QuerySpecOptionDistinct})) - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Option: tree.QuerySpecOptionDistinctRow})) +func TestTrimRewriteRuleForMerge(t *testing.T) { + require.Equal(t, "select a from db1.t1", trimRewriteRuleForMerge(" select a from db1.t1 ;; ; ")) + require.Equal(t, "select a from db1.t1", trimRewriteRuleForMerge("select a from db1.t1")) + require.Equal(t, "", trimRewriteRuleForMerge(" ; ; ")) } -func TestSameRewriteOutputColumnsComparesNamesAndExpressions(t *testing.T) { +func TestRewriteRuleMergeShapeForRule(t *testing.T) { cases := []struct { - name string - left string - right string - same bool + name string + rule string + ok bool + selectList string + table string }{ { - name: "same aliases and same column expression", - left: "select A as x, age as y from db1.t1", - right: "select a as x, age as y from db1.t1", - same: true, + name: "simple select", + rule: "select A, Age from db1.t1 where age > 28", + ok: true, + selectList: "a, age", + table: "db1.t1", + }, + { + name: "star projection", + rule: "select * from db1.t1 where age > 28", + ok: true, + selectList: "*", + table: "db1.t1", + }, + { + name: "qualified star with alias", + rule: "select t.* from db1.t1 as t where age > 28", + ok: true, + selectList: "t.*", + table: "db1.t1 as t", }, { - name: "same aliases but swapped column expressions", - left: "select a as x, age as y from db1.t1", - right: "select age as x, a as y from db1.t1", - same: false, + name: "aliased column", + rule: "select a as x from db1.t1 where age > 28", + ok: true, + selectList: "a as x", + table: "db1.t1", }, { - name: "same aliases but different scalar expressions", - left: "select a + 1 as x from db1.t1", - right: "select a + 2 as x from db1.t1", - same: false, + name: "where subquery with order by and limit", + rule: "select a from db1.t1 where a in (select a from db1.t2 order by a limit 1)", + ok: true, + selectList: "a", + table: "db1.t1", }, { - name: "same expressions but different aliases", - left: "select a + 1 as x from db1.t1", - right: "select a + 1 as y from db1.t1", - same: false, + name: "parenthesized single table", + rule: "select a from (db1.t1) where age > 28", + ok: true, + selectList: "a", + table: "db1.t1", + }, + { + name: "scalar expression", + rule: "select a + 1 from db1.t1 where age > 28", + ok: false, + }, + { + name: "aggregate", + rule: "select count(*) from db1.t1 where age > 28", + ok: false, + }, + { + name: "window", + rule: "select row_number() over () from db1.t1 where age > 28", + ok: false, + }, + { + name: "top-level order by", + rule: "select a from db1.t1 where age > 28 order by a", + ok: false, + }, + { + name: "top-level limit", + rule: "select a from db1.t1 where age > 28 limit 1", + ok: false, + }, + { + name: "distinct", + rule: "select distinct a from db1.t1 where age > 28", + ok: false, + }, + { + name: "group by", + rule: "select a from db1.t1 where age > 28 group by a", + ok: false, + }, + { + name: "having", + rule: "select a from db1.t1 where age > 28 having a > 1", + ok: false, + }, + { + name: "union", + rule: "select a from db1.t1 where age > 28 union all select a from db1.t1 where age < 3", + ok: false, + }, + { + name: "join", + rule: "select t1.a from db1.t1 join db1.t2 on t1.a = t2.a where t1.age > 28", + ok: false, }, } ctx := context.Background() for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - left, ok, err := rewriteRuleOutputColumns(ctx, tc.left) - require.NoError(t, err) - require.True(t, ok) - right, ok, err := rewriteRuleOutputColumns(ctx, tc.right) + shape, ok, err := rewriteRuleMergeShapeForRule(ctx, tc.rule) require.NoError(t, err) - require.True(t, ok) - require.Equal(t, tc.same, sameRewriteOutputColumns(left, right)) - }) - } -} - -func TestRewriteExprIsMergeSafe(t *testing.T) { - cases := []struct { - name string - sql string - safe bool - }{ - {name: "simple column", sql: "select a from db1.t1", safe: true}, - {name: "binary expression", sql: "select a + age from db1.t1", safe: true}, - {name: "comparison expression", sql: "select a > age from db1.t1", safe: true}, - {name: "case expression", sql: "select case when a > 0 then age else 0 end from db1.t1", safe: true}, - {name: "scalar function", sql: "select abs(a) from db1.t1", safe: true}, - {name: "cast expression", sql: "select cast(a as signed) from db1.t1", safe: true}, - {name: "between expression", sql: "select a between 1 and 3 from db1.t1", safe: true}, - {name: "aggregate function", sql: "select count(*) from db1.t1", safe: false}, - {name: "window function", sql: "select row_number() over () from db1.t1", safe: false}, - {name: "subquery expression", sql: "select exists (select a from db1.t1) from db1.t1", safe: false}, - {name: "system variable expression", sql: "select @@sql_mode from db1.t1", safe: false}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - exprs := parseRewriteRuleSelectExprs(t, tc.sql) - require.Len(t, exprs, 1) - require.Equal(t, tc.safe, rewriteExprIsMergeSafe(exprs[0].Expr)) - }) - } -} - -func TestRewriteExprIsMergeSafeASTBranches(t *testing.T) { - safeColumn := tree.NewUnresolvedColName("a") - safeNumber := rewriteRuleNumVal(1) - unsafeExpr := rewriteRuleUnsafeSubqueryExpr() - - cases := []struct { - name string - expr tree.Expr - safe bool - }{ - {name: "nil expression", expr: nil, safe: true}, - {name: "string literal", expr: tree.NewStrVal("'x'"), safe: true}, - {name: "unqualified star", expr: tree.UnqualifiedStar{}, safe: true}, - {name: "unary expression", expr: tree.NewUnaryExpr(tree.UNARY_MINUS, safeNumber), safe: true}, - {name: "comparison with escape", expr: tree.NewComparisonExprWithEscape(tree.LIKE, safeColumn, tree.NewStrVal("'a%'"), tree.NewStrVal("'\\'")), safe: true}, - {name: "and expression", expr: tree.NewAndExpr(safeColumn, safeNumber), safe: true}, - {name: "xor expression", expr: tree.NewXorExpr(safeColumn, safeNumber), safe: true}, - {name: "or expression", expr: tree.NewOrExpr(safeColumn, safeNumber), safe: true}, - {name: "not expression", expr: tree.NewNotExpr(safeColumn), safe: true}, - {name: "is null expression", expr: tree.NewIsNullExpr(safeColumn), safe: true}, - {name: "is not null expression", expr: tree.NewIsNotNullExpr(safeColumn), safe: true}, - {name: "is unknown expression", expr: tree.NewIsUnknownExpr(safeColumn), safe: true}, - {name: "is not unknown expression", expr: tree.NewIsNotUnknownExpr(safeColumn), safe: true}, - {name: "is true expression", expr: tree.NewIsTrueExpr(safeColumn), safe: true}, - {name: "is not true expression", expr: tree.NewIsNotTrueExpr(safeColumn), safe: true}, - {name: "is false expression", expr: tree.NewIsFalseExpr(safeColumn), safe: true}, - {name: "is not false expression", expr: tree.NewIsNotFalseExpr(safeColumn), safe: true}, - {name: "paren expression", expr: tree.NewParentExpr(safeColumn), safe: true}, - {name: "function with table type", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_TABLE, nil, safeColumn), safe: false}, - {name: "function without a resolvable name", expr: &tree.FuncExpr{}, safe: false}, - {name: "function with unsafe argument", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_DEFAULT, nil, unsafeExpr), safe: false}, - {name: "function with unsafe order by", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_DEFAULT, tree.OrderBy{&tree.Order{Expr: unsafeExpr}}, safeColumn), safe: false}, - {name: "serial extract expression", expr: &tree.SerialExtractExpr{SerialExpr: safeColumn, IndexExpr: safeNumber}, safe: true}, - {name: "serial extract with unsafe index", expr: &tree.SerialExtractExpr{SerialExpr: safeColumn, IndexExpr: unsafeExpr}, safe: false}, - {name: "bit cast expression", expr: &tree.BitCastExpr{Expr: safeColumn}, safe: true}, - {name: "tuple expression", expr: &tree.Tuple{Exprs: tree.Exprs{safeColumn, safeNumber}}, safe: true}, - {name: "tuple with unsafe member", expr: &tree.Tuple{Exprs: tree.Exprs{safeColumn, unsafeExpr}}, safe: false}, - {name: "case with unsafe else", expr: &tree.CaseExpr{Expr: safeColumn, Else: unsafeExpr}, safe: false}, - {name: "case with nil when", expr: &tree.CaseExpr{Whens: []*tree.When{nil, tree.NewWhen(safeColumn, safeNumber)}, Else: safeNumber}, safe: true}, - {name: "case with unsafe when", expr: &tree.CaseExpr{Whens: []*tree.When{tree.NewWhen(unsafeExpr, safeNumber)}, Else: safeNumber}, safe: false}, - {name: "interval expression", expr: &tree.IntervalExpr{Expr: safeNumber}, safe: true}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.safe, rewriteExprIsMergeSafe(tc.expr)) + require.Equal(t, tc.ok, ok) + if !tc.ok { + require.Nil(t, shape) + return + } + require.NotNil(t, shape) + require.Equal(t, tc.selectList, shape.selectList) + require.Equal(t, tc.table, shape.table) }) } } -func TestRewriteExprsAndOrderByAreMergeSafe(t *testing.T) { - safeColumn := tree.NewUnresolvedColName("a") - unsafeExpr := rewriteRuleUnsafeSubqueryExpr() - - require.True(t, rewriteExprsAreMergeSafe(nil)) - require.True(t, rewriteExprsAreMergeSafe(tree.Exprs{safeColumn})) - require.False(t, rewriteExprsAreMergeSafe(tree.Exprs{safeColumn, unsafeExpr})) - - require.True(t, rewriteOrderByIsMergeSafe(tree.OrderBy{nil, &tree.Order{Expr: safeColumn}})) - require.False(t, rewriteOrderByIsMergeSafe(tree.OrderBy{&tree.Order{Expr: unsafeExpr}})) -} - -func TestRewriteFuncExprName(t *testing.T) { - require.Equal(t, "abs", rewriteFuncExprName(&tree.FuncExpr{FuncName: tree.NewCStr("ABS", 1)})) - require.Equal(t, "lower", rewriteFuncExprName(&tree.FuncExpr{ - Func: tree.FuncName2ResolvableFunctionReference(tree.NewUnresolvedColName("LOWER")), - })) - require.Equal(t, "", rewriteFuncExprName(&tree.FuncExpr{})) -} - -func parseRewriteRuleSelectExprs(t *testing.T, sql string) tree.SelectExprs { - t.Helper() - - stmt, err := parsers.ParseOne(context.Background(), dialect.MYSQL, sql, 1) - require.NoError(t, err) - selectStmt, ok := stmt.(*tree.Select) - require.True(t, ok) - selectClause, ok := selectStmt.Select.(*tree.SelectClause) - require.True(t, ok) - return selectClause.Exprs -} - -func rewriteRuleNumVal(v int64) *tree.NumVal { - return tree.NewNumVal[int64](v, strconv.FormatInt(v, 10), false, tree.P_int64) -} - -func rewriteRuleFuncExpr(name string, typ tree.FuncType, orderBy tree.OrderBy, exprs ...tree.Expr) *tree.FuncExpr { - return &tree.FuncExpr{ - FuncName: tree.NewCStr(name, 1), - Type: typ, - Exprs: exprs, - OrderBy: orderBy, - } -} - -func rewriteRuleUnsafeSubqueryExpr() tree.Expr { - return tree.NewSubquery(&tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - }, true) -} - func newMrsForRewriteRules(rows [][]interface{}) *MysqlResultSet { mrs := &MysqlResultSet{} diff --git a/test/distributed/cases/tenant/privilege/owner1.result b/test/distributed/cases/tenant/privilege/owner1.result index 166d5744a97a6..67868289f3996 100644 --- a/test/distributed/cases/tenant/privilege/owner1.result +++ b/test/distributed/cases/tenant/privilege/owner1.result @@ -14,9 +14,7 @@ set secondary role all; create table db1.t3(a int); set role role2; drop table db1.t3; -internal error: do not have privilege to execute the statement drop table db1.t2; -drop table db1.t3; drop database db1; drop account default_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/tenant/privilege/owner1.sql b/test/distributed/cases/tenant/privilege/owner1.sql index de115b62dab68..2bd31cbc3da3b 100644 --- a/test/distributed/cases/tenant/privilege/owner1.sql +++ b/test/distributed/cases/tenant/privilege/owner1.sql @@ -24,8 +24,7 @@ drop table db1.t3; drop table db1.t2; -- @session -- @session:id=5&user=default_1:user1:role1&password=123456 -drop table db1.t3; drop database db1; -- @session drop account default_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.result b/test/distributed/cases/zz_accesscontrol/role_rule.result index 0eea49af706a2..8e4187a30947a 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.result +++ b/test/distributed/cases/zz_accesscontrol/role_rule.result @@ -5,6 +5,7 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -17,6 +18,8 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; drop database if exists db1; drop database if exists db2; create database db1; @@ -64,8 +67,8 @@ create table db2.t2(a int, age int); insert into db2.t2 values (10,10),(20,35),(200,60); create role test_rule_role_multi_a; create role test_rule_role_multi_b; -alter role test_rule_role_multi_a add rule "select a, age from db1.t1 where age > 28" on table db1.t1; -alter role test_rule_role_multi_b add rule "select a, age from db1.t1 where age < 3" on table db1.t1; +alter role test_rule_role_multi_a add rule "select * from db1.t1 where age > 1" on table db1.t1; +alter role test_rule_role_multi_b add rule "select * from db1.t1 where age < 3" on table db1.t1; alter role test_rule_role_multi_b add rule "select * from db2.t2 where age > 30" on table db2.t2; create user test_rule_user_multi identified by '123456' default role test_rule_role_multi_a; grant connect on account * to test_rule_role_multi_a; @@ -76,6 +79,7 @@ grant test_rule_role_multi_b to test_rule_user_multi; set enable_remap_hint = 1; select * from db1.t1 order by a; a age +2 2 100 30 set secondary role all; select * from db1.t1 order by a; @@ -90,6 +94,7 @@ a age set secondary role none; select * from db1.t1 order by a; a age +2 2 100 30 create role test_rule_role_multi_c; create role test_rule_role_multi_d; @@ -153,12 +158,29 @@ set secondary role all; select * from db1.t1 order by a; a age 30 100 +create role test_rule_role_where_a; +create role test_rule_role_where_b; +alter role test_rule_role_where_a add rule "select * from db1.t1 where age > 28" on table db1.t1; +alter role test_rule_role_where_b add rule "select * from db1.t1" on table db1.t1; +create user test_rule_user_where identified by '123456' default role test_rule_role_where_a; +grant connect on account * to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_b; +grant test_rule_role_where_b to test_rule_user_where; +set enable_remap_hint = 1; +set secondary role all; +select * from db1.t1 order by a; +a age +1 1 +2 2 +100 30 drop user if exists test_rule_user; drop user if exists test_rule_user_multi; drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -171,6 +193,8 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.sql b/test/distributed/cases/zz_accesscontrol/role_rule.sql index 05ef4feef105b..7569e9e53e0d1 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.sql +++ b/test/distributed/cases/zz_accesscontrol/role_rule.sql @@ -6,6 +6,7 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -18,6 +19,8 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; drop database if exists db1; drop database if exists db2; create database db1; @@ -67,14 +70,14 @@ set enable_remap_hint = 1; select * from db1.t1; -- @session --- 11. SET SECONDARY ROLE ALL applies rewrite rules from all active roles +-- 11. SET SECONDARY ROLE ALL merges select * rewrite rules from all active roles create database db2; create table db2.t2(a int, age int); insert into db2.t2 values (10,10),(20,35),(200,60); create role test_rule_role_multi_a; create role test_rule_role_multi_b; -alter role test_rule_role_multi_a add rule "select a, age from db1.t1 where age > 28" on table db1.t1; -alter role test_rule_role_multi_b add rule "select a, age from db1.t1 where age < 3" on table db1.t1; +alter role test_rule_role_multi_a add rule "select * from db1.t1 where age > 1" on table db1.t1; +alter role test_rule_role_multi_b add rule "select * from db1.t1 where age < 3" on table db1.t1; alter role test_rule_role_multi_b add rule "select * from db2.t2 where age > 30" on table db2.t2; create user test_rule_user_multi identified by '123456' default role test_rule_role_multi_a; grant connect on account * to test_rule_role_multi_a; @@ -159,6 +162,22 @@ set secondary role all; select * from db1.t1 order by a; -- @session +-- 17. Merging with a no-WHERE rule produces the broader rule +create role test_rule_role_where_a; +create role test_rule_role_where_b; +alter role test_rule_role_where_a add rule "select * from db1.t1 where age > 28" on table db1.t1; +alter role test_rule_role_where_b add rule "select * from db1.t1" on table db1.t1; +create user test_rule_user_where identified by '123456' default role test_rule_role_where_a; +grant connect on account * to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_b; +grant test_rule_role_where_b to test_rule_user_where; +-- @session:id=7&user=sys:test_rule_user_where:test_rule_role_where_a&password=123456 +set enable_remap_hint = 1; +set secondary role all; +select * from db1.t1 order by a; +-- @session + -- cleanup all test resources drop user if exists test_rule_user; drop user if exists test_rule_user_multi; @@ -166,6 +185,7 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -178,6 +198,8 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/use_role.result b/test/distributed/cases/zz_accesscontrol/use_role.result index 4f3c98e735742..5c66b601905dc 100644 --- a/test/distributed/cases/zz_accesscontrol/use_role.result +++ b/test/distributed/cases/zz_accesscontrol/use_role.result @@ -48,12 +48,12 @@ internal error: do not have privilege to execute the statement drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement set secondary role all; -create table use_db_1.use_table_5(a int,b varchar(20),c double ); -insert into use_db_1.use_table_5 values(10,'yellow',99.99); -select * from use_db_1.use_table_5; +insert into use_db_1.use_table_2 values(20,'green',88.88); +select * from use_db_1.use_table_2 order by a; a b c 10 yellow 99.99 -drop table use_db_1.use_table_5; +20 green 88.88 +34 kelly 90.3 set role moadmin; internal error: the role moadmin has not be granted to the user use_user_1 create user use_user_2 identified by '123456'; @@ -72,15 +72,21 @@ internal error: there is no role use_role_test drop role use_role_test; internal error: do not have privilege to execute the statement set secondary role all; -create table use_db_1.use_table_6(a int,b varchar(20),c double); -insert into use_db_1.use_table_6 values(10,'yellow',99.99); +insert into use_db_1.use_table_2 values(30,'blue',77.77); +select * from use_db_1.use_table_2 order by a; +a b c +10 yellow 99.99 +20 green 88.88 +30 blue 77.77 +34 kelly 90.3 create database use_db_test; internal error: do not have privilege to execute the statement drop database use_db_test; internal error: do not have privilege to execute the statement set secondary role none; -insert into use_db_1.use_table_6 values (10, 'yellow', 99.99); -drop table use_db_1.use_table_6; +insert into use_db_1.use_table_2 values (40, 'red', 66.66); +internal error: do not have privilege to execute the statement +drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement create role if not exists use_role_test; internal error: do not have privilege to execute the statement @@ -88,11 +94,11 @@ drop role use_role_test; internal error: do not have privilege to execute the statement set role use_role_5; internal error: the role use_role_5 has not be granted to the user use_user_1 -drop table use_db_1.use_table_6; +drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement create database if not exists use_db_test; internal error: do not have privilege to execute the statement drop role if exists use_role_1,use_role_2,use_role_3,use_role_4,use_role_5; drop user if exists use_user_1,use_user_2; drop database if exists use_db_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/use_role.sql b/test/distributed/cases/zz_accesscontrol/use_role.sql index 88f5812350567..cbf3fa89072f6 100644 --- a/test/distributed/cases/zz_accesscontrol/use_role.sql +++ b/test/distributed/cases/zz_accesscontrol/use_role.sql @@ -37,10 +37,8 @@ insert into use_db_1.use_table_2 values(10,'yellow',99.99); select * from use_db_1.use_table_2; drop table use_db_1.use_table_2; set secondary role all; -create table use_db_1.use_table_5(a int,b varchar(20),c double ); -insert into use_db_1.use_table_5 values(10,'yellow',99.99); -select * from use_db_1.use_table_5; -drop table use_db_1.use_table_5; +insert into use_db_1.use_table_2 values(20,'green',88.88); +select * from use_db_1.use_table_2 order by a; set role moadmin; -- @session @@ -57,21 +55,21 @@ create role use_role_test; set role use_role_test; drop role use_role_test; set secondary role all; -create table use_db_1.use_table_6(a int,b varchar(20),c double); -insert into use_db_1.use_table_6 values(10,'yellow',99.99); +insert into use_db_1.use_table_2 values(30,'blue',77.77); +select * from use_db_1.use_table_2 order by a; create database use_db_test; drop database use_db_test; set secondary role none; -insert into use_db_1.use_table_6 values (10, 'yellow', 99.99); -drop table use_db_1.use_table_6; +insert into use_db_1.use_table_2 values (40, 'red', 66.66); +drop table use_db_1.use_table_2; create role if not exists use_role_test; drop role use_role_test; set role use_role_5; -drop table use_db_1.use_table_6; +drop table use_db_1.use_table_2; create database if not exists use_db_test; -- @session drop role if exists use_role_1,use_role_2,use_role_3,use_role_4,use_role_5; drop user if exists use_user_1,use_user_2; drop database if exists use_db_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; From 8b4440897b80d47acad2a0840ca3a8568d2693b6 Mon Sep 17 00:00:00 2001 From: ouyuanning <52741323@qq.com> Date: Fri, 12 Jun 2026 09:44:10 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=A4=87=E6=B3=A8\ut\bvt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/frontend/rewrite_rule.go | 151 ++++++++++++++++-- pkg/frontend/rewrite_rule_test.go | 21 ++- .../cases/zz_accesscontrol/role_rule.result | 48 ++++++ .../cases/zz_accesscontrol/role_rule.sql | 48 ++++++ 4 files changed, 248 insertions(+), 20 deletions(-) diff --git a/pkg/frontend/rewrite_rule.go b/pkg/frontend/rewrite_rule.go index a311d4a98414e..908ef7d326be4 100644 --- a/pkg/frontend/rewrite_rule.go +++ b/pkg/frontend/rewrite_rule.go @@ -30,6 +30,7 @@ import ( "github.com/matrixorigin/matrixone/pkg/sql/parsers" "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" + "github.com/matrixorigin/matrixone/pkg/sql/plan/function" ) const ( @@ -338,8 +339,12 @@ type rewriteRuleMergeShape struct { table string } -// Rewrite rules are merged by OR-ing filters only for simple row-preserving -// single-table SELECTs with exactly matching select lists and source tables. +// Rewrite rules from active roles are a source-row visibility union: if any +// role allows a base row, that row should remain visible. For simple +// single-table SELECTs with matching select lists and source tables, OR-ing the +// filters expresses that row union directly. This is intentionally not a +// UNION DISTINCT equivalent for partial projections; two visible base rows that +// project to the same values should both remain visible. func mergeRewriteRulesSafely(ctx context.Context, leftRule, rightRule string) (string, bool, error) { leftShape, ok, err := rewriteRuleMergeShapeForRule(ctx, leftRule) if err != nil { @@ -427,27 +432,27 @@ func rewriteRuleSelectClauseIsMergeable(clause *tree.SelectClause) bool { func rewriteRuleSelectExprsAreMergeable(exprs tree.SelectExprs) bool { for _, expr := range exprs { - switch e := expr.Expr.(type) { - case tree.UnqualifiedStar: - if expr.As != nil && !expr.As.Empty() { - return false - } - continue - case *tree.UnresolvedName: - if e.Star && expr.As != nil && !expr.As.Empty() { - return false - } - if e.Star || e.NumParts > 0 { - continue - } + if rewriteRuleSelectExprIsStar(expr.Expr) && expr.As != nil && !expr.As.Empty() { return false - default: + } + if !rewriteExprIsMergeSafe(expr.Expr) { return false } } return true } +func rewriteRuleSelectExprIsStar(expr tree.Expr) bool { + switch e := expr.(type) { + case tree.UnqualifiedStar: + return true + case *tree.UnresolvedName: + return e.Star + default: + return false + } +} + func rewriteRuleSingleTableSource(from *tree.From) (string, bool) { if from == nil || len(from.Tables) != 1 { return "", false @@ -506,6 +511,120 @@ func normalizeRewriteSQL(sql string) string { return strings.ToLower(strings.TrimSpace(sql)) } +func rewriteExprIsMergeSafe(expr tree.Expr) bool { + if expr == nil { + return true + } + + switch e := expr.(type) { + case *tree.UnresolvedName, tree.UnqualifiedStar, *tree.NumVal, *tree.StrVal: + return true + case *tree.BinaryExpr: + return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) + case *tree.UnaryExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.ComparisonExpr: + return rewriteExprIsMergeSafe(e.Left) && + rewriteExprIsMergeSafe(e.Right) && + rewriteExprIsMergeSafe(e.Escape) + case *tree.AndExpr: + return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) + case *tree.XorExpr: + return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) + case *tree.OrExpr: + return rewriteExprIsMergeSafe(e.Left) && rewriteExprIsMergeSafe(e.Right) + case *tree.NotExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsNullExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsNotNullExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsUnknownExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsNotUnknownExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsTrueExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsNotTrueExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsFalseExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.IsNotFalseExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.ParenExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.FuncExpr: + name := rewriteFuncExprName(e) + if name == "" || e.Type == tree.FUNC_TYPE_TABLE || + e.WindowSpec != nil || + function.GetFunctionIsAggregateByName(name) || + function.GetFunctionIsWinFunByName(name) { + return false + } + return rewriteExprsAreMergeSafe(e.Exprs) && rewriteOrderByIsMergeSafe(e.OrderBy) + case *tree.SerialExtractExpr: + return rewriteExprIsMergeSafe(e.SerialExpr) && rewriteExprIsMergeSafe(e.IndexExpr) + case *tree.CastExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.BitCastExpr: + return rewriteExprIsMergeSafe(e.Expr) + case *tree.Tuple: + return rewriteExprsAreMergeSafe(e.Exprs) + case *tree.RangeCond: + return rewriteExprIsMergeSafe(e.Left) && + rewriteExprIsMergeSafe(e.From) && + rewriteExprIsMergeSafe(e.To) + case *tree.CaseExpr: + if !rewriteExprIsMergeSafe(e.Expr) || !rewriteExprIsMergeSafe(e.Else) { + return false + } + for _, when := range e.Whens { + if when == nil { + continue + } + if !rewriteExprIsMergeSafe(when.Cond) || !rewriteExprIsMergeSafe(when.Val) { + return false + } + } + return true + case *tree.IntervalExpr: + return rewriteExprIsMergeSafe(e.Expr) + default: + return false + } +} + +func rewriteExprsAreMergeSafe(exprs tree.Exprs) bool { + for _, expr := range exprs { + if !rewriteExprIsMergeSafe(expr) { + return false + } + } + return true +} + +func rewriteOrderByIsMergeSafe(orderBy tree.OrderBy) bool { + for _, order := range orderBy { + if order == nil { + continue + } + if !rewriteExprIsMergeSafe(order.Expr) { + return false + } + } + return true +} + +func rewriteFuncExprName(fn *tree.FuncExpr) string { + if fn.FuncName != nil { + return strings.ToLower(fn.FuncName.Origin()) + } + if name, ok := fn.Func.FunctionReference.(*tree.UnresolvedName); ok { + return strings.ToLower(name.ColName()) + } + return "" +} + func validateRewriteRuleSQL(ctx context.Context, rule string) error { if strings.TrimSpace(rule) == "" { return moerr.NewInvalidInput(ctx, "rewrite rule SQL is empty") diff --git a/pkg/frontend/rewrite_rule_test.go b/pkg/frontend/rewrite_rule_test.go index 3607ea00097b0..3959550634495 100644 --- a/pkg/frontend/rewrite_rule_test.go +++ b/pkg/frontend/rewrite_rule_test.go @@ -686,6 +686,17 @@ func TestMergeRewriteRules(t *testing.T) { require.NoError(t, err) require.Equal(t, "select a from db1.t1 where (a = 1) or (a = 2)", merged) + // Role rule merging is a base-row visibility union, not UNION DISTINCT over + // projected values. Partial projections are intentionally OR-merged so two + // visible rows with the same projected value are not collapsed here. + rowUnionMerged, err := mergeRewriteRules( + ctx, + "select a from db1.t1 where role_marker = 1", + "select a from db1.t1 where role_marker = 2", + ) + require.NoError(t, err) + require.Equal(t, "select a from db1.t1 where (role_marker = 1) or (role_marker = 2)", rowUnionMerged) + merged, err = mergeRewriteRules(ctx, merged, "select a from db1.t1 where a = 3") require.NoError(t, err) require.Equal(t, "select a from db1.t1 where ((a = 1) or (a = 2)) or (a = 3)", merged) @@ -776,7 +787,7 @@ func TestMergeRewriteRules(t *testing.T) { merged, err = mergeRewriteRules(ctx, "select a + 1 as b from db1.t1 where age > 28", "select a + 1 as b from db1.t1 where age < 3") require.NoError(t, err) - require.Equal(t, "select a + 1 as b from db1.t1 where age < 3", merged) + require.Equal(t, "select a + 1 as b from db1.t1 where (age > 28) or (age < 3)", merged) _, err = mergeRewriteRules(ctx, "select a from", "select a from db1.t1") require.Error(t, err) @@ -883,9 +894,11 @@ func TestRewriteRuleMergeShapeForRule(t *testing.T) { table: "db1.t1", }, { - name: "scalar expression", - rule: "select a + 1 from db1.t1 where age > 28", - ok: false, + name: "scalar expression", + rule: "select a + 1 from db1.t1 where age > 28", + ok: true, + selectList: "a + 1", + table: "db1.t1", }, { name: "aggregate", diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.result b/test/distributed/cases/zz_accesscontrol/role_rule.result index 8e4187a30947a..751178ae44650 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.result +++ b/test/distributed/cases/zz_accesscontrol/role_rule.result @@ -6,6 +6,8 @@ drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -20,11 +22,17 @@ drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; drop role if exists test_rule_role_where_a; drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; create database db1; create table db1.t1(a int, age int); insert into db1.t1 values (1,1),(2,2),(100,30); +create table db1.t_dup(a int, marker int); +insert into db1.t_dup values (1,1),(1,2),(2,3); create role test_rule_role; alter role test_rule_role add rule "select * from db1.t1 where age > 28" on table db1.t1; show rules on role test_rule_role; @@ -174,6 +182,40 @@ a age 1 1 2 2 100 30 +create role test_rule_role_dup_projection_a; +create role test_rule_role_dup_projection_b; +alter role test_rule_role_dup_projection_a add rule "select a from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_dup_projection_b add rule "select a from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_dup_projection identified by '123456' default role test_rule_role_dup_projection_a; +grant connect on account * to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_b; +grant test_rule_role_dup_projection_b to test_rule_user_dup_projection; +set enable_remap_hint = 1; +select count(*) from db1.t_dup; +count(*) +1 +set secondary role all; +select a, count(*) from db1.t_dup group by a order by a; +a count(*) +1 2 +select count(*) from db1.t_dup; +count(*) +2 +create role test_rule_role_expr_projection_a; +create role test_rule_role_expr_projection_b; +alter role test_rule_role_expr_projection_a add rule "select a + 1 as b from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_expr_projection_b add rule "select a + 1 as b from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_expr_projection identified by '123456' default role test_rule_role_expr_projection_a; +grant connect on account * to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_b; +grant test_rule_role_expr_projection_b to test_rule_user_expr_projection; +set enable_remap_hint = 1; +set secondary role all; +select b, count(*) from db1.t_dup group by b order by b; +b count(*) +2 2 drop user if exists test_rule_user; drop user if exists test_rule_user_multi; drop user if exists test_rule_user_multi_diff; @@ -181,6 +223,8 @@ drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -195,6 +239,10 @@ drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; drop role if exists test_rule_role_where_a; drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.sql b/test/distributed/cases/zz_accesscontrol/role_rule.sql index 7569e9e53e0d1..12a997e951def 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.sql +++ b/test/distributed/cases/zz_accesscontrol/role_rule.sql @@ -7,6 +7,8 @@ drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -21,11 +23,17 @@ drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; drop role if exists test_rule_role_where_a; drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; create database db1; create table db1.t1(a int, age int); insert into db1.t1 values (1,1),(2,2),(100,30); +create table db1.t_dup(a int, marker int); +insert into db1.t_dup values (1,1),(1,2),(2,3); -- 1. ADD RULE normal case + SHOW RULES verification create role test_rule_role; @@ -178,6 +186,40 @@ set secondary role all; select * from db1.t1 order by a; -- @session +-- 18. Partial projections keep duplicate projected values from distinct visible rows +create role test_rule_role_dup_projection_a; +create role test_rule_role_dup_projection_b; +alter role test_rule_role_dup_projection_a add rule "select a from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_dup_projection_b add rule "select a from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_dup_projection identified by '123456' default role test_rule_role_dup_projection_a; +grant connect on account * to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_b; +grant test_rule_role_dup_projection_b to test_rule_user_dup_projection; +-- @session:id=8&user=sys:test_rule_user_dup_projection:test_rule_role_dup_projection_a&password=123456 +set enable_remap_hint = 1; +select count(*) from db1.t_dup; +set secondary role all; +select a, count(*) from db1.t_dup group by a order by a; +select count(*) from db1.t_dup; +-- @session + +-- 19. Row-wise expression projections also merge without dropping role visibility +create role test_rule_role_expr_projection_a; +create role test_rule_role_expr_projection_b; +alter role test_rule_role_expr_projection_a add rule "select a + 1 as b from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_expr_projection_b add rule "select a + 1 as b from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_expr_projection identified by '123456' default role test_rule_role_expr_projection_a; +grant connect on account * to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_b; +grant test_rule_role_expr_projection_b to test_rule_user_expr_projection; +-- @session:id=9&user=sys:test_rule_user_expr_projection:test_rule_role_expr_projection_a&password=123456 +set enable_remap_hint = 1; +set secondary role all; +select b, count(*) from db1.t_dup group by b order by b; +-- @session + -- cleanup all test resources drop user if exists test_rule_user; drop user if exists test_rule_user_multi; @@ -186,6 +228,8 @@ drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -200,6 +244,10 @@ drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; drop role if exists test_rule_role_where_a; drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on;