Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions internal/migration_acceptance_tests/policy_cases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,108 @@ var policyAcceptanceTestCases = []acceptanceTestCase{
},
expectedHazardTypes: nil,
},
{
// Test case for policy-to-function dependency: adding both a function and a policy
// that uses that function. The function must be created before the policy.
// NOTE: We use a simple function that doesn't reference tables to avoid
// function→table dependency ordering issues (which is a separate feature).
name: "Add policy referencing new function (policy-to-function dependency)",
oldSchemaDDL: []string{
`
CREATE TABLE items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID
);
`,
},
newSchemaDDL: []string{
`
CREATE TABLE items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID
);
-- Simple function that doesn't reference any table
CREATE FUNCTION is_valid_owner(owner_id UUID) RETURNS BOOLEAN AS $$
SELECT owner_id IS NOT NULL;
$$ LANGUAGE SQL IMMUTABLE;
ALTER TABLE items ENABLE ROW LEVEL SECURITY;
CREATE POLICY owner_policy ON items
FOR ALL
USING (is_valid_owner(owner_id));
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
// Test case for policy-to-function dependency in non-public schema
name: "Add policy referencing new function in non-public schema",
roles: []string{
"authenticated",
},
oldSchemaDDL: []string{
`
CREATE SCHEMA app;
CREATE TABLE app.items (
id UUID NOT NULL,
priority INT NOT NULL
);
`,
},
newSchemaDDL: []string{
`
CREATE SCHEMA app;
CREATE TABLE app.items (
id UUID NOT NULL,
priority INT NOT NULL
);
-- Simple function that doesn't reference any table
CREATE FUNCTION app.is_high_priority(p INT) RETURNS BOOLEAN AS $$
SELECT p > 5;
$$ LANGUAGE SQL IMMUTABLE;
ALTER TABLE app.items ENABLE ROW LEVEL SECURITY;
CREATE POLICY priority_policy ON app.items
FOR SELECT
TO authenticated
USING (app.is_high_priority(priority));
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
// Test case for dropping both a policy and function it references.
// The policy must be dropped before the function.
name: "Drop policy and function it references",
oldSchemaDDL: []string{
`
CREATE TABLE items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID
);
CREATE FUNCTION is_valid_owner(owner_id UUID) RETURNS BOOLEAN AS $$
SELECT owner_id IS NOT NULL;
$$ LANGUAGE SQL IMMUTABLE;
ALTER TABLE items ENABLE ROW LEVEL SECURITY;
CREATE POLICY owner_policy ON items
FOR ALL
USING (is_valid_owner(owner_id));
`,
},
newSchemaDDL: []string{
`
CREATE TABLE items (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_id UUID
);
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
}

func TestPolicyCases(t *testing.T) {
Expand Down
23 changes: 22 additions & 1 deletion internal/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,28 @@ SELECT
AND d.refclassid = 'pg_class'::REGCLASS
AND dep_c.oid != table_c.oid
AND dep_ns.nspname NOT IN ('pg_catalog', 'information_schema')
)::TEXT[] AS table_dependencies
)::TEXT[] AS table_dependencies,
-- Function dependencies: functions referenced in USING/CHECK expressions.
-- This is needed for correct statement ordering when a policy calls
-- user-defined functions in its expressions.
(SELECT
ARRAY_AGG(DISTINCT JSONB_BUILD_OBJECT(
'schema', proc_ns.nspname,
'name', proc.proname,
'identity_arguments', pg_catalog.pg_get_function_identity_arguments(proc.oid)
))
FROM pg_catalog.pg_depend AS d
INNER JOIN pg_catalog.pg_proc AS proc
ON d.refobjid = proc.oid
INNER JOIN pg_catalog.pg_namespace AS proc_ns
ON proc.pronamespace = proc_ns.oid
WHERE
d.objid = pol.oid
AND d.classid = 'pg_policy'::REGCLASS
AND d.refclassid = 'pg_proc'::REGCLASS
AND d.deptype = 'n'
AND proc_ns.nspname NOT IN ('pg_catalog', 'information_schema')
)::TEXT[] AS function_dependencies
FROM pg_catalog.pg_policy AS pol
INNER JOIN pg_catalog.pg_class AS table_c ON pol.polrelid = table_c.oid
INNER JOIN
Expand Down
25 changes: 24 additions & 1 deletion internal/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 36 additions & 8 deletions internal/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func normalizeTable(t Table) Table {
p.Columns = sortByKey(p.Columns, func(s string) string {
return s
})
p.FunctionDependencies = sortSchemaObjectsByName(p.FunctionDependencies)
normPolicies = append(normPolicies, p)
}
t.Policies = normPolicies
Expand Down Expand Up @@ -500,6 +501,10 @@ type Policy struct {
// references in its USING/CHECK expressions. This is used for correct
// statement ordering when a policy references columns from other tables.
TableDependencies []TableDependency
// FunctionDependencies are functions that the policy references in its
// USING/CHECK expressions. This is used for correct statement ordering
// when a policy calls user-defined functions in its expressions.
FunctionDependencies []SchemaQualifiedName
}

func (p Policy) GetName() string {
Expand Down Expand Up @@ -1397,16 +1402,21 @@ func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, er
if err != nil {
return nil, fmt.Errorf("parsing table dependencies for policy %s: %w", rp.PolicyName, err)
}
functionDependencies, err := parseJSONFunctionDependencies(rp.FunctionDependencies)
if err != nil {
return nil, fmt.Errorf("parsing function dependencies for policy %s: %w", rp.PolicyName, err)
}
policies = append(policies, policyAndTable{
policy: Policy{
EscapedName: EscapeIdentifier(rp.PolicyName),
IsPermissive: rp.IsPermissive,
AppliesTo: rp.AppliesTo,
Cmd: PolicyCmd(rp.Cmd),
CheckExpression: rp.CheckExpression,
UsingExpression: rp.UsingExpression,
Columns: rp.ColumnNames,
TableDependencies: tableDependencies,
EscapedName: EscapeIdentifier(rp.PolicyName),
IsPermissive: rp.IsPermissive,
AppliesTo: rp.AppliesTo,
Cmd: PolicyCmd(rp.Cmd),
CheckExpression: rp.CheckExpression,
UsingExpression: rp.UsingExpression,
Columns: rp.ColumnNames,
TableDependencies: tableDependencies,
FunctionDependencies: functionDependencies,
},
table: buildNameFromUnescaped(rp.OwningTableName, rp.OwningTableSchemaName),
})
Expand Down Expand Up @@ -1592,6 +1602,24 @@ func parseJSONTableDependencies(vals []string) ([]TableDependency, error) {
return out, nil
}

// parseJSONFunctionDependencies takes a slice of JSON values with schema,
// `schema: string; name: string; identity_arguments: string` and unmarshals them into SchemaQualifiedName.
func parseJSONFunctionDependencies(vals []string) ([]SchemaQualifiedName, error) {
var out []SchemaQualifiedName
for _, v := range vals {
var s struct {
Schema string `json:"schema"`
Name string `json:"name"`
IdentityArguments string `json:"identity_arguments"`
}
if err := json.Unmarshal([]byte(v), &s); err != nil {
return nil, fmt.Errorf("json.Unmarshal(%q, function dependency): %w", string(v), err)
}
out = append(out, buildProcName(s.Name, s.IdentityArguments, s.Schema))
}
return out, nil
}

// buildProcName is used to build the schema qualified name for a proc (function, procedure), i.e., anything
// identified by a name AND its arguments.
func buildProcName(name, identityArguments, schemaName string) SchemaQualifiedName {
Expand Down
60 changes: 59 additions & 1 deletion internal/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ var (
GRANT SELECT ON schema_2.foo TO some_role_1;
GRANT INSERT ON schema_2.foo TO some_role_2 WITH GRANT OPTION;
`},
expectedHash: "43388964f7bede0",
expectedHash: "69342de97351df15",
expectedSchema: Schema{
NamedSchemas: []NamedSchema{
{Name: "public"},
Expand Down Expand Up @@ -1384,6 +1384,64 @@ var (
},
},
},
{
name: "Policy-to-function dependency",
ddl: []string{`
CREATE TABLE user_roles (
user_id UUID NOT NULL,
role TEXT NOT NULL
);

CREATE FUNCTION has_role(uid UUID, required_role TEXT) RETURNS BOOLEAN AS $$
SELECT EXISTS (
SELECT 1 FROM user_roles WHERE user_id = uid AND role = required_role
);
$$ LANGUAGE SQL STABLE SECURITY DEFINER;

ALTER TABLE user_roles ENABLE ROW LEVEL SECURITY;

CREATE POLICY admin_view_policy ON user_roles
FOR SELECT
TO PUBLIC
USING (has_role(user_id, 'admin'));
`},
expectedSchema: Schema{
NamedSchemas: []NamedSchema{
{Name: "public"},
},
Tables: []Table{
{
SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"user_roles\""},
Columns: []Column{
{Name: "user_id", Type: "uuid", Size: 16},
{Name: "role", Type: "text", Size: -1, Collation: defaultCollation},
},
Policies: []Policy{
{
EscapedName: "\"admin_view_policy\"",
IsPermissive: true,
AppliesTo: []string{"PUBLIC"},
Cmd: SelectPolicyCmd,
UsingExpression: "has_role(user_id, 'admin'::text)",
Columns: []string{"user_id"},
FunctionDependencies: []SchemaQualifiedName{
{SchemaName: "public", EscapedName: "\"has_role\"(uid uuid, required_role text)"},
},
},
},
ReplicaIdentity: ReplicaIdentityDefault,
RLSEnabled: true,
},
},
Functions: []Function{
{
SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"has_role\"(uid uuid, required_role text)"},
FunctionDef: "CREATE OR REPLACE FUNCTION public.has_role(uid uuid, required_role text)\n RETURNS boolean\n LANGUAGE sql\n STABLE SECURITY DEFINER\nAS $function$\n\t\t\t\tSELECT EXISTS (\n\t\t\t\t\tSELECT 1 FROM user_roles WHERE user_id = uid AND role = required_role\n\t\t\t\t);\n\t\t\t$function$\n",
Language: "sql",
},
},
},
},
}
)

Expand Down
19 changes: 17 additions & 2 deletions pkg/diff/policy_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,11 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error)
oldCopy.CheckExpression = diff.new.CheckExpression
}
oldCopy.Columns = diff.new.Columns
// TableDependencies is a derived field based on the USING/CHECK expressions, so when
// expressions change the dependencies naturally update. No special handling needed.
// TableDependencies and FunctionDependencies are derived fields based on the
// USING/CHECK expressions, so when expressions change the dependencies naturally
// update. No special handling needed.
oldCopy.TableDependencies = diff.new.TableDependencies
oldCopy.FunctionDependencies = diff.new.FunctionDependencies

if diff := cmp.Diff(oldCopy, diff.new); diff != "" {
return nil, fmt.Errorf("unsupported diff %s: %w", diff, ErrNotImplemented)
Expand Down Expand Up @@ -295,6 +297,13 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter)))
}

// Run after any dependent functions are added/altered.
// This handles policy-to-function dependencies where the policy's USING/CHECK
// expression calls user-defined functions.
for _, f := range newPolicy.FunctionDependencies {
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildFunctionVertexId(f, diffTypeAddAlter)))
}

if !cmp.Equal(oldPolicy, schema.Policy{}) {
// Run before the old columns are deleted (if they are deleted)
oldTargetColumns, err := getTargetColumns(oldPolicy.Columns, psg.oldSchemaColumnsByName)
Expand Down Expand Up @@ -332,5 +341,11 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter)))
}

// The policy needs to be deleted before any dependent functions are deleted.
// This handles policy-to-function dependencies.
for _, f := range pol.FunctionDependencies {
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildFunctionVertexId(f, diffTypeDelete)))
}

return deps, nil
}