From c32edb3fa599156b65be882e9e09e33017eb086d Mon Sep 17 00:00:00 2001 From: Qasim Date: Fri, 27 Mar 2026 09:22:13 -0400 Subject: [PATCH] [TW-4729] refactor(dashboard,email): reduce code duplication in auth and pagination Extract shared helpers to eliminate duplicated logic across dashboard auth and email list/search commands. Dashboard auth: - Extract loadDashboardTokens() into session_store.go (was duplicated in AuthService and AppService) - Fix multi-org bug: only auto-store org when exactly one exists - SyncSessionOrg now surfaces errors; CLI wraps with warning - New persistActiveOrg() and syncSessionOrgWithWarning() shared helpers - Exported SyncSessionOrg() changed to best-effort (no error return) Email pagination: - Extract shared fetchMessages() with messagesClient interface - Remove ~60 lines of duplicated pagination boilerplate from list.go Release workflow: - Add snapshot validation build before tagging on workflow_dispatch - Restrict real release/DMG steps to tag push events only --- .github/workflows/release.yml | 34 ++-- internal/app/dashboard/app_service.go | 7 +- internal/app/dashboard/auth_service.go | 15 +- internal/app/dashboard/auth_service_test.go | 105 ++++++++++- internal/app/dashboard/session_store.go | 19 ++ internal/cli/dashboard/dashboard_test.go | 198 ++++++++++++++++++++ internal/cli/dashboard/exports.go | 11 +- internal/cli/dashboard/helpers.go | 25 ++- internal/cli/dashboard/login.go | 10 +- internal/cli/dashboard/sso.go | 9 +- internal/cli/email/list.go | 84 +-------- internal/cli/email/search.go | 74 +++++--- internal/cli/email/search_test.go | 127 +++++++++++++ 13 files changed, 562 insertions(+), 156 deletions(-) create mode 100644 internal/app/dashboard/session_store.go create mode 100644 internal/cli/dashboard/dashboard_test.go create mode 100644 internal/cli/email/search_test.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b71ac82..bce124c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -60,14 +60,6 @@ jobs: exit 1 fi - - name: Create tag - if: github.event_name == 'workflow_dispatch' && !inputs.dry_run - env: - TAG: ${{ steps.version.outputs.VERSION }} - run: | - git tag "$TAG" - git push origin "$TAG" - - name: Set up Go uses: actions/setup-go@v5 with: @@ -76,24 +68,44 @@ jobs: - name: Run tests run: go test ./... -short + - name: Validate release build + if: github.event_name == 'workflow_dispatch' + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: latest + args: release --clean --snapshot --skip=publish + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GORELEASER_CURRENT_TAG: ${{ steps.version.outputs.VERSION }} + + - name: Create tag + if: github.event_name == 'workflow_dispatch' && !inputs.dry_run + env: + TAG: ${{ steps.version.outputs.VERSION }} + run: | + git tag "$TAG" + git push origin "$TAG" + - name: Run GoReleaser + if: github.event_name == 'push' uses: goreleaser/goreleaser-action@v6 with: distribution: goreleaser version: latest - args: release --clean ${{ (github.event_name == 'workflow_dispatch' && inputs.dry_run) && '--snapshot' || '' }} + args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GORELEASER_CURRENT_TAG: ${{ steps.version.outputs.VERSION }} - name: Create DMG files - if: ${{ !(github.event_name == 'workflow_dispatch' && inputs.dry_run) }} + if: github.event_name == 'push' env: TAG: ${{ steps.version.outputs.VERSION }} run: ./scripts/create-dmg.sh "$TAG" - name: Upload DMG to release - if: ${{ !(github.event_name == 'workflow_dispatch' && inputs.dry_run) }} + if: github.event_name == 'push' uses: softprops/action-gh-release@v2 with: tag_name: ${{ steps.version.outputs.VERSION }} diff --git a/internal/app/dashboard/app_service.go b/internal/app/dashboard/app_service.go index 37e4ff5..59b198a 100644 --- a/internal/app/dashboard/app_service.go +++ b/internal/app/dashboard/app_service.go @@ -129,10 +129,5 @@ func deduplicateApps(apps []domain.GatewayApplication) []domain.GatewayApplicati // loadTokens retrieves the stored dashboard tokens. func (s *AppService) loadTokens() (userToken, orgToken string, err error) { - userToken, err = s.secrets.Get(ports.KeyDashboardUserToken) - if err != nil || userToken == "" { - return "", "", fmt.Errorf("%w", domain.ErrDashboardNotLoggedIn) - } - orgToken, _ = s.secrets.Get(ports.KeyDashboardOrgToken) - return userToken, orgToken, nil + return loadDashboardTokens(s.secrets) } diff --git a/internal/app/dashboard/auth_service.go b/internal/app/dashboard/auth_service.go index 84b5c1a..597f1ee 100644 --- a/internal/app/dashboard/auth_service.go +++ b/internal/app/dashboard/auth_service.go @@ -209,10 +209,12 @@ func (s *AuthService) SwitchOrg(ctx context.Context, orgPublicID string) (*domai func (s *AuthService) SyncSessionOrg(ctx context.Context) error { session, err := s.GetCurrentSession(ctx) if err != nil { - return nil // best effort — login already succeeded + return fmt.Errorf("failed to fetch current dashboard session: %w", err) } if session.CurrentOrg != "" { - _ = s.secrets.Set(ports.KeyDashboardOrgPublicID, session.CurrentOrg) + if err := s.secrets.Set(ports.KeyDashboardOrgPublicID, session.CurrentOrg); err != nil { + return fmt.Errorf("failed to store active organization: %w", err) + } } return nil } @@ -232,7 +234,7 @@ func (s *AuthService) storeTokens(resp *domain.DashboardAuthResponse) error { return err } } - if len(resp.Organizations) > 0 { + if len(resp.Organizations) == 1 { if err := s.secrets.Set(ports.KeyDashboardOrgPublicID, resp.Organizations[0].PublicID); err != nil { return err } @@ -258,10 +260,5 @@ func (s *AuthService) clearTokens() { // loadTokens retrieves the stored tokens. func (s *AuthService) loadTokens() (userToken, orgToken string, err error) { - userToken, err = s.secrets.Get(ports.KeyDashboardUserToken) - if err != nil || userToken == "" { - return "", "", fmt.Errorf("%w", domain.ErrDashboardNotLoggedIn) - } - orgToken, _ = s.secrets.Get(ports.KeyDashboardOrgToken) - return userToken, orgToken, nil + return loadDashboardTokens(s.secrets) } diff --git a/internal/app/dashboard/auth_service_test.go b/internal/app/dashboard/auth_service_test.go index 43b0274..076991e 100644 --- a/internal/app/dashboard/auth_service_test.go +++ b/internal/app/dashboard/auth_service_test.go @@ -45,6 +45,18 @@ func (m *memSecretStore) IsAvailable() bool { return true } func (m *memSecretStore) Name() string { return "mem" } +type failingSecretStore struct { + *memSecretStore + failSetKey string +} + +func (f *failingSecretStore) Set(key, value string) error { + if key == f.failSetKey { + return errors.New("set failed") + } + return f.memSecretStore.Set(key, value) +} + // seedTokens pre-populates userToken (and optionally orgToken) so that // loadTokens() succeeds without going through a full Login flow. func seedTokens(s ports.SecretStore, userToken, orgToken string) { @@ -335,8 +347,10 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { name string seedUser string seedOrg string + storeFactory func() ports.SecretStore mockFn func(ctx context.Context, userToken, orgToken string) (*domain.DashboardSessionResponse, error) - wantErr bool // SyncSessionOrg is best-effort; always returns nil + wantErr bool + wantErrIs error wantOrgPublicID string wantNoOrgPublicID bool }{ @@ -351,18 +365,18 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { wantOrgPublicID: "org-synced", }, { - name: "returns nil when GetCurrentSession fails (best-effort)", + name: "returns error when GetCurrentSession fails", seedUser: "ut-abc", mockFn: func(_ context.Context, _, _ string) (*domain.DashboardSessionResponse, error) { return nil, errors.New("session fetch failed") }, - // wantErr is false — SyncSessionOrg is best-effort. + wantErr: true, wantNoOrgPublicID: true, }, { - name: "returns nil when not logged in (best-effort)", - // No seedUser means loadTokens returns ErrDashboardNotLoggedIn. - // SyncSessionOrg should still return nil. + name: "returns error when not logged in", + wantErr: true, + wantErrIs: domain.ErrDashboardNotLoggedIn, wantNoOrgPublicID: true, }, { @@ -386,6 +400,23 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { }, wantOrgPublicID: "org-from-server", }, + { + name: "returns error when storing synced org fails", + seedUser: "ut-abc", + storeFactory: func() ports.SecretStore { + return &failingSecretStore{ + memSecretStore: newMemSecretStore(), + failSetKey: ports.KeyDashboardOrgPublicID, + } + }, + mockFn: func(_ context.Context, _, _ string) (*domain.DashboardSessionResponse, error) { + return &domain.DashboardSessionResponse{ + CurrentOrg: "org-synced", + }, nil + }, + wantErr: true, + wantNoOrgPublicID: true, + }, } for _, tt := range tests { @@ -393,7 +424,12 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - store := newMemSecretStore() + var store ports.SecretStore + if tt.storeFactory != nil { + store = tt.storeFactory() + } else { + store = newMemSecretStore() + } seedTokens(store, tt.seedUser, tt.seedOrg) mock := &dashboardadapter.MockAccountClient{ @@ -403,8 +439,14 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { svc := NewAuthService(mock, store) err := svc.SyncSessionOrg(context.Background()) - // SyncSessionOrg is always best-effort — it must never return an error. - require.NoError(t, err) + if tt.wantErr { + require.Error(t, err) + if tt.wantErrIs != nil { + assert.ErrorIs(t, err, tt.wantErrIs) + } + } else { + require.NoError(t, err) + } stored, _ := store.Get(ports.KeyDashboardOrgPublicID) if tt.wantOrgPublicID != "" { @@ -416,3 +458,48 @@ func TestAuthService_SyncSessionOrg(t *testing.T) { }) } } + +func TestAuthServiceStoreTokens(t *testing.T) { + t.Parallel() + + t.Run("stores org when there is exactly one organization", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + svc := NewAuthService(&dashboardadapter.MockAccountClient{}, store) + + err := svc.storeTokens(&domain.DashboardAuthResponse{ + UserToken: "user-token", + User: domain.DashboardUser{PublicID: "user-1"}, + Organizations: []domain.DashboardOrganization{ + {PublicID: "org-only"}, + }, + }) + + require.NoError(t, err) + + storedOrgID, _ := store.Get(ports.KeyDashboardOrgPublicID) + assert.Equal(t, "org-only", storedOrgID) + }) + + t.Run("does not guess active org when multiple organizations exist", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + svc := NewAuthService(&dashboardadapter.MockAccountClient{}, store) + + err := svc.storeTokens(&domain.DashboardAuthResponse{ + UserToken: "user-token", + User: domain.DashboardUser{PublicID: "user-1"}, + Organizations: []domain.DashboardOrganization{ + {PublicID: "org-1"}, + {PublicID: "org-2"}, + }, + }) + + require.NoError(t, err) + + storedOrgID, _ := store.Get(ports.KeyDashboardOrgPublicID) + assert.Empty(t, storedOrgID) + }) +} diff --git a/internal/app/dashboard/session_store.go b/internal/app/dashboard/session_store.go new file mode 100644 index 0000000..08f1cc9 --- /dev/null +++ b/internal/app/dashboard/session_store.go @@ -0,0 +1,19 @@ +package dashboard + +import ( + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +// loadDashboardTokens retrieves the stored dashboard access and session tokens. +// Returns ErrDashboardNotLoggedIn when no user token is present. +func loadDashboardTokens(secrets ports.SecretStore) (userToken, orgToken string, err error) { + userToken, err = secrets.Get(ports.KeyDashboardUserToken) + if err != nil || userToken == "" { + return "", "", fmt.Errorf("%w", domain.ErrDashboardNotLoggedIn) + } + orgToken, _ = secrets.Get(ports.KeyDashboardOrgToken) + return userToken, orgToken, nil +} diff --git a/internal/cli/dashboard/dashboard_test.go b/internal/cli/dashboard/dashboard_test.go new file mode 100644 index 0000000..8ac710d --- /dev/null +++ b/internal/cli/dashboard/dashboard_test.go @@ -0,0 +1,198 @@ +package dashboard + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/nylas/cli/internal/domain" +) + +func TestResolveAuthMethod(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + google bool + microsoft bool + github bool + email bool + action string + want string + wantErr string + }{ + { + name: "google flag wins", + google: true, + action: "log in", + want: methodGoogle, + }, + { + name: "microsoft flag wins", + microsoft: true, + action: "log in", + want: methodMicrosoft, + }, + { + name: "github flag wins", + github: true, + action: "log in", + want: methodGitHub, + }, + { + name: "email login is allowed", + email: true, + action: "log in", + want: methodEmailPassword, + }, + { + name: "email registration is rejected", + email: true, + action: "register", + wantErr: "temporarily disabled", + }, + { + name: "multiple flags are rejected", + google: true, + github: true, + action: "log in", + wantErr: "only one auth method flag allowed", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := resolveAuthMethod(tt.google, tt.microsoft, tt.github, tt.email, tt.action) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, got) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetDashboardAccountBaseURL(t *testing.T) { + t.Parallel() + + origURL := os.Getenv("NYLAS_DASHBOARD_ACCOUNT_URL") + defer func() { + if origURL == "" { + _ = os.Unsetenv("NYLAS_DASHBOARD_ACCOUNT_URL") + return + } + _ = os.Setenv("NYLAS_DASHBOARD_ACCOUNT_URL", origURL) + }() + + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_ACCOUNT_URL", "https://dashboard.example.com")) + assert.Equal(t, "https://dashboard.example.com", getDashboardAccountBaseURL(nil)) +} + +func TestMapProvider(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider string + want string + wantErr string + }{ + {name: "google", provider: "google", want: domain.SSOLoginTypeGoogle}, + {name: "microsoft", provider: "microsoft", want: domain.SSOLoginTypeMicrosoft}, + {name: "github", provider: "github", want: domain.SSOLoginTypeGitHub}, + {name: "case insensitive", provider: "GitHub", want: domain.SSOLoginTypeGitHub}, + {name: "unsupported", provider: "okta", wantErr: "unsupported SSO provider"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := mapProvider(tt.provider) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, got) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestPresentAbsent(t *testing.T) { + t.Parallel() + + assert.Equal(t, "present", presentAbsent(true)) + assert.Equal(t, "absent", presentAbsent(false)) +} + +func TestFormatOrgLabel(t *testing.T) { + t.Parallel() + + assert.Equal(t, "Acme (org-123)", formatOrgLabel("org-123", "Acme")) + assert.Equal(t, "org-123", formatOrgLabel("org-123", "")) +} + +func TestFormatSessionOrg(t *testing.T) { + t.Parallel() + + session := &domain.DashboardSessionResponse{ + CurrentOrg: "org-1", + Relations: []domain.DashboardSessionRelation{ + {OrgPublicID: "org-1", OrgName: "Acme"}, + {OrgPublicID: "org-2", OrgName: "Beta"}, + }, + } + + assert.Equal(t, "Acme (org-1)", formatSessionOrg(session, "org-1")) + assert.Equal(t, "org-missing", formatSessionOrg(session, "org-missing")) +} + +func TestToAppRows(t *testing.T) { + t.Parallel() + + apps := []domain.GatewayApplication{ + { + ApplicationID: "app-1", + Region: "us", + Environment: "", + Branding: &domain.GatewayApplicationBrand{Name: "Primary"}, + }, + { + ApplicationID: "app-2", + Region: "eu", + Environment: "sandbox", + }, + } + + rows := toAppRows(apps) + + require.Len(t, rows, 2) + assert.Equal(t, appRow{ + ApplicationID: "app-1", + Region: "us", + Environment: "production", + Name: "Primary", + }, rows[0]) + assert.Equal(t, appRow{ + ApplicationID: "app-2", + Region: "eu", + Environment: "sandbox", + Name: "", + }, rows[1]) +} diff --git a/internal/cli/dashboard/exports.go b/internal/cli/dashboard/exports.go index 87d0816..6ca84b1 100644 --- a/internal/cli/dashboard/exports.go +++ b/internal/cli/dashboard/exports.go @@ -37,14 +37,15 @@ func GetActiveOrgID() (string, error) { } // SyncSessionOrg syncs the active org from the server session (exported for setup wizard). -func SyncSessionOrg() error { +// Failures are logged as warnings rather than returned, since this is a +// best-effort step that should not block an otherwise successful login. +func SyncSessionOrg() { authSvc, _, err := createAuthService() if err != nil { - return err + common.PrintWarning("failed to create auth service for org sync: %v", err) + return } - ctx, cancel := common.CreateContext() - defer cancel() - return authSvc.SyncSessionOrg(ctx) + syncSessionOrgWithWarning(authSvc) } // ReadLine prompts for a line of text input (exported for setup wizard). diff --git a/internal/cli/dashboard/helpers.go b/internal/cli/dashboard/helpers.go index 7e941ad..d0ff528 100644 --- a/internal/cli/dashboard/helpers.go +++ b/internal/cli/dashboard/helpers.go @@ -178,6 +178,20 @@ func selectOrg(orgs []domain.DashboardOrganization) string { return selected } +func persistActiveOrg(authSvc *dashboardapp.AuthService, auth *domain.DashboardAuthResponse, orgPublicID string) error { + selectedOrgID := orgPublicID + if selectedOrgID == "" && len(auth.Organizations) > 1 { + selectedOrgID = selectOrg(auth.Organizations) + } + if selectedOrgID == "" { + return nil + } + if err := authSvc.SetActiveOrg(selectedOrgID); err != nil { + return fmt.Errorf("failed to store selected organization: %w", err) + } + return nil +} + // printAuthSuccess prints the standard post-login success message. // It reads the stored active org from the keyring (set by SyncSessionOrg) // so it reflects the server's actual current org. @@ -189,7 +203,7 @@ func printAuthSuccess(auth *domain.DashboardAuthResponse) { if _, secrets, err := createDPoPService(); err == nil { orgID, _ = secrets.Get(ports.KeyDashboardOrgPublicID) } - if orgID == "" && len(auth.Organizations) > 0 { + if orgID == "" && len(auth.Organizations) == 1 { orgID = auth.Organizations[0].PublicID } @@ -210,6 +224,15 @@ func printAuthSuccess(auth *domain.DashboardAuthResponse) { } } +func syncSessionOrgWithWarning(authSvc *dashboardapp.AuthService) { + syncCtx, syncCancel := common.CreateContext() + defer syncCancel() + + if err := authSvc.SyncSessionOrg(syncCtx); err != nil { + common.PrintWarning("authenticated, but failed to sync the active dashboard organization: %v", err) + } +} + // acceptPrivacyPolicy prompts for or validates privacy policy acceptance. func acceptPrivacyPolicy() error { accepted, err := common.ConfirmPrompt("Accept Nylas Privacy Policy?", true) diff --git a/internal/cli/dashboard/login.go b/internal/cli/dashboard/login.go index 5f39bc3..35658ff 100644 --- a/internal/cli/dashboard/login.go +++ b/internal/cli/dashboard/login.go @@ -135,15 +135,11 @@ func runEmailLogin(userFlag, passFlag, orgPublicID string) error { } } - if orgPublicID == "" && len(auth.Organizations) > 1 { - orgID := selectOrg(auth.Organizations) - _ = authSvc.SetActiveOrg(orgID) + if err := persistActiveOrg(authSvc, auth, orgPublicID); err != nil { + return wrapDashboardError(err) } - // Sync the actual active org from the server session - syncCtx, syncCancel := common.CreateContext() - defer syncCancel() - _ = authSvc.SyncSessionOrg(syncCtx) + syncSessionOrgWithWarning(authSvc) printAuthSuccess(auth) return nil diff --git a/internal/cli/dashboard/sso.go b/internal/cli/dashboard/sso.go index ec5acd8..ab8d47e 100644 --- a/internal/cli/dashboard/sso.go +++ b/internal/cli/dashboard/sso.go @@ -125,10 +125,11 @@ func runSSO(provider, mode string, privacyPolicyAccepted bool, orgPublicIDs ...s return wrapDashboardError(err) } - // Sync the actual active org from the server session - syncCtx, syncCancel := common.CreateContext() - defer syncCancel() - _ = authSvc.SyncSessionOrg(syncCtx) + if err := persistActiveOrg(authSvc, auth, orgPublicID); err != nil { + return wrapDashboardError(err) + } + + syncSessionOrgWithWarning(authSvc) printAuthSuccess(auth) return nil diff --git a/internal/cli/email/list.go b/internal/cli/email/list.go index 37b0327..938eaba 100644 --- a/internal/cli/email/list.go +++ b/internal/cli/email/list.go @@ -57,9 +57,10 @@ Use --max to limit total messages when using --all.`, // Auto-paginate when limit exceeds API maximum if limit > common.MaxAPILimit && !all { - all = true maxItems = limit limit = common.MaxAPILimit + } else if !all { + maxItems = -1 // single-page fetch } // Traditional formatted output @@ -110,43 +111,9 @@ Use --max to limit total messages when using --all.`, } } - var messages []domain.Message - var err error - - if all { - // Use pagination to fetch all messages - pageSize := min(limit, common.MaxAPILimit) - if pageSize <= 0 { - pageSize = common.MaxAPILimit - } - params.Limit = pageSize - - fetcher := func(ctx context.Context, cursor string) (common.PageResult[domain.Message], error) { - params.PageToken = cursor - resp, err := client.GetMessagesWithCursor(ctx, grantID, params) - if err != nil { - return common.PageResult[domain.Message]{}, err - } - return common.PageResult[domain.Message]{ - Data: resp.Data, - NextCursor: resp.Pagination.NextCursor, - }, nil - } - - config := common.DefaultPaginationConfig() - config.PageSize = pageSize - config.MaxItems = maxItems - - messages, err = common.FetchAllPages(ctx, config, fetcher) - if err != nil { - return struct{}{}, common.WrapFetchError("messages", err) - } - } else { - // Standard single-page fetch - messages, err = client.GetMessagesWithParams(ctx, grantID, params) - if err != nil { - return struct{}{}, common.WrapGetError("messages", err) - } + messages, err := fetchMessages(ctx, client, grantID, params, maxItems) + if err != nil { + return struct{}{}, common.WrapFetchError("messages", err) } if len(messages) == 0 { @@ -235,9 +202,10 @@ func runListStructured(cmd *cobra.Command, args []string, limit int, unread, sta // Auto-paginate when limit exceeds API maximum if limit > common.MaxAPILimit && !all { - all = true maxItems = limit limit = common.MaxAPILimit + } else if !all { + maxItems = -1 // single-page fetch } _, err := common.WithClient(args, func(ctx context.Context, client ports.NylasClient, grantID string) (struct{}, error) { @@ -281,41 +249,9 @@ func runListStructured(cmd *cobra.Command, args []string, limit int, unread, sta } } - var messages []domain.Message - var err error - - if all { - pageSize := min(limit, common.MaxAPILimit) - if pageSize <= 0 { - pageSize = common.MaxAPILimit - } - params.Limit = pageSize - - fetcher := func(ctx context.Context, cursor string) (common.PageResult[domain.Message], error) { - params.PageToken = cursor - resp, err := client.GetMessagesWithCursor(ctx, grantID, params) - if err != nil { - return common.PageResult[domain.Message]{}, err - } - return common.PageResult[domain.Message]{ - Data: resp.Data, - NextCursor: resp.Pagination.NextCursor, - }, nil - } - - config := common.DefaultPaginationConfig() - config.PageSize = pageSize - config.MaxItems = maxItems - - messages, err = common.FetchAllPages(ctx, config, fetcher) - if err != nil { - return struct{}{}, common.WrapFetchError("messages", err) - } - } else { - messages, err = client.GetMessagesWithParams(ctx, grantID, params) - if err != nil { - return struct{}{}, common.WrapGetError("messages", err) - } + messages, err := fetchMessages(ctx, client, grantID, params, maxItems) + if err != nil { + return struct{}{}, common.WrapFetchError("messages", err) } // Output structured data diff --git a/internal/cli/email/search.go b/internal/cli/email/search.go index 4a8a218..4ac2e6f 100644 --- a/internal/cli/email/search.go +++ b/internal/cli/email/search.go @@ -11,6 +11,11 @@ import ( "github.com/spf13/cobra" ) +type messagesClient interface { + GetMessagesWithParams(ctx context.Context, grantID string, params *domain.MessageQueryParams) ([]domain.Message, error) + GetMessagesWithCursor(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) +} + func newSearchCmd() *cobra.Command { var ( limit int @@ -55,15 +60,14 @@ Examples: remainingArgs := args[1:] _, err := common.WithClient(remainingArgs, func(ctx context.Context, client ports.NylasClient, grantID string) (struct{}, error) { - // Auto-paginate when limit exceeds API maximum - needsPagination := limit > common.MaxAPILimit - apiLimit := limit - if needsPagination { - apiLimit = common.MaxAPILimit + // maxItems >= 0 triggers auto-pagination; < 0 means single-page fetch + maxItems := -1 + if limit > common.MaxAPILimit { + maxItems = limit } params := &domain.MessageQueryParams{ - Limit: apiLimit, + Limit: limit, } // Use query as subject search unless it's a wildcard @@ -109,30 +113,7 @@ Examples: params.ReceivedBefore = t.Unix() } - var messages []domain.Message - var err error - - if needsPagination { - fetcher := func(ctx context.Context, cursor string) (common.PageResult[domain.Message], error) { - params.PageToken = cursor - resp, fetchErr := client.GetMessagesWithCursor(ctx, grantID, params) - if fetchErr != nil { - return common.PageResult[domain.Message]{}, fetchErr - } - return common.PageResult[domain.Message]{ - Data: resp.Data, - NextCursor: resp.Pagination.NextCursor, - }, nil - } - - config := common.DefaultPaginationConfig() - config.PageSize = apiLimit - config.MaxItems = limit - - messages, err = common.FetchAllPages(ctx, config, fetcher) - } else { - messages, err = client.GetMessagesWithParams(ctx, grantID, params) - } + messages, err := fetchMessages(ctx, client, grantID, params, maxItems) if err != nil { return struct{}{}, common.WrapSearchError("messages", err) } @@ -172,6 +153,39 @@ Examples: return cmd } +// fetchMessages retrieves messages, using automatic pagination when maxItems >= 0. +// Pass maxItems = 0 for unlimited pagination, >0 for a capped fetch, or <0 to +// skip pagination and perform a single-page request via GetMessagesWithParams. +func fetchMessages(ctx context.Context, client messagesClient, grantID string, params *domain.MessageQueryParams, maxItems int) ([]domain.Message, error) { + if maxItems < 0 { + return client.GetMessagesWithParams(ctx, grantID, params) + } + + pageSize := min(params.Limit, common.MaxAPILimit) + if pageSize <= 0 { + pageSize = common.MaxAPILimit + } + params.Limit = pageSize + + fetcher := func(ctx context.Context, cursor string) (common.PageResult[domain.Message], error) { + params.PageToken = cursor + resp, err := client.GetMessagesWithCursor(ctx, grantID, params) + if err != nil { + return common.PageResult[domain.Message]{}, err + } + return common.PageResult[domain.Message]{ + Data: resp.Data, + NextCursor: resp.Pagination.NextCursor, + }, nil + } + + config := common.DefaultPaginationConfig() + config.PageSize = pageSize + config.MaxItems = maxItems + + return common.FetchAllPages(ctx, config, fetcher) +} + // parseDate parses a date string in YYYY-MM-DD format using local timezone. func parseDate(s string) (time.Time, error) { return time.ParseInLocation("2006-01-02", s, time.Local) diff --git a/internal/cli/email/search_test.go b/internal/cli/email/search_test.go new file mode 100644 index 0000000..079f041 --- /dev/null +++ b/internal/cli/email/search_test.go @@ -0,0 +1,127 @@ +package email + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/domain" +) + +type stubMessagesClient struct { + getMessagesWithParamsFunc func(ctx context.Context, grantID string, params *domain.MessageQueryParams) ([]domain.Message, error) + getMessagesWithCursorFunc func(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) + + getMessagesWithParamsCalls int + getMessagesWithCursorCalls int + pageTokens []string + limits []int +} + +func (s *stubMessagesClient) GetMessagesWithParams(ctx context.Context, grantID string, params *domain.MessageQueryParams) ([]domain.Message, error) { + s.getMessagesWithParamsCalls++ + s.limits = append(s.limits, params.Limit) + if s.getMessagesWithParamsFunc != nil { + return s.getMessagesWithParamsFunc(ctx, grantID, params) + } + return nil, nil +} + +func (s *stubMessagesClient) GetMessagesWithCursor(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + s.getMessagesWithCursorCalls++ + s.pageTokens = append(s.pageTokens, params.PageToken) + s.limits = append(s.limits, params.Limit) + if s.getMessagesWithCursorFunc != nil { + return s.getMessagesWithCursorFunc(ctx, grantID, params) + } + return nil, nil +} + +func TestFetchMessages(t *testing.T) { + common.ResetLogger() + common.InitLogger(false, true) + defer common.ResetLogger() + + t.Run("uses direct fetch when maxItems is negative", func(t *testing.T) { + expected := []domain.Message{{ID: "msg-1"}, {ID: "msg-2"}} + client := &stubMessagesClient{ + getMessagesWithParamsFunc: func(_ context.Context, grantID string, params *domain.MessageQueryParams) ([]domain.Message, error) { + assert.Equal(t, "grant-123", grantID) + assert.Equal(t, 50, params.Limit) + return expected, nil + }, + } + + params := &domain.MessageQueryParams{Limit: 50} + messages, err := fetchMessages(context.Background(), client, "grant-123", params, -1) + + require.NoError(t, err) + assert.Equal(t, expected, messages) + assert.Equal(t, 1, client.getMessagesWithParamsCalls) + assert.Zero(t, client.getMessagesWithCursorCalls) + }) + + t.Run("auto paginates when limit exceeds API maximum", func(t *testing.T) { + client := &stubMessagesClient{ + getMessagesWithCursorFunc: func(_ context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + assert.Equal(t, "grant-123", grantID) + switch params.PageToken { + case "": + return &domain.MessageListResponse{ + Data: makeMessages(200, "page-1"), + Pagination: domain.Pagination{ + NextCursor: "cursor-2", + }, + }, nil + case "cursor-2": + return &domain.MessageListResponse{ + Data: makeMessages(100, "page-2"), + }, nil + default: + t.Fatalf("unexpected page token %q", params.PageToken) + return nil, nil + } + }, + } + + params := &domain.MessageQueryParams{Limit: common.MaxAPILimit} + messages, err := fetchMessages(context.Background(), client, "grant-123", params, 250) + + require.NoError(t, err) + assert.Len(t, messages, 250) + assert.Equal(t, "page-1-0", messages[0].ID) + assert.Equal(t, "page-2-49", messages[len(messages)-1].ID) + assert.Equal(t, []string{"", "cursor-2"}, client.pageTokens) + assert.Equal(t, []int{common.MaxAPILimit, common.MaxAPILimit}, client.limits) + assert.Zero(t, client.getMessagesWithParamsCalls) + assert.Equal(t, 2, client.getMessagesWithCursorCalls) + }) + + t.Run("returns pagination errors", func(t *testing.T) { + client := &stubMessagesClient{ + getMessagesWithCursorFunc: func(_ context.Context, _ string, _ *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + return nil, errors.New("boom") + }, + } + + params := &domain.MessageQueryParams{Limit: common.MaxAPILimit} + messages, err := fetchMessages(context.Background(), client, "grant-123", params, 250) + + require.Error(t, err) + assert.Nil(t, messages) + assert.Contains(t, err.Error(), "failed to fetch page 1") + }) +} + +func makeMessages(count int, prefix string) []domain.Message { + messages := make([]domain.Message, count) + for i := range count { + messages[i] = domain.Message{ID: prefix + "-" + strconv.Itoa(i)} + } + return messages +}