diff --git a/core/membership/service.go b/core/membership/service.go index 5076c6997..2eeb3586a 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1113,11 +1113,9 @@ type Member struct { Roles []role.Role } -// ListPrincipalsByResource returns the principals (users, service users, groups) -// that have at least one policy on the given resource, optionally filtered by -// principal type and/or role, and optionally enriched with the full list of -// roles each principal holds on the resource. -func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter MemberFilter) ([]Member, error) { +// resourcePolicyFilter builds the policy filter that scopes a listing to the +// given resource. Returns ErrInvalidResourceType for unsupported namespaces. +func resourcePolicyFilter(resourceID, resourceType string, filter MemberFilter) (policy.Filter, error) { flt := policy.Filter{ PrincipalType: filter.PrincipalType, RoleIDs: filter.RoleIDs, @@ -1131,7 +1129,19 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso case schema.GroupNamespace: flt.GroupID = resourceID default: - return nil, ErrInvalidResourceType + return policy.Filter{}, ErrInvalidResourceType + } + return flt, nil +} + +// ListPrincipalsByResource returns the principals (users, service users, groups) +// that have at least one policy on the given resource, optionally filtered by +// principal type and/or role, and optionally enriched with the full list of +// roles each principal holds on the resource. +func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter MemberFilter) ([]Member, error) { + flt, err := resourcePolicyFilter(resourceID, resourceType, filter) + if err != nil { + return nil, err } policies, err := s.policyService.List(ctx, flt) @@ -1214,6 +1224,36 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso return members, nil } +// ListPrincipalIDsByResource returns the IDs of principals of the given type +// that have at least one policy on the resource. It is a primitive-typed, +// ID-only variant of ListPrincipalsByResource: it skips role enrichment +// entirely (a single policy query) and exists for consumer packages that +// cannot import membership types without creating an import cycle +// (e.g. core/serviceuser, which this package itself imports). +func (s *Service) ListPrincipalIDsByResource(ctx context.Context, resourceID, resourceType, principalType string) ([]string, error) { + flt, err := resourcePolicyFilter(resourceID, resourceType, MemberFilter{PrincipalType: principalType}) + if err != nil { + return nil, err + } + + policies, err := s.policyService.List(ctx, flt) + if err != nil { + return nil, fmt.Errorf("list policies: %w", err) + } + policies = excludePATAllProjects(policies, resourceType) + + ids := make([]string, 0, len(policies)) + seen := make(map[string]struct{}, len(policies)) + for _, pol := range policies { + if _, ok := seen[pol.PrincipalID]; ok { + continue + } + seen[pol.PrincipalID] = struct{}{} + ids = append(ids, pol.PrincipalID) + } + return ids, nil +} + // SetGroupMemberRole upserts the role assignment for a principal in a group: // if the principal has no existing group policy, they are added with the // requested role; otherwise their existing role is replaced with the diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 7ee7576bf..087a31bd3 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -1659,6 +1659,72 @@ func TestService_ListPrincipalsByResource(t *testing.T) { } } +func TestService_ListPrincipalIDsByResource(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + su1, su2 := uuid.New().String(), uuid.New().String() + roleID := uuid.New().String() + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RoleService) + want []string + wantErrMsg string + }{ + { + name: "returns deduplicated principal IDs from a single policy query", + setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) { + suPolicies := []policy.Policy{ + {PrincipalID: su1, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID}, + {PrincipalID: su1, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID}, + {PrincipalID: su2, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID}, + } + ps.EXPECT().List(ctx, policy.Filter{ + OrgID: orgID, + PrincipalType: schema.ServiceUserPrincipal, + ResourceType: schema.OrganizationNamespace, + }).Return(suPolicies, nil).Once() + // no role-service calls: the ID-only path skips role enrichment + }, + want: []string{su1, su2}, + }, + { + name: "returns empty when no principals", + setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) { + ps.EXPECT().List(ctx, mock.Anything).Return([]policy.Policy{}, nil) + }, + want: []string{}, + }, + { + name: "propagates errors", + setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) { + ps.EXPECT().List(ctx, mock.Anything).Return(nil, errors.New("db down")) + }, + wantErrMsg: "db down", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRoleSvc) + } + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + got, err := svc.ListPrincipalIDsByResource(ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal) + if tt.wantErrMsg != "" { + assert.ErrorContains(t, err, tt.wantErrMsg) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestService_SetGroupMemberRole(t *testing.T) { ctx := context.Background() orgID := uuid.New().String() diff --git a/core/serviceuser/mocks/membership_service.go b/core/serviceuser/mocks/membership_service.go index a5ad7c519..b2f48a533 100644 --- a/core/serviceuser/mocks/membership_service.go +++ b/core/serviceuser/mocks/membership_service.go @@ -71,6 +71,67 @@ func (_c *MembershipService_AddOrganizationMember_Call) RunAndReturn(run func(co return _c } +// ListPrincipalIDsByResource provides a mock function with given fields: ctx, resourceID, resourceType, principalType +func (_m *MembershipService) ListPrincipalIDsByResource(ctx context.Context, resourceID string, resourceType string, principalType string) ([]string, error) { + ret := _m.Called(ctx, resourceID, resourceType, principalType) + + if len(ret) == 0 { + panic("no return value specified for ListPrincipalIDsByResource") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) ([]string, error)); ok { + return rf(ctx, resourceID, resourceType, principalType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) []string); ok { + r0 = rf(ctx, resourceID, resourceType, principalType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, resourceID, resourceType, principalType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListPrincipalIDsByResource_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPrincipalIDsByResource' +type MembershipService_ListPrincipalIDsByResource_Call struct { + *mock.Call +} + +// ListPrincipalIDsByResource is a helper method to define mock.On call +// - ctx context.Context +// - resourceID string +// - resourceType string +// - principalType string +func (_e *MembershipService_Expecter) ListPrincipalIDsByResource(ctx interface{}, resourceID interface{}, resourceType interface{}, principalType interface{}) *MembershipService_ListPrincipalIDsByResource_Call { + return &MembershipService_ListPrincipalIDsByResource_Call{Call: _e.mock.On("ListPrincipalIDsByResource", ctx, resourceID, resourceType, principalType)} +} + +func (_c *MembershipService_ListPrincipalIDsByResource_Call) Run(run func(ctx context.Context, resourceID string, resourceType string, principalType string)) *MembershipService_ListPrincipalIDsByResource_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MembershipService_ListPrincipalIDsByResource_Call) Return(_a0 []string, _a1 error) *MembershipService_ListPrincipalIDsByResource_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListPrincipalIDsByResource_Call) RunAndReturn(run func(context.Context, string, string, string) ([]string, error)) *MembershipService_ListPrincipalIDsByResource_Call { + _c.Call.Return(run) + return _c +} + // RemoveOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType func (_m *MembershipService) RemoveOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string) error { ret := _m.Called(ctx, orgID, principalID, principalType) diff --git a/core/serviceuser/mocks/relation_service.go b/core/serviceuser/mocks/relation_service.go index 027c253ac..38646b244 100644 --- a/core/serviceuser/mocks/relation_service.go +++ b/core/serviceuser/mocks/relation_service.go @@ -242,65 +242,6 @@ func (_c *RelationService_Delete_Call) RunAndReturn(run func(context.Context, re return _c } -// LookupSubjects provides a mock function with given fields: ctx, rel -func (_m *RelationService) LookupSubjects(ctx context.Context, rel relation.Relation) ([]string, error) { - ret := _m.Called(ctx, rel) - - if len(ret) == 0 { - panic("no return value specified for LookupSubjects") - } - - var r0 []string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) ([]string, error)); ok { - return rf(ctx, rel) - } - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) []string); ok { - r0 = rf(ctx, rel) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { - r1 = rf(ctx, rel) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RelationService_LookupSubjects_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LookupSubjects' -type RelationService_LookupSubjects_Call struct { - *mock.Call -} - -// LookupSubjects is a helper method to define mock.On call -// - ctx context.Context -// - rel relation.Relation -func (_e *RelationService_Expecter) LookupSubjects(ctx interface{}, rel interface{}) *RelationService_LookupSubjects_Call { - return &RelationService_LookupSubjects_Call{Call: _e.mock.On("LookupSubjects", ctx, rel)} -} - -func (_c *RelationService_LookupSubjects_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_LookupSubjects_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(relation.Relation)) - }) - return _c -} - -func (_c *RelationService_LookupSubjects_Call) Return(_a0 []string, _a1 error) *RelationService_LookupSubjects_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *RelationService_LookupSubjects_Call) RunAndReturn(run func(context.Context, relation.Relation) ([]string, error)) *RelationService_LookupSubjects_Call { - _c.Call.Return(run) - return _c -} - // NewRelationService creates a new instance of RelationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewRelationService(t interface { diff --git a/core/serviceuser/service.go b/core/serviceuser/service.go index bebeb7e08..f9a8f8df6 100644 --- a/core/serviceuser/service.go +++ b/core/serviceuser/service.go @@ -41,7 +41,6 @@ type CredentialRepository interface { type RelationService interface { Create(ctx context.Context, rel relation.Relation) (relation.Relation, error) Delete(ctx context.Context, rel relation.Relation) error - LookupSubjects(ctx context.Context, rel relation.Relation) ([]string, error) CheckPermission(ctx context.Context, rel relation.Relation) (bool, error) BatchCheckPermission(ctx context.Context, rel []relation.Relation) ([]relation.CheckPair, error) } @@ -49,6 +48,7 @@ type RelationService interface { type MembershipService interface { AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error RemoveOrganizationMember(ctx context.Context, orgID, principalID, principalType string) error + ListPrincipalIDsByResource(ctx context.Context, resourceID, resourceType, principalType string) ([]string, error) } type Service struct { @@ -123,16 +123,7 @@ func (s Service) GetByIDs(ctx context.Context, ids []string) ([]ServiceUser, err } func (s Service) ListByOrg(ctx context.Context, orgID string) ([]ServiceUser, error) { - userIDs, err := s.relationService.LookupSubjects(ctx, relation.Relation{ - Object: relation.Object{ - ID: orgID, - Namespace: schema.OrganizationNamespace, - }, - Subject: relation.Subject{ - Namespace: schema.ServiceUserPrincipal, - }, - RelationName: schema.MembershipPermission, - }) + userIDs, err := s.membershipService.ListPrincipalIDsByResource(ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal) if err != nil { return nil, err } diff --git a/core/serviceuser/service_test.go b/core/serviceuser/service_test.go index d65f7f91a..8c295b0fd 100644 --- a/core/serviceuser/service_test.go +++ b/core/serviceuser/service_test.go @@ -163,3 +163,67 @@ func TestService_Get(t *testing.T) { }) } } + +func TestService_ListByOrg(t *testing.T) { + ctx := context.Background() + const orgID = "org-id" + + tests := []struct { + name string + setup func(*mocks.Repository, *mocks.MembershipService) + want int + wantErr bool + }{ + { + name: "members found are fetched from the repo", + setup: func(repo *mocks.Repository, mem *mocks.MembershipService) { + mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal). + Return([]string{"su-1", "su-2"}, nil) + repo.On("GetByIDs", ctx, []string{"su-1", "su-2"}). + Return([]serviceuser.ServiceUser{{ID: "su-1"}, {ID: "su-2"}}, nil) + }, + want: 2, + }, + { + name: "no members returns empty list without hitting the repo", + setup: func(repo *mocks.Repository, mem *mocks.MembershipService) { + mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal). + Return([]string{}, nil) + // repo.GetByIDs must NOT be called + }, + want: 0, + }, + { + name: "membership error is propagated", + setup: func(repo *mocks.Repository, mem *mocks.MembershipService) { + mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal). + Return(nil, errors.New("policy store unavailable")) + }, + wantErr: true, + }, + { + name: "repo error after successful ID lookup is propagated", + setup: func(repo *mocks.Repository, mem *mocks.MembershipService) { + mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal). + Return([]string{"su-1"}, nil) + repo.On("GetByIDs", ctx, []string{"su-1"}). + Return(nil, errors.New("db unavailable")) + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, repo, _, _, mem := newTestService(t) + tt.setup(repo, mem) + + got, err := svc.ListByOrg(ctx, orgID) + if (err != nil) != tt.wantErr { + t.Fatalf("ListByOrg() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && len(got) != tt.want { + t.Errorf("ListByOrg() returned %d service users, want %d", len(got), tt.want) + } + }) + } +}