From 88004441db5b157f9fd84871a2175ffd834cfd92 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Fri, 29 May 2026 15:28:44 -0700 Subject: [PATCH] implement alter default privileges --- postgres/parser/parser/sql.y | 8 +- postgres/parser/sem/tree/alter_index.go | 4 +- postgres/parser/sem/tree/alter_table.go | 8 +- server/ast/alter_default_privileges.go | 51 ++- server/auth/database.go | 11 + server/auth/default_privileges.go | 291 +++++++++++++++++ server/auth/serialization.go | 39 ++- server/node/alter_default_privileges.go | 168 ++++++++++ server/node/create_function.go | 12 + server/node/create_procedure.go | 12 + server/node/create_sequence.go | 12 + server/node/create_table.go | 43 ++- server/tables/pgcatalog/pg_default_acl.go | 36 ++- testing/go/auth_test.go | 367 ++++++++++++++++++++++ testing/go/import_dumps_test.go | 2 + 15 files changed, 1049 insertions(+), 15 deletions(-) create mode 100644 server/auth/default_privileges.go create mode 100644 server/node/alter_default_privileges.go diff --git a/postgres/parser/parser/sql.y b/postgres/parser/parser/sql.y index 6ace140439..16a9f770bb 100644 --- a/postgres/parser/parser/sql.y +++ b/postgres/parser/parser/sql.y @@ -2026,9 +2026,9 @@ alter_oneindex_stmt: { $$.val = &tree.AlterIndex{Index: $5.tableIndexName(), IfExists: true, Cmd: $6.alterIndexCmd()} } -| ALTER INDEX table_index_name ATTACH PARTITION index_name +| ALTER INDEX table_index_name ATTACH PARTITION db_object_name { - $$.val = &tree.AlterIndex{Index: $3.tableIndexName(), Cmd: &tree.AlterIndexAttachPartition{Index: tree.UnrestrictedName($6)}} + $$.val = &tree.AlterIndex{Index: $3.tableIndexName(), Cmd: &tree.AlterIndexAttachPartition{Index: $6.unresolvedObjectName()}} } | ALTER INDEX table_index_name opt_no DEPENDS ON EXTENSION name { @@ -2128,9 +2128,9 @@ alter_table_action: } } // ALTER TABLE ADD CONSTRAINT ... USING INDEX -| ADD CONSTRAINT constraint_name unique_or_primary USING INDEX index_name opt_deferrable_mode opt_initially +| ADD CONSTRAINT constraint_name unique_or_primary USING INDEX db_object_name opt_deferrable_mode opt_initially { - $$.val = tree.AlterTableConstraintUsingIndex{Constraint: tree.Name($3), IsUnique: $4.bool(), Index: tree.Name($7), Deferrable: $8.deferrableMode(), Initially: $9.initiallyMode()} + $$.val = tree.AlterTableConstraintUsingIndex{Constraint: tree.Name($3), IsUnique: $4.bool(), Index: $7.unresolvedObjectName(), Deferrable: $8.deferrableMode(), Initially: $9.initiallyMode()} } // ALTER TABLE ALTER CONSTRAINT ... | ALTER CONSTRAINT constraint_name opt_deferrable_mode opt_initially diff --git a/postgres/parser/sem/tree/alter_index.go b/postgres/parser/sem/tree/alter_index.go index 34d4ecace3..4e88720f9a 100644 --- a/postgres/parser/sem/tree/alter_index.go +++ b/postgres/parser/sem/tree/alter_index.go @@ -92,13 +92,13 @@ var _ AlterIndexCmd = &AlterIndexSetTablespace{} // AlterIndexAttachPartition represents an ALTER INDEX ... ATTACH PARTITION statement. type AlterIndexAttachPartition struct { - Index UnrestrictedName + Index *UnresolvedObjectName } // Format implements the NodeFormatter interface. func (node *AlterIndexAttachPartition) Format(ctx *FmtCtx) { ctx.WriteString(" ATTACH PARTITION ") - ctx.FormatNode(&node.Index) + node.Index.Format(ctx) } // AlterIndexExtension represents an ALTER INDEX ... [NO] DEPENDS ON EXTENSION statement. diff --git a/postgres/parser/sem/tree/alter_table.go b/postgres/parser/sem/tree/alter_table.go index b97fcba067..857d4669d4 100644 --- a/postgres/parser/sem/tree/alter_table.go +++ b/postgres/parser/sem/tree/alter_table.go @@ -404,7 +404,7 @@ func (node *AlterTableComputed) GetColumn() Name { type AlterTableConstraintUsingIndex struct { Constraint Name IsUnique bool - Index Name + Index *UnresolvedObjectName Deferrable DeferrableMode Initially InitiallyMode } @@ -421,7 +421,7 @@ func (node *AlterTableConstraintUsingIndex) Format(ctx *FmtCtx) { ctx.WriteString(" PRIMARY KEY") } ctx.WriteString(" USING INDEX") - ctx.FormatNode(&node.Index) + node.Index.Format(ctx) switch node.Deferrable { case Deferrable: ctx.WriteString(" DEFERRABLE") @@ -966,7 +966,7 @@ func (node *AlterTablePartition) Format(ctx *FmtCtx) { node.Name.Format(ctx) if node.IsDetach { ctx.WriteString(" DETACH PARTITION ") - node.Name.Format(ctx) + node.Partition.Format(ctx) switch node.DetachType { case DetachPartitionNone: case DetachPartitionConcurrently: @@ -976,7 +976,7 @@ func (node *AlterTablePartition) Format(ctx *FmtCtx) { } } else { ctx.WriteString(" ATTACH PARTITION ") - node.Name.Format(ctx) + node.Partition.Format(ctx) ctx.WriteByte(' ') ctx.FormatNode(&node.Spec) } diff --git a/server/ast/alter_default_privileges.go b/server/ast/alter_default_privileges.go index 8d1cb74f39..8ac234f478 100644 --- a/server/ast/alter_default_privileges.go +++ b/server/ast/alter_default_privileges.go @@ -15,12 +15,61 @@ package ast import ( + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/postgres/parser/privilege" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + "github.com/dolthub/doltgresql/server/auth" + pgnodes "github.com/dolthub/doltgresql/server/node" ) // nodeAlterDefaultPrivileges handles *tree.AlterDefaultPrivileges nodes. func nodeAlterDefaultPrivileges(ctx *Context, node *tree.AlterDefaultPrivileges) (vitess.Statement, error) { - return NotYetSupportedError("ALTER DEFAULT PRIVILEGES statement is not yet supported") + if node == nil { + return nil, nil + } + + objType, err := convertDefaultPrivilegeObjectType(node.Target.TargetType) + if err != nil { + return nil, err + } + + privileges, err := convertPrivilegeKinds(objType, node.Privileges) + if err != nil { + return nil, err + } + + return vitess.InjectedStatement{ + Statement: &pgnodes.AlterDefaultPrivileges{ + OwnerRoles: node.TargetRoles, + Schemas: node.Target.InSchema, + ObjectType: objType, + Privileges: privileges, + Grantees: node.Grantees, + Grant: node.Grant, + GrantOption: node.GrantOption, + Cascade: node.DropBehavior == tree.DropCascade, + }, + Children: nil, + }, nil +} + +// convertDefaultPrivilegeObjectType converts a privilege.ObjectType to an auth.PrivilegeObject for use in default +// privileges. Only the object types valid for ALTER DEFAULT PRIVILEGES are accepted. +func convertDefaultPrivilegeObjectType(objType privilege.ObjectType) (auth.PrivilegeObject, error) { + switch objType { + case privilege.Table: + return auth.PrivilegeObject_TABLE, nil + case privilege.Sequence: + return auth.PrivilegeObject_SEQUENCE, nil + case privilege.Function, privilege.Procedure, privilege.Routine: + return auth.PrivilegeObject_FUNCTION, nil + case privilege.Schema: + return auth.PrivilegeObject_SCHEMA, nil + case privilege.Type: + return auth.PrivilegeObject_TYPE, nil + default: + return 0, errors.Errorf("object type %q is not supported in ALTER DEFAULT PRIVILEGES", string(objType)) + } } diff --git a/server/auth/database.go b/server/auth/database.go index 9e22fbdf7e..ec97cd65f0 100644 --- a/server/auth/database.go +++ b/server/auth/database.go @@ -45,6 +45,7 @@ type Database struct { sequencePrivileges *SequencePrivileges routinePrivileges *RoutinePrivileges roleMembership *RoleMembership + defaultPrivileges *DefaultPrivileges } // ClearDatabase clears the internal database, leaving only the default users. This is primarily for use by tests. @@ -57,6 +58,7 @@ func ClearDatabase() { clear(globalDatabase.sequencePrivileges.Data) clear(globalDatabase.routinePrivileges.Data) clear(globalDatabase.roleMembership.Data) + clear(globalDatabase.defaultPrivileges.Data) dbInitDefault() } @@ -96,6 +98,14 @@ func RoleExists(name string) bool { return ok } +// GetRoleName returns the name of the role with the given ID. Returns an empty string if the role does not exist. +func GetRoleName(id RoleID) string { + if role, ok := globalDatabase.rolesByID[id]; ok { + return role.Name + } + return "" +} + // SetRole sets the role matching the given name. This will add a role that does not yet exist, and overwrite an // existing role. func SetRole(role Role) { @@ -143,6 +153,7 @@ func dbInit(dEnv *env.DoltEnv, cfg Config) { sequencePrivileges: NewSequencePrivileges(), routinePrivileges: NewRoutinePrivileges(), roleMembership: NewRoleMembership(), + defaultPrivileges: NewDefaultPrivileges(), } globalLock = &sync.RWMutex{} if dEnv != nil { diff --git a/server/auth/default_privileges.go b/server/auth/default_privileges.go new file mode 100644 index 0000000000..de83a40b09 --- /dev/null +++ b/server/auth/default_privileges.go @@ -0,0 +1,291 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + + "github.com/dolthub/doltgresql/utils" +) + +// DefaultPrivileges stores the default privileges automatically applied when objects are created. +type DefaultPrivileges struct { + Data map[DefaultPrivilegeKey]DefaultPrivilegeValue +} + +// DefaultPrivilegeKey identifies the context for a set of default privileges: +// the owner role, the optional schema scope, and the object type. +type DefaultPrivilegeKey struct { + OwnerRole RoleID + Schema string // empty = applicable to any schema + ObjectType PrivilegeObject // TABLE, SEQUENCE, FUNCTION, SCHEMA, TYPE +} + +// DefaultPrivilegeValue stores the grantee ACL entries for a given DefaultPrivilegeKey. +type DefaultPrivilegeValue struct { + Key DefaultPrivilegeKey + Grantees map[RoleID]DefaultPrivilegeGranteeValue +} + +// DefaultPrivilegeGranteeValue stores the privileges granted to a specific role within a default ACL. +type DefaultPrivilegeGranteeValue struct { + Grantee RoleID + Privileges map[Privilege]map[GrantedPrivilege]bool +} + +// NewDefaultPrivileges returns a new *DefaultPrivileges. +func NewDefaultPrivileges() *DefaultPrivileges { + return &DefaultPrivileges{make(map[DefaultPrivilegeKey]DefaultPrivilegeValue)} +} + +// AddDefaultPrivilege adds a default privilege entry to the global database. +func AddDefaultPrivilege(key DefaultPrivilegeKey, grantee RoleID, privilege GrantedPrivilege, withGrantOption bool) { + dpv, ok := globalDatabase.defaultPrivileges.Data[key] + if !ok { + dpv = DefaultPrivilegeValue{ + Key: key, + Grantees: make(map[RoleID]DefaultPrivilegeGranteeValue), + } + } + granteeValue, ok := dpv.Grantees[grantee] + if !ok { + granteeValue = DefaultPrivilegeGranteeValue{ + Grantee: grantee, + Privileges: make(map[Privilege]map[GrantedPrivilege]bool), + } + } + privilegeMap, ok := granteeValue.Privileges[privilege.Privilege] + if !ok { + privilegeMap = make(map[GrantedPrivilege]bool) + granteeValue.Privileges[privilege.Privilege] = privilegeMap + } + privilegeMap[privilege] = withGrantOption + dpv.Grantees[grantee] = granteeValue + globalDatabase.defaultPrivileges.Data[key] = dpv +} + +// RemoveDefaultPrivilege removes a default privilege entry from the global database. +// If grantOptionOnly is true, only the WITH GRANT OPTION flag is revoked. +func RemoveDefaultPrivilege(key DefaultPrivilegeKey, grantee RoleID, privilege GrantedPrivilege, grantOptionOnly bool) { + dpv, ok := globalDatabase.defaultPrivileges.Data[key] + if !ok { + return + } + granteeValue, ok := dpv.Grantees[grantee] + if !ok { + return + } + privilegeMap, ok := granteeValue.Privileges[privilege.Privilege] + if !ok { + return + } + if grantOptionOnly { + if privilege.GrantedBy.IsValid() { + if _, ok = privilegeMap[privilege]; ok { + privilegeMap[privilege] = false + } + } else { + for k := range privilegeMap { + privilegeMap[k] = false + } + } + } else { + if privilege.GrantedBy.IsValid() { + delete(privilegeMap, privilege) + } else { + clear(privilegeMap) + } + if len(privilegeMap) == 0 { + delete(granteeValue.Privileges, privilege.Privilege) + } + } + if len(granteeValue.Privileges) == 0 { + delete(dpv.Grantees, grantee) + } else { + dpv.Grantees[grantee] = granteeValue + } + if len(dpv.Grantees) == 0 { + delete(globalDatabase.defaultPrivileges.Data, key) + } else { + globalDatabase.defaultPrivileges.Data[key] = dpv + } +} + +// GetAllDefaultPrivileges returns all default privilege entries. +func GetAllDefaultPrivileges() []DefaultPrivilegeValue { + result := make([]DefaultPrivilegeValue, 0, len(globalDatabase.defaultPrivileges.Data)) + for _, v := range globalDatabase.defaultPrivileges.Data { + result = append(result, v) + } + return result +} + +// ApplyDefaultPrivilegesForNewTable applies any matching default privileges to a newly created table. +// Must be called under LockWrite. +func ApplyDefaultPrivilegesForNewTable(ownerRoleID RoleID, schemaName, tableName string) { + for key, dpv := range globalDatabase.defaultPrivileges.Data { + if key.OwnerRole != ownerRoleID || key.ObjectType != PrivilegeObject_TABLE { + continue + } + if key.Schema != "" && key.Schema != schemaName { + continue + } + for granteeID, granteeValue := range dpv.Grantees { + for _, privilegeMap := range granteeValue.Privileges { + for grantedPriv, withGrantOption := range privilegeMap { + AddTablePrivilege(TablePrivilegeKey{ + Role: granteeID, + Table: doltdb.TableName{Name: tableName, Schema: schemaName}, + }, grantedPriv, withGrantOption) + } + } + } + } +} + +// ApplyDefaultPrivilegesForNewSequence applies any matching default privileges to a newly created sequence. +// Must be called under LockWrite. +func ApplyDefaultPrivilegesForNewSequence(ownerRoleID RoleID, schemaName, seqName string) { + for key, dpv := range globalDatabase.defaultPrivileges.Data { + if key.OwnerRole != ownerRoleID || key.ObjectType != PrivilegeObject_SEQUENCE { + continue + } + if key.Schema != "" && key.Schema != schemaName { + continue + } + for granteeID, granteeValue := range dpv.Grantees { + for _, privilegeMap := range granteeValue.Privileges { + for grantedPriv, withGrantOption := range privilegeMap { + AddSequencePrivilege(SequencePrivilegeKey{ + Role: granteeID, + Schema: schemaName, + Name: seqName, + }, grantedPriv, withGrantOption) + } + } + } + } +} + +// ApplyDefaultPrivilegesForNewRoutine applies any matching default privileges to a newly created function or procedure. +// Must be called under LockWrite. +func ApplyDefaultPrivilegesForNewRoutine(ownerRoleID RoleID, schemaName, routineName string) { + for key, dpv := range globalDatabase.defaultPrivileges.Data { + if key.OwnerRole != ownerRoleID || key.ObjectType != PrivilegeObject_FUNCTION { + continue + } + if key.Schema != "" && key.Schema != schemaName { + continue + } + for granteeID, granteeValue := range dpv.Grantees { + for _, privilegeMap := range granteeValue.Privileges { + for grantedPriv, withGrantOption := range privilegeMap { + AddRoutinePrivilege(RoutinePrivilegeKey{ + Role: granteeID, + Schema: schemaName, + Name: routineName, + }, grantedPriv, withGrantOption) + } + } + } + } +} + +// DefaultPrivilegeObjTypeChar returns the PostgreSQL pg_default_acl defaclobjtype character for a PrivilegeObject. +func DefaultPrivilegeObjTypeChar(objType PrivilegeObject) string { + switch objType { + case PrivilegeObject_TABLE: + return "r" + case PrivilegeObject_SEQUENCE: + return "S" + case PrivilegeObject_FUNCTION: + return "f" + case PrivilegeObject_TYPE: + return "T" + case PrivilegeObject_SCHEMA: + return "n" + default: + return "?" + } +} + +// serialize writes the DefaultPrivileges to the given writer. +func (dp *DefaultPrivileges) serialize(writer *utils.Writer) { + // Version 2 + // Write the total number of values + writer.Uint64(uint64(len(dp.Data))) + for _, value := range dp.Data { + writer.Uint64(uint64(value.Key.OwnerRole)) + writer.String(value.Key.Schema) + writer.Uint8(uint8(value.Key.ObjectType)) + writer.Uint64(uint64(len(value.Grantees))) + for _, granteeValue := range value.Grantees { + writer.Uint64(uint64(granteeValue.Grantee)) + writer.Uint64(uint64(len(granteeValue.Privileges))) + for priv, privilegeMap := range granteeValue.Privileges { + writer.String(string(priv)) + writer.Uint32(uint32(len(privilegeMap))) + for grantedPrivilege, withGrantOption := range privilegeMap { + writer.Uint64(uint64(grantedPrivilege.GrantedBy)) + writer.Bool(withGrantOption) + } + } + } + } +} + +// deserialize reads the DefaultPrivileges from the given reader. +func (dp *DefaultPrivileges) deserialize(version uint32, reader *utils.Reader) { + dp.Data = make(map[DefaultPrivilegeKey]DefaultPrivilegeValue) + switch version { + case 0: + case 1: + case 2: + dataCount := reader.Uint64() + for i := uint64(0); i < dataCount; i++ { + dpv := DefaultPrivilegeValue{ + Grantees: make(map[RoleID]DefaultPrivilegeGranteeValue), + } + dpv.Key.OwnerRole = RoleID(reader.Uint64()) + dpv.Key.Schema = reader.String() + dpv.Key.ObjectType = PrivilegeObject(reader.Uint8()) + granteeCount := reader.Uint64() + for j := uint64(0); j < granteeCount; j++ { + granteeValue := DefaultPrivilegeGranteeValue{ + Grantee: RoleID(reader.Uint64()), + Privileges: make(map[Privilege]map[GrantedPrivilege]bool), + } + privCount := reader.Uint64() + for k := uint64(0); k < privCount; k++ { + priv := Privilege(reader.String()) + grantedCount := reader.Uint32() + grantedMap := make(map[GrantedPrivilege]bool) + for l := uint32(0); l < grantedCount; l++ { + gp := GrantedPrivilege{ + Privilege: priv, + GrantedBy: RoleID(reader.Uint64()), + } + grantedMap[gp] = reader.Bool() + } + granteeValue.Privileges[priv] = grantedMap + } + dpv.Grantees[granteeValue.Grantee] = granteeValue + } + dp.Data[dpv.Key] = dpv + } + default: + panic("unexpected version in SequencePrivileges") + } +} diff --git a/server/auth/serialization.go b/server/auth/serialization.go index b6700ddd0d..b8c0e034ae 100644 --- a/server/auth/serialization.go +++ b/server/auth/serialization.go @@ -33,7 +33,7 @@ func PersistChanges() error { func (db *Database) serialize() []byte { writer := utils.NewWriter(16384) // Write the version - writer.Uint32(1) + writer.Uint32(2) // Write the roles writer.Uint32(uint32(len(db.rolesByID))) for _, role := range db.rolesByID { @@ -51,6 +51,8 @@ func (db *Database) serialize() []byte { db.routinePrivileges.serialize(writer) // Write the role chain db.roleMembership.serialize(writer) + // Write the default privileges + db.defaultPrivileges.serialize(writer) return writer.Data() } @@ -66,6 +68,8 @@ func (db *Database) deserialize(data []byte) error { return db.deserializeV0(reader) case 1: return db.deserializeV1(reader) + case 2: + return db.deserializeV2(reader) default: return errors.Errorf("Authorization database format %d is not supported, please upgrade Doltgres", version) } @@ -95,6 +99,8 @@ func (db *Database) deserializeV0(reader *utils.Reader) error { db.routinePrivileges.deserialize(0, reader) // Read the role membership db.roleMembership.deserialize(0, reader) + // V0 has no default privileges; initialize empty + db.defaultPrivileges = NewDefaultPrivileges() return nil } @@ -122,5 +128,36 @@ func (db *Database) deserializeV1(reader *utils.Reader) error { db.routinePrivileges.deserialize(1, reader) // Read the role membership db.roleMembership.deserialize(1, reader) + // V1 has no default privileges; initialize empty + db.defaultPrivileges.deserialize(1, reader) + return nil +} + +// deserializeV2 creates a Database from a byte slice. Expects a reader that has already read the version. +func (db *Database) deserializeV2(reader *utils.Reader) error { + // Read the roles + clear(db.rolesByName) + clear(db.rolesByID) + roleCount := reader.Uint32() + for i := uint32(0); i < roleCount; i++ { + r := Role{} + r.deserialize(1, reader) + db.rolesByName[r.Name] = r.id + db.rolesByID[r.id] = r + } + // Read the database privileges + db.databasePrivileges.deserialize(2, reader) + // Read the schema privileges + db.schemaPrivileges.deserialize(2, reader) + // Read the table privileges + db.tablePrivileges.deserialize(2, reader) + // Read the sequence privileges + db.sequencePrivileges.deserialize(2, reader) + // Read the routine privileges + db.routinePrivileges.deserialize(2, reader) + // Read the role membership + db.roleMembership.deserialize(2, reader) + // Read the default privileges + db.defaultPrivileges.deserialize(2, reader) return nil } diff --git a/server/node/alter_default_privileges.go b/server/node/alter_default_privileges.go new file mode 100644 index 0000000000..30d1ebd8ae --- /dev/null +++ b/server/node/alter_default_privileges.go @@ -0,0 +1,168 @@ +// Copyright 2026 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/doltgresql/server/auth" +) + +// AlterDefaultPrivileges handles the ALTER DEFAULT PRIVILEGES statement. +type AlterDefaultPrivileges struct { + OwnerRoles []string + Schemas []string + ObjectType auth.PrivilegeObject + Privileges []auth.Privilege + Grantees []string + Grant bool // false = REVOKE + GrantOption bool + Cascade bool +} + +var _ sql.ExecSourceRel = (*AlterDefaultPrivileges)(nil) +var _ vitess.Injectable = (*AlterDefaultPrivileges)(nil) + +// Children implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) Children() []sql.Node { + return nil +} + +// IsReadOnly implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) IsReadOnly() bool { + return false +} + +// Resolved implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) Resolved() bool { + return true +} + +// RowIter implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { + if n.Cascade { + return nil, errors.New("ALTER DEFAULT PRIVILEGES does not yet support CASCADE") + } + var err error + auth.LockWrite(func() { + err = n.execute(ctx) + if err != nil { + return + } + err = auth.PersistChanges() + }) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +// Schema implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) Schema(ctx *sql.Context) sql.Schema { + return nil +} + +// String implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) String() string { + return "ALTER DEFAULT PRIVILEGES" +} + +// WithChildren implements the interface sql.ExecSourceRel. +func (n *AlterDefaultPrivileges) WithChildren(ctx *sql.Context, children ...sql.Node) (sql.Node, error) { + return plan.NillaryWithChildren(n, children...) +} + +// WithResolvedChildren implements the interface vitess.Injectable. +func (n *AlterDefaultPrivileges) WithResolvedChildren(ctx context.Context, children []any) (any, error) { + if len(children) != 0 { + return nil, ErrVitessChildCount.New(0, len(children)) + } + return n, nil +} + +// execute performs the actual default privilege changes. +func (n *AlterDefaultPrivileges) execute(ctx *sql.Context) error { + ownerRoles, err := n.resolveOwnerRoles(ctx) + if err != nil { + return err + } + + granteeRoles := make([]auth.Role, len(n.Grantees)) + for i, name := range n.Grantees { + role := auth.GetRole(name) + if !role.IsValid() { + return errors.Errorf(`role "%s" does not exist`, name) + } + granteeRoles[i] = role + } + + schemas := n.Schemas + // empty list means all schemas + if len(schemas) == 0 { + // TODO: get all schemas + schemas = []string{""} + } + + for _, ownerRole := range ownerRoles { + for _, schema := range schemas { + key := auth.DefaultPrivilegeKey{ + OwnerRole: ownerRole.ID(), + Schema: schema, + ObjectType: n.ObjectType, + } + for _, granteeRole := range granteeRoles { + for _, priv := range n.Privileges { + grantedPrivilege := auth.GrantedPrivilege{ + Privilege: priv, + GrantedBy: ownerRole.ID(), + } + if n.Grant { + auth.AddDefaultPrivilege(key, granteeRole.ID(), grantedPrivilege, n.GrantOption) + } else { + auth.RemoveDefaultPrivilege(key, granteeRole.ID(), grantedPrivilege, n.GrantOption) + } + } + } + } + } + return nil +} + +// resolveOwnerRoles returns the roles that own the default privileges being modified. +// When no roles are explicitly specified, the current session user is used. +func (n *AlterDefaultPrivileges) resolveOwnerRoles(ctx *sql.Context) ([]auth.Role, error) { + // empty means current user + if len(n.OwnerRoles) == 0 { + userRole := auth.GetRole(ctx.Client().User) + if !userRole.IsValid() { + return nil, errors.Errorf(`role "%s" does not exist`, ctx.Client().User) + } + return []auth.Role{userRole}, nil + } + roles := make([]auth.Role, len(n.OwnerRoles)) + for i, name := range n.OwnerRoles { + role := auth.GetRole(name) + if !role.IsValid() { + return nil, errors.Errorf(`role "%s" does not exist`, name) + } + roles[i] = role + } + return roles, nil +} diff --git a/server/node/create_function.go b/server/node/create_function.go index b639c53969..7b37dac2f3 100644 --- a/server/node/create_function.go +++ b/server/node/create_function.go @@ -28,6 +28,7 @@ import ( "github.com/dolthub/doltgresql/core/functions" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/procedures" + "github.com/dolthub/doltgresql/server/auth" "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -167,6 +168,17 @@ func (c *CreateFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, erro if err != nil { return nil, err } + var authErr error + auth.LockWrite(func() { + ownerRole := auth.GetRole(ctx.Client().User) + if ownerRole.IsValid() { + auth.ApplyDefaultPrivilegesForNewRoutine(ownerRole.ID(), schemaName, c.FunctionName) + } + authErr = auth.PersistChanges() + }) + if authErr != nil { + return nil, authErr + } return sql.RowsToRowIter(), nil } diff --git a/server/node/create_procedure.go b/server/node/create_procedure.go index d1633ef82a..fbae74e647 100644 --- a/server/node/create_procedure.go +++ b/server/node/create_procedure.go @@ -26,6 +26,7 @@ import ( "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/procedures" + "github.com/dolthub/doltgresql/server/auth" "github.com/dolthub/doltgresql/server/plpgsql" ) @@ -144,6 +145,17 @@ func (c *CreateProcedure) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, err if err != nil { return nil, err } + var authErr error + auth.LockWrite(func() { + ownerRole := auth.GetRole(ctx.Client().User) + if ownerRole.IsValid() { + auth.ApplyDefaultPrivilegesForNewRoutine(ownerRole.ID(), schemaName, c.ProcedureName) + } + authErr = auth.PersistChanges() + }) + if authErr != nil { + return nil, authErr + } return sql.RowsToRowIter(), nil } diff --git a/server/node/create_sequence.go b/server/node/create_sequence.go index 7258d827c3..16beb344f8 100644 --- a/server/node/create_sequence.go +++ b/server/node/create_sequence.go @@ -29,6 +29,7 @@ import ( "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/sequences" + "github.com/dolthub/doltgresql/server/auth" pgexprs "github.com/dolthub/doltgresql/server/expression" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -208,6 +209,17 @@ func (c *CreateSequence) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, erro return nil, err } } + var authErr error + auth.LockWrite(func() { + ownerRole := auth.GetRole(ctx.Client().User) + if ownerRole.IsValid() { + auth.ApplyDefaultPrivilegesForNewSequence(ownerRole.ID(), c.sequence.Id.SchemaName(), c.sequence.Id.SequenceName()) + } + authErr = auth.PersistChanges() + }) + if authErr != nil { + return nil, authErr + } return sql.RowsToRowIter(), nil } diff --git a/server/node/create_table.go b/server/node/create_table.go index 3be0f7ff40..4743e10815 100644 --- a/server/node/create_table.go +++ b/server/node/create_table.go @@ -16,12 +16,14 @@ package node import ( "fmt" + "io" "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/server/auth" ) // CreateTable is a node that implements functionality specifically relevant to Doltgres' table creation needs. @@ -93,7 +95,46 @@ func (c *CreateTable) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder, r sq return nil, err } } - return createTableIter, err + + ownerRole := auth.GetRole(ctx.Client().User) + if ownerRole.IsValid() { + return &createTableDefaultPrivsIter{ + inner: createTableIter, + ownerID: ownerRole.ID(), + schemaName: schemaName, + tableName: c.gmsCreateTable.Name(), + }, nil + } + return createTableIter, nil +} + +// createTableDefaultPrivsIter wraps the create table iter to apply default privileges after creation. +type createTableDefaultPrivsIter struct { + inner sql.RowIter + ownerID auth.RoleID + schemaName string + tableName string + applied bool +} + +func (i *createTableDefaultPrivsIter) Next(ctx *sql.Context) (sql.Row, error) { + row, err := i.inner.Next(ctx) + if err == io.EOF && !i.applied { + i.applied = true + var applyErr error + auth.LockWrite(func() { + auth.ApplyDefaultPrivilegesForNewTable(i.ownerID, i.schemaName, i.tableName) + applyErr = auth.PersistChanges() + }) + if applyErr != nil { + return nil, applyErr + } + } + return row, err +} + +func (i *createTableDefaultPrivsIter) Close(ctx *sql.Context) error { + return i.inner.Close(ctx) } // Schema implements the interface sql.ExecBuilderNode. diff --git a/server/tables/pgcatalog/pg_default_acl.go b/server/tables/pgcatalog/pg_default_acl.go index ed7fd78088..64e85312da 100644 --- a/server/tables/pgcatalog/pg_default_acl.go +++ b/server/tables/pgcatalog/pg_default_acl.go @@ -15,10 +15,12 @@ package pgcatalog import ( + "fmt" "io" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/server/auth" "github.com/dolthub/doltgresql/server/tables" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -43,8 +45,38 @@ func (p PgDefaultAclHandler) Name() string { // RowIter implements the interface tables.Handler. func (p PgDefaultAclHandler) RowIter(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { - // TODO: Implement pg_default_acl row iter - return emptyRowIter() + var rows []sql.Row + auth.LockRead(func() { + entries := auth.GetAllDefaultPrivileges() + for oid, entry := range entries { + // Build the aclitem array: each element is "grantee=privs/grantor" + var aclItems []interface{} + for _, granteeValue := range entry.Grantees { + granteeName := auth.GetRoleName(granteeValue.Grantee) + for priv, grantedMap := range granteeValue.Privileges { + for grantedPriv, withGrantOption := range grantedMap { + grantorName := auth.GetRoleName(grantedPriv.GrantedBy) + privStr := string(priv) + if withGrantOption { + privStr += "*" + } + aclItems = append(aclItems, fmt.Sprintf("%s=%s/%s", granteeName, privStr, grantorName)) + } + } + } + // Namespace OID: 0 means any schema; non-zero would need a real OID lookup + namespaceOid := uint32(0) + _ = namespaceOid + rows = append(rows, sql.Row{ + uint32(oid + 1), // oid (synthetic, 1-based index) + uint32(entry.Key.OwnerRole), + uint32(0), // defaclnamespace: 0 = any schema (schema OID lookup not yet implemented) + auth.DefaultPrivilegeObjTypeChar(entry.Key.ObjectType), + aclItems, + }) + } + }) + return sql.RowsToRowIter(rows...), nil } // Schema implements the interface tables.Handler. diff --git a/testing/go/auth_test.go b/testing/go/auth_test.go index d03a873759..90d6b2470b 100644 --- a/testing/go/auth_test.go +++ b/testing/go/auth_test.go @@ -38,6 +38,373 @@ var ( authTestCreateBasicUser = fmt.Sprintf("create user if not exists '%s' with password '%s'", authTestBasicUser, authTestBasicPass) ) +func TestAlterDefaultPrivileges(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "ALTER DEFAULT PRIVILEGES basic grant", + SetUpScript: []string{ + "CREATE ROLE testrole;", + "CREATE ROLE grantee1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE testrole GRANT SELECT ON TABLES TO grantee1;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{1}}, + }, + { + Query: "SELECT defaclobjtype FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{"r"}}, + }, + }, + }, + { + Name: "ALTER DEFAULT PRIVILEGES FOR ROLE multiple privileges", + SetUpScript: []string{ + "CREATE ROLE owner1;", + "CREATE ROLE recipient1;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE owner1 GRANT SELECT, INSERT ON TABLES TO recipient1;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{1}}, + }, + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE owner1 GRANT USAGE ON SEQUENCES TO recipient1;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{2}}, + }, + }, + }, + { + Name: "ALTER DEFAULT PRIVILEGES FOR USER (synonym for FOR ROLE)", + SetUpScript: []string{ + "CREATE ROLE userrole;", + "CREATE ROLE grantee2;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES FOR USER userrole GRANT EXECUTE ON FUNCTIONS TO grantee2;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT defaclobjtype FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{"f"}}, + }, + }, + }, + { + Name: "ALTER DEFAULT PRIVILEGES without FOR ROLE uses current user", + SetUpScript: []string{ + "CREATE ROLE grantee3;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES GRANT SELECT ON TABLES TO grantee3;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{1}}, + }, + }, + }, + { + Name: "ALTER DEFAULT PRIVILEGES REVOKE removes the entry", + SetUpScript: []string{ + "CREATE ROLE ownerrole;", + "CREATE ROLE granteex;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE ownerrole GRANT SELECT ON TABLES TO granteex;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{1}}, + }, + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE ownerrole REVOKE SELECT ON TABLES FROM granteex;", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM pg_catalog.pg_default_acl;", + Expected: []sql.Row{{0}}, + }, + }, + }, + { + Name: "ALTER DEFAULT PRIVILEGES nonexistent role and grantee returns error", + SetUpScript: []string{ + "CREATE ROLE ownerrole2;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE no_such_role GRANT SELECT ON TABLES TO postgres;", + ExpectedErr: `role "no_such_role" does not exist`, + }, + { + Query: "ALTER DEFAULT PRIVILEGES FOR ROLE ownerrole2 GRANT SELECT ON TABLES TO no_such_grantee;", + ExpectedErr: `role "no_such_grantee" does not exist`, + }, + }, + }, + { + Name: `ALTER DEFAULT PRIVILEGES`, + SetUpScript: []string{ + authTestCreateSuperUser, + `CREATE USER readonly_user PASSWORD 'a';`, + `GRANT USAGE ON SCHEMA public TO readonly_user;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE TABLE test (pk INT4 PRIMARY KEY);`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM test;`, + Username: `readonly_user`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `ALTER DEFAULT PRIVILEGES FOR USER auth_test_super IN SCHEMA public GRANT SELECT ON TABLES TO readonly_user;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM test;`, + Username: `readonly_user`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `CREATE TABLE another_table (pk INT4 PRIMARY KEY);`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM another_table;`, + Username: `readonly_user`, + Password: `a`, + Expected: []sql.Row{}, + }, + + { + Query: `create table user_table (i int);`, + Username: `readonly_user`, + Password: `a`, + ExpectedErr: `denied`, + }, + }, + }, + { + Name: `ALTER DEFAULT PRIVILEGES applies to new sequences`, + SetUpScript: []string{ + authTestCreateSuperUser, + `CREATE USER seq_reader PASSWORD 'a';`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE SEQUENCE old_seq START WITH 1;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT nextval('old_seq');`, + Username: `seq_reader`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `ALTER DEFAULT PRIVILEGES FOR USER auth_test_super IN SCHEMA public GRANT USAGE ON SEQUENCES TO seq_reader;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT nextval('old_seq');`, + Username: `seq_reader`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `CREATE SEQUENCE new_seq START WITH 10;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT nextval('new_seq');`, + Username: `seq_reader`, + Password: `a`, + Expected: []sql.Row{{10}}, + }, + }, + }, + { + Name: `ALTER DEFAULT PRIVILEGES applies to new functions`, + SetUpScript: []string{ + authTestCreateSuperUser, + `CREATE USER func_reader PASSWORD 'a';`, + `GRANT USAGE ON SCHEMA public TO func_reader;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE FUNCTION old_func() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT old_func();`, + Username: `func_reader`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `ALTER DEFAULT PRIVILEGES FOR USER auth_test_super IN SCHEMA public GRANT EXECUTE ON FUNCTIONS TO func_reader;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT old_func();`, + Username: `func_reader`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `CREATE FUNCTION new_func() RETURNS int AS $$ BEGIN RETURN 42; END; $$ LANGUAGE plpgsql;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT new_func();`, + Username: `func_reader`, + Password: `a`, + Expected: []sql.Row{{42}}, + }, + }, + }, + { + Name: `ALTER DEFAULT PRIVILEGES FOR ROLE`, + SetUpScript: []string{ + authTestCreateSuperUser, + `create user another_super with superuser password 'another';`, + `CREATE USER user1 PASSWORD 'a';`, + `CREATE TABLE test (pk INT4 PRIMARY KEY);`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT * FROM test;`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + // It only applies to tables created after this command is executed. + Query: `ALTER DEFAULT PRIVILEGES FOR ROLE auth_test_super IN SCHEMA public GRANT SELECT ON TABLES TO user1;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM test;`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `CREATE TABLE new_table (pk INT4 PRIMARY KEY);`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM new_table;`, + Username: `user1`, + Password: `a`, + Expected: []sql.Row{}, + }, + { + Query: `CREATE TABLE by_another (pk INT4 PRIMARY KEY);`, + Username: `another_super`, + Password: `another`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM by_another;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + // cannot select from tables created by `another` user + Query: `SELECT * FROM by_another;`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `INSERT INTO test VALUES (1), (5), (6);`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + // It only applies to tables created after this command is executed. + Query: `ALTER DEFAULT PRIVILEGES FOR ROLE auth_test_super IN SCHEMA public GRANT INSERT ON TABLES TO user1;`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO test VALUES (1), (5), (6);`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, + { + Query: `CREATE TABLE different_test (pk INT4 PRIMARY KEY);`, + Username: authTestSuperUser, + Password: authTestSuperPass, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO different_test VALUES (1), (5), (6);`, + Username: `user1`, + Password: `a`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM different_test;`, + Username: `user1`, + Password: `a`, + Expected: []sql.Row{{1}, {5}, {6}}, + }, + }, + }, + }) +} + func TestAuthTests(t *testing.T) { RunScripts(t, []ScriptTest{ { diff --git a/testing/go/import_dumps_test.go b/testing/go/import_dumps_test.go index 3d32028b00..47505b38d4 100644 --- a/testing/go/import_dumps_test.go +++ b/testing/go/import_dumps_test.go @@ -101,6 +101,7 @@ func TestImportingDumps(t *testing.T) { }, { SetUpScript: []string{ + "CREATE USER crisisresolver WITH SUPERUSER PASSWORD 'password';", "CREATE USER crisisresolver_visitor WITH SUPERUSER PASSWORD 'password';", }, Name: "blacktscoder/CrisisSolver", @@ -352,6 +353,7 @@ func TestImportingDumps(t *testing.T) { { SetUpScript: []string{ `CREATE USER neondb_owner WITH SUPERUSER PASSWORD 'password';`, + `CREATE USER cloud_admin WITH SUPERUSER PASSWORD 'password';`, }, Name: "mvnp/start-dashboard-v3-backend", SQLFilename: "mvnp_start-dashboard-v3-backend.sql",