diff --git a/commands/auth/api_token_login.go b/commands/auth/api_token_login.go new file mode 100644 index 000000000..3a9cd00c9 --- /dev/null +++ b/commands/auth/api_token_login.go @@ -0,0 +1,96 @@ +package auth + +import ( + "bufio" + "errors" + "fmt" + "os" + "strings" + + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func NewAPITokenLoginCommand(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "auth:api-token-login", + Short: "Log in using an API token", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + // Block if TOKEN env var is already set via config. + if os.Getenv(cfg.Application.EnvPrefix+"TOKEN") != "" { + fmt.Fprintln(cmd.ErrOrStderr(), "An API token is already set via config") + return fmt.Errorf("an API token is already set via config") + } + // Non-interactive guard only when no arg (stdin would be needed). + if len(args) == 0 && os.Getenv(cfg.Application.EnvPrefix+"NO_INTERACTION") != "" { + fmt.Fprintln(cmd.ErrOrStderr(), "Non-interactive use of this command is not supported.") + return fmt.Errorf("non-interactive use of this command is not supported") + } + + var ( + apiToken string + s *session.Session + ) + if len(args) > 0 { + apiToken = strings.TrimSpace(args[0]) + var err error + s, err = exchangeAPIToken(cmd.Context(), cfg, apiToken) + if err != nil { + return fmt.Errorf("login failed: %w", err) + } + } else { + const maxAttempts = 5 + scanner := bufio.NewScanner(cmd.InOrStdin()) + for attempt := 1; attempt <= maxAttempts; attempt++ { + fmt.Fprint(cmd.ErrOrStderr(), "Enter your API token: ") + if !scanner.Scan() { + return fmt.Errorf("read API token: %w", scanner.Err()) + } + apiToken = strings.TrimSpace(scanner.Text()) + if apiToken == "" { + fmt.Fprintln(cmd.ErrOrStderr(), "The token cannot be empty") + continue + } + var err error + s, err = exchangeAPIToken(cmd.Context(), cfg, apiToken) + if err == nil { + break + } + if errors.Is(err, ErrInvalidAPIToken) { + fmt.Fprintln(cmd.ErrOrStderr(), ErrInvalidAPIToken.Error()) + if attempt == maxAttempts { + return fmt.Errorf("login failed after %d attempts", maxAttempts) + } + continue + } + return fmt.Errorf("login failed: %w", err) + } + } + fmt.Fprintln(cmd.ErrOrStderr(), "The API token is valid.") + + mgr, err := session.New(cfg) + if err != nil { + return err + } + if err := mgr.SetAPIToken(apiToken); err != nil { + return err + } + if err := mgr.Save(s); err != nil { + return err + } + + fmt.Fprintln(cmd.ErrOrStderr(), "You are logged in.") + if err := printUserInfo(cmd.Context(), mgr, cfg, cmd.ErrOrStderr()); err != nil { + return err + } + delegateSSHFinalization(cmd.Context(), cfg, cmd) + return nil + }, + } + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/commands/auth/browser_login.go b/commands/auth/browser_login.go new file mode 100644 index 000000000..9df358763 --- /dev/null +++ b/commands/auth/browser_login.go @@ -0,0 +1,122 @@ +// commands/auth/browser_login.go +package auth + +import ( + "bufio" + "fmt" + "os" + "strings" + "time" + + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + internalauth "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func NewBrowserLoginCommand(cfg *config.Config) *cobra.Command { + var ( + force bool + methods []string + maxAge int + ) + cmd := &cobra.Command{ + Use: "auth:browser-login", + Aliases: []string{"login"}, + Short: "Log in via a browser", + RunE: func(cmd *cobra.Command, _ []string) error { + // If an API token is configured, browser login is not applicable. + if apiToken := os.Getenv(cfg.Application.EnvPrefix + "TOKEN"); apiToken != "" { + return fmt.Errorf("cannot log in via the browser while an API token is set (%sTOKEN)", cfg.Application.EnvPrefix) + } + + mgr, err := session.New(cfg) + if err != nil { + return err + } + + // Also check for an API token in the session. + if storedToken, err := mgr.GetAPIToken(); err == nil && storedToken != "" { + return fmt.Errorf("cannot log in via the browser while an API token is configured") + } + + // Non-interactive guard. + if os.Getenv(cfg.Application.EnvPrefix+"NO_INTERACTION") != "" { + fmt.Fprintln(cmd.ErrOrStderr(), "Non-interactive use of this command is not supported.") + return fmt.Errorf("non-interactive use of this command is not supported") + } + + printSessionID(cmd.ErrOrStderr(), cfg, mgr) + + hasMaxAge := cmd.Flags().Changed("max-age") + + // Check if already logged in (unless --force). + if !force && len(methods) == 0 && !hasMaxAge { + s, err := mgr.Load() + if err == nil && s != nil && s.AccessToken != "" && time.Now().Unix() < s.Expires { + fmt.Fprintf(cmd.ErrOrStderr(), "You are already logged in as a user.\n") + fmt.Fprint(cmd.ErrOrStderr(), "Log in anyway? [y/N] ") + scanner := bufio.NewScanner(cmd.InOrStdin()) + scanner.Scan() + answer := strings.TrimSpace(strings.ToLower(scanner.Text())) + if answer != "y" && answer != "yes" { + return fmt.Errorf("login canceled") + } + force = true + } + } + + flow := internalauth.NewBrowserFlow(cfg) + opts := internalauth.BrowserFlowOptions{ + Force: force, + Methods: methods, + Stderr: cmd.ErrOrStderr(), + OnCodeReceived: func() { + fmt.Fprintln(cmd.ErrOrStderr(), "Login information received. Verifying...") + }, + } + if hasMaxAge { + opts.MaxAge = &maxAge + } + + fmt.Fprintf(cmd.ErrOrStderr(), + "\nHelp:\n Leave this command running during login.\n If you need to quit, use Ctrl+C.\n\n") + + s, err := flow.Run(cmd.Context(), opts) + if err != nil { + return err + } + + if err := mgr.Save(s); err != nil { + return err + } + + if s.RefreshToken == "" { + clientID := cfg.API.OAuth2ClientID + fmt.Fprintln(cmd.ErrOrStderr(), "") + fmt.Fprintln(cmd.ErrOrStderr(), "Warning:") + fmt.Fprintln(cmd.ErrOrStderr(), "No refresh token is available. This will cause frequent login errors.") + fmt.Fprintln(cmd.ErrOrStderr(), "Please contact support.") + fmt.Fprintf(cmd.ErrOrStderr(), + "For internal use: the OAuth 2 client is probably misconfigured (client ID: %s).\n", + clientID) + } + + fmt.Fprintln(cmd.ErrOrStderr(), "You are logged in.") + + if err := printUserInfo(cmd.Context(), mgr, cfg, cmd.ErrOrStderr()); err != nil { + fmt.Fprintf(cmd.ErrOrStderr(), "Warning: could not retrieve user info: %v\n", err) + } + + delegateSSHFinalization(cmd.Context(), cfg, cmd) + return nil + }, + } + cmd.Flags().BoolVarP(&force, "force", "f", false, "Log in again, even if already logged in") + cmd.Flags().StringArrayVar(&methods, "method", nil, "Require specific authentication method(s)") + cmd.Flags().IntVar(&maxAge, "max-age", 0, "Maximum age (seconds) of the web authentication session") + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/commands/auth/helpers.go b/commands/auth/helpers.go new file mode 100644 index 000000000..1f431c4dc --- /dev/null +++ b/commands/auth/helpers.go @@ -0,0 +1,156 @@ +package auth + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "strings" + + "golang.org/x/oauth2" + + "github.com/upsun/cli/internal/api" + internalauth "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/legacy" + "github.com/upsun/cli/internal/session" +) + +// InjectSessionCredentials injects stored credentials into a legacy PHP CLI wrapper so PHP +// can authenticate without relying on the OS credential helper (e.g. macOS Keychain). +// Respects an already-set TOKEN env var and is a no-op if session loading fails. +func InjectSessionCredentials(cfg *config.Config, wrapper *legacy.CLIWrapper) { + envPrefix := cfg.Application.EnvPrefix + if os.Getenv(envPrefix+"TOKEN") != "" { + return + } + mgr, err := session.New(cfg) + if err != nil { + return + } + if apiToken, err := mgr.GetAPIToken(); err == nil && apiToken != "" { + wrapper.ExtraEnv = append(wrapper.ExtraEnv, envPrefix+"TOKEN="+apiToken) + return + } + if s, err := mgr.Load(); err == nil && s != nil && s.AccessToken != "" { + wrapper.ExtraEnv = append(wrapper.ExtraEnv, envPrefix+"API_TOKEN="+s.AccessToken) + } +} + +// httpClient is used for all outbound HTTP requests in this package. +// Can be replaced in tests to inject a custom transport. +var httpClient *http.Client = http.DefaultClient + +// resolveBaseURL returns the API base URL, preferring the env var override. +func resolveBaseURL(cfg *config.Config) string { + if v := os.Getenv(cfg.Application.EnvPrefix + "API_URL"); v != "" { + return v + } + return cfg.API.BaseURL +} + +// newAPIClient creates an authenticated API client for commands. +// +// Auth priority: +// 1. API token from env var ({EnvPrefix}TOKEN) or session storage — exchanged for OAuth access token. +// 2. Session OAuth token — used directly. +func newAPIClient(ctx context.Context, mgr *session.Manager, cfg *config.Config) (*api.Client, error) { + // Check for API token in env or session storage. + apiToken := os.Getenv(cfg.Application.EnvPrefix + "TOKEN") + if apiToken == "" { + var err error + apiToken, err = mgr.GetAPIToken() + if err != nil { + return nil, err + } + } + + if apiToken != "" { + // Exchange the API token for an OAuth2 access token. + s, err := exchangeAPIToken(ctx, cfg, apiToken) + if err != nil { + return nil, err + } + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: s.AccessToken}) + return api.NewClient(resolveBaseURL(cfg), oauth2.NewClient(ctx, ts)) + } + + // Fall back to session-based OAuth token source. + authClient, err := internalauth.NewClient(ctx, mgr, cfg) + if err != nil { + return nil, err + } + return api.NewClient(resolveBaseURL(cfg), authClient.HTTPClient) +} + +// printSessionID prints the current session ID hint when the session is non-default or +// multiple sessions exist, followed by a blank line. No-ops otherwise. +func printSessionID(w io.Writer, cfg *config.Config, mgr *session.Manager) { + sessionID := mgr.SessionID() + ids, _ := mgr.List() + if sessionID == "default" && len(ids) <= 1 { + return + } + fmt.Fprintf(w, "The current session ID is: %s\n", sessionID) + if os.Getenv(cfg.Application.EnvPrefix+"SESSION_ID") == "" { + fmt.Fprintf(w, "Change this using: %s session:switch\n", cfg.Application.Executable) + } + fmt.Fprintln(w) +} + +// printUserInfo fetches and prints the current user's info to w (used post-login). +func printUserInfo(ctx context.Context, mgr *session.Manager, cfg *config.Config, w io.Writer) error { + apiClient, err := newAPIClient(ctx, mgr, cfg) + if err != nil { + return err + } + info, err := apiClient.GetMyUser(ctx, false) + if err != nil { + return err + } + username, _ := info["username"].(string) + email, _ := info["email"].(string) + fmt.Fprintf(w, "Logged in as: %s (%s)\n", username, email) + return nil +} + +// printTable writes a two-column property/value table to w. +func printTable(w io.Writer, properties []string, data map[string]interface{}) { + col1 := len("Property") + col2 := len("Value") + for _, p := range properties { + if len(p) > col1 { + col1 = len(p) + } + v := formatValue(data[p]) + if len(v) > col2 { + col2 = len(v) + } + } + sep := "+" + strings.Repeat("-", col1+2) + "+" + strings.Repeat("-", col2+2) + "+" + fmt.Fprintln(w, sep) + fmt.Fprintf(w, "| %-*s | %-*s |\n", col1, "Property", col2, "Value") + fmt.Fprintln(w, sep) + for _, p := range properties { + v := formatValue(data[p]) + fmt.Fprintf(w, "| %-*s | %-*s |\n", col1, p, col2, v) + } + fmt.Fprintln(w, sep) +} + +// formatValue converts an interface{} to a display string. +func formatValue(v interface{}) string { + if v == nil { + return "" + } + switch val := v.(type) { + case bool: + if val { + return "true" + } + return "false" + default: + return strings.TrimSpace(fmt.Sprintf("%v", val)) + } +} diff --git a/commands/auth/helpers_test.go b/commands/auth/helpers_test.go new file mode 100644 index 000000000..088cac2c3 --- /dev/null +++ b/commands/auth/helpers_test.go @@ -0,0 +1,178 @@ +package auth + +import ( + "bytes" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/legacy" + "github.com/upsun/cli/internal/session" +) + +// ---- formatValue ---- + +func TestFormatValue_Nil(t *testing.T) { + assert.Equal(t, "", formatValue(nil)) +} + +func TestFormatValue_BoolTrue(t *testing.T) { + assert.Equal(t, "true", formatValue(true)) +} + +func TestFormatValue_BoolFalse(t *testing.T) { + assert.Equal(t, "false", formatValue(false)) +} + +func TestFormatValue_String(t *testing.T) { + assert.Equal(t, "hello", formatValue("hello")) +} + +func TestFormatValue_StringWithWhitespace(t *testing.T) { + assert.Equal(t, "hello", formatValue(" hello ")) +} + +func TestFormatValue_Number(t *testing.T) { + assert.Equal(t, "42", formatValue(42)) +} + +// ---- printTable ---- + +func TestPrintTable_RendersHeaderAndRows(t *testing.T) { + var buf bytes.Buffer + data := map[string]interface{}{ + "id": "user-1", + "email": "x@example.com", + } + printTable(&buf, []string{"id", "email"}, data) + out := buf.String() + + assert.Contains(t, out, "Property") + assert.Contains(t, out, "Value") + assert.Contains(t, out, "id") + assert.Contains(t, out, "user-1") + assert.Contains(t, out, "email") + assert.Contains(t, out, "x@example.com") +} + +func TestPrintTable_ColumnsWideEnoughForContent(t *testing.T) { + var buf bytes.Buffer + longKey := "a_very_long_property_name" + printTable(&buf, []string{longKey}, map[string]interface{}{longKey: "v"}) + out := buf.String() + // Each data row must contain the key without truncation + assert.Contains(t, out, longKey) +} + +// ---- InjectSessionCredentials ---- + +func testCfg(t *testing.T) *config.Config { + t.Helper() + data, err := os.ReadFile("../../integration-tests/config.yaml") + require.NoError(t, err) + cfg, err := config.FromYAML(data) + require.NoError(t, err) + return cfg +} + +func TestInjectSessionCredentials_EnvTokenAlreadySet(t *testing.T) { + cfg := testCfg(t) + t.Setenv(cfg.Application.EnvPrefix+"TOKEN", "env-token") + + wrapper := &legacy.CLIWrapper{} + InjectSessionCredentials(cfg, wrapper) + + assert.Empty(t, wrapper.ExtraEnv, "should not inject when TOKEN env var is set") +} + +func TestInjectSessionCredentials_InjectsAPITokenFromSession(t *testing.T) { + cfg := testCfg(t) + mgr, err := session.New(cfg) + require.NoError(t, err) + require.NoError(t, mgr.SetAPIToken("stored-api-token")) + t.Cleanup(func() { _ = mgr.DeleteAPIToken() }) + + wrapper := &legacy.CLIWrapper{} + InjectSessionCredentials(cfg, wrapper) + + require.Len(t, wrapper.ExtraEnv, 1) + assert.Equal(t, cfg.Application.EnvPrefix+"TOKEN=stored-api-token", wrapper.ExtraEnv[0]) +} + +func TestInjectSessionCredentials_InjectsOAuthAccessToken(t *testing.T) { + cfg := testCfg(t) + mgr, err := session.New(cfg) + require.NoError(t, err) + require.NoError(t, mgr.Save(&session.Session{ + AccessToken: "oauth-access-token", + Expires: time.Now().Add(time.Hour).Unix(), + })) + t.Cleanup(func() { _ = mgr.Delete() }) + + wrapper := &legacy.CLIWrapper{} + InjectSessionCredentials(cfg, wrapper) + + require.Len(t, wrapper.ExtraEnv, 1) + assert.Equal(t, cfg.Application.EnvPrefix+"API_TOKEN=oauth-access-token", wrapper.ExtraEnv[0]) +} + +func TestInjectSessionCredentials_NoOpWhenNoCredentials(t *testing.T) { + cfg := testCfg(t) + + wrapper := &legacy.CLIWrapper{} + InjectSessionCredentials(cfg, wrapper) + + assert.Empty(t, wrapper.ExtraEnv) +} + +// ---- printSessionID ---- + +func TestPrintSessionID_DefaultSingleSession_NoOutput(t *testing.T) { + cfg := testCfg(t) + mgr := session.NewWithStore(cfg, session.NewMemStore()) + + var buf bytes.Buffer + printSessionID(&buf, cfg, mgr) + + assert.Empty(t, buf.String()) +} + +func TestPrintSessionID_NonDefaultSession_PrintsHint(t *testing.T) { + cfg := testCfg(t) + // Use a non-default session ID + cfg2 := *cfg + cfg2.API.SessionID = "work" + mgr := session.NewWithStore(&cfg2, session.NewMemStore()) + + var buf bytes.Buffer + printSessionID(&buf, &cfg2, mgr) + + out := buf.String() + assert.Contains(t, out, "work") + assert.Contains(t, out, "session:switch") +} + +func TestPrintSessionID_MultipleSessionsShowsHint(t *testing.T) { + // Load config first to get the env prefix, then set HOME before any dir is computed. + base := testCfg(t) + t.Setenv(base.Application.EnvPrefix+"HOME", t.TempDir()) + + // Fresh config load after env var is set, so WritableUserDir cache is clean. + cfg := testCfg(t) + + mgr, err := session.New(cfg) + require.NoError(t, err) + require.NoError(t, mgr.Save(&session.Session{AccessToken: "tok", Expires: time.Now().Add(time.Hour).Unix()})) + + mgr2 := session.NewWithID(cfg, "other") + require.NoError(t, mgr2.Save(&session.Session{AccessToken: "tok2", Expires: time.Now().Add(time.Hour).Unix()})) + + var buf bytes.Buffer + printSessionID(&buf, cfg, mgr) + + assert.Contains(t, buf.String(), "session ID") +} diff --git a/commands/auth/info.go b/commands/auth/info.go new file mode 100644 index 000000000..fe26a1b10 --- /dev/null +++ b/commands/auth/info.go @@ -0,0 +1,164 @@ +package auth + +import ( + "bufio" + "fmt" + "os" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/spf13/viper" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + internalauth "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func NewInfoCommand(cfg *config.Config) *cobra.Command { + var ( + noAutoLogin bool + property string + refresh bool + ) + cmd := &cobra.Command{ + Use: "auth:info [property]", + Short: "Display your account information", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + if property != "" { + return fmt.Errorf("cannot use both the argument and --property option") + } + property = args[0] + } + + ctx := cmd.Context() + mgr, err := session.New(cfg) + if err != nil { + return err + } + + // Determine login state; distinguish expired sessions from never-logged-in. + envToken := os.Getenv(cfg.Application.EnvPrefix + "TOKEN") + loggedIn := envToken != "" + var sessionExpired bool + if !loggedIn { + if apiToken, _ := mgr.GetAPIToken(); apiToken != "" { + loggedIn = true + } else if s, _ := mgr.Load(); s != nil && s.AccessToken != "" { + if time.Now().Unix() < s.Expires { + loggedIn = true + } else { + sessionExpired = true + } + } + } + + if noAutoLogin && !loggedIn { + return nil + } + + if !loggedIn { + if sessionExpired { + fmt.Fprintln(cmd.ErrOrStderr(), "Your session has expired. You have been logged out.") + fmt.Fprintln(cmd.ErrOrStderr(), "") + } + + // --yes auto-accepts the browser login prompt. + // --no-interaction (without --yes) suppresses the prompt entirely. + yesFlag := viper.GetBool("yes") + noInteraction := !yesFlag && (os.Getenv(cfg.Application.EnvPrefix+"NO_INTERACTION") != "" || + viper.GetBool("no-interaction")) + + shouldLogin := false + if yesFlag { + fmt.Fprintln(cmd.ErrOrStderr(), "Authentication is required.") + shouldLogin = true + } else if !noInteraction { + fmt.Fprintln(cmd.ErrOrStderr(), "Authentication is required.") + fmt.Fprint(cmd.ErrOrStderr(), "Log in via a browser? [Y/n] ") + scanner := bufio.NewScanner(cmd.InOrStdin()) + scanner.Scan() + answer := strings.TrimSpace(strings.ToLower(scanner.Text())) + shouldLogin = answer == "" || answer == "y" || answer == "yes" + } + + if shouldLogin { + fmt.Fprintf(cmd.ErrOrStderr(), + "\nHelp:\n Leave this command running during login.\n If you need to quit, use Ctrl+C.\n\n") + flow := internalauth.NewBrowserFlow(cfg) + opts := internalauth.BrowserFlowOptions{ + Stderr: cmd.ErrOrStderr(), + OnCodeReceived: func() { + fmt.Fprintln(cmd.ErrOrStderr(), "Login information received. Verifying...") + }, + } + s, err := flow.Run(ctx, opts) + if err != nil { + return err + } + if err := mgr.Save(s); err != nil { + return err + } + loggedIn = true + } + + if !loggedIn { + return fmt.Errorf("not logged in. Run '%s login' to authenticate", cfg.Application.Executable) + } + } + + apiClient, err := newAPIClient(ctx, mgr, cfg) + if err != nil { + return err + } + + info, err := apiClient.GetMyUser(ctx, refresh) + if err != nil { + return err + } + + // Handle deprecated property aliases. + if property == "display_name" { + fmt.Fprintln(cmd.ErrOrStderr(), + "Deprecated: the \"display_name\" property has been replaced by \"first_name\" and \"last_name\".") + firstName, _ := info["first_name"].(string) + lastName, _ := info["last_name"].(string) + fmt.Fprintln(cmd.OutOrStdout(), firstName+" "+lastName) + return nil + } + if property == "mail" { + fmt.Fprintln(cmd.ErrOrStderr(), "Deprecated: the \"mail\" property is now named \"email\".") + property = "email" + } + if property == "uuid" { + fmt.Fprintln(cmd.ErrOrStderr(), "Deprecated: the \"uuid\" property is now named \"id\".") + property = "id" + } + + if property != "" { + val, ok := info[property] + if !ok { + return fmt.Errorf("property not found: %s", property) + } + fmt.Fprintln(cmd.OutOrStdout(), formatValue(val)) + return nil + } + + // Table output. + properties := []string{"id", "first_name", "last_name", "username", "email", "phone_number_verified"} + printTable(cmd.OutOrStdout(), properties, info) + + printSessionID(cmd.ErrOrStderr(), cfg, mgr) + return nil + }, + } + cmd.Flags().BoolVar(&noAutoLogin, "no-auto-login", false, "Skip auto login; exit 0 if not logged in") + cmd.Flags().StringVarP(&property, "property", "P", "", "The account property to view") + cmd.Flags().BoolVar(&refresh, "refresh", false, "Refresh the cache") + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/commands/auth/logout.go b/commands/auth/logout.go new file mode 100644 index 000000000..b577c36b1 --- /dev/null +++ b/commands/auth/logout.go @@ -0,0 +1,129 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "strings" + + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + internalauth "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +// revokeSession POSTs the access token to the OAuth2 revocation endpoint. +// Network or server errors are printed as warnings — local cleanup always proceeds. +func revokeSession(ctx context.Context, mgr *session.Manager, cfg *config.Config, warn func(string)) { + s, err := mgr.Load() + if err != nil || s == nil || s.AccessToken == "" { + return + } + revokeURL := internalauth.OAuth2RevokeURL(cfg) + if revokeURL == "" { + return + } + body := url.Values{ + "token": {s.AccessToken}, + "token_type_hint": {"access_token"}, + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, revokeURL, strings.NewReader(body.Encode())) + if err != nil { + warn(fmt.Sprintf("Warning: could not build revoke request: %v", err)) + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := httpClient.Do(req) + if err != nil { + warn(fmt.Sprintf("Warning: could not revoke token: %v", err)) + return + } + resp.Body.Close() +} + +func NewLogoutCommand(cfg *config.Config) *cobra.Command { + var ( + all bool + other bool + ) + cmd := &cobra.Command{ + Use: "auth:logout", + Aliases: []string{"logout"}, + Short: "Log out", + RunE: func(cmd *cobra.Command, _ []string) error { + mgr, err := session.New(cfg) + if err != nil { + return err + } + + if os.Getenv(cfg.Application.EnvPrefix+"TOKEN") != "" { + fmt.Fprintln(cmd.ErrOrStderr(), "Warning: an API token is set via config") + } + + warn := func(msg string) { fmt.Fprintln(cmd.ErrOrStderr(), msg) } + + if other && !all { + currentID := mgr.SessionID() + fmt.Fprintf(cmd.ErrOrStderr(), "The current session ID is: %s\n", currentID) + ids, err := mgr.List() + if err != nil { + return err + } + var others []string + for _, id := range ids { + if id != currentID { + others = append(others, id) + } + } + if len(others) == 0 { + fmt.Fprintln(cmd.ErrOrStderr(), "No other sessions exist.") + return nil + } + fmt.Fprintln(cmd.ErrOrStderr()) + for _, id := range others { + sub := session.NewWithID(cfg, id) + revokeSession(cmd.Context(), sub, cfg, warn) + if err := sub.Delete(); err != nil { + return fmt.Errorf("delete session %q: %w", id, err) + } + fmt.Fprintf(cmd.ErrOrStderr(), "Logged out from session: %s\n", id) + } + fmt.Fprintln(cmd.ErrOrStderr()) + fmt.Fprintln(cmd.ErrOrStderr(), "All other sessions have been deleted.") + return nil + } + + revokeSession(cmd.Context(), mgr, cfg, warn) + if err := mgr.Delete(); err != nil { + return err + } + fmt.Fprintln(cmd.ErrOrStderr(), "You are now logged out.") + + if all { + if err := mgr.DeleteAll(); err != nil { + return err + } + fmt.Fprintln(cmd.ErrOrStderr(), "All sessions have been deleted.") + return nil + } + + ids, err := mgr.List() + if err != nil { + return err + } + if len(ids) > 0 { + fmt.Fprintf(cmd.ErrOrStderr(), "\nOther sessions exist. Log out of all sessions with: %s logout --all\n", + cfg.Application.Executable) + } + return nil + }, + } + cmd.Flags().BoolVarP(&all, "all", "a", false, "Log out from all local sessions") + cmd.Flags().BoolVar(&other, "other", false, "Log out from other local sessions") + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/commands/auth/ssh.go b/commands/auth/ssh.go new file mode 100644 index 000000000..e5d573e35 --- /dev/null +++ b/commands/auth/ssh.go @@ -0,0 +1,27 @@ +// commands/auth/ssh.go +package auth + +import ( + "context" + "io" + + "github.com/spf13/cobra" + + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/legacy" +) + +// delegateSSHFinalization calls the legacy PHP CLI to run post-login SSH setup. +// This is best-effort — errors are intentionally ignored. +func delegateSSHFinalization(ctx context.Context, cfg *config.Config, cmd *cobra.Command) { + wrapper := &legacy.CLIWrapper{ + Config: cfg, + Version: config.Version, + DisableInteraction: true, + Stdout: io.Discard, + Stderr: cmd.ErrOrStderr(), + Stdin: cmd.InOrStdin(), + } + InjectSessionCredentials(cfg, wrapper) + _ = wrapper.Exec(ctx, "ssh-cert:load", "--no-interaction") +} diff --git a/commands/auth/token.go b/commands/auth/token.go new file mode 100644 index 000000000..b054795b6 --- /dev/null +++ b/commands/auth/token.go @@ -0,0 +1,142 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + internalauth "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func NewTokenCommand(cfg *config.Config) *cobra.Command { + var ( + header bool + noWarn bool + ) + cmd := &cobra.Command{ + Use: "auth:token", + Short: "Obtain an OAuth 2 access token for API requests", + Hidden: true, + RunE: func(cmd *cobra.Command, _ []string) error { + if !noWarn { + fmt.Fprintln(cmd.ErrOrStderr(), "Warning: keep access tokens secret.") + } + + mgr, err := session.New(cfg) + if err != nil { + return err + } + + // Check for an API token in the environment or session. + apiToken := os.Getenv(cfg.Application.EnvPrefix + "TOKEN") + if apiToken == "" { + apiToken, err = mgr.GetAPIToken() + if err != nil { + return err + } + } + + var accessToken string + if apiToken != "" { + s, err := exchangeAPIToken(cmd.Context(), cfg, apiToken) + if err != nil { + return err + } + accessToken = s.AccessToken + } else { + ts := internalauth.NewSessionTokenSource(mgr, cfg) + tok, err := ts.TokenContext(cmd.Context()) + if err != nil { + return err + } + accessToken = tok.AccessToken + } + + out := accessToken + if header { + out = "Authorization: Bearer " + out + } + fmt.Fprintln(cmd.OutOrStdout(), out) + return nil + }, + } + cmd.Flags().BoolVarP(&header, "header", "H", false, `Output the token as an HTTP "Authorization: Bearer" header`) + cmd.Flags().BoolVarP(&noWarn, "no-warn", "W", false, "Suppress the warning message") + cobrahelp.SetPhpStyle(cmd) + return cmd +} + +// ErrInvalidAPIToken is returned by exchangeAPIToken when the server rejects the token (400/401). +// Callers can use errors.Is to detect this and re-prompt. +var ErrInvalidAPIToken = errors.New("invalid API token") + +// exchangeAPIToken exchanges an API token for OAuth2 tokens and returns the resulting session. +func exchangeAPIToken(ctx context.Context, cfg *config.Config, apiToken string) (*session.Session, error) { + tokenURL := internalauth.OAuth2TokenURL(cfg) + if tokenURL == "" { + return nil, fmt.Errorf("no OAuth2 token URL configured") + } + + data := url.Values{ + "grant_type": {"api_token"}, + "api_token": {apiToken}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("exchange API token: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if cfg.API.OAuth2ClientID != "" { + req.SetBasicAuth(cfg.API.OAuth2ClientID, "") + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("exchange API token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusBadRequest || resp.StatusCode == http.StatusUnauthorized { + return nil, ErrInvalidAPIToken + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exchange API token: server returned %d", resp.StatusCode) + } + + var result struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Error string `json:"error"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("exchange API token: decode response: %w", err) + } + if result.Error != "" { + return nil, fmt.Errorf("exchange API token: %s", result.Error) + } + if result.AccessToken == "" { + return nil, fmt.Errorf("exchange API token: no access token in response") + } + + expiry := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second).Unix() + return &session.Session{ + AccessToken: result.AccessToken, + TokenType: result.TokenType, + Expires: expiry, + RefreshToken: result.RefreshToken, + }, nil +} diff --git a/commands/auth/verify_phone.go b/commands/auth/verify_phone.go new file mode 100644 index 000000000..6bca046f6 --- /dev/null +++ b/commands/auth/verify_phone.go @@ -0,0 +1,153 @@ +package auth + +import ( + "bufio" + "fmt" + "strings" + "unicode" + + "github.com/nyaruka/phonenumbers" + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func NewVerifyPhoneNumberCommand(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "auth:verify-phone-number", + Short: "Verify your phone number interactively", + Hidden: true, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + mgr, err := session.New(cfg) + if err != nil { + return err + } + apiClient, err := newAPIClient(ctx, mgr, cfg) + if err != nil { + return err + } + + info, err := apiClient.GetMyUser(ctx, true) + if err != nil { + return err + } + if verified, _ := info["phone_number_verified"].(bool); verified { + fmt.Fprintln(cmd.ErrOrStderr(), "Your user account already has a verified phone number.") + return nil + } + userID, _ := info["id"].(string) + if userID == "" { + return fmt.Errorf("could not determine user ID") + } + country, _ := info["country"].(string) + + scanner := bufio.NewScanner(cmd.InOrStdin()) + + // Choose method. + fmt.Fprintln(cmd.ErrOrStderr(), "Choose a verification method:") + fmt.Fprintln(cmd.ErrOrStderr(), " [0] SMS (default)") + fmt.Fprintln(cmd.ErrOrStderr(), " [1] WhatsApp message") + fmt.Fprintln(cmd.ErrOrStderr(), " [2] Call") + fmt.Fprint(cmd.ErrOrStderr(), "Enter a number (default: 0): ") + scanner.Scan() + if err := scanner.Err(); err != nil { + return fmt.Errorf("read input: %w", err) + } + choice := strings.TrimSpace(scanner.Text()) + var method string + switch choice { + case "", "0": + method = "sms" + case "1": + method = "whatsapp" + case "2": + method = "call" + default: + method = choice + } + + // Get phone number — re-prompt up to 5 times on invalid input (PHP parity: askInput with setMaxAttempts(5)). + const maxAttempts = 5 + var e164 string + for attempt := 1; attempt <= maxAttempts; attempt++ { + fmt.Fprint(cmd.ErrOrStderr(), "Enter your phone number (international format, e.g. +1 415 555 0100): ") + scanner.Scan() + if err := scanner.Err(); err != nil { + return fmt.Errorf("read input: %w", err) + } + rawNumber := strings.TrimSpace(scanner.Text()) + num, parseErr := phonenumbers.Parse(rawNumber, country) + if rawNumber == "" || parseErr != nil || !phonenumbers.IsValidNumber(num) { + fmt.Fprintln(cmd.ErrOrStderr(), "The phone number is not valid.") + if attempt == maxAttempts { + return fmt.Errorf("too many invalid phone numbers") + } + continue + } + e164 = phonenumbers.Format(num, phonenumbers.E164) + break + } + + sid, err := apiClient.SendPhoneVerification(ctx, userID, e164, method) + if err != nil { + return fmt.Errorf("send verification: %w", err) + } + + switch method { + case "call": + fmt.Fprintf(cmd.ErrOrStderr(), "Calling the number %s with a verification code.\n", e164) + case "sms": + fmt.Fprintf(cmd.ErrOrStderr(), "A verification code has been sent using SMS to the number: %s\n", e164) + case "whatsapp": + fmt.Fprintf(cmd.ErrOrStderr(), "A verification code has been sent using WhatsApp to the number: %s\n", e164) + } + fmt.Fprintln(cmd.ErrOrStderr()) + + // Get verification code — re-prompt up to 5 times on invalid input or rejected code (PHP parity). + for attempt := 1; attempt <= maxAttempts; attempt++ { + fmt.Fprint(cmd.ErrOrStderr(), "Enter the verification code: ") + scanner.Scan() + if err := scanner.Err(); err != nil { + return fmt.Errorf("read input: %w", err) + } + code := strings.TrimSpace(scanner.Text()) + isNumeric := code != "" + for _, c := range code { + if !unicode.IsDigit(c) { + isNumeric = false + break + } + } + if !isNumeric { + fmt.Fprintln(cmd.ErrOrStderr(), "Invalid verification code") + if attempt == maxAttempts { + return fmt.Errorf("too many invalid verification codes") + } + continue + } + if err := apiClient.VerifyPhone(ctx, userID, sid, code); err != nil { + fmt.Fprintln(cmd.ErrOrStderr(), "Invalid verification code") + if attempt == maxAttempts { + return fmt.Errorf("too many invalid verification codes") + } + continue + } + break + } + + // Verify the status was actually updated. + if err := apiClient.CheckVerificationStatus(ctx); err != nil { + fmt.Fprintln(cmd.ErrOrStderr(), "Phone verification succeeded but the status check failed.") + return fmt.Errorf("verification status check failed: %w", err) + } + + fmt.Fprintln(cmd.ErrOrStderr(), "Your phone number has been successfully verified.") + return nil + }, + } + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/commands/cobrahelp/help.go b/commands/cobrahelp/help.go new file mode 100644 index 000000000..3f91c9e99 --- /dev/null +++ b/commands/cobrahelp/help.go @@ -0,0 +1,74 @@ +// commands/cobrahelp/help.go +package cobrahelp + +import ( + "bytes" + "fmt" + "strings" + "text/tabwriter" + + "github.com/fatih/color" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// SetPhpStyle sets a PHP CLI-style SetHelpFunc on cmd. +// Call this on each new Go command after creating it. +func SetPhpStyle(cmd *cobra.Command) { + cmd.SetHelpFunc(func(c *cobra.Command, _ []string) { + c.Print(RenderHelp(c)) + }) +} + +// RenderHelp renders a PHP CLI-style help page for the given Cobra command. +func RenderHelp(cmd *cobra.Command) string { + var b bytes.Buffer + w := tabwriter.NewWriter(&b, 0, 8, 1, ' ', 0) + + // Command name: first word of Use, then rest of Use is the signature. + name := cmd.Use + if idx := strings.Index(name, " "); idx != -1 { + name = name[:idx] + } + fmt.Fprintln(w, color.YellowString("Command: ")+name) + fmt.Fprintln(w, color.YellowString("Description: ")+cmd.Short) + fmt.Fprintln(w, "") + + // Usage line: executable + Use. + root := cmd + for root.HasParent() { + root = root.Parent() + } + binary := root.Use + fmt.Fprintln(w, color.YellowString("Usage:")) + fmt.Fprintln(w, " "+binary+" "+cmd.Use) + if len(cmd.Aliases) > 0 { + for _, alias := range cmd.Aliases { + fmt.Fprintln(w, " "+binary+" "+alias) + } + } + fmt.Fprintln(w, "") + + // Options: local flags only (not inherited). + hasFlags := false + cmd.Flags().VisitAll(func(_ *pflag.Flag) { + hasFlags = true + }) + if hasFlags { + fmt.Fprintln(w, color.YellowString("Options:")) + cmd.Flags().VisitAll(func(f *pflag.Flag) { + shorthand := "" + if f.Shorthand != "" { + shorthand = color.GreenString("-"+f.Shorthand) + "," + } else { + shorthand = " " + } + longName := color.GreenString("--" + f.Name) + fmt.Fprintf(w, " %s %s\t%s\n", shorthand, longName, f.Usage) + }) + fmt.Fprintln(w, "") + } + + w.Flush() + return b.String() +} diff --git a/commands/init.go b/commands/init.go index f2fab75fc..8412fa54e 100644 --- a/commands/init.go +++ b/commands/init.go @@ -27,6 +27,7 @@ import ( "github.com/upsun/cli/internal/auth" "github.com/upsun/cli/internal/config" _init "github.com/upsun/cli/internal/init" + "github.com/upsun/cli/internal/session" ) func newInitCommand(cnf *config.Config, assets *vendorization.VendorAssets) *cobra.Command { @@ -108,8 +109,12 @@ func runInitCommand( cnf := config.FromContext(cmd.Context()) - legacyCLIClient, err := auth.NewLegacyCLIClient(cmd.Context(), - makeLegacyCLIWrapper(cnf, cmd.OutOrStdout(), cmd.ErrOrStderr(), cmd.InOrStdin())) + mgr, err := session.New(cnf) + if err != nil { + return err + } + + authClient, err := auth.NewClient(cmd.Context(), mgr, cnf) if err != nil { return err } @@ -130,7 +135,7 @@ func runInitCommand( var isInteractive = !viper.GetBool("no-interaction") debugLogf("Checking selected organization") - org, err := handleOrganizations(cmd.Context(), cnf, legacyCLIClient, initOptions) + org, err := handleOrganizations(cmd.Context(), cnf, authClient, initOptions) if err != nil { return err } @@ -168,7 +173,7 @@ func runInitCommand( "Note: AI configuration is only compatible with `%s` organizations\n", api.OrgTypeFlexible)) } - if err := legacyCLIClient.EnsureAuthenticated(cmd.Context()); err != nil { + if err := authClient.EnsureAuthenticated(cmd.Context()); err != nil { return err } @@ -178,7 +183,7 @@ func runInitCommand( return err } - initOptions.HTTPClient = legacyCLIClient.HTTPClient + initOptions.HTTPClient = authClient.HTTPClient initOptions.APIURL = cnf.API.BaseURL initOptions.UserAgent = cnf.UserAgent() initOptions.IsInteractive = isInteractive @@ -192,13 +197,13 @@ func runInitCommand( // handleOrganizations manages organization selection and validation. // It modifies initOptions.OrganizationID and initOptions.ProjectID. func handleOrganizations( - ctx context.Context, cnf *config.Config, legacyCLIClient *auth.LegacyCLIClient, initOptions *_init.Options, + ctx context.Context, cnf *config.Config, authClient *auth.Client, initOptions *_init.Options, ) (*api.Organization, error) { if !cnf.API.EnableOrganizations { return nil, nil } - apiClient, err := api.NewClient(cnf.API.BaseURL, legacyCLIClient.HTTPClient) + apiClient, err := api.NewClient(cnf.API.BaseURL, authClient.HTTPClient) if err != nil { return nil, err } @@ -210,7 +215,7 @@ func handleOrganizations( return nil, nil } - if err := legacyCLIClient.EnsureAuthenticated(ctx); err != nil { + if err := authClient.EnsureAuthenticated(ctx); err != nil { return nil, err } diff --git a/commands/root.go b/commands/root.go index 491788296..1fb82a671 100644 --- a/commands/root.go +++ b/commands/root.go @@ -18,6 +18,8 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + authcmds "github.com/upsun/cli/commands/auth" + sessioncmds "github.com/upsun/cli/commands/session" "github.com/upsun/cli/internal" "github.com/upsun/cli/internal/config" "github.com/upsun/cli/internal/config/alt" @@ -88,6 +90,7 @@ func newRootCommand(cnf *config.Config, assets *vendorization.VendorAssets) *cob }, Run: func(cmd *cobra.Command, _ []string) { c := makeLegacyCLIWrapper(cnf, cmd.OutOrStdout(), cmd.ErrOrStderr(), cmd.InOrStdin()) + authcmds.InjectSessionCredentials(cnf, c) if err := c.Exec(cmd.Context(), os.Args[1:]...); err != nil { exitWithError(err) } @@ -143,6 +146,13 @@ func newRootCommand(cnf *config.Config, assets *vendorization.VendorAssets) *cob // Add subcommands. cmd.AddCommand( + authcmds.NewAPITokenLoginCommand(cnf), + authcmds.NewBrowserLoginCommand(cnf), + authcmds.NewInfoCommand(cnf), + authcmds.NewLogoutCommand(cnf), + authcmds.NewTokenCommand(cnf), + authcmds.NewVerifyPhoneNumberCommand(cnf), + sessioncmds.NewSwitchCommand(cnf), newConfigInstallCommand(), newCompletionCommand(cnf), newHelpCommand(cnf), diff --git a/commands/session/switch.go b/commands/session/switch.go new file mode 100644 index 000000000..5b7f5b242 --- /dev/null +++ b/commands/session/switch.go @@ -0,0 +1,91 @@ +// commands/session/switch.go +package session + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + + "github.com/spf13/cobra" + + cobrahelp "github.com/upsun/cli/commands/cobrahelp" + "github.com/upsun/cli/internal/config" + internalsession "github.com/upsun/cli/internal/session" +) + +var validSessionID = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +func NewSwitchCommand(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "session:switch [id]", + Short: "Switch between sessions", + Hidden: true, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + // Blocked if session ID comes from env. + envKey := cfg.Application.EnvPrefix + "SESSION_ID" + if id := os.Getenv(envKey); id != "" { + return fmt.Errorf( + "the session ID is set via the environment variable %s; it cannot be changed using this command", + envKey) + } + + mgr, err := internalsession.New(cfg) + if err != nil { + return err + } + previousID := mgr.SessionID() + + var newID string + switch { + case len(args) > 0: + newID = args[0] + case os.Getenv(cfg.Application.EnvPrefix+"NO_INTERACTION") != "": + return fmt.Errorf("the new session ID is required") + default: + // Interactive prompt. + ids, err := mgr.List() + if err != nil { + return err + } + fmt.Fprintf(cmd.ErrOrStderr(), "Current session ID: %s\n", previousID) + if len(ids) > 0 { + fmt.Fprintf(cmd.ErrOrStderr(), "Existing sessions: %s\n", strings.Join(ids, ", ")) + } + fmt.Fprint(cmd.ErrOrStderr(), "Enter new session ID: ") + scanner := bufio.NewScanner(cmd.InOrStdin()) + if scanner.Scan() { + newID = strings.TrimSpace(scanner.Text()) + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("read input: %w", err) + } + } + + if newID == "" { + return fmt.Errorf("session ID cannot be empty") + } + if strings.HasPrefix(newID, "api-token-") { + return fmt.Errorf("invalid session ID: %q", newID) + } + if !validSessionID.MatchString(newID) { + return fmt.Errorf("invalid session ID %q: must match [a-zA-Z0-9_-]+", newID) + } + + if newID == previousID { + fmt.Fprintf(cmd.ErrOrStderr(), "Session ID is already set as %q\n", newID) + return nil + } + + if err := mgr.SetActiveSessionID(newID); err != nil { + return err + } + fmt.Fprintf(cmd.ErrOrStderr(), "Session ID changed from %q to %q\n", previousID, newID) + return nil + }, + } + cobrahelp.SetPhpStyle(cmd) + return cmd +} diff --git a/go.mod b/go.mod index 6c5c820ff..fc25023fe 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,12 @@ require ( github.com/go-chi/chi/v5 v5.2.3 github.com/go-playground/validator/v10 v10.27.0 github.com/gofrs/flock v0.12.1 + github.com/mattn/go-isatty v0.0.20 + github.com/nyaruka/phonenumbers v1.7.1 github.com/oklog/ulid/v2 v2.1.1 github.com/platformsh/platformify v0.5.0 github.com/spf13/cobra v1.10.1 + github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/symfony-cli/terminal v1.0.7 @@ -84,7 +87,6 @@ require ( github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mholt/archives v0.1.2 // indirect @@ -109,7 +111,6 @@ require ( github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect - github.com/spf13/pflag v1.0.10 // indirect github.com/stoewer/go-strcase v1.3.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tetratelabs/wazero v1.9.0 // indirect @@ -132,6 +133,6 @@ require ( golang.org/x/text v0.29.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect - google.golang.org/protobuf v1.36.8 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 9f64acbc3..ac8346018 100644 --- a/go.sum +++ b/go.sum @@ -265,6 +265,8 @@ github.com/muesli/termenv v0.15.1 h1:UzuTb/+hhlBugQz28rpzey4ZuKcZ03MeKsoG7IJZIxs github.com/muesli/termenv v0.15.1/go.mod h1:HeAQPTzpfs016yGtA4g00CsdYnVLJvxsS4ANqrZs2sQ= github.com/nwaples/rardecode/v2 v2.1.0 h1:JQl9ZoBPDy+nIZGb1mx8+anfHp/LV3NE2MjMiv0ct/U= github.com/nwaples/rardecode/v2 v2.1.0/go.mod h1:7uz379lSxPe6j9nvzxUZ+n7mnJNgjsRNb6IbvGVHRmw= +github.com/nyaruka/phonenumbers v1.7.1 h1:k8FHBMLegwW2tEIhsurC5YJk5Dix++H1k6liu1LUruY= +github.com/nyaruka/phonenumbers v1.7.1/go.mod h1:fsKPJ70O9JetEA4ggnJadYTFWwtGPvu/lETTXNXq6Cs= github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= @@ -565,8 +567,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/integration-tests/auth_api_token_login_test.go b/integration-tests/auth_api_token_login_test.go new file mode 100644 index 000000000..574a52a4a --- /dev/null +++ b/integration-tests/auth_api_token_login_test.go @@ -0,0 +1,143 @@ +package tests + +import ( + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/pkg/mockapi" +) + +func TestAuthAPITokenLogin_Valid(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", Username: "testuser", Email: "test@example.com"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // Clear the pre-set TOKEN so we are testing login from scratch. + // Clear NO_INTERACTION so the PHP CLI can run interactively. + // Set TEST_CLI_AUTH_URL so the PHP CLI can reach the mock auth server. + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", + EnvPrefix+"NO_INTERACTION=", + EnvPrefix+"AUTH_URL="+authServer.URL, + "SHELL_INTERACTIVE=1", + ) + // Pipe stdin: first the API token, then "n" to reject any browser login prompt + // that the PHP CLI may trigger when initializing the API client. + f.stdin = strings.NewReader(mockapi.ValidAPITokens[0] + "\nn\n" + mockapi.ValidAPITokens[0] + "\n") + _, stderr, err := f.RunCombinedOutput("auth:api-token-login") + require.NoError(t, err) + assert.Contains(t, stderr, "logged in") +} + +// TestAuthAPITokenLogin_PHPCommandAfterLogin verifies that after logging in via +// the Go auth:api-token-login command, subsequent PHP commands can authenticate +// using the stored session (via injectSessionAuth), without TOKEN being pre-set. +// This is a regression test for the credential-helper incompatibility bug. +func TestAuthAPITokenLogin_PHPCommandAfterLogin(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + + myUserID := "u1" + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: myUserID, Username: "testuser", Email: "test@example.com"}) + apiHandler.SetOrgs([]*mockapi.Org{ + { + ID: "org-id-1", + Name: "acme", + Label: "ACME Inc.", + Owner: myUserID, + Type: "flexible", + Capabilities: []string{}, + Links: mockapi.MakeHALLinks("self=/organizations/" + url.PathEscape("org-id-1")), + }, + }) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // Clear the pre-set TOKEN so that subsequent commands must rely on the stored session. + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", + EnvPrefix+"AUTH_URL="+authServer.URL, + "SHELL_INTERACTIVE=1", + ) + + // Step 1: Log in via the Go auth:api-token-login command. + // Token is passed as an argument to avoid interactive stdin complications. + _, stderr, err := f.RunCombinedOutput("auth:api-token-login", mockapi.ValidAPITokens[0]) + require.NoError(t, err, "login must succeed; stderr: %s", stderr) + assert.Contains(t, stderr, "logged in") + + // Step 2: Run a PHP-backed command (orgs) without TOKEN in env. + // injectSessionAuth must read the stored API token and inject it into the PHP subprocess. + out, errOut, err := f.RunCombinedOutput("orgs", "--format", "csv", "--columns", "name", "--no-header") + require.NoError(t, err, "php command must succeed after login; stderr: %s", errOut) + assert.Contains(t, out, "acme") +} + +func TestAuthAPITokenLogin_Invalid(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, EnvPrefix+"TOKEN=") + _, _, err := f.RunCombinedOutput("auth:api-token-login", "bad-token") + assert.Error(t, err) +} + +// TestAuthAPITokenLogin_RetryOnInvalid: feeding an invalid token then a valid one +// should succeed on the second attempt (retry up to 5 times on invalid token). +func TestAuthAPITokenLogin_RetryOnInvalid(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", Username: "testuser", Email: "test@example.com"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", + EnvPrefix+"NO_INTERACTION=", + "SHELL_INTERACTIVE=1", + ) + // First line is an invalid token; second is the valid one. + f.stdin = strings.NewReader("bad-token\n" + mockapi.ValidAPITokens[0] + "\n") + + _, stderr, err := f.RunCombinedOutput("auth:api-token-login") + require.NoError(t, err, "expected success after retry; stderr: %s", stderr) + assert.Contains(t, stderr, "invalid API token") + assert.Contains(t, stderr, "The API token is valid.") +} + +// TestAuthAPITokenLogin_ExhaustsRetries: 5 consecutive invalid tokens should fail. +func TestAuthAPITokenLogin_ExhaustsRetries(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", + EnvPrefix+"NO_INTERACTION=", + "SHELL_INTERACTIVE=1", + ) + // 5 invalid tokens — all should be rejected, command exits non-zero. + f.stdin = strings.NewReader("bad1\nbad2\nbad3\nbad4\nbad5\n") + + _, stderr, err := f.RunCombinedOutput("auth:api-token-login") + require.Error(t, err, "expected failure after 5 invalid tokens") + assert.Contains(t, stderr, "invalid API token") +} diff --git a/integration-tests/auth_browser_login_test.go b/integration-tests/auth_browser_login_test.go new file mode 100644 index 000000000..5f4e5401a --- /dev/null +++ b/integration-tests/auth_browser_login_test.go @@ -0,0 +1,167 @@ +// integration-tests/auth_browser_login_test.go +package tests + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/pkg/mockapi" +) + +func TestAuthBrowserLogin_Success(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", Username: "testuser", Email: "test@example.com"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // Clear API token so the command doesn't bail out with "Cannot log in via the browser". + // Clear NO_INTERACTION so the command doesn't bail out with "Non-interactive use of this command is not supported". + // Set SHELL_INTERACTIVE so the PHP CLI treats stdin as interactive even without a real TTY. + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", + EnvPrefix+"NO_INTERACTION=", + "SHELL_INTERACTIVE=1", + ) + + cmd := f.buildCommand("auth:browser-login") + // Override any stderr set by buildCommand (e.g. in verbose mode) so we can pipe it. + cmd.Stderr = nil + stderrPipe, err := cmd.StderrPipe() + require.NoError(t, err) + require.NoError(t, cmd.Start()) + + // Read stderr lines until we find the local server port. + portCh := make(chan string, 1) + go func() { + re := regexp.MustCompile(`127\.0\.0\.1:(\d+)`) + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + line := scanner.Text() + if m := re.FindStringSubmatch(line); m != nil { + portCh <- m[1] + // Drain remaining stderr. + for scanner.Scan() { + } + return + } + } + }() + + select { + case port := <-portCh: + // Simulate the browser hitting the callback with a valid auth code. + // The CLI's local server is at 127.0.0.1:. + // We need to get the state parameter first by following the authorize redirect. + // Use a non-redirecting client to capture the state. + localURL := fmt.Sprintf("http://127.0.0.1:%s", port) + + // Give the local server a moment to start. + time.Sleep(100 * time.Millisecond) + + // Fetch the local page to get the redirect URL with state. + noRedirect := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + resp, err := noRedirect.Get(localURL) + require.NoError(t, err) + loc := resp.Header.Get("Location") + _ = resp.Body.Close() + + // Extract state from the authorize redirect. + stateRe := regexp.MustCompile(`[?&]state=([^&]+)`) + stateMatch := stateRe.FindStringSubmatch(loc) + require.NotEmpty(t, stateMatch, "expected state in redirect URL, got: %s", loc) + + // Hit the authorize endpoint to get the auth code. + authResp, err := noRedirect.Get(loc) + require.NoError(t, err) + callbackLoc := authResp.Header.Get("Location") + _ = authResp.Body.Close() + + // The auth server redirects to our local callback — GET to simulate browser. + callbackResp, err := http.Get(callbackLoc) //nolint:gosec + require.NoError(t, err) + _, _ = io.Copy(io.Discard, callbackResp.Body) + _ = callbackResp.Body.Close() + + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for local server to start") + } + + err = cmd.Wait() + require.NoError(t, err) +} + +// writeOAuthSession writes a pre-populated OAuth session directly to the filesystem for a given +// homeDir and session ID. This bypasses the session.Manager so integration tests can set up +// authenticated state without running a full login flow. +func writeOAuthSession(t *testing.T, homeDir, sessionID string, s map[string]interface{}) { + t.Helper() + base := filepath.Join(homeDir, ".platform-test-cli", ".session") + sessDir := filepath.Join(base, "sess-"+sessionID) + cliDir := filepath.Join(base, "sess-cli-"+sessionID) + require.NoError(t, os.MkdirAll(sessDir, 0o700)) + require.NoError(t, os.MkdirAll(cliDir, 0o700)) + data, err := json.Marshal(s) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(sessDir, "sess-"+sessionID+".json"), data, 0o600)) +} + +func TestAuthBrowserLogin_AlreadyLoggedInOAuth_DeclineRelogin(t *testing.T) { + // No servers needed: the command exits before reaching the browser flow. + f := newCommandFactory(t, "", "") + f.extraEnv = append(f.extraEnv, + EnvPrefix+"NO_INTERACTION=", // allow interactive mode (testEnv sets it to 1) + "SHELL_INTERACTIVE=1", + ) + + // Pre-populate a valid, non-expired OAuth session. + writeOAuthSession(t, f.homeDir, "default", map[string]interface{}{ + "accessToken": "test-oauth-token", + "tokenType": "bearer", + "expires": time.Now().Add(time.Hour).Unix(), + "refreshToken": "test-refresh", + }) + + // Pipe "n" to decline re-login. + f.stdin = strings.NewReader("n\n") + + _, stderr, err := f.RunCombinedOutput("auth:browser-login") + require.Error(t, err, "expected exit 1 when user declines re-login, stderr: %s", stderr) + assert.Contains(t, stderr, "You are already logged in") +} + +func TestAuthBrowserLogin_AlreadyLoggedIn(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", Username: "testuser", Email: "test@example.com"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // TOKEN is set by the factory → hasApiToken returns true → command refuses to start browser login. + _, stderr, err := f.RunCombinedOutput("auth:browser-login") + // The command exits non-zero when an API token is configured. + require.Error(t, err) + assert.True(t, + strings.Contains(stderr, "Cannot log in via the browser") || strings.Contains(stderr, "log in"), + "unexpected stderr: %s", stderr, + ) +} diff --git a/integration-tests/auth_info_test.go b/integration-tests/auth_info_test.go index a305b5735..0e312cd00 100644 --- a/integration-tests/auth_info_test.go +++ b/integration-tests/auth_info_test.go @@ -1,10 +1,18 @@ package tests import ( + "bufio" + "fmt" + "io" + "net/http" "net/http/httptest" + "regexp" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upsun/cli/pkg/mockapi" ) @@ -49,3 +57,158 @@ func TestAuthInfo(t *testing.T) { assert.Equal(t, "my-user-id\n", f.Run("auth:info", "-P", "id")) } + +func TestAuthInfo_NoAutoLogin_NotLoggedIn(t *testing.T) { + f := newCommandFactory(t, "", "") + // No auth configured — with --no-auto-login should exit 0 and produce no stdout. + out, stderr, err := f.RunCombinedOutput("auth:info", "--no-auto-login", "-P", "id") + t.Log("stderr:", stderr) + require.NoError(t, err) + assert.Empty(t, strings.TrimSpace(out)) +} + +func TestAuthInfo_NotLoggedIn_DeclineRelogin(t *testing.T) { + f := newCommandFactory(t, "", "") + f.extraEnv = append(f.extraEnv, + EnvPrefix+"NO_INTERACTION=", // allow interactive mode (testEnv sets it to 1) + "SHELL_INTERACTIVE=1", + ) + f.stdin = strings.NewReader("n\n") + + _, stderr, err := f.RunCombinedOutput("auth:info") + require.Error(t, err, "expected exit 1 when user declines re-login") + assert.Contains(t, stderr, "Authentication is required.") + assert.Contains(t, stderr, "not logged in") +} + +func TestAuthInfo_ExpiredSession_DeclineRelogin(t *testing.T) { + f := newCommandFactory(t, "", "") + f.extraEnv = append(f.extraEnv, + EnvPrefix+"NO_INTERACTION=", // allow interactive mode (testEnv sets it to 1) + "SHELL_INTERACTIVE=1", + ) + + // Pre-populate an expired OAuth session. + writeOAuthSession(t, f.homeDir, "default", map[string]interface{}{ + "accessToken": "expired-token", + "tokenType": "bearer", + "expires": time.Now().Add(-time.Hour).Unix(), + "refreshToken": "expired-refresh", + }) + + f.stdin = strings.NewReader("n\n") + + _, stderr, err := f.RunCombinedOutput("auth:info") + require.Error(t, err, "expected exit 1 when user declines re-login after session expiry") + assert.Contains(t, stderr, "Your session has expired. You have been logged out.") + assert.Contains(t, stderr, "Authentication is required.") + assert.Contains(t, stderr, "not logged in") +} + +func TestAuthInfo_NotLoggedIn_NoInteraction(t *testing.T) { + // testEnv sets NO_INTERACTION=1 via env var — no prompt should appear. + f := newCommandFactory(t, "", "") + + _, stderr, err := f.RunCombinedOutput("auth:info") + require.Error(t, err, "expected exit 1 when not logged in (non-interactive via env)") + assert.NotContains(t, stderr, "Log in via a browser") + assert.Contains(t, stderr, "not logged in") +} + +func TestAuthInfo_NotLoggedIn_FlagNoInteraction(t *testing.T) { + // --no-interaction flag (via Viper) must also suppress the prompt. + f := newCommandFactory(t, "", "") + f.extraEnv = append(f.extraEnv, EnvPrefix+"NO_INTERACTION=") + + _, stderr, err := f.RunCombinedOutput("auth:info", "--no-interaction") + require.Error(t, err, "expected exit 1 when not logged in (--no-interaction flag)") + assert.NotContains(t, stderr, "Log in via a browser") + assert.Contains(t, stderr, "not logged in") +} + +func TestAuthInfo_NotLoggedIn_FlagYes(t *testing.T) { + // --yes must auto-accept the browser login prompt and complete the full flow. + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", Username: "testuser", Email: "test@example.com"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, + EnvPrefix+"TOKEN=", // clear API token so browser flow is used + EnvPrefix+"NO_INTERACTION=", // clear env var; --yes must override + "SHELL_INTERACTIVE=1", + ) + + cmd := f.buildCommand("auth:info", "--yes") + cmd.Stderr = nil + stderrPipe, err := cmd.StderrPipe() + require.NoError(t, err) + require.NoError(t, cmd.Start()) + + // Read stderr until we find the local callback server port. + portCh := make(chan string, 1) + go func() { + re := regexp.MustCompile(`127\.0\.0\.1:(\d+)`) + scanner := bufio.NewScanner(stderrPipe) + for scanner.Scan() { + if m := re.FindStringSubmatch(scanner.Text()); m != nil { + portCh <- m[1] + for scanner.Scan() { // drain + } + return + } + } + }() + + select { + case port := <-portCh: + localURL := fmt.Sprintf("http://127.0.0.1:%s", port) + time.Sleep(100 * time.Millisecond) + + noRedirect := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + resp, err := noRedirect.Get(localURL) + require.NoError(t, err) + loc := resp.Header.Get("Location") + _ = resp.Body.Close() + + authResp, err := noRedirect.Get(loc) + require.NoError(t, err) + callbackLoc := authResp.Header.Get("Location") + _ = authResp.Body.Close() + + callbackResp, err := http.Get(callbackLoc) //nolint:noctx,gosec + require.NoError(t, err) + _, _ = io.Copy(io.Discard, callbackResp.Body) + _ = callbackResp.Body.Close() + + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for browser callback server to start") + } + + require.NoError(t, cmd.Wait()) +} + +func TestAuthInfo_DeprecatedAliases(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ + ID: "uid-1", FirstName: "Foo", LastName: "Bar", Email: "foo@example.com", + }) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + f := newCommandFactory(t, apiServer.URL, authServer.URL) + + // display_name is deprecated but must still work. + out := f.Run("auth:info", "-P", "display_name") + assert.Equal(t, "Foo Bar\n", out) + + // mail is deprecated alias for email. + out = f.Run("auth:info", "-P", "mail") + assert.Equal(t, "foo@example.com\n", out) +} diff --git a/integration-tests/auth_logout_test.go b/integration-tests/auth_logout_test.go new file mode 100644 index 000000000..5cb67e5e2 --- /dev/null +++ b/integration-tests/auth_logout_test.go @@ -0,0 +1,99 @@ +package tests + +import ( + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/pkg/mockapi" +) + +func TestAuthLogout_Single(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + + _, stderr, err := f.RunCombinedOutput("auth:logout") + require.NoError(t, err) + assert.Contains(t, stderr, "logged out") +} + +// TestAuthLogout_OtherSessionsHint: when other sessions still exist after a single logout, +// the message must match the PHP wording exactly. +func TestAuthLogout_OtherSessionsHint(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + + f := newCommandFactory(t, "", authServer.URL) + + // Pre-populate two sessions so "other sessions exist" branch fires. + future := time.Now().Add(time.Hour).Unix() + writeOAuthSession(t, f.homeDir, "default", map[string]interface{}{ + "accessToken": "token-default", "tokenType": "bearer", "expires": future, + }) + writeOAuthSession(t, f.homeDir, "other", map[string]interface{}{ + "accessToken": "token-other", "tokenType": "bearer", "expires": future, + }) + + _, stderr, err := f.RunCombinedOutput("auth:logout") + require.NoError(t, err, "stderr: %s", stderr) + assert.Contains(t, stderr, "Log out of all sessions with:") +} + +func TestAuthLogout_All(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + + _, stderr, err := f.RunCombinedOutput("auth:logout", "--all") + require.NoError(t, err) + assert.Contains(t, stderr, "logged out") +} + +func TestAuthLogout_Other(t *testing.T) { + // Use the auth server so the revoke POST has a valid endpoint to hit. + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + + // Factory sets TEST_CLI_AUTH_URL and TEST_CLI_TOKEN (both needed: AUTH_URL for revoke, TOKEN is just a warning). + f := newCommandFactory(t, "", authServer.URL) + + // Pre-populate two sessions: "default" (current) and "other". + future := time.Now().Add(time.Hour).Unix() + writeOAuthSession(t, f.homeDir, "default", map[string]interface{}{ + "accessToken": "token-default", + "tokenType": "bearer", + "expires": future, + }) + writeOAuthSession(t, f.homeDir, "other", map[string]interface{}{ + "accessToken": "token-other", + "tokenType": "bearer", + "expires": future, + }) + + _, stderr, err := f.RunCombinedOutput("auth:logout", "--other") + require.NoError(t, err, "stderr: %s", stderr) + assert.Contains(t, stderr, "All other sessions have been deleted") + + // "default" session file must still exist. + defaultSessFile := filepath.Join(f.homeDir, ".platform-test-cli", ".session", "sess-default", "sess-default.json") + assert.FileExists(t, defaultSessFile) + + // "other" session dir must be gone. + otherSessDir := filepath.Join(f.homeDir, ".platform-test-cli", ".session", "sess-other") + _, statErr := os.Stat(otherSessDir) + assert.True(t, os.IsNotExist(statErr), "expected sess-other dir to be deleted, but it still exists") +} diff --git a/integration-tests/auth_token_test.go b/integration-tests/auth_token_test.go new file mode 100644 index 000000000..2095c1b12 --- /dev/null +++ b/integration-tests/auth_token_test.go @@ -0,0 +1,46 @@ +package tests + +import ( + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/pkg/mockapi" +) + +func TestAuthToken_PrintsToken(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + out := f.Run("auth:token", "--no-warn") + assert.Equal(t, "access-token-1", strings.TrimSpace(out)) +} + +func TestAuthToken_HeaderFlag(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + out := f.Run("auth:token", "--no-warn", "--header") + assert.Equal(t, "Authorization: Bearer access-token-1", strings.TrimSpace(out)) +} + +func TestAuthToken_WarnsByDefault(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiServer := httptest.NewServer(mockapi.NewHandler(t)) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + _, stderr, err := f.RunCombinedOutput("auth:token") + require.NoError(t, err) + assert.Contains(t, stderr, "Warning") +} diff --git a/integration-tests/auth_verify_phone_test.go b/integration-tests/auth_verify_phone_test.go new file mode 100644 index 000000000..d403a1a33 --- /dev/null +++ b/integration-tests/auth_verify_phone_test.go @@ -0,0 +1,133 @@ +// integration-tests/auth_verify_phone_test.go +package tests + +import ( + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/pkg/mockapi" +) + +func TestAuthVerifyPhoneNumber_AlreadyVerified(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: true}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // The PHP command checks isInteractive() before checking phone status, so we must + // disable NO_INTERACTION. SHELL_INTERACTIVE makes the PHP CLI treat stdin as a TTY. + f.extraEnv = append(f.extraEnv, + EnvPrefix+"NO_INTERACTION=0", + "SHELL_INTERACTIVE=1", + ) + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.NoError(t, err) + assert.Contains(t, stderr, "already has a verified phone number") +} + +// TestAuthVerifyPhoneNumber_InvalidPhoneRetry: an invalid phone number is re-prompted +// (not an immediate exit), then a valid number succeeds. +func TestAuthVerifyPhoneNumber_InvalidPhoneRetry(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: false, Country: "US"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, EnvPrefix+"NO_INTERACTION=0", "SHELL_INTERACTIVE=1") + // First phone number is invalid; second is valid. + f.stdin = strings.NewReader("0\nnot-a-number\n+12015550123\n" + mockapi.TestPhoneVerificationCode + "\n") + + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.NoError(t, err, "expected success after retrying with valid number; stderr: %s", stderr) + assert.Contains(t, stderr, "verified") +} + +// TestAuthVerifyPhoneNumber_InvalidCodeRetry: an invalid verification code is re-prompted +// (not an immediate exit), then the correct code succeeds. +func TestAuthVerifyPhoneNumber_InvalidCodeRetry(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: false, Country: "US"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, EnvPrefix+"NO_INTERACTION=0", "SHELL_INTERACTIVE=1") + // Valid phone, then wrong code, then correct code. + f.stdin = strings.NewReader("0\n+12015550123\n000000\n" + mockapi.TestPhoneVerificationCode + "\n") + + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.NoError(t, err, "expected success after retrying with correct code; stderr: %s", stderr) + assert.Contains(t, stderr, "verified") +} + +// TestAuthVerifyPhoneNumber_ExhaustPhoneAttempts: 5 consecutive invalid phone numbers exit non-zero. +func TestAuthVerifyPhoneNumber_ExhaustPhoneAttempts(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: false, Country: "US"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, EnvPrefix+"NO_INTERACTION=0", "SHELL_INTERACTIVE=1") + f.stdin = strings.NewReader("0\nbad1\nbad2\nbad3\nbad4\nbad5\n") + + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.Error(t, err, "expected failure after 5 invalid phone numbers") + assert.Contains(t, stderr, "The phone number is not valid.") +} + +// TestAuthVerifyPhoneNumber_ExhaustCodeAttempts: 5 consecutive wrong codes exit non-zero. +func TestAuthVerifyPhoneNumber_ExhaustCodeAttempts(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: false, Country: "US"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + f.extraEnv = append(f.extraEnv, EnvPrefix+"NO_INTERACTION=0", "SHELL_INTERACTIVE=1") + f.stdin = strings.NewReader("0\n+12015550123\n000001\n000002\n000003\n000004\n000005\n") + + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.Error(t, err, "expected failure after 5 wrong codes") + assert.Contains(t, stderr, "Invalid verification code") +} + +func TestAuthVerifyPhoneNumber_SMSFlow(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + apiHandler := mockapi.NewHandler(t) + apiHandler.SetMyUser(&mockapi.User{ID: "u1", PhoneNumberVerified: false, Country: "US"}) + apiServer := httptest.NewServer(apiHandler) + defer apiServer.Close() + + f := newCommandFactory(t, apiServer.URL, authServer.URL) + // Disable NO_INTERACTION so the command accepts interactive prompts. + // Set SHELL_INTERACTIVE so the PHP CLI treats stdin as interactive even without a real TTY. + f.extraEnv = append(f.extraEnv, + EnvPrefix+"NO_INTERACTION=0", + "SHELL_INTERACTIVE=1", + ) + // Use the factory stdin field so buildCommand wires it up before CombinedOutput locks stderr. + // Provide stdin: method=sms (choice 0), phone number, verification code. + // The number must be a valid E.164 number parseable with the user's country (US). + f.stdin = strings.NewReader("0\n+12015550123\n" + mockapi.TestPhoneVerificationCode + "\n") + _, stderr, err := f.RunCombinedOutput("auth:verify-phone-number") + require.NoError(t, err, stderr) + assert.Contains(t, stderr, "verified") +} diff --git a/integration-tests/session_switch_test.go b/integration-tests/session_switch_test.go new file mode 100644 index 000000000..5b955c06a --- /dev/null +++ b/integration-tests/session_switch_test.go @@ -0,0 +1,33 @@ +package tests + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSessionSwitch_WritesIDFile(t *testing.T) { + // session:switch does not require API or auth server access. + f := newCommandFactory(t, "", "") + + _, _, err := f.RunCombinedOutput("session:switch", "work") + require.NoError(t, err) + + idFile := filepath.Join(f.homeDir, ".platform-test-cli", "session-id") + data, err := os.ReadFile(idFile) + require.NoError(t, err) + assert.Equal(t, "work", strings.TrimSpace(string(data))) +} + +func TestSessionSwitch_BlockedByEnvVar(t *testing.T) { + f := newCommandFactory(t, "", "") + f.extraEnv = append(f.extraEnv, EnvPrefix+"SESSION_ID=env-session") + + _, stderr, err := f.RunCombinedOutput("session:switch", "other") + assert.Error(t, err) + assert.Contains(t, stderr, "environment variable") +} diff --git a/integration-tests/tests.go b/integration-tests/tests.go index f63335bf1..3269b4064 100644 --- a/integration-tests/tests.go +++ b/integration-tests/tests.go @@ -47,7 +47,7 @@ func getCommandName(t *testing.T) string { candidate = c } versionCmd := exec.Command(candidate, "--version") - versionCmd.Env = testEnv() + versionCmd.Env = testEnv(os.TempDir()) output, err := versionCmd.Output() require.NoError(t, err, "running '--version' must succeed under the CLI at: %s", candidate) require.Contains(t, string(output), "Platform Test CLI ") @@ -60,11 +60,13 @@ type cmdFactory struct { t *testing.T apiURL string authURL string + homeDir string extraEnv []string + stdin io.Reader } func newCommandFactory(t *testing.T, apiURL, authURL string) *cmdFactory { - return &cmdFactory{t: t, apiURL: apiURL, authURL: authURL} + return &cmdFactory{t: t, apiURL: apiURL, authURL: authURL, homeDir: t.TempDir()} } // Run runs a command, asserts that it did not error, and returns its normal (stdout) output. @@ -93,18 +95,21 @@ func (f *cmdFactory) RunCombinedOutput(args ...string) (stdOut, stdErr string, e func (f *cmdFactory) buildCommand(args ...string) *exec.Cmd { cmd := exec.Command(getCommandName(f.t), args...) //nolint:gosec - cmd.Env = testEnv() - cmd.Dir = os.TempDir() + cmd.Env = testEnv(f.homeDir) + cmd.Dir = f.homeDir if testing.Verbose() { cmd.Stderr = os.Stderr } if f.apiURL != "" { - cmd.Env = append(cmd.Env, EnvPrefix+"API_BASE_URL="+f.apiURL) + cmd.Env = append(cmd.Env, EnvPrefix+"API_URL="+f.apiURL) } if f.authURL != "" { - cmd.Env = append(cmd.Env, EnvPrefix+"API_AUTH_URL="+f.authURL, EnvPrefix+"TOKEN="+mockapi.ValidAPITokens[0]) + cmd.Env = append(cmd.Env, EnvPrefix+"AUTH_URL="+f.authURL, EnvPrefix+"TOKEN="+mockapi.ValidAPITokens[0]) } cmd.Env = append(cmd.Env, f.extraEnv...) + if f.stdin != nil { + cmd.Stdin = f.stdin + } return cmd } @@ -114,7 +119,7 @@ func assertTrimmed(t *testing.T, expected, actual string) { const EnvPrefix = "TEST_CLI_" -func testEnv() []string { +func testEnv(homeDir string) []string { configPath, err := filepath.Abs("config.yaml") if err != nil { panic(err) @@ -125,7 +130,7 @@ func testEnv() []string { "CLI_CONFIG_FILE="+configPath, EnvPrefix+"NO_INTERACTION=1", EnvPrefix+"VERSION=1.0.0", - EnvPrefix+"HOME="+os.TempDir(), + EnvPrefix+"HOME="+homeDir, "TZ=UTC", ) } diff --git a/internal/api/users.go b/internal/api/users.go new file mode 100644 index 000000000..ecd95b8b0 --- /dev/null +++ b/internal/api/users.go @@ -0,0 +1,112 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +// GetMyUser fetches the current user's account information from GET /users/me. +func (c *Client) GetMyUser(ctx context.Context, _ bool) (map[string]interface{}, error) { + u, err := c.resolveURL("users/me") + if err != nil { + return nil, err + } + var result map[string]interface{} + if err := c.getResource(ctx, u.String(), &result); err != nil { + return nil, err + } + return result, nil +} + +// SendPhoneVerification initiates phone verification for the given user. +// Returns the verification session ID (sid) needed for VerifyPhone. +func (c *Client) SendPhoneVerification(ctx context.Context, userID, phoneNumber, channel string) (string, error) { + u, err := c.baseURLWithSegments("users", userID, "phonenumber") + if err != nil { + return "", fmt.Errorf("send phone verification: %w", err) + } + body, err := json.Marshal(map[string]string{"phone_number": phoneNumber, "channel": channel}) + if err != nil { + return "", fmt.Errorf("send phone verification: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("send phone verification: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.HTTPClient.Do(req) + if err != nil { + return "", fmt.Errorf("send phone verification: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("send phone verification: server returned %d", resp.StatusCode) + } + var result struct { + SID string `json:"sid"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("send phone verification: decode response: %w", err) + } + return result.SID, nil +} + +// CheckVerificationStatus posts to /me/verification?force_refresh=1 and returns an error +// if the user's verification type is still "phone" (i.e., phone verification did not complete). +func (c *Client) CheckVerificationStatus(ctx context.Context) error { + u, err := c.resolveURL("me/verification") + if err != nil { + return err + } + q := u.Query() + q.Set("force_refresh", "1") + u.RawQuery = q.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), http.NoBody) + if err != nil { + return err + } + resp, err := c.HTTPClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + var result struct { + Type string `json:"type"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + if result.Type == "phone" { + return fmt.Errorf("phone verification status is still pending") + } + return nil +} + +// VerifyPhone confirms the verification code for the given user and session ID. +func (c *Client) VerifyPhone(ctx context.Context, userID, sid, code string) error { + u, err := c.baseURLWithSegments("users", userID, "phonenumber", sid) + if err != nil { + return fmt.Errorf("verify phone: %w", err) + } + body, err := json.Marshal(map[string]string{"code": code}) + if err != nil { + return fmt.Errorf("verify phone: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("verify phone: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("verify phone: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("verify phone: server returned %d", resp.StatusCode) + } + return nil +} diff --git a/internal/auth/client.go b/internal/auth/client.go index 8d1f77129..a50b2aa00 100644 --- a/internal/auth/client.go +++ b/internal/auth/client.go @@ -2,36 +2,30 @@ package auth import ( "context" - "fmt" "net/http" "golang.org/x/oauth2" - "github.com/upsun/cli/internal/legacy" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" ) -type LegacyCLIClient struct { +// Client is an authenticated HTTP client for the Upsun API. +type Client struct { HTTPClient *http.Client - tokenSource oauth2.TokenSource + tokenSource *sessionTokenSource } -func (c *LegacyCLIClient) EnsureAuthenticated(_ context.Context) error { +// EnsureAuthenticated checks that a valid token is available. +func (c *Client) EnsureAuthenticated(_ context.Context) error { _, err := c.tokenSource.Token() return err } -// NewLegacyCLIClient creates an HTTP client authenticated through the legacy CLI. -// The wrapper argument must be a dedicated wrapper, not used by other callers. -func NewLegacyCLIClient(ctx context.Context, wrapper *legacy.CLIWrapper) (*LegacyCLIClient, error) { - ts, err := NewLegacyCLITokenSource(ctx, wrapper) - if err != nil { - return nil, fmt.Errorf("oauth2: create token source: %w", err) - } +// NewClient creates an HTTP client authenticated via the session Manager. +func NewClient(ctx context.Context, mgr *session.Manager, cfg *config.Config) (*Client, error) { + ts := NewSessionTokenSource(mgr, cfg) - refresher, ok := ts.(refresher) - if !ok { - return nil, fmt.Errorf("token source does not implement refresher") - } baseRT := http.DefaultTransport if rt, ok := TransportFromContext(ctx); ok && rt != nil { baseRT = rt @@ -39,7 +33,7 @@ func NewLegacyCLIClient(ctx context.Context, wrapper *legacy.CLIWrapper) (*Legac httpClient := &http.Client{ Transport: &Transport{ - refresher: refresher, + refresher: ts, base: &oauth2.Transport{ Source: ts, Base: baseRT, @@ -47,7 +41,7 @@ func NewLegacyCLIClient(ctx context.Context, wrapper *legacy.CLIWrapper) (*Legac }, } - return &LegacyCLIClient{ + return &Client{ HTTPClient: httpClient, tokenSource: ts, }, nil diff --git a/internal/auth/flow.go b/internal/auth/flow.go new file mode 100644 index 000000000..a6fb3e2b5 --- /dev/null +++ b/internal/auth/flow.go @@ -0,0 +1,230 @@ +// internal/auth/flow.go +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +// BrowserFlowOptions configures the browser login flow. +type BrowserFlowOptions struct { + Force bool + Methods []string + MaxAge *int + // Stderr is the writer used for user-facing messages (local URL, instructions). + // Defaults to os.Stderr if nil. + Stderr io.Writer + // OnCodeReceived is called just before the authorization code is exchanged. + // May be nil. + OnCodeReceived func() +} + +// BrowserFlow orchestrates the OAuth2 PKCE browser login flow. +type BrowserFlow struct { + cfg *config.Config + OpenURL func(string) error // override for testing; defaults to opening the system browser + HTTPClient *http.Client // override for testing; defaults to http.DefaultClient +} + +// NewBrowserFlow creates a BrowserFlow with system browser support. +func NewBrowserFlow(cfg *config.Config) *BrowserFlow { + return &BrowserFlow{cfg: cfg, OpenURL: openSystemBrowser, HTTPClient: http.DefaultClient} +} + +// Run performs the full PKCE flow and returns a session on success. +func (f *BrowserFlow) Run(ctx context.Context, opts BrowserFlowOptions) (*session.Session, error) { + verifier, err := GenerateVerifier() + if err != nil { + return nil, err + } + challenge := VerifierToChallenge(verifier) + state, err := GenerateVerifier() // state is any random string + if err != nil { + return nil, err + } + + // Find an available port in 5000–5010. + listener, port, err := findPort(5000, 5010) + if err != nil { + return nil, fmt.Errorf("find available port (5000-5010): %w", err) + } + localURL := fmt.Sprintf("http://127.0.0.1:%d", port) + + // Resolve the writer for user-facing messages. + w := opts.Stderr + if w == nil { + w = os.Stderr + } + + // Build the authorization URL ahead of time. + authURL := f.buildAuthURL(localURL, challenge, state, opts) + + // Channel to receive the auth code from the callback handler. + codeCh := make(chan callbackResult, 1) + mux := http.NewServeMux() + mux.HandleFunc("/", func(hw http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + // If no code param, redirect to the authorization server. + if q.Get("code") == "" && q.Get("error") == "" { + http.Redirect(hw, r, authURL, http.StatusFound) + return + } + if errParam := q.Get("error"); errParam != "" { + select { + case codeCh <- callbackResult{err: fmt.Errorf("OAuth error: %s — %s", errParam, q.Get("error_description"))}: + default: + } + fmt.Fprintln(hw, "Login failed. You may close this tab.") + return + } + if q.Get("state") != state { + select { + case codeCh <- callbackResult{err: fmt.Errorf("state mismatch")}: + default: + } + fmt.Fprintln(hw, "Login failed (invalid state). You may close this tab.") + return + } + select { + case codeCh <- callbackResult{code: q.Get("code"), redirectURI: localURL}: + default: + } + fmt.Fprintln(hw, "Login successful. You may close this tab.") + }) + + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} + go func() { _ = srv.Serve(listener) }() + defer func() { + shutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(shutCtx) + }() + + // Open browser or print URL. + if err := f.OpenURL(localURL); err != nil { + fmt.Fprintf(w, "Please open the following URL in a browser:\n%s\n", localURL) + } else { + fmt.Fprintf(w, "Opened URL: %s\nPlease use the browser to log in.\n", localURL) + } + + // Wait for callback (30-minute timeout). + select { + case result := <-codeCh: + if result.err != nil { + return nil, result.err + } + if opts.OnCodeReceived != nil { + opts.OnCodeReceived() + } + return f.exchangeCode(ctx, result.code, verifier, result.redirectURI) + case <-time.After(30 * time.Minute): + return nil, fmt.Errorf("login timed out after 30 minutes") + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +type callbackResult struct { + code string + redirectURI string + err error +} + +func (f *BrowserFlow) buildAuthURL(localURL, challenge, state string, opts BrowserFlowOptions) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {f.cfg.API.OAuth2ClientID}, + "redirect_uri": {localURL}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + "scope": {"offline_access"}, + } + prompt := "consent" + if opts.Force { + prompt = "consent select_account" + } + params.Set("prompt", prompt) + if len(opts.Methods) > 0 { + params.Set("acr_values", strings.Join(opts.Methods, " ")) + } + if opts.MaxAge != nil { + params.Set("max_age", fmt.Sprintf("%d", *opts.MaxAge)) + } + return OAuth2AuthorizeURL(f.cfg) + "?" + params.Encode() +} + +func (f *BrowserFlow) exchangeCode(ctx context.Context, code, verifier, redirectURI string) (*session.Session, error) { + tokenURL := OAuth2TokenURL(f.cfg) + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {verifier}, + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("code exchange: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if f.cfg.API.OAuth2ClientID != "" { + req.SetBasicAuth(f.cfg.API.OAuth2ClientID, "") + } + + resp, err := f.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("code exchange: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("code exchange: server returned %d", resp.StatusCode) + } + + var result struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Error string `json:"error"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("code exchange: decode response: %w", err) + } + if result.Error != "" { + return nil, fmt.Errorf("code exchange: %s", result.Error) + } + + expiry := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second).Unix() + return &session.Session{ + AccessToken: result.AccessToken, + TokenType: result.TokenType, + Expires: expiry, + RefreshToken: result.RefreshToken, + }, nil +} + +func findPort(start, end int) (net.Listener, int, error) { + for port := start; port <= end; port++ { + l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err == nil { + return l, port, nil + } + } + return nil, 0, fmt.Errorf("failed to find available port between %d and %d", start, end) +} + +func openSystemBrowser(u string) error { + return openBrowser(u) +} diff --git a/internal/auth/flow_unix.go b/internal/auth/flow_unix.go new file mode 100644 index 000000000..f466fa059 --- /dev/null +++ b/internal/auth/flow_unix.go @@ -0,0 +1,25 @@ +//go:build !windows + +package auth + +import ( + "fmt" + "os" + "os/exec" + + "github.com/mattn/go-isatty" +) + +func openBrowser(url string) error { + // Only attempt to open a browser if stderr is a real terminal. + // In piped/CI environments, skip the browser open and let the user copy the URL. + if !isatty.IsTerminal(os.Stderr.Fd()) && !isatty.IsCygwinTerminal(os.Stderr.Fd()) { + return fmt.Errorf("not a terminal: browser not opened") + } + for _, cmd := range []string{"xdg-open", "open"} { + if err := exec.Command(cmd, url).Start(); err == nil { + return nil + } + } + return fmt.Errorf("no browser opener found") +} diff --git a/internal/auth/flow_windows.go b/internal/auth/flow_windows.go new file mode 100644 index 000000000..509d91c94 --- /dev/null +++ b/internal/auth/flow_windows.go @@ -0,0 +1,9 @@ +//go:build windows + +package auth + +import "os/exec" + +func openBrowser(url string) error { + return exec.Command("cmd", "/c", "start", url).Start() +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go deleted file mode 100644 index 3760a9f27..000000000 --- a/internal/auth/jwt.go +++ /dev/null @@ -1,42 +0,0 @@ -package auth - -import ( - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "strings" - "time" -) - -// unsafeGetJWTExpiry parses a JWT without verifying its signature and returns its expiry time. -// WARNING: This is intentionally unsafe and must not be used for trust decisions. -func unsafeGetJWTExpiry(token string) (time.Time, error) { - if token == "" { - return time.Time{}, errors.New("jwt: empty token") - } - parts := strings.Split(token, ".") - if len(parts) < 2 { - return time.Time{}, fmt.Errorf("jwt: malformed token, expected 3 parts, got %d", len(parts)) - } - payloadSeg := parts[1] - - // Base64 URL decode without padding as per RFC 7515. - payloadBytes, err := base64.RawURLEncoding.DecodeString(payloadSeg) - if err != nil { - return time.Time{}, fmt.Errorf("jwt: decode payload: %w", err) - } - - var claims struct { - ExpiresAt *int64 `json:"exp,omitempty"` - } - if err := json.Unmarshal(payloadBytes, &claims); err != nil { - return time.Time{}, fmt.Errorf("jwt: unmarshal claims: %w", err) - } - - if claims.ExpiresAt == nil { - return time.Time{}, errors.New("jwt: no expiry time found") - } - - return time.Unix(*claims.ExpiresAt, 0), nil -} diff --git a/internal/auth/legacy.go b/internal/auth/legacy.go deleted file mode 100644 index e07ee662f..000000000 --- a/internal/auth/legacy.go +++ /dev/null @@ -1,108 +0,0 @@ -package auth - -import ( - "bytes" - "context" - "fmt" - "io" - "sync" - - "golang.org/x/oauth2" - - "github.com/upsun/cli/internal/legacy" -) - -type legacyCLITokenSource struct { - ctx context.Context - cached *oauth2.Token - wrapper *legacy.CLIWrapper - mu sync.Mutex -} - -func (ts *legacyCLITokenSource) unsafeGetLegacyCLIToken() (*oauth2.Token, error) { - bt := bytes.NewBuffer(nil) - ts.wrapper.Stdout = bt - if err := ts.wrapper.Exec(ts.ctx, "auth:token", "-W"); err != nil { - return nil, fmt.Errorf("cannot retrieve token: %w", err) - } - - expiry, err := unsafeGetJWTExpiry(bt.String()) - - if err != nil { - return nil, fmt.Errorf("cannot parse token: %w", err) - } - - return &oauth2.Token{ - AccessToken: bt.String(), - TokenType: "Bearer", - Expiry: expiry, - }, nil -} - -func (ts *legacyCLITokenSource) refreshToken() error { - ts.mu.Lock() - defer ts.mu.Unlock() - - return ts.unsafeRefreshToken() -} - -func (ts *legacyCLITokenSource) unsafeRefreshToken() error { - ts.cached = nil - ts.wrapper.Stdout = io.Discard - if err := ts.wrapper.Exec(ts.ctx, "auth:info", "--refresh"); err != nil { - return fmt.Errorf("cannot refresh token: %w", err) - } - - return nil -} - -func (ts *legacyCLITokenSource) invalidateToken() error { - ts.mu.Lock() - defer ts.mu.Unlock() - - return ts.unsafeInvalidateToken() -} - -func (ts *legacyCLITokenSource) unsafeInvalidateToken() error { - if ts.cached != nil { - ts.cached.AccessToken = "" - } - - return nil -} - -func (ts *legacyCLITokenSource) Token() (*oauth2.Token, error) { - ts.mu.Lock() - defer ts.mu.Unlock() - - if ts.cached == nil { - tok, err := ts.unsafeGetLegacyCLIToken() - if err != nil { - return nil, err - } - ts.cached = tok - } - - if ts.cached != nil && ts.cached.Valid() { - return ts.cached, nil - } - - if err := ts.unsafeRefreshToken(); err != nil { - return nil, err - } - - tok, err := ts.unsafeGetLegacyCLIToken() - if err != nil { - return nil, err - } - - ts.cached = tok - return ts.cached, nil -} - -func NewLegacyCLITokenSource(ctx context.Context, wrapper *legacy.CLIWrapper) (oauth2.TokenSource, error) { - return &legacyCLITokenSource{ - ctx: ctx, - wrapper: wrapper, - }, nil -} diff --git a/internal/auth/pkce.go b/internal/auth/pkce.go new file mode 100644 index 000000000..1c00630b8 --- /dev/null +++ b/internal/auth/pkce.go @@ -0,0 +1,23 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +// GenerateVerifier creates a random PKCE code verifier (RFC 7636 §4.1). +// 32 random bytes base64url-encoded without padding = 43 characters. +func GenerateVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// VerifierToChallenge derives the S256 code challenge from a verifier (RFC 7636 §4.2). +func VerifierToChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} diff --git a/internal/auth/pkce_test.go b/internal/auth/pkce_test.go new file mode 100644 index 000000000..0111fba44 --- /dev/null +++ b/internal/auth/pkce_test.go @@ -0,0 +1,29 @@ +package auth_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/internal/auth" +) + +func TestGenerateVerifier(t *testing.T) { + v1, err := auth.GenerateVerifier() + require.NoError(t, err) + v2, err := auth.GenerateVerifier() + require.NoError(t, err) + + assert.Len(t, v1, 43) // 32 bytes base64url = 43 chars (no padding) + assert.NotEqual(t, v1, v2) + assert.False(t, strings.ContainsAny(v1, "+/="), "must be base64url, not standard base64") +} + +func TestVerifierToChallenge(t *testing.T) { + // RFC 7636 §B test vector + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + challenge := auth.VerifierToChallenge(verifier) + assert.Equal(t, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", challenge) +} diff --git a/internal/auth/token_source.go b/internal/auth/token_source.go new file mode 100644 index 000000000..ed8b075fc --- /dev/null +++ b/internal/auth/token_source.go @@ -0,0 +1,146 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/oauth2" + + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +// sessionTokenSource implements oauth2.TokenSource and the refresher interface +// using the session Manager for persistence. +type sessionTokenSource struct { + mgr *session.Manager + cfg *config.Config + httpClient *http.Client + mu sync.Mutex + cached *oauth2.Token +} + +// NewSessionTokenSource creates a token source backed by session files. +// +//nolint:revive // intentionally returns unexported type; callers use := and only call Token/TokenContext +func NewSessionTokenSource(mgr *session.Manager, cfg *config.Config) *sessionTokenSource { + return &sessionTokenSource{mgr: mgr, cfg: cfg, httpClient: http.DefaultClient} +} + +// Token returns a valid access token, refreshing if necessary. +// Implements oauth2.TokenSource. Uses context.Background() for the refresh request; +// use TokenContext for cancellable refresh. +func (ts *sessionTokenSource) Token() (*oauth2.Token, error) { + return ts.TokenContext(context.Background()) +} + +// TokenContext is like Token but uses the provided context for any refresh request. +func (ts *sessionTokenSource) TokenContext(ctx context.Context) (*oauth2.Token, error) { + ts.mu.Lock() + defer ts.mu.Unlock() + + if ts.cached != nil && ts.cached.Valid() { + return ts.cached, nil + } + + s, err := ts.mgr.Load() + if err != nil { + return nil, fmt.Errorf("load session: %w", err) + } + if s == nil || s.AccessToken == "" { + return nil, fmt.Errorf("not logged in: run 'login' to authenticate") + } + + expiry := time.Unix(s.Expires, 0) + tok := &oauth2.Token{ + AccessToken: s.AccessToken, + TokenType: s.TokenType, + RefreshToken: s.RefreshToken, + Expiry: expiry, + } + + if tok.Valid() { + ts.cached = tok + return tok, nil + } + + // Token expired — refresh using the already-loaded session (avoids a second Load). + if err := ts.unsafeRefreshToken(ctx, s); err != nil { + return nil, err + } + return ts.cached, nil +} + +func (ts *sessionTokenSource) unsafeRefreshToken(ctx context.Context, s *session.Session) error { + ts.cached = nil + + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {s.RefreshToken}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, OAuth2TokenURL(ts.cfg), strings.NewReader(data.Encode())) + if err != nil { + return fmt.Errorf("refresh token: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if ts.cfg.API.OAuth2ClientID != "" { + req.SetBasicAuth(ts.cfg.API.OAuth2ClientID, "") + } + + resp, err := ts.httpClient.Do(req) + if err != nil { + return fmt.Errorf("refresh token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("refresh token: server returned %d", resp.StatusCode) + } + + var result struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return fmt.Errorf("refresh token: decode response: %w", err) + } + + expiry := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second).Unix() + if result.RefreshToken == "" { + result.RefreshToken = s.RefreshToken // keep existing if not rotated + } + + newSession := &session.Session{ + AccessToken: result.AccessToken, + TokenType: result.TokenType, + Expires: expiry, + RefreshToken: result.RefreshToken, + } + if err := ts.mgr.Save(newSession); err != nil { + return fmt.Errorf("save refreshed session: %w", err) + } + + ts.cached = &oauth2.Token{ + AccessToken: result.AccessToken, + TokenType: result.TokenType, + RefreshToken: result.RefreshToken, + Expiry: time.Unix(expiry, 0), + } + return nil +} + +func (ts *sessionTokenSource) invalidateToken() error { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.cached = nil + return nil +} diff --git a/internal/auth/token_source_test.go b/internal/auth/token_source_test.go new file mode 100644 index 000000000..8e15435a5 --- /dev/null +++ b/internal/auth/token_source_test.go @@ -0,0 +1,161 @@ +package auth_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func TestSessionTokenSource_ValidToken(t *testing.T) { + cfg := loadTestConfig(t) + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + future := time.Now().Add(time.Hour).Unix() + require.NoError(t, mgr.Save(&session.Session{ + AccessToken: "valid-token", + TokenType: "bearer", + Expires: future, + RefreshToken: "refresh-token", + })) + + ts := auth.NewSessionTokenSource(mgr, cfg) + tok, err := ts.Token() + require.NoError(t, err) + assert.Equal(t, "valid-token", tok.AccessToken) +} + +func TestSessionTokenSource_RefreshExpired(t *testing.T) { + var refreshCalled bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "old-refresh", r.FormValue("refresh_token")) + refreshCalled = true + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "new-token", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "new-refresh", + }) + })) + defer server.Close() + + cfg := loadTestConfig(t) + cfg.API.AuthURL = "" + cfg.API.OAuth2TokenURL = server.URL + "/token" + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + past := time.Now().Add(-time.Hour).Unix() + require.NoError(t, mgr.Save(&session.Session{ + AccessToken: "old-token", + TokenType: "bearer", + Expires: past, + RefreshToken: "old-refresh", + })) + + ts := auth.NewSessionTokenSource(mgr, cfg) + tok, err := ts.Token() + require.NoError(t, err) + assert.True(t, refreshCalled) + assert.Equal(t, "new-token", tok.AccessToken) +} + +// TestSessionTokenSource_NoSession: Token() returns error when no session has been saved. +func TestSessionTokenSource_NoSession(t *testing.T) { + cfg := loadTestConfig(t) + mgr := session.NewWithStore(cfg, session.NewMemStore()) + + ts := auth.NewSessionTokenSource(mgr, cfg) + _, err := ts.Token() + require.Error(t, err) + assert.Contains(t, err.Error(), "not logged in") +} + +// TestSessionTokenSource_EmptyAccessToken: Token() returns error when session has no access token. +func TestSessionTokenSource_EmptyAccessToken(t *testing.T) { + cfg := loadTestConfig(t) + mgr := session.NewWithStore(cfg, session.NewMemStore()) + require.NoError(t, mgr.Save(&session.Session{AccessToken: "", RefreshToken: "r"})) + + ts := auth.NewSessionTokenSource(mgr, cfg) + _, err := ts.Token() + require.Error(t, err) + assert.Contains(t, err.Error(), "not logged in") +} + +// TestSessionTokenSource_RefreshServerError: when the token server returns non-200, refresh fails with an error. +func TestSessionTokenSource_RefreshServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + cfg := loadTestConfig(t) + cfg.API.AuthURL = "" + cfg.API.OAuth2TokenURL = server.URL + "/token" + mgr := session.NewWithStore(cfg, session.NewMemStore()) + + past := time.Now().Add(-time.Hour).Unix() + require.NoError(t, mgr.Save(&session.Session{ + AccessToken: "old", + Expires: past, + RefreshToken: "old-refresh", + })) + + ts := auth.NewSessionTokenSource(mgr, cfg) + _, err := ts.Token() + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +// TestSessionTokenSource_RefreshKeepsExistingRefreshToken: when the server omits refresh_token, +// the original refresh token is preserved. +func TestSessionTokenSource_RefreshKeepsExistingRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "new-access", + "token_type": "bearer", + "expires_in": 3600, + // refresh_token intentionally omitted + }) + })) + defer server.Close() + + cfg := loadTestConfig(t) + cfg.API.AuthURL = "" + cfg.API.OAuth2TokenURL = server.URL + "/token" + mgr := session.NewWithStore(cfg, session.NewMemStore()) + + past := time.Now().Add(-time.Hour).Unix() + require.NoError(t, mgr.Save(&session.Session{ + AccessToken: "old", + Expires: past, + RefreshToken: "keep-me", + })) + + ts := auth.NewSessionTokenSource(mgr, cfg) + tok, err := ts.Token() + require.NoError(t, err) + assert.Equal(t, "new-access", tok.AccessToken) + assert.Equal(t, "keep-me", tok.RefreshToken) +} + +func loadTestConfig(t *testing.T) *config.Config { + t.Helper() + data, err := os.ReadFile("../../integration-tests/config.yaml") + require.NoError(t, err) + cfg, err := config.FromYAML(data) + require.NoError(t, err) + return cfg +} diff --git a/internal/auth/transport.go b/internal/auth/transport.go index 3c2a169e2..82ee56a1d 100644 --- a/internal/auth/transport.go +++ b/internal/auth/transport.go @@ -9,7 +9,6 @@ import ( ) type refresher interface { - refreshToken() error invalidateToken() error } @@ -30,7 +29,9 @@ type Transport struct { // RoundTrip adds Authorization via the underlying oauth2.Transport. If the // response is 401 Unauthorized, it clears the cached token and retries once. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - req.Body = wrapReader(req.Body) + if err := wrapRequest(req); err != nil { + return nil, fmt.Errorf("buffer request body: %w", err) + } resp, err := t.base.RoundTrip(req) @@ -41,6 +42,12 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("failed to invalidate token: %w", err) } flushReader(resp.Body) + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to rewind request body: %w", err) + } + } resp, err = t.base.RoundTrip(req) } @@ -77,13 +84,25 @@ func TransportFromContext(ctx context.Context) (http.RoundTripper, bool) { return rt, true } -func wrapReader(r io.ReadCloser) io.ReadCloser { - if r == nil { +// wrapRequest buffers req.Body so that it can be replayed on retry. +// It stores the bytes in req.GetBody so RoundTrip can restore the body +// before the second attempt (bytes.Buffer is drained after the first read). +// If GetBody is already set the caller already provides replay capability, +// so no buffering is needed. +func wrapRequest(req *http.Request) error { + if req.Body == nil || req.GetBody != nil { return nil } - bodyBytes, _ := io.ReadAll(r) - _ = r.Close() - return io.NopCloser(bytes.NewBuffer(bodyBytes)) + b, err := io.ReadAll(req.Body) + if err != nil { + return err + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(b)), nil + } + return nil } func flushReader(r io.ReadCloser) { diff --git a/internal/auth/urls.go b/internal/auth/urls.go new file mode 100644 index 000000000..e99674d55 --- /dev/null +++ b/internal/auth/urls.go @@ -0,0 +1,41 @@ +package auth + +import ( + "os" + "strings" + + "github.com/upsun/cli/internal/config" +) + +// resolveAuthBase returns the OAuth2 server base URL. +// Priority: {EnvPrefix}AUTH_URL env → cfg.API.AuthURL +func resolveAuthBase(cfg *config.Config) string { + if v := os.Getenv(cfg.Application.EnvPrefix + "AUTH_URL"); v != "" { + return strings.TrimRight(v, "/") + } + return strings.TrimRight(cfg.API.AuthURL, "/") +} + +// OAuth2TokenURL resolves the OAuth2 token endpoint. +func OAuth2TokenURL(cfg *config.Config) string { + if base := resolveAuthBase(cfg); base != "" { + return base + "/oauth2/token" + } + return cfg.API.OAuth2TokenURL +} + +// OAuth2AuthorizeURL resolves the OAuth2 authorize endpoint. +func OAuth2AuthorizeURL(cfg *config.Config) string { + if base := resolveAuthBase(cfg); base != "" { + return base + "/oauth2/authorize" + } + return cfg.API.OAuth2AuthorizeURL +} + +// OAuth2RevokeURL resolves the OAuth2 revocation endpoint. +func OAuth2RevokeURL(cfg *config.Config) string { + if base := resolveAuthBase(cfg); base != "" { + return base + "/oauth2/revoke" + } + return cfg.API.OAuth2RevokeURL +} diff --git a/internal/auth/urls_test.go b/internal/auth/urls_test.go new file mode 100644 index 000000000..f4291805a --- /dev/null +++ b/internal/auth/urls_test.go @@ -0,0 +1,89 @@ +package auth_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/upsun/cli/internal/auth" + "github.com/upsun/cli/internal/config" +) + +// minCfg returns a minimal config with the given env prefix and auth URLs set. +func minCfg(envPrefix, authURL, tokenURL, authorizeURL, revokeURL string) *config.Config { + cfg := &config.Config{} + cfg.Application.EnvPrefix = envPrefix + cfg.API.AuthURL = authURL + cfg.API.OAuth2TokenURL = tokenURL + cfg.API.OAuth2AuthorizeURL = authorizeURL + cfg.API.OAuth2RevokeURL = revokeURL + return cfg +} + +// TestOAuth2TokenURL_EnvOverride: AUTH_URL env var takes priority over config. +func TestOAuth2TokenURL_EnvOverride(t *testing.T) { + t.Setenv("UPSUN_CLI_AUTH_URL", "https://env-auth.example.com") + cfg := minCfg("UPSUN_CLI_", "https://config-auth.example.com", "https://fallback.example.com/token", "", "") + + assert.Equal(t, "https://env-auth.example.com/oauth2/token", auth.OAuth2TokenURL(cfg)) +} + +// TestOAuth2TokenURL_ConfigAuthURL: cfg.API.AuthURL used when no env var. +func TestOAuth2TokenURL_ConfigAuthURL(t *testing.T) { + cfg := minCfg("UPSUN_CLI_", "https://config-auth.example.com", "https://fallback.example.com/token", "", "") + + assert.Equal(t, "https://config-auth.example.com/oauth2/token", auth.OAuth2TokenURL(cfg)) +} + +// TestOAuth2TokenURL_FallbackField: falls back to cfg.API.OAuth2TokenURL when no base URL. +func TestOAuth2TokenURL_FallbackField(t *testing.T) { + cfg := minCfg("UPSUN_CLI_", "", "https://fallback.example.com/token", "", "") + + assert.Equal(t, "https://fallback.example.com/token", auth.OAuth2TokenURL(cfg)) +} + +// TestOAuth2TokenURL_TrailingSlashStripped: trailing slash on AUTH_URL is stripped before appending path. +func TestOAuth2TokenURL_TrailingSlashStripped(t *testing.T) { + t.Setenv("UPSUN_CLI_AUTH_URL", "https://auth.example.com/") + cfg := minCfg("UPSUN_CLI_", "", "", "", "") + + assert.Equal(t, "https://auth.example.com/oauth2/token", auth.OAuth2TokenURL(cfg)) +} + +// TestOAuth2AuthorizeURL_EnvOverride mirrors the token URL env precedence. +func TestOAuth2AuthorizeURL_EnvOverride(t *testing.T) { + t.Setenv("UPSUN_CLI_AUTH_URL", "https://env-auth.example.com") + cfg := minCfg("UPSUN_CLI_", "", "", "https://fallback.example.com/authorize", "") + + assert.Equal(t, "https://env-auth.example.com/oauth2/authorize", auth.OAuth2AuthorizeURL(cfg)) +} + +// TestOAuth2AuthorizeURL_Fallback: falls back to cfg.API.OAuth2AuthorizeURL. +func TestOAuth2AuthorizeURL_Fallback(t *testing.T) { + cfg := minCfg("UPSUN_CLI_", "", "", "https://fallback.example.com/authorize", "") + + assert.Equal(t, "https://fallback.example.com/authorize", auth.OAuth2AuthorizeURL(cfg)) +} + +// TestOAuth2RevokeURL_EnvOverride mirrors the token URL env precedence. +func TestOAuth2RevokeURL_EnvOverride(t *testing.T) { + t.Setenv("UPSUN_CLI_AUTH_URL", "https://env-auth.example.com") + cfg := minCfg("UPSUN_CLI_", "", "", "", "https://fallback.example.com/revoke") + + assert.Equal(t, "https://env-auth.example.com/oauth2/revoke", auth.OAuth2RevokeURL(cfg)) +} + +// TestOAuth2RevokeURL_Fallback: falls back to cfg.API.OAuth2RevokeURL. +func TestOAuth2RevokeURL_Fallback(t *testing.T) { + cfg := minCfg("UPSUN_CLI_", "", "", "", "https://fallback.example.com/revoke") + + assert.Equal(t, "https://fallback.example.com/revoke", auth.OAuth2RevokeURL(cfg)) +} + +// TestOAuth2TokenURL_EnvPrefixIsolation: a different prefix's env var must NOT affect resolution. +func TestOAuth2TokenURL_EnvPrefixIsolation(t *testing.T) { + t.Setenv("OTHER_CLI_AUTH_URL", "https://wrong.example.com") + cfg := minCfg("UPSUN_CLI_", "", "https://correct.example.com/token", "", "") + + assert.Equal(t, "https://correct.example.com/token", auth.OAuth2TokenURL(cfg)) +} diff --git a/internal/legacy/legacy.go b/internal/legacy/legacy.go index 2c30c9531..0865b31e1 100644 --- a/internal/legacy/legacy.go +++ b/internal/legacy/legacy.go @@ -39,6 +39,8 @@ type CLIWrapper struct { ForceColor bool DebugLogFunc func(string, ...any) + ExtraEnv []string + initOnce sync.Once _cacheDir string } @@ -154,6 +156,7 @@ func (c *CLIWrapper) Exec(ctx context.Context, args ...string) error { c.Version, PHPVersion, )) + cmd.Env = append(cmd.Env, c.ExtraEnv...) if err := cmd.Run(); err != nil { return fmt.Errorf("could not run PHP CLI command: %w", err) } diff --git a/internal/session/manager.go b/internal/session/manager.go new file mode 100644 index 000000000..bd56c7346 --- /dev/null +++ b/internal/session/manager.go @@ -0,0 +1,229 @@ +package session + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/upsun/cli/internal/config" +) + +// Manager is the single entry point for all session operations. +type Manager struct { + cfg *config.Config + store Store + id string // cached resolved session ID +} + +// New creates a Manager backed by the filesystem. +func New(cfg *config.Config) (*Manager, error) { + id, err := ResolveSessionID(cfg) + if err != nil { + return nil, err + } + return &Manager{cfg: cfg, store: &FileStore{}, id: id}, nil +} + +// NewWithStore creates a Manager with an injected Store (for testing). +func NewWithStore(cfg *config.Config, store Store) *Manager { + id, err := ResolveSessionID(cfg) + if err != nil { + // WritableUserDir is misconfigured; fall back to "default" and warn. + fmt.Fprintf(os.Stderr, "session: could not resolve session ID: %v\n", err) + id = "default" + } + return &Manager{cfg: cfg, store: store, id: id} +} + +// NewWithID creates a Manager for a specific session ID (used by logout --other). +func NewWithID(cfg *config.Config, id string) *Manager { + return &Manager{cfg: cfg, store: &FileStore{}, id: id} +} + +// SessionID returns the resolved session ID. +func (m *Manager) SessionID() string { return m.id } + +func (m *Manager) sessionBaseDir() (string, error) { + writableDir, err := m.cfg.WritableUserDir() //nolint:staticcheck // backwards compatibility needed for session files + if err != nil { + return "", err + } + return filepath.Join(writableDir, ".session"), nil +} + +// sessionPath returns the path to the OAuth session JSON file. +// Pattern: /.session/sess-/sess-.json +// Matches the PHP platformsh/client File storage format. +func (m *Manager) sessionPath() (string, error) { + base, err := m.sessionBaseDir() + if err != nil { + return "", err + } + slug := sessionDirName(m.id) + return filepath.Join(base, slug, slug+".json"), nil +} + +// cliDir returns the path to the CLI artifact directory for this session. +// Pattern: /.session/sess-cli-/ +// Used for API token storage and as a session existence marker. +func (m *Manager) cliDir() (string, error) { + base, err := m.sessionBaseDir() + if err != nil { + return "", err + } + return filepath.Join(base, cliDirName(m.id)), nil +} + +func (m *Manager) tokenPath() (string, error) { + dir, err := m.cliDir() + if err != nil { + return "", err + } + return filepath.Join(dir, "api-token"), nil +} + +// Load reads the current session from disk. Returns (nil, nil) if no session exists. +func (m *Manager) Load() (*Session, error) { + path, err := m.sessionPath() + if err != nil { + return nil, err + } + return m.store.Load(path) +} + +// Save writes the session to disk and creates the sess-cli-/ marker directory. +func (m *Manager) Save(s *Session) error { + path, err := m.sessionPath() + if err != nil { + return err + } + if err := m.store.Save(path, s); err != nil { + return err + } + // Create the sess-cli-/ marker so List() can discover this session. + dir, err := m.cliDir() + if err != nil { + return err + } + return m.store.MkdirAll(dir) +} + +// Delete removes the current session (both OAuth file and CLI artifact dir). +func (m *Manager) Delete() error { + path, err := m.sessionPath() + if err != nil { + return err + } + if err := m.store.Delete(filepath.Dir(path)); err != nil { + return err + } + dir, err := m.cliDir() + if err != nil { + return err + } + return m.store.Delete(dir) +} + +// DeleteAll removes all sessions. +func (m *Manager) DeleteAll() error { + base, err := m.sessionBaseDir() + if err != nil { + return err + } + ids, err := m.List() + if err != nil { + return err + } + for _, id := range ids { + sub := &Manager{cfg: m.cfg, store: m.store, id: id} + if err := sub.Delete(); err != nil { + return fmt.Errorf("delete session %q: %w", id, err) + } + } + // Also remove any sess- dirs (oauth session dirs not covered by List). + entries, err := os.ReadDir(base) + if err != nil && !os.IsNotExist(err) { + return err + } + for _, e := range entries { + if e.IsDir() && strings.HasPrefix(e.Name(), "sess-") && !strings.HasPrefix(e.Name(), "sess-cli-") { + if err := os.RemoveAll(filepath.Join(base, e.Name())); err != nil { + return err + } + } + } + return nil +} + +// List returns all session IDs discovered via sess-cli-* directories. +func (m *Manager) List() ([]string, error) { + base, err := m.sessionBaseDir() + if err != nil { + return nil, err + } + ids, err := m.store.List(base) + if err != nil { + return nil, err + } + // Exclude api-token-specific session IDs (PHP convention). + var filtered []string + for _, id := range ids { + if !strings.HasPrefix(id, "api-token-") { + filtered = append(filtered, id) + } + } + return filtered, nil +} + +// GetAPIToken reads the stored API token for the current session. +// Returns ("", nil) if no token is stored. +func (m *Manager) GetAPIToken() (string, error) { + path, err := m.tokenPath() + if err != nil { + return "", err + } + data, err := m.store.ReadFile(path) + if os.IsNotExist(err) { + return "", nil + } + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +// SetAPIToken writes an API token to disk. +func (m *Manager) SetAPIToken(token string) error { + path, err := m.tokenPath() + if err != nil { + return err + } + if err := m.store.MkdirAll(filepath.Dir(path)); err != nil { + return err + } + return m.store.WriteFile(path, []byte(token)) +} + +// DeleteAPIToken removes the stored API token. +func (m *Manager) DeleteAPIToken() error { + path, err := m.tokenPath() + if err != nil { + return err + } + return m.store.RemoveFile(path) +} + +// SetActiveSessionID writes the session ID to the session-id file, +// persisting the active session across invocations. +func (m *Manager) SetActiveSessionID(id string) error { + writableDir, err := m.cfg.WritableUserDir() //nolint:staticcheck // backwards compatibility needed for session files + if err != nil { + return err + } + path := filepath.Join(writableDir, "session-id") + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return err + } + return os.WriteFile(path, []byte(id), 0o600) +} diff --git a/internal/session/manager_test.go b/internal/session/manager_test.go new file mode 100644 index 000000000..326ce2d9c --- /dev/null +++ b/internal/session/manager_test.go @@ -0,0 +1,188 @@ +// internal/session/manager_test.go +package session_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upsun/cli/internal/config" + "github.com/upsun/cli/internal/session" +) + +func TestFileStore_SaveAndLoad(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sess-default", "sess-default.json") + + fs := &session.FileStore{} + s := &session.Session{ + AccessToken: "tok", + TokenType: "bearer", + Expires: 9999999999, + RefreshToken: "ref", + } + require.NoError(t, fs.Save(path, s)) + + loaded, err := fs.Load(path) + require.NoError(t, err) + assert.Equal(t, s, loaded) +} + +func TestFileStore_LoadMissing(t *testing.T) { + fs := &session.FileStore{} + loaded, err := fs.Load("/nonexistent/path.json") + require.NoError(t, err) + assert.Nil(t, loaded) +} + +func TestFileStore_Delete(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sess-default", "sess-default.json") + fs := &session.FileStore{} + require.NoError(t, fs.Save(path, &session.Session{AccessToken: "tok"})) + require.NoError(t, fs.Delete(filepath.Join(dir, "sess-default"))) + _, err := os.Stat(filepath.Join(dir, "sess-default")) + assert.True(t, os.IsNotExist(err)) +} + +func TestFileStore_List(t *testing.T) { + dir := t.TempDir() + for _, name := range []string{"sess-cli-default", "sess-cli-work", "sess-default", "other-dir"} { + require.NoError(t, os.MkdirAll(filepath.Join(dir, name), 0o700)) + } + fs := &session.FileStore{} + ids, err := fs.List(dir) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"default", "work"}, ids) +} + +func TestResolveSessionID_EnvVar(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_SESSION_ID", "my-env-session") + id, err := session.ResolveSessionID(cfg) + require.NoError(t, err) + assert.Equal(t, "my-env-session", id) +} + +func TestResolveSessionID_File(t *testing.T) { + cfg := testConfig(t) + dir := t.TempDir() + t.Setenv("TEST_CLI_HOME", dir) + idFile := filepath.Join(dir, ".platform-test-cli", "session-id") + require.NoError(t, os.MkdirAll(filepath.Dir(idFile), 0o700)) + require.NoError(t, os.WriteFile(idFile, []byte("file-session\n"), 0o600)) + id, err := session.ResolveSessionID(cfg) + require.NoError(t, err) + assert.Equal(t, "file-session", id) +} + +func TestResolveSessionID_Default(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + id, err := session.ResolveSessionID(cfg) + require.NoError(t, err) + assert.Equal(t, "default", id) +} + +func TestResolveSessionID_ConfigField(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + cfg.API.SessionID = "config-session" + id, err := session.ResolveSessionID(cfg) + require.NoError(t, err) + assert.Equal(t, "config-session", id) +} + +func TestManager_SaveAndLoad(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + s := &session.Session{AccessToken: "tok", TokenType: "bearer", Expires: 9999999999, RefreshToken: "ref"} + require.NoError(t, mgr.Save(s)) + + loaded, err := mgr.Load() + require.NoError(t, err) + require.NotNil(t, loaded) + assert.Equal(t, "tok", loaded.AccessToken) +} + +func TestManager_Delete(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + require.NoError(t, mgr.Save(&session.Session{AccessToken: "tok"})) + require.NoError(t, mgr.Delete()) + + loaded, err := mgr.Load() + require.NoError(t, err) + assert.Nil(t, loaded) +} + +func TestManager_APIToken(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + require.NoError(t, mgr.SetAPIToken("my-api-token")) + tok, err := mgr.GetAPIToken() + require.NoError(t, err) + assert.Equal(t, "my-api-token", tok) + + require.NoError(t, mgr.DeleteAPIToken()) + tok, err = mgr.GetAPIToken() + require.NoError(t, err) + assert.Equal(t, "", tok) +} + +func TestManager_List(t *testing.T) { + cfg := testConfig(t) + t.Setenv("TEST_CLI_HOME", t.TempDir()) + store := session.NewMemStore() + + mgr1 := session.NewWithStore(cfg, store) + require.NoError(t, mgr1.Save(&session.Session{AccessToken: "tok1"})) + + // Second session + cfg2 := testConfig(t) + cfg2.API.SessionID = "work" + mgr2 := session.NewWithStore(cfg2, store) + require.NoError(t, mgr2.Save(&session.Session{AccessToken: "tok2"})) + + ids, err := mgr1.List() + require.NoError(t, err) + assert.ElementsMatch(t, []string{"default", "work"}, ids) +} + +func TestManager_SetActiveSessionID(t *testing.T) { + cfg := testConfig(t) + dir := t.TempDir() + t.Setenv("TEST_CLI_HOME", dir) + store := session.NewMemStore() + mgr := session.NewWithStore(cfg, store) + + require.NoError(t, mgr.SetActiveSessionID("work")) + + idFile := filepath.Join(dir, ".platform-test-cli", "session-id") + data, err := os.ReadFile(idFile) + require.NoError(t, err) + assert.Equal(t, "work", strings.TrimSpace(string(data))) +} + +// testConfig returns a minimal *config.Config for tests using the integration test config.yaml. +func testConfig(t *testing.T) *config.Config { + t.Helper() + data, err := os.ReadFile("../../integration-tests/config.yaml") + require.NoError(t, err) + cfg, err := config.FromYAML(data) + require.NoError(t, err) + return cfg +} diff --git a/internal/session/resolver.go b/internal/session/resolver.go new file mode 100644 index 000000000..33160e0e7 --- /dev/null +++ b/internal/session/resolver.go @@ -0,0 +1,61 @@ +package session + +import ( + "os" + "path/filepath" + "strings" + + "github.com/upsun/cli/internal/config" +) + +// ResolveSessionID returns the current session ID by checking, in order: +// 1. {APP_ENV_PREFIX}SESSION_ID environment variable +// 2. /session-id file (written by session:switch) +// 3. config.API.SessionID field +// 4. "default" +func ResolveSessionID(cfg *config.Config) (string, error) { + if id := os.Getenv(cfg.Application.EnvPrefix + "SESSION_ID"); id != "" { + return id, nil + } + writableDir, err := cfg.WritableUserDir() //nolint:staticcheck // backwards compatibility needed for session files + if err != nil { + return "", err + } + idFile := filepath.Join(writableDir, "session-id") + if data, err := os.ReadFile(idFile); err == nil { + if id := strings.TrimSpace(string(data)); id != "" { + return id, nil + } + } + if cfg.API.SessionID != "" { + return cfg.API.SessionID, nil + } + return "default", nil +} + +// sanitiseID replaces runs of characters not in [a-zA-Z0-9_-] with a single hyphen, +// matching PHP's preg_replace('/[^\w\-]+/', '-', $id). +func sanitiseID(id string) string { + var b strings.Builder + prevHyphen := false + for _, r := range id { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + prevHyphen = false + } else if !prevHyphen { + b.WriteRune('-') + prevHyphen = true + } + } + return b.String() +} + +// sessionDirName returns the directory name for the OAuth session (e.g. "sess-default"). +func sessionDirName(id string) string { + return "sess-" + sanitiseID(id) +} + +// cliDirName returns the directory name for CLI artifacts (e.g. "sess-cli-default"). +func cliDirName(id string) string { + return "sess-cli-" + sanitiseID(id) +} diff --git a/internal/session/resolver_test.go b/internal/session/resolver_test.go new file mode 100644 index 000000000..2dbedb2d4 --- /dev/null +++ b/internal/session/resolver_test.go @@ -0,0 +1,49 @@ +package session + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSanitiseID_AlphanumericPassthrough(t *testing.T) { + assert.Equal(t, "abc123", sanitiseID("abc123")) +} + +func TestSanitiseID_UnderscoreAndHyphenPassthrough(t *testing.T) { + assert.Equal(t, "my_session-1", sanitiseID("my_session-1")) +} + +func TestSanitiseID_SpaceReplacedWithHyphen(t *testing.T) { + assert.Equal(t, "my-session", sanitiseID("my session")) +} + +func TestSanitiseID_ConsecutiveInvalidCollapsedToSingleHyphen(t *testing.T) { + // PHP: preg_replace('/[^\w\-]+/', '-', $id) — multiple invalid chars → one hyphen + assert.Equal(t, "a-b", sanitiseID("a b")) + assert.Equal(t, "a-b", sanitiseID("a@#$b")) + assert.Equal(t, "a-b", sanitiseID("a!@#b")) +} + +func TestSanitiseID_UnicodeReplacedWithHyphen(t *testing.T) { + assert.Equal(t, "caf-au-lait", sanitiseID("café au lait")) +} + +func TestSanitiseID_EmptyString(t *testing.T) { + assert.Equal(t, "", sanitiseID("")) +} + +func TestSanitiseID_OnlyInvalidChars(t *testing.T) { + // All invalid chars collapse to a single hyphen + assert.Equal(t, "-", sanitiseID("@#$")) +} + +func TestSessionDirName(t *testing.T) { + assert.Equal(t, "sess-default", sessionDirName("default")) + assert.Equal(t, "sess-my-session", sessionDirName("my session")) +} + +func TestCLIDirName(t *testing.T) { + assert.Equal(t, "sess-cli-default", cliDirName("default")) + assert.Equal(t, "sess-cli-my-session", cliDirName("my session")) +} diff --git a/internal/session/store.go b/internal/session/store.go new file mode 100644 index 000000000..70e82b6e1 --- /dev/null +++ b/internal/session/store.go @@ -0,0 +1,182 @@ +package session + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" +) + +// Session holds OAuth2 tokens. JSON field names must match the PHP platformsh/client format. +type Session struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + Expires int64 `json:"expires"` + RefreshToken string `json:"refreshToken"` +} + +// Store abstracts session file I/O for testing. +type Store interface { + Load(path string) (*Session, error) + Save(path string, s *Session) error + Delete(dir string) error + List(baseDir string) ([]string, error) + MkdirAll(path string) error + ReadFile(path string) ([]byte, error) + WriteFile(path string, data []byte) error + RemoveFile(path string) error +} + +// FileStore is the production Store backed by the filesystem. +type FileStore struct{} + +func (fs *FileStore) Load(path string) (*Session, error) { + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + var s Session + if err := json.Unmarshal(data, &s); err != nil { + return nil, err + } + return &s, nil +} + +func (fs *FileStore) Save(path string, s *Session) error { + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return err + } + data, err := json.Marshal(s) + if err != nil { + return err + } + return os.WriteFile(path, data, 0o600) +} + +// Delete removes the directory containing the session file. +func (fs *FileStore) Delete(dir string) error { + return os.RemoveAll(dir) +} + +// MkdirAll creates the directory with permissions 0700. +func (fs *FileStore) MkdirAll(path string) error { + return os.MkdirAll(path, 0o700) +} + +func (fs *FileStore) ReadFile(path string) ([]byte, error) { + return os.ReadFile(path) +} + +func (fs *FileStore) WriteFile(path string, data []byte) error { + return os.WriteFile(path, data, 0o600) +} + +func (fs *FileStore) RemoveFile(path string) error { + err := os.Remove(path) + if os.IsNotExist(err) { + return nil + } + return err +} + +// List scans baseDir for sess-cli-* directories and returns the session IDs. +// This matches PHP's listSessionIds() which globs sess-cli-* to discover sessions. +func (fs *FileStore) List(baseDir string) ([]string, error) { + entries, err := os.ReadDir(baseDir) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + var ids []string + for _, e := range entries { + if e.IsDir() && strings.HasPrefix(e.Name(), "sess-cli-") { + id := strings.TrimPrefix(e.Name(), "sess-cli-") + ids = append(ids, id) + } + } + return ids, nil +} + +// MemStore is an in-memory Store for tests. It is not safe for concurrent use. +type MemStore struct { + sessions map[string]*Session + tokens map[string]string + dirs map[string]bool +} + +func NewMemStore() *MemStore { + return &MemStore{ + sessions: make(map[string]*Session), + tokens: make(map[string]string), + dirs: make(map[string]bool), + } +} + +func (m *MemStore) Load(path string) (*Session, error) { + s := m.sessions[path] + if s == nil { + return nil, nil + } + cp := *s + return &cp, nil +} + +func (m *MemStore) Save(path string, s *Session) error { + cp := *s + m.sessions[path] = &cp + m.dirs[filepath.Dir(path)] = true + return nil +} + +func (m *MemStore) Delete(dir string) error { + for k := range m.sessions { + if strings.HasPrefix(k, dir+"/") || k == dir { + delete(m.sessions, k) + } + } + delete(m.dirs, dir) + return nil +} + +func (m *MemStore) List(baseDir string) ([]string, error) { + var ids []string + for dir := range m.dirs { + parent := filepath.Dir(dir) + if parent == baseDir { + base := filepath.Base(dir) + if strings.HasPrefix(base, "sess-cli-") { + ids = append(ids, strings.TrimPrefix(base, "sess-cli-")) + } + } + } + return ids, nil +} + +// MkdirAll records the path in the dirs map. +func (m *MemStore) MkdirAll(path string) error { + m.dirs[path] = true + return nil +} + +func (m *MemStore) ReadFile(path string) ([]byte, error) { + data, ok := m.tokens[path] + if !ok { + return nil, os.ErrNotExist + } + return []byte(data), nil +} + +func (m *MemStore) WriteFile(path string, data []byte) error { + m.tokens[path] = string(data) + return nil +} + +func (m *MemStore) RemoveFile(path string) error { + delete(m.tokens, path) + return nil +} diff --git a/pkg/mockapi/api_server.go b/pkg/mockapi/api_server.go index a054f1c2c..c1fe05c60 100644 --- a/pkg/mockapi/api_server.go +++ b/pkg/mockapi/api_server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "strings" + "sync" "testing" "github.com/go-chi/chi/v5" @@ -12,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +var TestPhoneVerificationCode = "123456" + type Handler struct { *chi.Mux @@ -44,6 +47,41 @@ func NewHandler(t *testing.T) *Handler { _ = json.NewEncoder(w).Encode(map[string]any{"state": false, "type": ""}) }) + var ( + phoneVerifyMu sync.Mutex + pendingPhoneCode = TestPhoneVerificationCode + phoneVerifyPending bool + ) + + // Phone verification endpoints — match the actual Upsun API (same paths the PHP CLI calls). + const phoneSID = "test-sid-1" + h.Post("/users/{user_id}/phonenumber", func(w http.ResponseWriter, _ *http.Request) { + phoneVerifyMu.Lock() + phoneVerifyPending = true + phoneVerifyMu.Unlock() + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"sid": phoneSID}) + }) + h.Post("/users/{user_id}/phonenumber/{sid}", func(w http.ResponseWriter, req *http.Request) { + var body struct { + Code string `json:"code"` + } + _ = json.NewDecoder(req.Body).Decode(&body) + phoneVerifyMu.Lock() + pending := phoneVerifyPending + phoneVerifyMu.Unlock() + if !pending || body.Code != pendingPhoneCode { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid code"}) + return + } + phoneVerifyMu.Lock() + phoneVerifyPending = false + phoneVerifyMu.Unlock() + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "verified"}) + }) + h.Get("/organizations", h.handleListOrgs) h.Post("/organizations", h.handleCreateOrg) h.Get("/organizations/{organization_id}", h.handleGetOrg) diff --git a/pkg/mockapi/auth_server.go b/pkg/mockapi/auth_server.go index fb5871b48..6b8298bdb 100644 --- a/pkg/mockapi/auth_server.go +++ b/pkg/mockapi/auth_server.go @@ -4,10 +4,13 @@ import ( "crypto" "crypto/ed25519" "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "slices" + "sync" "testing" "time" @@ -20,32 +23,124 @@ import ( var ValidAPITokens = []string{"api-token-1"} var accessTokens = []string{"access-token-1"} +// AuthServer is a mock authentication server for testing. +type AuthServer struct { + *httptest.Server + revokedMu sync.Mutex + revokedTokens []string +} + +// RevokedTokens returns a copy of all tokens that have been revoked. +func (s *AuthServer) RevokedTokens() []string { + s.revokedMu.Lock() + defer s.revokedMu.Unlock() + out := make([]string, len(s.revokedTokens)) + copy(out, s.revokedTokens) + return out +} + // NewAuthServer creates a new mock authentication server. // The caller must call Close() on the server when finished. -func NewAuthServer(t *testing.T) *httptest.Server { +func NewAuthServer(t *testing.T) *AuthServer { mux := chi.NewRouter() if testing.Verbose() { mux.Use(middleware.DefaultLogger) } + type pendingAuth struct { + codeChallenge string + state string + } + var ( + pendingMu sync.Mutex + pendingAuths = map[string]pendingAuth{} // code → pendingAuth + ) + + srv := &AuthServer{} + + mux.Get("/oauth2/authorize", func(w http.ResponseWriter, req *http.Request) { + q := req.URL.Query() + code := "test-auth-code-" + q.Get("state") + pendingMu.Lock() + pendingAuths[code] = pendingAuth{ + codeChallenge: q.Get("code_challenge"), + state: q.Get("state"), + } + pendingMu.Unlock() + redirectURI := q.Get("redirect_uri") + http.Redirect(w, req, redirectURI+"?code="+code+"&state="+q.Get("state"), http.StatusFound) + }) + mux.Post("/oauth2/token", func(w http.ResponseWriter, req *http.Request) { require.NoError(t, req.ParseForm()) - if gt := req.Form.Get("grant_type"); gt != "api_token" { + switch req.Form.Get("grant_type") { + case "api_token": + apiToken := req.Form.Get("api_token") + if slices.Contains(ValidAPITokens, apiToken) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": accessTokens[0], + "expires_in": 3600, + "token_type": "bearer", + "refresh_token": "test-refresh-token", + }) + return + } w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid grant type: " + gt}) - return - } - apiToken := req.Form.Get("api_token") - if slices.Contains(ValidAPITokens, apiToken) { - _ = json.NewEncoder(w).Encode(struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - Type string `json:"token_type"` - }{AccessToken: accessTokens[0], ExpiresIn: 60, Type: "bearer"}) - return + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid API token"}) + + case "authorization_code": + code := req.Form.Get("code") + verifier := req.Form.Get("code_verifier") + pendingMu.Lock() + pending, ok := pendingAuths[code] + delete(pendingAuths, code) + pendingMu.Unlock() + if !ok { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid code"}) + return + } + // Verify PKCE S256 challenge. + h := sha256.Sum256([]byte(verifier)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + if expected != pending.codeChallenge { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid code_verifier"}) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": accessTokens[0], + "expires_in": 3600, + "token_type": "bearer", + "refresh_token": "test-refresh-token", + }) + + case "refresh_token": + if req.Form.Get("refresh_token") == "test-refresh-token" { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": accessTokens[0], + "expires_in": 3600, + "token_type": "bearer", + "refresh_token": "test-refresh-token", + }) + return + } + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid refresh token"}) + + default: + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid grant type: " + req.Form.Get("grant_type")}) } - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid API token"}) + }) + + mux.Post("/oauth2/revoke", func(w http.ResponseWriter, req *http.Request) { + require.NoError(t, req.ParseForm()) + token := req.Form.Get("token") + srv.revokedMu.Lock() + srv.revokedTokens = append(srv.revokedTokens, token) + srv.revokedMu.Unlock() + w.WriteHeader(http.StatusOK) }) mux.Get("/ssh/authority", func(w http.ResponseWriter, _ *http.Request) { @@ -98,7 +193,8 @@ func NewAuthServer(t *testing.T) *httptest.Server { }{string(ssh.MarshalAuthorizedKey(cert))}) }) - return httptest.NewServer(mux) + srv.Server = httptest.NewServer(mux) + return srv } // publicKeys returns the server's public keys, e.g. for SSH certificate generation.