diff --git a/pkg/frontend/authenticate.go b/pkg/frontend/authenticate.go index 5444cd80b4daf..cc2e3fed2a533 100644 --- a/pkg/frontend/authenticate.go +++ b/pkg/frontend/authenticate.go @@ -3498,22 +3498,16 @@ func doAlterAccount(ctx context.Context, ses *Session, aa *alterAccount) (err er return err } -// doSetSecondaryRoleAll set the session role of the user with smallness role_id +// doSetSecondaryRoleAll validates user role metadata before enabling all secondary roles. +// The current primary role must not change; SET SECONDARY ROLE ALL only affects secondary roles. func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) { var sql string var userId uint32 - var erArray []ExecResult - var roleId int64 - var roleName string account := ses.GetTenantInfo() // get current user_id userId = account.GetUserID() - // init role_id and role_name - roleId = publicRoleID - roleName = publicRoleName - // step1:get all roles expect public bh := ses.GetBackgroundExec(ctx) defer bh.Close() @@ -3533,26 +3527,7 @@ func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) { return err } - erArray, err = getResultSet(ctx, bh) - if err != nil { - return err - } - if execResultArrayHasData(erArray) { - roleId, err = erArray[0].GetInt64(ctx, 0, 0) - if err != nil { - return err - } - - roleName, err = erArray[0].GetString(ctx, 0, 1) - if err != nil { - return err - } - } - - // step2 : switch the default role and role id; - account.SetDefaultRoleID(uint32(roleId)) - account.SetDefaultRole(roleName) - + _, err = getResultSet(ctx, bh) return err } @@ -3649,6 +3624,7 @@ func doSwitchRole(ctx context.Context, ses *Session, sr *tree.SetRole) (err erro account.SetDefaultRole(sr.Role.UserName) // then, reset secondary role to none account.SetUseSecondaryRole(false) + ses.InvalidatePrivilegeCache() return err } diff --git a/pkg/frontend/authenticate_test.go b/pkg/frontend/authenticate_test.go index 0ff6484679a35..4bf2c738d32c2 100644 --- a/pkg/frontend/authenticate_test.go +++ b/pkg/frontend/authenticate_test.go @@ -8185,6 +8185,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) { err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses) convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") }) convey.Convey("do set secondary role succ", t, func() { @@ -8224,6 +8226,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) { err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses) convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") }) } @@ -8263,6 +8267,8 @@ func TestDoSwitchRoleSecondaryRoleAllInvalidatesRuleCache(t *testing.T) { }) convey.So(err, convey.ShouldBeNil) convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeTrue) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1") ses.ruleCacheMu.RLock() cacheIsNil := ses.ruleCache == nil @@ -8337,6 +8343,56 @@ func TestDoSwitchRoleSecondaryRoleNoneInvalidatesRuleCache(t *testing.T) { }) } +func TestDoSwitchRolePrimaryRoleInvalidatesRuleCache(t *testing.T) { + convey.Convey("set role switches current role and invalidates rule cache", t, func() { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + bh := &backgroundExecTest{} + bh.init() + + bhStub := gostub.StubFunc(&NewBackgroundExec, bh) + defer bhStub.Reset() + + ses := newSes(&privilege{}, ctrl) + tenant := &TenantInfo{ + Tenant: "test_account", + User: "test_user", + DefaultRole: "role1", + TenantID: 3001, + UserID: 3, + DefaultRoleID: 5, + } + tenant.SetUseSecondaryRole(true) + ses.SetTenantInfo(tenant) + ses.ruleCache = map[string]string{"db1.t1": "select a from db1.t1"} + + bh.sql2result["begin;"] = nil + bh.sql2result["commit;"] = nil + bh.sql2result["rollback;"] = nil + + roleSQL, err := getSqlForRoleIdOfRole(context.Background(), "role2") + convey.So(err, convey.ShouldBeNil) + bh.sql2result[roleSQL] = newMrsForRoleIdOfRole([][]interface{}{{int64(6)}}) + bh.sql2result[getSqlForCheckUserGrant(6, int64(tenant.UserID))] = newMrsForCheckUserGrant([][]interface{}{ + {int64(6), int64(tenant.UserID), false}, + }) + + err = doSwitchRole(ses.GetTxnHandler().GetTxnCtx(), ses, &tree.SetRole{ + Role: &tree.Role{UserName: "role2"}, + }) + convey.So(err, convey.ShouldBeNil) + convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(6)) + convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role2") + convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeFalse) + + ses.ruleCacheMu.RLock() + cacheIsNil := ses.ruleCache == nil + ses.ruleCacheMu.RUnlock() + convey.So(cacheIsNil, convey.ShouldBeTrue) + }) +} + func TestGetSessionSysVar(t *testing.T) { convey.Convey("get session system variable succ", t, func() { ctrl := gomock.NewController(t) diff --git a/pkg/frontend/rewrite_rule.go b/pkg/frontend/rewrite_rule.go index b396c33867b2f..908ef7d326be4 100644 --- a/pkg/frontend/rewrite_rule.go +++ b/pkg/frontend/rewrite_rule.go @@ -307,30 +307,24 @@ func getSqlForRoleRulesOfRoleIDs(roleIDs []int64) string { } func mergeRewriteRules(ctx context.Context, leftRule, rightRule string) (string, error) { - leftRule = trimRewriteRuleForUnion(leftRule) - rightRule = trimRewriteRuleForUnion(rightRule) + leftRule = trimRewriteRuleForMerge(leftRule) + rightRule = trimRewriteRuleForMerge(rightRule) - leftColumns, ok, err := rewriteRuleOutputColumns(ctx, leftRule) - if err != nil { - return "", err + if leftRule == rightRule { + return leftRule, nil } - if !ok { - return rightRule, nil - } - rightColumns, ok, err := rewriteRuleOutputColumns(ctx, rightRule) + + mergedRule, ok, err := mergeRewriteRulesSafely(ctx, leftRule, rightRule) if err != nil { return "", err } if !ok { return rightRule, nil } - if !sameRewriteOutputColumns(leftColumns, rightColumns) { - return rightRule, nil - } - return fmt.Sprintf("(%s) union distinct (%s)", leftRule, rightRule), nil + return mergedRule, nil } -func trimRewriteRuleForUnion(rule string) string { +func trimRewriteRuleForMerge(rule string) string { rule = strings.TrimSpace(rule) for strings.HasSuffix(rule, ";") { rule = strings.TrimSpace(strings.TrimSuffix(rule, ";")) @@ -338,95 +332,109 @@ func trimRewriteRuleForUnion(rule string) string { return rule } -type rewriteRuleOutputColumn struct { - name string - expr string +type rewriteRuleMergeShape struct { + stmt *tree.Select + clause *tree.SelectClause + selectList string + table string } -func rewriteRuleOutputColumns(ctx context.Context, rule string) ([]rewriteRuleOutputColumn, bool, error) { - stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) +// Rewrite rules from active roles are a source-row visibility union: if any +// role allows a base row, that row should remain visible. For simple +// single-table SELECTs with matching select lists and source tables, OR-ing the +// filters expresses that row union directly. This is intentionally not a +// UNION DISTINCT equivalent for partial projections; two visible base rows that +// project to the same values should both remain visible. +func mergeRewriteRulesSafely(ctx context.Context, leftRule, rightRule string) (string, bool, error) { + leftShape, ok, err := rewriteRuleMergeShapeForRule(ctx, leftRule) if err != nil { - return nil, false, moerr.NewInternalErrorf(ctx, "failed to parse rewrite rule %q while merging rewrite rules: %v", rule, err) + return "", false, err } - columns, ok := outputColumnsFromRewriteStatement(stmt) - return columns, ok, nil -} - -func validateRewriteRuleSQL(ctx context.Context, rule string) error { - if strings.TrimSpace(rule) == "" { - return moerr.NewInvalidInput(ctx, "rewrite rule SQL is empty") + if !ok { + return "", false, nil } - stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + rightShape, ok, err := rewriteRuleMergeShapeForRule(ctx, rightRule) if err != nil { - return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: %v", rule, err) + return "", false, err + } + if !ok { + return "", false, nil } - switch stmt.(type) { - case *tree.Select, *tree.ParenSelect: - return nil - default: - return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: only accept SELECT-like statements as rewrites", rule) + if leftShape.selectList != rightShape.selectList || leftShape.table != rightShape.table { + return "", false, nil } + + merged := *leftShape.stmt + mergedClause := *leftShape.clause + mergedClause.Where = mergeRewriteRuleWhere(leftShape.clause.Where, rightShape.clause.Where) + merged.Select = &mergedClause + + return tree.String(&merged, dialect.MYSQL), true, nil } -func outputColumnsFromRewriteStatement(stmt tree.Statement) ([]rewriteRuleOutputColumn, bool) { +func rewriteRuleMergeShapeForRule(ctx context.Context, rule string) (*rewriteRuleMergeShape, bool, error) { + stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + if err != nil { + return nil, false, moerr.NewInternalErrorf(ctx, "failed to parse rewrite rule %q while merging rewrite rules: %v", rule, err) + } + shape, ok := rewriteRuleMergeShapeForStatement(stmt) + return shape, ok, nil +} + +func rewriteRuleMergeShapeForStatement(stmt tree.Statement) (*rewriteRuleMergeShape, bool) { switch s := stmt.(type) { case *tree.Select: - if len(s.OrderBy) > 0 || s.Limit != nil { - return nil, false - } - return outputColumnsFromRewriteSelectStatement(s.Select) + return rewriteRuleMergeShapeForSelect(s) case *tree.ParenSelect: if s.Select == nil { return nil, false } - return outputColumnsFromRewriteStatement(s.Select) + return rewriteRuleMergeShapeForStatement(s.Select) default: return nil, false } } -func outputColumnsFromRewriteSelectStatement(stmt tree.SelectStatement) ([]rewriteRuleOutputColumn, bool) { - switch s := stmt.(type) { - case *tree.SelectClause: - if !mergeableRewriteSelectClause(s) { - return nil, false - } - columns := make([]rewriteRuleOutputColumn, 0, len(s.Exprs)) - for _, expr := range s.Exprs { - column, ok := outputColumnFromRewriteSelectExpr(expr) - if !ok { - return nil, false - } - columns = append(columns, column) - } - return columns, true - case *tree.UnionClause: - return outputColumnsFromRewriteSelectStatement(s.Left) - case *tree.ParenSelect: - if s.Select == nil { - return nil, false - } - return outputColumnsFromRewriteStatement(s.Select) - case *tree.Select: - if len(s.OrderBy) > 0 || s.Limit != nil { - return nil, false - } - return outputColumnsFromRewriteSelectStatement(s.Select) - default: +func rewriteRuleMergeShapeForSelect(stmt *tree.Select) (*rewriteRuleMergeShape, bool) { + if stmt == nil || len(stmt.OrderBy) > 0 || stmt.Limit != nil || stmt.With != nil || + stmt.TimeWindow != nil || stmt.RankOption != nil || stmt.Ep != nil || stmt.SelectLockInfo != nil { return nil, false } -} -func mergeableRewriteSelectClause(stmt *tree.SelectClause) bool { - if stmt.Distinct || stmt.Option&(tree.QuerySpecOptionDistinct|tree.QuerySpecOptionDistinctRow) != 0 { - return false + clause, ok := stmt.Select.(*tree.SelectClause) + if !ok || !rewriteRuleSelectClauseIsMergeable(clause) { + return nil, false } - if stmt.GroupBy != nil || stmt.Having != nil { - return false + + table, ok := rewriteRuleSingleTableSource(clause.From) + if !ok { + return nil, false } - for _, expr := range stmt.Exprs { + + return &rewriteRuleMergeShape{ + stmt: stmt, + clause: clause, + selectList: normalizeRewriteSQL(tree.String(&clause.Exprs, dialect.MYSQL)), + table: table, + }, true +} + +func rewriteRuleSelectClauseIsMergeable(clause *tree.SelectClause) bool { + return clause != nil && + !clause.Distinct && + clause.Option == 0 && + clause.GroupBy == nil && + clause.Having == nil && + rewriteRuleSelectExprsAreMergeable(clause.Exprs) +} + +func rewriteRuleSelectExprsAreMergeable(exprs tree.SelectExprs) bool { + for _, expr := range exprs { + if rewriteRuleSelectExprIsStar(expr.Expr) && expr.As != nil && !expr.As.Empty() { + return false + } if !rewriteExprIsMergeSafe(expr.Expr) { return false } @@ -434,55 +442,73 @@ func mergeableRewriteSelectClause(stmt *tree.SelectClause) bool { return true } -func outputColumnFromRewriteSelectExpr(expr tree.SelectExpr) (rewriteRuleOutputColumn, bool) { - if expr.As != nil && !expr.As.Empty() { - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(expr.As.Compare()), - expr: normalizeRewriteOutputExpr(expr.Expr), - }, true - } - - switch e := expr.Expr.(type) { +func rewriteRuleSelectExprIsStar(expr tree.Expr) bool { + switch e := expr.(type) { case tree.UnqualifiedStar: - return rewriteRuleOutputColumn{}, false + return true case *tree.UnresolvedName: - if e.Star { - return rewriteRuleOutputColumn{}, false - } - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(e.ColName()), - expr: normalizeRewriteOutputColumn(tree.String(expr.Expr, dialect.MYSQL)), - }, true + return e.Star default: - exprText := normalizeRewriteOutputExpr(expr.Expr) - return rewriteRuleOutputColumn{ - name: normalizeRewriteOutputColumn(tree.String(expr.Expr, dialect.MYSQL)), - expr: exprText, - }, true + return false } } -func normalizeRewriteOutputColumn(column string) string { - return strings.ToLower(strings.TrimSpace(column)) -} +func rewriteRuleSingleTableSource(from *tree.From) (string, bool) { + if from == nil || len(from.Tables) != 1 { + return "", false + } -func normalizeRewriteOutputExpr(expr tree.Expr) string { - if _, ok := expr.(*tree.UnresolvedName); ok { - return normalizeRewriteOutputColumn(tree.String(expr, dialect.MYSQL)) + tableExpr, ok := rewriteRuleSingleTableExpr(from.Tables[0]) + if !ok { + return "", false } - return strings.TrimSpace(tree.String(expr, dialect.MYSQL)) + + return normalizeRewriteSQL(tree.String(tableExpr, dialect.MYSQL)), true } -func sameRewriteOutputColumns(leftColumns, rightColumns []rewriteRuleOutputColumn) bool { - if len(leftColumns) != len(rightColumns) { - return false +func rewriteRuleSingleTableExpr(expr tree.TableExpr) (*tree.AliasedTableExpr, bool) { + if joinExpr, ok := expr.(*tree.JoinTableExpr); ok { + // The MySQL parser wraps a lone table reference as a CROSS JoinTableExpr. + if joinExpr.Right != nil || joinExpr.Cond != nil || joinExpr.Option != "" { + return nil, false + } + expr = joinExpr.Left } - for i := range leftColumns { - if leftColumns[i] != rightColumns[i] { - return false + + tableExpr, ok := expr.(*tree.AliasedTableExpr) + if !ok || tableExpr.As.Cols != nil || len(tableExpr.IndexHints) > 0 { + return nil, false + } + + tableName, ok := tableExpr.Expr.(*tree.TableName) + if !ok || tableName.AtTsExpr != nil { + return nil, false + } + + return tableExpr, true +} + +func mergeRewriteRuleWhere(left, right *tree.Where) *tree.Where { + // Mergeability is decided from the top-level SELECT shape only. Predicate + // internals, including subqueries, stay opaque and are only OR-ed together. + switch { + case left == nil || left.Expr == nil: + return nil + case right == nil || right.Expr == nil: + return nil + default: + return &tree.Where{ + Type: tree.AstWhere, + Expr: tree.NewOrExpr( + tree.NewParentExpr(left.Expr), + tree.NewParentExpr(right.Expr), + ), } } - return true +} + +func normalizeRewriteSQL(sql string) string { + return strings.ToLower(strings.TrimSpace(sql)) } func rewriteExprIsMergeSafe(expr tree.Expr) bool { @@ -599,6 +625,24 @@ func rewriteFuncExprName(fn *tree.FuncExpr) string { return "" } +func validateRewriteRuleSQL(ctx context.Context, rule string) error { + if strings.TrimSpace(rule) == "" { + return moerr.NewInvalidInput(ctx, "rewrite rule SQL is empty") + } + + stmt, err := parsers.ParseOne(ctx, dialect.MYSQL, rule, 1) + if err != nil { + return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: %v", rule, err) + } + + switch stmt.(type) { + case *tree.Select, *tree.ParenSelect: + return nil + default: + return moerr.NewInvalidInputf(ctx, "invalid rewrite rule SQL %q: only accept SELECT-like statements as rewrites", rule) + } +} + // escapeSQLString escapes a string for safe use in SQL literals using writeEscapedSQLString. func escapeSQLString(s string) string { var buf bytes.Buffer diff --git a/pkg/frontend/rewrite_rule_test.go b/pkg/frontend/rewrite_rule_test.go index b22b33088551d..3959550634495 100644 --- a/pkg/frontend/rewrite_rule_test.go +++ b/pkg/frontend/rewrite_rule_test.go @@ -18,7 +18,6 @@ import ( "context" "math/rand" "reflect" - "strconv" "strings" "testing" "testing/quick" @@ -29,8 +28,6 @@ import ( "github.com/stretchr/testify/require" "github.com/matrixorigin/matrixone/pkg/defines" - "github.com/matrixorigin/matrixone/pkg/sql/parsers" - "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" ) @@ -475,7 +472,7 @@ func TestLoadRuleCacheIncludesSecondaryRoles(t *testing.T) { rules, err := loadRuleCache(context.Background(), ses) require.NoError(t, err) require.Equal(t, map[string]string{ - "db1.t1": "(select a, age from db1.t1 where age < 3) union distinct (select A, Age from db1.t1 where age > 28)", + "db1.t1": "select a, age from db1.t1 where (age < 3) or (age > 28)", "db2.t2": "select a from db2.t2 where a = 20", }, rules) } @@ -687,11 +684,22 @@ func TestMergeRewriteRules(t *testing.T) { right := "select a from db1.t1 where a = 2" merged, err := mergeRewriteRules(ctx, left, right) require.NoError(t, err) - require.Equal(t, "(select a from db1.t1 where a = 1) union distinct (select a from db1.t1 where a = 2)", merged) + require.Equal(t, "select a from db1.t1 where (a = 1) or (a = 2)", merged) + + // Role rule merging is a base-row visibility union, not UNION DISTINCT over + // projected values. Partial projections are intentionally OR-merged so two + // visible rows with the same projected value are not collapsed here. + rowUnionMerged, err := mergeRewriteRules( + ctx, + "select a from db1.t1 where role_marker = 1", + "select a from db1.t1 where role_marker = 2", + ) + require.NoError(t, err) + require.Equal(t, "select a from db1.t1 where (role_marker = 1) or (role_marker = 2)", rowUnionMerged) merged, err = mergeRewriteRules(ctx, merged, "select a from db1.t1 where a = 3") require.NoError(t, err) - require.Equal(t, "((select a from db1.t1 where a = 1) union distinct (select a from db1.t1 where a = 2)) union distinct (select a from db1.t1 where a = 3)", merged) + require.Equal(t, "select a from db1.t1 where ((a = 1) or (a = 2)) or (a = 3)", merged) merged, err = mergeRewriteRules(ctx, "select a, age from db1.t1 where age > 28", "select a from db1.t1 where a = 2") require.NoError(t, err) @@ -703,14 +711,14 @@ func TestMergeRewriteRules(t *testing.T) { merged, err = mergeRewriteRules(ctx, "select * from db1.t1 where age > 28", "select * from db1.t1 where age < 3") require.NoError(t, err) require.Equal(t, - "select * from db1.t1 where age < 3", + "select * from db1.t1 where (age > 28) or (age < 3)", merged, ) merged, err = mergeRewriteRules(ctx, "select t.* from db1.t1 as t where age > 28", "select t.* from db1.t1 as t where age < 3") require.NoError(t, err) require.Equal(t, - "select t.* from db1.t1 as t where age < 3", + "select t.* from db1.t1 as t where (age > 28) or (age < 3)", merged, ) @@ -775,11 +783,11 @@ func TestMergeRewriteRules(t *testing.T) { merged, err = mergeRewriteRules(ctx, "select a as x from db1.t1 where age > 28", "select a as x from db1.t1 where age < 3") require.NoError(t, err) - require.Equal(t, "(select a as x from db1.t1 where age > 28) union distinct (select a as x from db1.t1 where age < 3)", merged) + require.Equal(t, "select a as x from db1.t1 where (age > 28) or (age < 3)", merged) merged, err = mergeRewriteRules(ctx, "select a + 1 as b from db1.t1 where age > 28", "select a + 1 as b from db1.t1 where age < 3") require.NoError(t, err) - require.Equal(t, "(select a + 1 as b from db1.t1 where age > 28) union distinct (select a + 1 as b from db1.t1 where age < 3)", merged) + require.Equal(t, "select a + 1 as b from db1.t1 where (age > 28) or (age < 3)", merged) _, err = mergeRewriteRules(ctx, "select a from", "select a from db1.t1") require.Error(t, err) @@ -829,303 +837,133 @@ func TestMergeRewriteRulesFallbackWhenEitherSideIsUnmergeable(t *testing.T) { } } -func TestRewriteRuleOutputColumns(t *testing.T) { - ctx := context.Background() - - columns, ok, err := rewriteRuleOutputColumns(ctx, "select A as x, a + 1 as b from db1.t1") - require.NoError(t, err) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{ - {name: "x", expr: "a"}, - {name: "b", expr: "a + 1"}, - }, columns) - - cases := []struct { - name string - rule string - }{ - {name: "top-level order by", rule: "select a from db1.t1 order by a"}, - {name: "top-level limit", rule: "select a from db1.t1 limit 1"}, - {name: "union order by", rule: "select a from db1.t1 union select a from db1.t2 order by a"}, - {name: "distinct", rule: "select distinct a from db1.t1"}, - {name: "group by", rule: "select a from db1.t1 group by a"}, - {name: "having", rule: "select a from db1.t1 having a > 1"}, - {name: "star", rule: "select * from db1.t1"}, - {name: "qualified star", rule: "select t.* from db1.t1 as t"}, - {name: "aggregate projection", rule: "select sum(a) as s from db1.t1"}, - {name: "window projection", rule: "select row_number() over () as rn from db1.t1"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - columns, ok, err := rewriteRuleOutputColumns(ctx, tc.rule) - require.NoError(t, err) - require.False(t, ok) - require.Nil(t, columns) - }) - } -} - -func TestTrimRewriteRuleForUnion(t *testing.T) { - require.Equal(t, "select a from db1.t1", trimRewriteRuleForUnion(" select a from db1.t1 ;; ; ")) - require.Equal(t, "select a from db1.t1", trimRewriteRuleForUnion("select a from db1.t1")) - require.Equal(t, "", trimRewriteRuleForUnion(" ; ; ")) -} - -func TestOutputColumnsFromRewriteStatementASTBranches(t *testing.T) { - selectClause := &tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - } - - columns, ok := outputColumnsFromRewriteStatement(&tree.ParenSelect{ - Select: &tree.Select{Select: selectClause}, - }) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteStatement(&tree.ParenSelect{}) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteStatement(&tree.Delete{}) - require.False(t, ok) - require.Nil(t, columns) +func TestTrimRewriteRuleForMerge(t *testing.T) { + require.Equal(t, "select a from db1.t1", trimRewriteRuleForMerge(" select a from db1.t1 ;; ; ")) + require.Equal(t, "select a from db1.t1", trimRewriteRuleForMerge("select a from db1.t1")) + require.Equal(t, "", trimRewriteRuleForMerge(" ; ; ")) } -func TestOutputColumnsFromRewriteSelectStatementASTBranches(t *testing.T) { - selectClause := &tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - } - - columns, ok := outputColumnsFromRewriteSelectStatement(&tree.UnionClause{ - Left: selectClause, - Right: &tree.SelectClause{Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("b")}}}, - }) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.ParenSelect{}) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.Select{ - Select: selectClause, - OrderBy: tree.OrderBy{&tree.Order{Expr: tree.NewUnresolvedColName("a")}}, - }) - require.False(t, ok) - require.Nil(t, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.Select{Select: selectClause}) - require.True(t, ok) - require.Equal(t, []rewriteRuleOutputColumn{{name: "a", expr: "a"}}, columns) - - columns, ok = outputColumnsFromRewriteSelectStatement(&tree.ValuesStatement{}) - require.False(t, ok) - require.Nil(t, columns) -} - -func TestOutputColumnFromRewriteSelectExprBranches(t *testing.T) { - column, ok := outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewBinaryExpr(tree.PLUS, tree.NewUnresolvedColName("a"), rewriteRuleNumVal(1)), - }) - require.True(t, ok) - require.Equal(t, rewriteRuleOutputColumn{name: "a + 1", expr: "a + 1"}, column) - - column, ok = outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewUnresolvedColName("A"), - As: tree.NewCStr("AliasA", 1), - }) - require.True(t, ok) - require.Equal(t, rewriteRuleOutputColumn{name: "aliasa", expr: "a"}, column) - - column, ok = outputColumnFromRewriteSelectExpr(tree.SelectExpr{ - Expr: tree.NewUnresolvedNameWithStar(tree.NewCStr("t", 1)), - }) - require.False(t, ok) - require.Equal(t, rewriteRuleOutputColumn{}, column) -} - -func TestMergeableRewriteSelectClauseRejectsDistinctOptions(t *testing.T) { - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Distinct: true})) - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Option: tree.QuerySpecOptionDistinct})) - require.False(t, mergeableRewriteSelectClause(&tree.SelectClause{Option: tree.QuerySpecOptionDistinctRow})) -} - -func TestSameRewriteOutputColumnsComparesNamesAndExpressions(t *testing.T) { +func TestRewriteRuleMergeShapeForRule(t *testing.T) { cases := []struct { - name string - left string - right string - same bool + name string + rule string + ok bool + selectList string + table string }{ { - name: "same aliases and same column expression", - left: "select A as x, age as y from db1.t1", - right: "select a as x, age as y from db1.t1", - same: true, + name: "simple select", + rule: "select A, Age from db1.t1 where age > 28", + ok: true, + selectList: "a, age", + table: "db1.t1", + }, + { + name: "star projection", + rule: "select * from db1.t1 where age > 28", + ok: true, + selectList: "*", + table: "db1.t1", + }, + { + name: "qualified star with alias", + rule: "select t.* from db1.t1 as t where age > 28", + ok: true, + selectList: "t.*", + table: "db1.t1 as t", + }, + { + name: "aliased column", + rule: "select a as x from db1.t1 where age > 28", + ok: true, + selectList: "a as x", + table: "db1.t1", + }, + { + name: "where subquery with order by and limit", + rule: "select a from db1.t1 where a in (select a from db1.t2 order by a limit 1)", + ok: true, + selectList: "a", + table: "db1.t1", + }, + { + name: "parenthesized single table", + rule: "select a from (db1.t1) where age > 28", + ok: true, + selectList: "a", + table: "db1.t1", }, { - name: "same aliases but swapped column expressions", - left: "select a as x, age as y from db1.t1", - right: "select age as x, a as y from db1.t1", - same: false, + name: "scalar expression", + rule: "select a + 1 from db1.t1 where age > 28", + ok: true, + selectList: "a + 1", + table: "db1.t1", }, { - name: "same aliases but different scalar expressions", - left: "select a + 1 as x from db1.t1", - right: "select a + 2 as x from db1.t1", - same: false, + name: "aggregate", + rule: "select count(*) from db1.t1 where age > 28", + ok: false, }, { - name: "same expressions but different aliases", - left: "select a + 1 as x from db1.t1", - right: "select a + 1 as y from db1.t1", - same: false, + name: "window", + rule: "select row_number() over () from db1.t1 where age > 28", + ok: false, + }, + { + name: "top-level order by", + rule: "select a from db1.t1 where age > 28 order by a", + ok: false, + }, + { + name: "top-level limit", + rule: "select a from db1.t1 where age > 28 limit 1", + ok: false, + }, + { + name: "distinct", + rule: "select distinct a from db1.t1 where age > 28", + ok: false, + }, + { + name: "group by", + rule: "select a from db1.t1 where age > 28 group by a", + ok: false, + }, + { + name: "having", + rule: "select a from db1.t1 where age > 28 having a > 1", + ok: false, + }, + { + name: "union", + rule: "select a from db1.t1 where age > 28 union all select a from db1.t1 where age < 3", + ok: false, + }, + { + name: "join", + rule: "select t1.a from db1.t1 join db1.t2 on t1.a = t2.a where t1.age > 28", + ok: false, }, } ctx := context.Background() for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - left, ok, err := rewriteRuleOutputColumns(ctx, tc.left) + shape, ok, err := rewriteRuleMergeShapeForRule(ctx, tc.rule) require.NoError(t, err) - require.True(t, ok) - right, ok, err := rewriteRuleOutputColumns(ctx, tc.right) - require.NoError(t, err) - require.True(t, ok) - require.Equal(t, tc.same, sameRewriteOutputColumns(left, right)) - }) - } -} - -func TestRewriteExprIsMergeSafe(t *testing.T) { - cases := []struct { - name string - sql string - safe bool - }{ - {name: "simple column", sql: "select a from db1.t1", safe: true}, - {name: "binary expression", sql: "select a + age from db1.t1", safe: true}, - {name: "comparison expression", sql: "select a > age from db1.t1", safe: true}, - {name: "case expression", sql: "select case when a > 0 then age else 0 end from db1.t1", safe: true}, - {name: "scalar function", sql: "select abs(a) from db1.t1", safe: true}, - {name: "cast expression", sql: "select cast(a as signed) from db1.t1", safe: true}, - {name: "between expression", sql: "select a between 1 and 3 from db1.t1", safe: true}, - {name: "aggregate function", sql: "select count(*) from db1.t1", safe: false}, - {name: "window function", sql: "select row_number() over () from db1.t1", safe: false}, - {name: "subquery expression", sql: "select exists (select a from db1.t1) from db1.t1", safe: false}, - {name: "system variable expression", sql: "select @@sql_mode from db1.t1", safe: false}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - exprs := parseRewriteRuleSelectExprs(t, tc.sql) - require.Len(t, exprs, 1) - require.Equal(t, tc.safe, rewriteExprIsMergeSafe(exprs[0].Expr)) - }) - } -} - -func TestRewriteExprIsMergeSafeASTBranches(t *testing.T) { - safeColumn := tree.NewUnresolvedColName("a") - safeNumber := rewriteRuleNumVal(1) - unsafeExpr := rewriteRuleUnsafeSubqueryExpr() - - cases := []struct { - name string - expr tree.Expr - safe bool - }{ - {name: "nil expression", expr: nil, safe: true}, - {name: "string literal", expr: tree.NewStrVal("'x'"), safe: true}, - {name: "unqualified star", expr: tree.UnqualifiedStar{}, safe: true}, - {name: "unary expression", expr: tree.NewUnaryExpr(tree.UNARY_MINUS, safeNumber), safe: true}, - {name: "comparison with escape", expr: tree.NewComparisonExprWithEscape(tree.LIKE, safeColumn, tree.NewStrVal("'a%'"), tree.NewStrVal("'\\'")), safe: true}, - {name: "and expression", expr: tree.NewAndExpr(safeColumn, safeNumber), safe: true}, - {name: "xor expression", expr: tree.NewXorExpr(safeColumn, safeNumber), safe: true}, - {name: "or expression", expr: tree.NewOrExpr(safeColumn, safeNumber), safe: true}, - {name: "not expression", expr: tree.NewNotExpr(safeColumn), safe: true}, - {name: "is null expression", expr: tree.NewIsNullExpr(safeColumn), safe: true}, - {name: "is not null expression", expr: tree.NewIsNotNullExpr(safeColumn), safe: true}, - {name: "is unknown expression", expr: tree.NewIsUnknownExpr(safeColumn), safe: true}, - {name: "is not unknown expression", expr: tree.NewIsNotUnknownExpr(safeColumn), safe: true}, - {name: "is true expression", expr: tree.NewIsTrueExpr(safeColumn), safe: true}, - {name: "is not true expression", expr: tree.NewIsNotTrueExpr(safeColumn), safe: true}, - {name: "is false expression", expr: tree.NewIsFalseExpr(safeColumn), safe: true}, - {name: "is not false expression", expr: tree.NewIsNotFalseExpr(safeColumn), safe: true}, - {name: "paren expression", expr: tree.NewParentExpr(safeColumn), safe: true}, - {name: "function with table type", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_TABLE, nil, safeColumn), safe: false}, - {name: "function without a resolvable name", expr: &tree.FuncExpr{}, safe: false}, - {name: "function with unsafe argument", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_DEFAULT, nil, unsafeExpr), safe: false}, - {name: "function with unsafe order by", expr: rewriteRuleFuncExpr("abs", tree.FUNC_TYPE_DEFAULT, tree.OrderBy{&tree.Order{Expr: unsafeExpr}}, safeColumn), safe: false}, - {name: "serial extract expression", expr: &tree.SerialExtractExpr{SerialExpr: safeColumn, IndexExpr: safeNumber}, safe: true}, - {name: "serial extract with unsafe index", expr: &tree.SerialExtractExpr{SerialExpr: safeColumn, IndexExpr: unsafeExpr}, safe: false}, - {name: "bit cast expression", expr: &tree.BitCastExpr{Expr: safeColumn}, safe: true}, - {name: "tuple expression", expr: &tree.Tuple{Exprs: tree.Exprs{safeColumn, safeNumber}}, safe: true}, - {name: "tuple with unsafe member", expr: &tree.Tuple{Exprs: tree.Exprs{safeColumn, unsafeExpr}}, safe: false}, - {name: "case with unsafe else", expr: &tree.CaseExpr{Expr: safeColumn, Else: unsafeExpr}, safe: false}, - {name: "case with nil when", expr: &tree.CaseExpr{Whens: []*tree.When{nil, tree.NewWhen(safeColumn, safeNumber)}, Else: safeNumber}, safe: true}, - {name: "case with unsafe when", expr: &tree.CaseExpr{Whens: []*tree.When{tree.NewWhen(unsafeExpr, safeNumber)}, Else: safeNumber}, safe: false}, - {name: "interval expression", expr: &tree.IntervalExpr{Expr: safeNumber}, safe: true}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.safe, rewriteExprIsMergeSafe(tc.expr)) + require.Equal(t, tc.ok, ok) + if !tc.ok { + require.Nil(t, shape) + return + } + require.NotNil(t, shape) + require.Equal(t, tc.selectList, shape.selectList) + require.Equal(t, tc.table, shape.table) }) } } -func TestRewriteExprsAndOrderByAreMergeSafe(t *testing.T) { - safeColumn := tree.NewUnresolvedColName("a") - unsafeExpr := rewriteRuleUnsafeSubqueryExpr() - - require.True(t, rewriteExprsAreMergeSafe(nil)) - require.True(t, rewriteExprsAreMergeSafe(tree.Exprs{safeColumn})) - require.False(t, rewriteExprsAreMergeSafe(tree.Exprs{safeColumn, unsafeExpr})) - - require.True(t, rewriteOrderByIsMergeSafe(tree.OrderBy{nil, &tree.Order{Expr: safeColumn}})) - require.False(t, rewriteOrderByIsMergeSafe(tree.OrderBy{&tree.Order{Expr: unsafeExpr}})) -} - -func TestRewriteFuncExprName(t *testing.T) { - require.Equal(t, "abs", rewriteFuncExprName(&tree.FuncExpr{FuncName: tree.NewCStr("ABS", 1)})) - require.Equal(t, "lower", rewriteFuncExprName(&tree.FuncExpr{ - Func: tree.FuncName2ResolvableFunctionReference(tree.NewUnresolvedColName("LOWER")), - })) - require.Equal(t, "", rewriteFuncExprName(&tree.FuncExpr{})) -} - -func parseRewriteRuleSelectExprs(t *testing.T, sql string) tree.SelectExprs { - t.Helper() - - stmt, err := parsers.ParseOne(context.Background(), dialect.MYSQL, sql, 1) - require.NoError(t, err) - selectStmt, ok := stmt.(*tree.Select) - require.True(t, ok) - selectClause, ok := selectStmt.Select.(*tree.SelectClause) - require.True(t, ok) - return selectClause.Exprs -} - -func rewriteRuleNumVal(v int64) *tree.NumVal { - return tree.NewNumVal[int64](v, strconv.FormatInt(v, 10), false, tree.P_int64) -} - -func rewriteRuleFuncExpr(name string, typ tree.FuncType, orderBy tree.OrderBy, exprs ...tree.Expr) *tree.FuncExpr { - return &tree.FuncExpr{ - FuncName: tree.NewCStr(name, 1), - Type: typ, - Exprs: exprs, - OrderBy: orderBy, - } -} - -func rewriteRuleUnsafeSubqueryExpr() tree.Expr { - return tree.NewSubquery(&tree.SelectClause{ - Exprs: tree.SelectExprs{{Expr: tree.NewUnresolvedColName("a")}}, - }, true) -} - func newMrsForRewriteRules(rows [][]interface{}) *MysqlResultSet { mrs := &MysqlResultSet{} diff --git a/test/distributed/cases/tenant/privilege/owner1.result b/test/distributed/cases/tenant/privilege/owner1.result index 166d5744a97a6..67868289f3996 100644 --- a/test/distributed/cases/tenant/privilege/owner1.result +++ b/test/distributed/cases/tenant/privilege/owner1.result @@ -14,9 +14,7 @@ set secondary role all; create table db1.t3(a int); set role role2; drop table db1.t3; -internal error: do not have privilege to execute the statement drop table db1.t2; -drop table db1.t3; drop database db1; drop account default_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/tenant/privilege/owner1.sql b/test/distributed/cases/tenant/privilege/owner1.sql index de115b62dab68..2bd31cbc3da3b 100644 --- a/test/distributed/cases/tenant/privilege/owner1.sql +++ b/test/distributed/cases/tenant/privilege/owner1.sql @@ -24,8 +24,7 @@ drop table db1.t3; drop table db1.t2; -- @session -- @session:id=5&user=default_1:user1:role1&password=123456 -drop table db1.t3; drop database db1; -- @session drop account default_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.result b/test/distributed/cases/zz_accesscontrol/role_rule.result index 0eea49af706a2..751178ae44650 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.result +++ b/test/distributed/cases/zz_accesscontrol/role_rule.result @@ -5,6 +5,9 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -17,11 +20,19 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; create database db1; create table db1.t1(a int, age int); insert into db1.t1 values (1,1),(2,2),(100,30); +create table db1.t_dup(a int, marker int); +insert into db1.t_dup values (1,1),(1,2),(2,3); create role test_rule_role; alter role test_rule_role add rule "select * from db1.t1 where age > 28" on table db1.t1; show rules on role test_rule_role; @@ -64,8 +75,8 @@ create table db2.t2(a int, age int); insert into db2.t2 values (10,10),(20,35),(200,60); create role test_rule_role_multi_a; create role test_rule_role_multi_b; -alter role test_rule_role_multi_a add rule "select a, age from db1.t1 where age > 28" on table db1.t1; -alter role test_rule_role_multi_b add rule "select a, age from db1.t1 where age < 3" on table db1.t1; +alter role test_rule_role_multi_a add rule "select * from db1.t1 where age > 1" on table db1.t1; +alter role test_rule_role_multi_b add rule "select * from db1.t1 where age < 3" on table db1.t1; alter role test_rule_role_multi_b add rule "select * from db2.t2 where age > 30" on table db2.t2; create user test_rule_user_multi identified by '123456' default role test_rule_role_multi_a; grant connect on account * to test_rule_role_multi_a; @@ -76,6 +87,7 @@ grant test_rule_role_multi_b to test_rule_user_multi; set enable_remap_hint = 1; select * from db1.t1 order by a; a age +2 2 100 30 set secondary role all; select * from db1.t1 order by a; @@ -90,6 +102,7 @@ a age set secondary role none; select * from db1.t1 order by a; a age +2 2 100 30 create role test_rule_role_multi_c; create role test_rule_role_multi_d; @@ -153,12 +166,65 @@ set secondary role all; select * from db1.t1 order by a; a age 30 100 +create role test_rule_role_where_a; +create role test_rule_role_where_b; +alter role test_rule_role_where_a add rule "select * from db1.t1 where age > 28" on table db1.t1; +alter role test_rule_role_where_b add rule "select * from db1.t1" on table db1.t1; +create user test_rule_user_where identified by '123456' default role test_rule_role_where_a; +grant connect on account * to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_b; +grant test_rule_role_where_b to test_rule_user_where; +set enable_remap_hint = 1; +set secondary role all; +select * from db1.t1 order by a; +a age +1 1 +2 2 +100 30 +create role test_rule_role_dup_projection_a; +create role test_rule_role_dup_projection_b; +alter role test_rule_role_dup_projection_a add rule "select a from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_dup_projection_b add rule "select a from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_dup_projection identified by '123456' default role test_rule_role_dup_projection_a; +grant connect on account * to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_b; +grant test_rule_role_dup_projection_b to test_rule_user_dup_projection; +set enable_remap_hint = 1; +select count(*) from db1.t_dup; +count(*) +1 +set secondary role all; +select a, count(*) from db1.t_dup group by a order by a; +a count(*) +1 2 +select count(*) from db1.t_dup; +count(*) +2 +create role test_rule_role_expr_projection_a; +create role test_rule_role_expr_projection_b; +alter role test_rule_role_expr_projection_a add rule "select a + 1 as b from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_expr_projection_b add rule "select a + 1 as b from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_expr_projection identified by '123456' default role test_rule_role_expr_projection_a; +grant connect on account * to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_b; +grant test_rule_role_expr_projection_b to test_rule_user_expr_projection; +set enable_remap_hint = 1; +set secondary role all; +select b, count(*) from db1.t_dup group by b order by b; +b count(*) +2 2 drop user if exists test_rule_user; drop user if exists test_rule_user_multi; drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -171,6 +237,12 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/role_rule.sql b/test/distributed/cases/zz_accesscontrol/role_rule.sql index 05ef4feef105b..12a997e951def 100644 --- a/test/distributed/cases/zz_accesscontrol/role_rule.sql +++ b/test/distributed/cases/zz_accesscontrol/role_rule.sql @@ -6,6 +6,9 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -18,11 +21,19 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; create database db1; create table db1.t1(a int, age int); insert into db1.t1 values (1,1),(2,2),(100,30); +create table db1.t_dup(a int, marker int); +insert into db1.t_dup values (1,1),(1,2),(2,3); -- 1. ADD RULE normal case + SHOW RULES verification create role test_rule_role; @@ -67,14 +78,14 @@ set enable_remap_hint = 1; select * from db1.t1; -- @session --- 11. SET SECONDARY ROLE ALL applies rewrite rules from all active roles +-- 11. SET SECONDARY ROLE ALL merges select * rewrite rules from all active roles create database db2; create table db2.t2(a int, age int); insert into db2.t2 values (10,10),(20,35),(200,60); create role test_rule_role_multi_a; create role test_rule_role_multi_b; -alter role test_rule_role_multi_a add rule "select a, age from db1.t1 where age > 28" on table db1.t1; -alter role test_rule_role_multi_b add rule "select a, age from db1.t1 where age < 3" on table db1.t1; +alter role test_rule_role_multi_a add rule "select * from db1.t1 where age > 1" on table db1.t1; +alter role test_rule_role_multi_b add rule "select * from db1.t1 where age < 3" on table db1.t1; alter role test_rule_role_multi_b add rule "select * from db2.t2 where age > 30" on table db2.t2; create user test_rule_user_multi identified by '123456' default role test_rule_role_multi_a; grant connect on account * to test_rule_role_multi_a; @@ -159,6 +170,56 @@ set secondary role all; select * from db1.t1 order by a; -- @session +-- 17. Merging with a no-WHERE rule produces the broader rule +create role test_rule_role_where_a; +create role test_rule_role_where_b; +alter role test_rule_role_where_a add rule "select * from db1.t1 where age > 28" on table db1.t1; +alter role test_rule_role_where_b add rule "select * from db1.t1" on table db1.t1; +create user test_rule_user_where identified by '123456' default role test_rule_role_where_a; +grant connect on account * to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_a; +grant select on table db1.t1 to test_rule_role_where_b; +grant test_rule_role_where_b to test_rule_user_where; +-- @session:id=7&user=sys:test_rule_user_where:test_rule_role_where_a&password=123456 +set enable_remap_hint = 1; +set secondary role all; +select * from db1.t1 order by a; +-- @session + +-- 18. Partial projections keep duplicate projected values from distinct visible rows +create role test_rule_role_dup_projection_a; +create role test_rule_role_dup_projection_b; +alter role test_rule_role_dup_projection_a add rule "select a from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_dup_projection_b add rule "select a from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_dup_projection identified by '123456' default role test_rule_role_dup_projection_a; +grant connect on account * to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_a; +grant select on table db1.t_dup to test_rule_role_dup_projection_b; +grant test_rule_role_dup_projection_b to test_rule_user_dup_projection; +-- @session:id=8&user=sys:test_rule_user_dup_projection:test_rule_role_dup_projection_a&password=123456 +set enable_remap_hint = 1; +select count(*) from db1.t_dup; +set secondary role all; +select a, count(*) from db1.t_dup group by a order by a; +select count(*) from db1.t_dup; +-- @session + +-- 19. Row-wise expression projections also merge without dropping role visibility +create role test_rule_role_expr_projection_a; +create role test_rule_role_expr_projection_b; +alter role test_rule_role_expr_projection_a add rule "select a + 1 as b from db1.t_dup where marker = 1" on table db1.t_dup; +alter role test_rule_role_expr_projection_b add rule "select a + 1 as b from db1.t_dup where marker = 2" on table db1.t_dup; +create user test_rule_user_expr_projection identified by '123456' default role test_rule_role_expr_projection_a; +grant connect on account * to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_a; +grant select on table db1.t_dup to test_rule_role_expr_projection_b; +grant test_rule_role_expr_projection_b to test_rule_user_expr_projection; +-- @session:id=9&user=sys:test_rule_user_expr_projection:test_rule_role_expr_projection_a&password=123456 +set enable_remap_hint = 1; +set secondary role all; +select b, count(*) from db1.t_dup group by b order by b; +-- @session + -- cleanup all test resources drop user if exists test_rule_user; drop user if exists test_rule_user_multi; @@ -166,6 +227,9 @@ drop user if exists test_rule_user_multi_diff; drop user if exists test_rule_user_inherit; drop user if exists test_rule_user_unmergeable; drop user if exists test_rule_user_alias; +drop user if exists test_rule_user_where; +drop user if exists test_rule_user_dup_projection; +drop user if exists test_rule_user_expr_projection; drop role if exists test_rule_role; drop role if exists test_rule_role_multi_a; drop role if exists test_rule_role_multi_b; @@ -178,6 +242,12 @@ drop role if exists test_rule_role_unmergeable_a; drop role if exists test_rule_role_unmergeable_b; drop role if exists test_rule_role_alias_a; drop role if exists test_rule_role_alias_b; +drop role if exists test_rule_role_where_a; +drop role if exists test_rule_role_where_b; +drop role if exists test_rule_role_dup_projection_a; +drop role if exists test_rule_role_dup_projection_b; +drop role if exists test_rule_role_expr_projection_a; +drop role if exists test_rule_role_expr_projection_b; drop database if exists db1; drop database if exists db2; set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/use_role.result b/test/distributed/cases/zz_accesscontrol/use_role.result index 4f3c98e735742..5c66b601905dc 100644 --- a/test/distributed/cases/zz_accesscontrol/use_role.result +++ b/test/distributed/cases/zz_accesscontrol/use_role.result @@ -48,12 +48,12 @@ internal error: do not have privilege to execute the statement drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement set secondary role all; -create table use_db_1.use_table_5(a int,b varchar(20),c double ); -insert into use_db_1.use_table_5 values(10,'yellow',99.99); -select * from use_db_1.use_table_5; +insert into use_db_1.use_table_2 values(20,'green',88.88); +select * from use_db_1.use_table_2 order by a; a b c 10 yellow 99.99 -drop table use_db_1.use_table_5; +20 green 88.88 +34 kelly 90.3 set role moadmin; internal error: the role moadmin has not be granted to the user use_user_1 create user use_user_2 identified by '123456'; @@ -72,15 +72,21 @@ internal error: there is no role use_role_test drop role use_role_test; internal error: do not have privilege to execute the statement set secondary role all; -create table use_db_1.use_table_6(a int,b varchar(20),c double); -insert into use_db_1.use_table_6 values(10,'yellow',99.99); +insert into use_db_1.use_table_2 values(30,'blue',77.77); +select * from use_db_1.use_table_2 order by a; +a b c +10 yellow 99.99 +20 green 88.88 +30 blue 77.77 +34 kelly 90.3 create database use_db_test; internal error: do not have privilege to execute the statement drop database use_db_test; internal error: do not have privilege to execute the statement set secondary role none; -insert into use_db_1.use_table_6 values (10, 'yellow', 99.99); -drop table use_db_1.use_table_6; +insert into use_db_1.use_table_2 values (40, 'red', 66.66); +internal error: do not have privilege to execute the statement +drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement create role if not exists use_role_test; internal error: do not have privilege to execute the statement @@ -88,11 +94,11 @@ drop role use_role_test; internal error: do not have privilege to execute the statement set role use_role_5; internal error: the role use_role_5 has not be granted to the user use_user_1 -drop table use_db_1.use_table_6; +drop table use_db_1.use_table_2; internal error: do not have privilege to execute the statement create database if not exists use_db_test; internal error: do not have privilege to execute the statement drop role if exists use_role_1,use_role_2,use_role_3,use_role_4,use_role_5; drop user if exists use_user_1,use_user_2; drop database if exists use_db_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on; diff --git a/test/distributed/cases/zz_accesscontrol/use_role.sql b/test/distributed/cases/zz_accesscontrol/use_role.sql index 88f5812350567..cbf3fa89072f6 100644 --- a/test/distributed/cases/zz_accesscontrol/use_role.sql +++ b/test/distributed/cases/zz_accesscontrol/use_role.sql @@ -37,10 +37,8 @@ insert into use_db_1.use_table_2 values(10,'yellow',99.99); select * from use_db_1.use_table_2; drop table use_db_1.use_table_2; set secondary role all; -create table use_db_1.use_table_5(a int,b varchar(20),c double ); -insert into use_db_1.use_table_5 values(10,'yellow',99.99); -select * from use_db_1.use_table_5; -drop table use_db_1.use_table_5; +insert into use_db_1.use_table_2 values(20,'green',88.88); +select * from use_db_1.use_table_2 order by a; set role moadmin; -- @session @@ -57,21 +55,21 @@ create role use_role_test; set role use_role_test; drop role use_role_test; set secondary role all; -create table use_db_1.use_table_6(a int,b varchar(20),c double); -insert into use_db_1.use_table_6 values(10,'yellow',99.99); +insert into use_db_1.use_table_2 values(30,'blue',77.77); +select * from use_db_1.use_table_2 order by a; create database use_db_test; drop database use_db_test; set secondary role none; -insert into use_db_1.use_table_6 values (10, 'yellow', 99.99); -drop table use_db_1.use_table_6; +insert into use_db_1.use_table_2 values (40, 'red', 66.66); +drop table use_db_1.use_table_2; create role if not exists use_role_test; drop role use_role_test; set role use_role_5; -drop table use_db_1.use_table_6; +drop table use_db_1.use_table_2; create database if not exists use_db_test; -- @session drop role if exists use_role_1,use_role_2,use_role_3,use_role_4,use_role_5; drop user if exists use_user_1,use_user_2; drop database if exists use_db_1; -set global enable_privilege_cache = on; \ No newline at end of file +set global enable_privilege_cache = on;