From f2c49b480d4d51e9a3b40005c3c5176e90de0d91 Mon Sep 17 00:00:00 2001 From: vshulev Date: Thu, 22 Jan 2026 11:15:46 +0100 Subject: [PATCH] Add policy-to-function dependency tracking This adds support for tracking dependencies between RLS policies and user-defined functions that they reference in USING/CHECK expressions. Previously, pg-schema-diff could create policies before the functions they depend on, causing migration failures when a policy's USING or CHECK expression calls a user-defined function. Changes: - Query pg_depend to extract function dependencies from policies - Add FunctionDependencies field to the Policy struct - Update dependency graph to order function creation before policies - Update dependency graph to order policy deletion before function deletion This follows the same pattern used for cross-table policy dependencies (issue #266). Co-Authored-By: Claude Opus 4.5 --- .../policy_cases_test.go | 102 ++++++++++++++++++ internal/queries/queries.sql | 23 +++- internal/queries/queries.sql.go | 25 ++++- internal/schema/schema.go | 44 ++++++-- internal/schema/schema_test.go | 60 ++++++++++- pkg/diff/policy_sql_generator.go | 19 +++- 6 files changed, 260 insertions(+), 13 deletions(-) diff --git a/internal/migration_acceptance_tests/policy_cases_test.go b/internal/migration_acceptance_tests/policy_cases_test.go index f2fd2a3..39db062 100644 --- a/internal/migration_acceptance_tests/policy_cases_test.go +++ b/internal/migration_acceptance_tests/policy_cases_test.go @@ -770,6 +770,108 @@ var policyAcceptanceTestCases = []acceptanceTestCase{ }, expectedHazardTypes: nil, }, + { + // Test case for policy-to-function dependency: adding both a function and a policy + // that uses that function. The function must be created before the policy. + // NOTE: We use a simple function that doesn't reference tables to avoid + // function→table dependency ordering issues (which is a separate feature). + name: "Add policy referencing new function (policy-to-function dependency)", + oldSchemaDDL: []string{ + ` + CREATE TABLE items ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE items ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID + ); + -- Simple function that doesn't reference any table + CREATE FUNCTION is_valid_owner(owner_id UUID) RETURNS BOOLEAN AS $$ + SELECT owner_id IS NOT NULL; + $$ LANGUAGE SQL IMMUTABLE; + ALTER TABLE items ENABLE ROW LEVEL SECURITY; + CREATE POLICY owner_policy ON items + FOR ALL + USING (is_valid_owner(owner_id)); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + // Test case for policy-to-function dependency in non-public schema + name: "Add policy referencing new function in non-public schema", + roles: []string{ + "authenticated", + }, + oldSchemaDDL: []string{ + ` + CREATE SCHEMA app; + CREATE TABLE app.items ( + id UUID NOT NULL, + priority INT NOT NULL + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA app; + CREATE TABLE app.items ( + id UUID NOT NULL, + priority INT NOT NULL + ); + -- Simple function that doesn't reference any table + CREATE FUNCTION app.is_high_priority(p INT) RETURNS BOOLEAN AS $$ + SELECT p > 5; + $$ LANGUAGE SQL IMMUTABLE; + ALTER TABLE app.items ENABLE ROW LEVEL SECURITY; + CREATE POLICY priority_policy ON app.items + FOR SELECT + TO authenticated + USING (app.is_high_priority(priority)); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + // Test case for dropping both a policy and function it references. + // The policy must be dropped before the function. + name: "Drop policy and function it references", + oldSchemaDDL: []string{ + ` + CREATE TABLE items ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID + ); + CREATE FUNCTION is_valid_owner(owner_id UUID) RETURNS BOOLEAN AS $$ + SELECT owner_id IS NOT NULL; + $$ LANGUAGE SQL IMMUTABLE; + ALTER TABLE items ENABLE ROW LEVEL SECURITY; + CREATE POLICY owner_policy ON items + FOR ALL + USING (is_valid_owner(owner_id)); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE items ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, } func TestPolicyCases(t *testing.T) { diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 63c4daf..d401c5e 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -509,7 +509,28 @@ SELECT AND d.refclassid = 'pg_class'::REGCLASS AND dep_c.oid != table_c.oid AND dep_ns.nspname NOT IN ('pg_catalog', 'information_schema') - )::TEXT[] AS table_dependencies + )::TEXT[] AS table_dependencies, + -- Function dependencies: functions referenced in USING/CHECK expressions. + -- This is needed for correct statement ordering when a policy calls + -- user-defined functions in its expressions. + (SELECT + ARRAY_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'schema', proc_ns.nspname, + 'name', proc.proname, + 'identity_arguments', pg_catalog.pg_get_function_identity_arguments(proc.oid) + )) + FROM pg_catalog.pg_depend AS d + INNER JOIN pg_catalog.pg_proc AS proc + ON d.refobjid = proc.oid + INNER JOIN pg_catalog.pg_namespace AS proc_ns + ON proc.pronamespace = proc_ns.oid + WHERE + d.objid = pol.oid + AND d.classid = 'pg_policy'::REGCLASS + AND d.refclassid = 'pg_proc'::REGCLASS + AND d.deptype = 'n' + AND proc_ns.nspname NOT IN ('pg_catalog', 'information_schema') + )::TEXT[] AS function_dependencies FROM pg_catalog.pg_policy AS pol INNER JOIN pg_catalog.pg_class AS table_c ON pol.polrelid = table_c.oid INNER JOIN diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index c56e638..08f4449 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -770,7 +770,28 @@ SELECT AND d.refclassid = 'pg_class'::REGCLASS AND dep_c.oid != table_c.oid AND dep_ns.nspname NOT IN ('pg_catalog', 'information_schema') - )::TEXT[] AS table_dependencies + )::TEXT[] AS table_dependencies, + -- Function dependencies: functions referenced in USING/CHECK expressions. + -- This is needed for correct statement ordering when a policy calls + -- user-defined functions in its expressions. + (SELECT + ARRAY_AGG(DISTINCT JSONB_BUILD_OBJECT( + 'schema', proc_ns.nspname, + 'name', proc.proname, + 'identity_arguments', pg_catalog.pg_get_function_identity_arguments(proc.oid) + )) + FROM pg_catalog.pg_depend AS d + INNER JOIN pg_catalog.pg_proc AS proc + ON d.refobjid = proc.oid + INNER JOIN pg_catalog.pg_namespace AS proc_ns + ON proc.pronamespace = proc_ns.oid + WHERE + d.objid = pol.oid + AND d.classid = 'pg_policy'::REGCLASS + AND d.refclassid = 'pg_proc'::REGCLASS + AND d.deptype = 'n' + AND proc_ns.nspname NOT IN ('pg_catalog', 'information_schema') + )::TEXT[] AS function_dependencies FROM pg_catalog.pg_policy AS pol INNER JOIN pg_catalog.pg_class AS table_c ON pol.polrelid = table_c.oid INNER JOIN @@ -793,6 +814,7 @@ type GetPoliciesRow struct { UsingExpression string ColumnNames []string TableDependencies []string + FunctionDependencies []string } func (q *Queries) GetPolicies(ctx context.Context) ([]GetPoliciesRow, error) { @@ -815,6 +837,7 @@ func (q *Queries) GetPolicies(ctx context.Context) ([]GetPoliciesRow, error) { &i.UsingExpression, pq.Array(&i.ColumnNames), pq.Array(&i.TableDependencies), + pq.Array(&i.FunctionDependencies), ); err != nil { return nil, err } diff --git a/internal/schema/schema.go b/internal/schema/schema.go index bc38431..6dffb7e 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -129,6 +129,7 @@ func normalizeTable(t Table) Table { p.Columns = sortByKey(p.Columns, func(s string) string { return s }) + p.FunctionDependencies = sortSchemaObjectsByName(p.FunctionDependencies) normPolicies = append(normPolicies, p) } t.Policies = normPolicies @@ -500,6 +501,10 @@ type Policy struct { // references in its USING/CHECK expressions. This is used for correct // statement ordering when a policy references columns from other tables. TableDependencies []TableDependency + // FunctionDependencies are functions that the policy references in its + // USING/CHECK expressions. This is used for correct statement ordering + // when a policy calls user-defined functions in its expressions. + FunctionDependencies []SchemaQualifiedName } func (p Policy) GetName() string { @@ -1397,16 +1402,21 @@ func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, er if err != nil { return nil, fmt.Errorf("parsing table dependencies for policy %s: %w", rp.PolicyName, err) } + functionDependencies, err := parseJSONFunctionDependencies(rp.FunctionDependencies) + if err != nil { + return nil, fmt.Errorf("parsing function dependencies for policy %s: %w", rp.PolicyName, err) + } policies = append(policies, policyAndTable{ policy: Policy{ - EscapedName: EscapeIdentifier(rp.PolicyName), - IsPermissive: rp.IsPermissive, - AppliesTo: rp.AppliesTo, - Cmd: PolicyCmd(rp.Cmd), - CheckExpression: rp.CheckExpression, - UsingExpression: rp.UsingExpression, - Columns: rp.ColumnNames, - TableDependencies: tableDependencies, + EscapedName: EscapeIdentifier(rp.PolicyName), + IsPermissive: rp.IsPermissive, + AppliesTo: rp.AppliesTo, + Cmd: PolicyCmd(rp.Cmd), + CheckExpression: rp.CheckExpression, + UsingExpression: rp.UsingExpression, + Columns: rp.ColumnNames, + TableDependencies: tableDependencies, + FunctionDependencies: functionDependencies, }, table: buildNameFromUnescaped(rp.OwningTableName, rp.OwningTableSchemaName), }) @@ -1592,6 +1602,24 @@ func parseJSONTableDependencies(vals []string) ([]TableDependency, error) { return out, nil } +// parseJSONFunctionDependencies takes a slice of JSON values with schema, +// `schema: string; name: string; identity_arguments: string` and unmarshals them into SchemaQualifiedName. +func parseJSONFunctionDependencies(vals []string) ([]SchemaQualifiedName, error) { + var out []SchemaQualifiedName + for _, v := range vals { + var s struct { + Schema string `json:"schema"` + Name string `json:"name"` + IdentityArguments string `json:"identity_arguments"` + } + if err := json.Unmarshal([]byte(v), &s); err != nil { + return nil, fmt.Errorf("json.Unmarshal(%q, function dependency): %w", string(v), err) + } + out = append(out, buildProcName(s.Name, s.IdentityArguments, s.Schema)) + } + return out, nil +} + // buildProcName is used to build the schema qualified name for a proc (function, procedure), i.e., anything // identified by a name AND its arguments. func buildProcName(name, identityArguments, schemaName string) SchemaQualifiedName { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 3bbedd8..334a026 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -239,7 +239,7 @@ var ( GRANT SELECT ON schema_2.foo TO some_role_1; GRANT INSERT ON schema_2.foo TO some_role_2 WITH GRANT OPTION; `}, - expectedHash: "43388964f7bede0", + expectedHash: "69342de97351df15", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -1384,6 +1384,64 @@ var ( }, }, }, + { + name: "Policy-to-function dependency", + ddl: []string{` + CREATE TABLE user_roles ( + user_id UUID NOT NULL, + role TEXT NOT NULL + ); + + CREATE FUNCTION has_role(uid UUID, required_role TEXT) RETURNS BOOLEAN AS $$ + SELECT EXISTS ( + SELECT 1 FROM user_roles WHERE user_id = uid AND role = required_role + ); + $$ LANGUAGE SQL STABLE SECURITY DEFINER; + + ALTER TABLE user_roles ENABLE ROW LEVEL SECURITY; + + CREATE POLICY admin_view_policy ON user_roles + FOR SELECT + TO PUBLIC + USING (has_role(user_id, 'admin')); + `}, + expectedSchema: Schema{ + NamedSchemas: []NamedSchema{ + {Name: "public"}, + }, + Tables: []Table{ + { + SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"user_roles\""}, + Columns: []Column{ + {Name: "user_id", Type: "uuid", Size: 16}, + {Name: "role", Type: "text", Size: -1, Collation: defaultCollation}, + }, + Policies: []Policy{ + { + EscapedName: "\"admin_view_policy\"", + IsPermissive: true, + AppliesTo: []string{"PUBLIC"}, + Cmd: SelectPolicyCmd, + UsingExpression: "has_role(user_id, 'admin'::text)", + Columns: []string{"user_id"}, + FunctionDependencies: []SchemaQualifiedName{ + {SchemaName: "public", EscapedName: "\"has_role\"(uid uuid, required_role text)"}, + }, + }, + }, + ReplicaIdentity: ReplicaIdentityDefault, + RLSEnabled: true, + }, + }, + Functions: []Function{ + { + SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"has_role\"(uid uuid, required_role text)"}, + FunctionDef: "CREATE OR REPLACE FUNCTION public.has_role(uid uuid, required_role text)\n RETURNS boolean\n LANGUAGE sql\n STABLE SECURITY DEFINER\nAS $function$\n\t\t\t\tSELECT EXISTS (\n\t\t\t\t\tSELECT 1 FROM user_roles WHERE user_id = uid AND role = required_role\n\t\t\t\t);\n\t\t\t$function$\n", + Language: "sql", + }, + }, + }, + }, } ) diff --git a/pkg/diff/policy_sql_generator.go b/pkg/diff/policy_sql_generator.go index 354d00e..d490fab 100644 --- a/pkg/diff/policy_sql_generator.go +++ b/pkg/diff/policy_sql_generator.go @@ -240,9 +240,11 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error) oldCopy.CheckExpression = diff.new.CheckExpression } oldCopy.Columns = diff.new.Columns - // TableDependencies is a derived field based on the USING/CHECK expressions, so when - // expressions change the dependencies naturally update. No special handling needed. + // TableDependencies and FunctionDependencies are derived fields based on the + // USING/CHECK expressions, so when expressions change the dependencies naturally + // update. No special handling needed. oldCopy.TableDependencies = diff.new.TableDependencies + oldCopy.FunctionDependencies = diff.new.FunctionDependencies if diff := cmp.Diff(oldCopy, diff.new); diff != "" { return nil, fmt.Errorf("unsupported diff %s: %w", diff, ErrNotImplemented) @@ -295,6 +297,13 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) } + // Run after any dependent functions are added/altered. + // This handles policy-to-function dependencies where the policy's USING/CHECK + // expression calls user-defined functions. + for _, f := range newPolicy.FunctionDependencies { + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildFunctionVertexId(f, diffTypeAddAlter))) + } + if !cmp.Equal(oldPolicy, schema.Policy{}) { // Run before the old columns are deleted (if they are deleted) oldTargetColumns, err := getTargetColumns(oldPolicy.Columns, psg.oldSchemaColumnsByName) @@ -332,5 +341,11 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([ deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildTableVertexId(t.SchemaQualifiedName, diffTypeAddAlter))) } + // The policy needs to be deleted before any dependent functions are deleted. + // This handles policy-to-function dependencies. + for _, f := range pol.FunctionDependencies { + deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildFunctionVertexId(f, diffTypeDelete))) + } + return deps, nil }