Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions core/membership/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1113,11 +1113,9 @@ type Member struct {
Roles []role.Role
}

// ListPrincipalsByResource returns the principals (users, service users, groups)
// that have at least one policy on the given resource, optionally filtered by
// principal type and/or role, and optionally enriched with the full list of
// roles each principal holds on the resource.
func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter MemberFilter) ([]Member, error) {
// resourcePolicyFilter builds the policy filter that scopes a listing to the
// given resource. Returns ErrInvalidResourceType for unsupported namespaces.
func resourcePolicyFilter(resourceID, resourceType string, filter MemberFilter) (policy.Filter, error) {
flt := policy.Filter{
PrincipalType: filter.PrincipalType,
RoleIDs: filter.RoleIDs,
Expand All @@ -1131,7 +1129,19 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso
case schema.GroupNamespace:
flt.GroupID = resourceID
default:
return nil, ErrInvalidResourceType
return policy.Filter{}, ErrInvalidResourceType
}
return flt, nil
}

// ListPrincipalsByResource returns the principals (users, service users, groups)
// that have at least one policy on the given resource, optionally filtered by
// principal type and/or role, and optionally enriched with the full list of
// roles each principal holds on the resource.
func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter MemberFilter) ([]Member, error) {
flt, err := resourcePolicyFilter(resourceID, resourceType, filter)
if err != nil {
return nil, err
}

policies, err := s.policyService.List(ctx, flt)
Expand Down Expand Up @@ -1214,6 +1224,36 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso
return members, nil
}

// ListPrincipalIDsByResource returns the IDs of principals of the given type
// that have at least one policy on the resource. It is a primitive-typed,
// ID-only variant of ListPrincipalsByResource: it skips role enrichment
// entirely (a single policy query) and exists for consumer packages that
// cannot import membership types without creating an import cycle
// (e.g. core/serviceuser, which this package itself imports).
func (s *Service) ListPrincipalIDsByResource(ctx context.Context, resourceID, resourceType, principalType string) ([]string, error) {

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now adding this wrapper method, as other wise we are seeing cyclic dependecy when trying to import MemberFilter type.

flt, err := resourcePolicyFilter(resourceID, resourceType, MemberFilter{PrincipalType: principalType})
if err != nil {
return nil, err
}

policies, err := s.policyService.List(ctx, flt)
if err != nil {
return nil, fmt.Errorf("list policies: %w", err)
}
policies = excludePATAllProjects(policies, resourceType)

ids := make([]string, 0, len(policies))
seen := make(map[string]struct{}, len(policies))
for _, pol := range policies {
if _, ok := seen[pol.PrincipalID]; ok {
continue
}
seen[pol.PrincipalID] = struct{}{}
ids = append(ids, pol.PrincipalID)
}
return ids, nil
}

// SetGroupMemberRole upserts the role assignment for a principal in a group:
// if the principal has no existing group policy, they are added with the
// requested role; otherwise their existing role is replaced with the
Expand Down
66 changes: 66 additions & 0 deletions core/membership/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,72 @@ func TestService_ListPrincipalsByResource(t *testing.T) {
}
}

func TestService_ListPrincipalIDsByResource(t *testing.T) {
ctx := context.Background()
orgID := uuid.New().String()
su1, su2 := uuid.New().String(), uuid.New().String()
roleID := uuid.New().String()

tests := []struct {
name string
setup func(*mocks.PolicyService, *mocks.RoleService)
want []string
wantErrMsg string
}{
{
name: "returns deduplicated principal IDs from a single policy query",
setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) {
suPolicies := []policy.Policy{
{PrincipalID: su1, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID},
{PrincipalID: su1, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID},
{PrincipalID: su2, PrincipalType: schema.ServiceUserPrincipal, RoleID: roleID},
}
ps.EXPECT().List(ctx, policy.Filter{
OrgID: orgID,
PrincipalType: schema.ServiceUserPrincipal,
ResourceType: schema.OrganizationNamespace,
}).Return(suPolicies, nil).Once()
// no role-service calls: the ID-only path skips role enrichment
},
want: []string{su1, su2},
},
{
name: "returns empty when no principals",
setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) {
ps.EXPECT().List(ctx, mock.Anything).Return([]policy.Policy{}, nil)
},
want: []string{},
},
{
name: "propagates errors",
setup: func(ps *mocks.PolicyService, rs *mocks.RoleService) {
ps.EXPECT().List(ctx, mock.Anything).Return(nil, errors.New("db down"))
},
wantErrMsg: "db down",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockPolicySvc := mocks.NewPolicyService(t)
mockRoleSvc := mocks.NewRoleService(t)
if tt.setup != nil {
tt.setup(mockPolicySvc, mockRoleSvc)
}

svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t))

got, err := svc.ListPrincipalIDsByResource(ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal)
if tt.wantErrMsg != "" {
assert.ErrorContains(t, err, tt.wantErrMsg)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestService_SetGroupMemberRole(t *testing.T) {
ctx := context.Background()
orgID := uuid.New().String()
Expand Down
61 changes: 61 additions & 0 deletions core/serviceuser/mocks/membership_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 0 additions & 59 deletions core/serviceuser/mocks/relation_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 2 additions & 11 deletions core/serviceuser/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ type CredentialRepository interface {
type RelationService interface {
Create(ctx context.Context, rel relation.Relation) (relation.Relation, error)
Delete(ctx context.Context, rel relation.Relation) error
LookupSubjects(ctx context.Context, rel relation.Relation) ([]string, error)
CheckPermission(ctx context.Context, rel relation.Relation) (bool, error)
BatchCheckPermission(ctx context.Context, rel []relation.Relation) ([]relation.CheckPair, error)
}

type MembershipService interface {
AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error
RemoveOrganizationMember(ctx context.Context, orgID, principalID, principalType string) error
ListPrincipalIDsByResource(ctx context.Context, resourceID, resourceType, principalType string) ([]string, error)
}

type Service struct {
Expand Down Expand Up @@ -123,16 +123,7 @@ func (s Service) GetByIDs(ctx context.Context, ids []string) ([]ServiceUser, err
}

func (s Service) ListByOrg(ctx context.Context, orgID string) ([]ServiceUser, error) {
userIDs, err := s.relationService.LookupSubjects(ctx, relation.Relation{
Object: relation.Object{
ID: orgID,
Namespace: schema.OrganizationNamespace,
},
Subject: relation.Subject{
Namespace: schema.ServiceUserPrincipal,
},
RelationName: schema.MembershipPermission,
})
userIDs, err := s.membershipService.ListPrincipalIDsByResource(ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal)
if err != nil {
return nil, err
}
Expand Down
64 changes: 64 additions & 0 deletions core/serviceuser/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,67 @@ func TestService_Get(t *testing.T) {
})
}
}

func TestService_ListByOrg(t *testing.T) {
ctx := context.Background()
const orgID = "org-id"

tests := []struct {
name string
setup func(*mocks.Repository, *mocks.MembershipService)
want int
wantErr bool
}{
{
name: "members found are fetched from the repo",
setup: func(repo *mocks.Repository, mem *mocks.MembershipService) {
mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal).
Return([]string{"su-1", "su-2"}, nil)
repo.On("GetByIDs", ctx, []string{"su-1", "su-2"}).
Return([]serviceuser.ServiceUser{{ID: "su-1"}, {ID: "su-2"}}, nil)
},
want: 2,
},
{
name: "no members returns empty list without hitting the repo",
setup: func(repo *mocks.Repository, mem *mocks.MembershipService) {
mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal).
Return([]string{}, nil)
// repo.GetByIDs must NOT be called
},
want: 0,
},
{
name: "membership error is propagated",
setup: func(repo *mocks.Repository, mem *mocks.MembershipService) {
mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal).
Return(nil, errors.New("policy store unavailable"))
},
wantErr: true,
},
{
name: "repo error after successful ID lookup is propagated",
setup: func(repo *mocks.Repository, mem *mocks.MembershipService) {
mem.On("ListPrincipalIDsByResource", ctx, orgID, schema.OrganizationNamespace, schema.ServiceUserPrincipal).
Return([]string{"su-1"}, nil)
repo.On("GetByIDs", ctx, []string{"su-1"}).
Return(nil, errors.New("db unavailable"))
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc, repo, _, _, mem := newTestService(t)
tt.setup(repo, mem)

got, err := svc.ListByOrg(ctx, orgID)
if (err != nil) != tt.wantErr {
t.Fatalf("ListByOrg() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && len(got) != tt.want {
t.Errorf("ListByOrg() returned %d service users, want %d", len(got), tt.want)
}
})
}
}
Loading