Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 4 additions & 28 deletions pkg/frontend/authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3498,22 +3498,16 @@ func doAlterAccount(ctx context.Context, ses *Session, aa *alterAccount) (err er
return err
}

// doSetSecondaryRoleAll set the session role of the user with smallness role_id
// doSetSecondaryRoleAll validates user role metadata before enabling all secondary roles.
// The current primary role must not change; SET SECONDARY ROLE ALL only affects secondary roles.
func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) {
var sql string
var userId uint32
var erArray []ExecResult
var roleId int64
var roleName string

account := ses.GetTenantInfo()
// get current user_id
userId = account.GetUserID()

// init role_id and role_name
roleId = publicRoleID
roleName = publicRoleName

// step1:get all roles expect public
bh := ses.GetBackgroundExec(ctx)
defer bh.Close()
Expand All @@ -3533,26 +3527,7 @@ func doSetSecondaryRoleAll(ctx context.Context, ses *Session) (err error) {
return err
}

erArray, err = getResultSet(ctx, bh)
if err != nil {
return err
}
if execResultArrayHasData(erArray) {
roleId, err = erArray[0].GetInt64(ctx, 0, 0)
if err != nil {
return err
}

roleName, err = erArray[0].GetString(ctx, 0, 1)
if err != nil {
return err
}
}

// step2 : switch the default role and role id;
account.SetDefaultRoleID(uint32(roleId))
account.SetDefaultRole(roleName)

_, err = getResultSet(ctx, bh)
return err
}

Expand Down Expand Up @@ -3649,6 +3624,7 @@ func doSwitchRole(ctx context.Context, ses *Session, sr *tree.SetRole) (err erro
account.SetDefaultRole(sr.Role.UserName)
// then, reset secondary role to none
account.SetUseSecondaryRole(false)
ses.InvalidatePrivilegeCache()

return err
}
Expand Down
56 changes: 56 additions & 0 deletions pkg/frontend/authenticate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8185,6 +8185,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) {

err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses)
convey.So(err, convey.ShouldBeNil)
convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5))
convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1")
})

convey.Convey("do set secondary role succ", t, func() {
Expand Down Expand Up @@ -8224,6 +8226,8 @@ func TestDoSetSecondaryRoleAll(t *testing.T) {

err := doSetSecondaryRoleAll(ses.GetTxnHandler().GetTxnCtx(), ses)
convey.So(err, convey.ShouldBeNil)
convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5))
convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1")
})
}

Expand Down Expand Up @@ -8263,6 +8267,8 @@ func TestDoSwitchRoleSecondaryRoleAllInvalidatesRuleCache(t *testing.T) {
})
convey.So(err, convey.ShouldBeNil)
convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeTrue)
convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(5))
convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role1")

ses.ruleCacheMu.RLock()
cacheIsNil := ses.ruleCache == nil
Expand Down Expand Up @@ -8337,6 +8343,56 @@ func TestDoSwitchRoleSecondaryRoleNoneInvalidatesRuleCache(t *testing.T) {
})
}

func TestDoSwitchRolePrimaryRoleInvalidatesRuleCache(t *testing.T) {
convey.Convey("set role switches current role and invalidates rule cache", t, func() {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

bh := &backgroundExecTest{}
bh.init()

bhStub := gostub.StubFunc(&NewBackgroundExec, bh)
defer bhStub.Reset()

ses := newSes(&privilege{}, ctrl)
tenant := &TenantInfo{
Tenant: "test_account",
User: "test_user",
DefaultRole: "role1",
TenantID: 3001,
UserID: 3,
DefaultRoleID: 5,
}
tenant.SetUseSecondaryRole(true)
ses.SetTenantInfo(tenant)
ses.ruleCache = map[string]string{"db1.t1": "select a from db1.t1"}

bh.sql2result["begin;"] = nil
bh.sql2result["commit;"] = nil
bh.sql2result["rollback;"] = nil

roleSQL, err := getSqlForRoleIdOfRole(context.Background(), "role2")
convey.So(err, convey.ShouldBeNil)
bh.sql2result[roleSQL] = newMrsForRoleIdOfRole([][]interface{}{{int64(6)}})
bh.sql2result[getSqlForCheckUserGrant(6, int64(tenant.UserID))] = newMrsForCheckUserGrant([][]interface{}{
{int64(6), int64(tenant.UserID), false},
})

err = doSwitchRole(ses.GetTxnHandler().GetTxnCtx(), ses, &tree.SetRole{
Role: &tree.Role{UserName: "role2"},
})
convey.So(err, convey.ShouldBeNil)
convey.So(tenant.GetDefaultRoleID(), convey.ShouldEqual, uint32(6))
convey.So(tenant.GetDefaultRole(), convey.ShouldEqual, "role2")
convey.So(tenant.GetUseSecondaryRole(), convey.ShouldBeFalse)

ses.ruleCacheMu.RLock()
cacheIsNil := ses.ruleCache == nil
ses.ruleCacheMu.RUnlock()
convey.So(cacheIsNil, convey.ShouldBeTrue)
})
}

func TestGetSessionSysVar(t *testing.T) {
convey.Convey("get session system variable succ", t, func() {
ctrl := gomock.NewController(t)
Expand Down
Loading
Loading