diff --git a/README.md b/README.md index 771b04b..29181a5 100644 --- a/README.md +++ b/README.md @@ -234,15 +234,21 @@ Examples: # Request and validate a JWT token https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --validation-url $VALIDATION_URL + # Request a JWT token, write it to a file and refresh it before expiration + https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --token-output-file /tmp/token --refresh + Usage: https-wrench jwtinfo [flags] Flags: -h, --help help for jwtinfo + --refresh Run in foreground and automatically refresh the token + --renew-threshold float Token renewal threshold as a percentage of lifetime (default 80) --request-url string HTTP address to use for the JWT token request --request-values-file string File containing the JSON encoded values to use for the JWT token request --request-values-json string JSON encoded values to use for the JWT token request --token-file string File containing the JWT token + --token-output-file string File where the acquired/refreshed token will be written --validation-url string Url of the JSON Web Key Set (JWKS) to use for validating the JWT token Global Flags: @@ -258,6 +264,18 @@ Decode a token from a file: ❯ https-wrench jwtinfo --token-file mytoken.jwt ``` +Request a token and save it to a file: + +```shell +❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt +``` + +Request a token, save it to a file, and keep it refreshed until interrupted: + +```shell +❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt --refresh --renew-threshold 90 +``` + ### HTTPS Wrench jwks `jwks` generates a public JSON Web Key Set from a PEM-encoded public key. This is useful for exposing your public keys at a `.well-known/jwks.json` endpoint. diff --git a/internal/cmd/jwtinfo.go b/internal/cmd/jwtinfo.go index c9e26f6..fe61b00 100644 --- a/internal/cmd/jwtinfo.go +++ b/internal/cmd/jwtinfo.go @@ -5,8 +5,12 @@ Copyright © 2026 Zeno Belli package cmd import ( + "context" "io" "net/http" + "os" + "os/signal" + "syscall" "github.com/MicahParks/keyfunc/v3" "github.com/spf13/cobra" @@ -19,11 +23,17 @@ var ( flagNameRequestURL = "request-url" flagNameTokenFile = "token-file" flagNameJwksURL = "validation-url" + flagNameRefresh = "refresh" + flagNameTokenOutputFile = "token-output-file" + flagNameRenewThreshold = "renew-threshold" requestJSONValues string requestValuesFile string requestURL string tokenFile string jwksURL string + refresh bool + tokenOutputFile string + renewThreshold float64 keyfuncDefOverride keyfunc.Override ) @@ -48,11 +58,16 @@ Examples: # Request and validate a JWT token https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --validation-url $VALIDATION_URL + + # Request a JWT token, write it to a file and refresh it before expiration + https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --token-output-file /tmp/token --refresh `, Run: func(cmd *cobra.Command, _ []string) { var ( - err error - tokenData jwtinfo.JwtTokenData + err error + tokenData *jwtinfo.JwtTokenData + client = &http.Client{} + requestValuesMap = make(map[string]string) ) // TODO: remove global --config option @@ -69,9 +84,6 @@ Examples: } if requestURL != "" { - client := &http.Client{} - requestValuesMap := make(map[string]string) - if requestValuesFile != "" { requestValuesMap, err = jwtinfo.ReadRequestValuesFile( requestValuesFile, @@ -115,7 +127,7 @@ Examples: } } - if tokenData.AccessTokenRaw != "" { + if tokenData != nil && tokenData.AccessTokenRaw != "" { err = tokenData.DecodeBase64() if err != nil { cmd.Printf("DecodeBase64 error: %s\n", err) @@ -135,6 +147,47 @@ Examples: cmd.Printf("error while printing token data: %s\n", err) return } + + if tokenOutputFile != "" { + tokenData.WriteTokenToFile(tokenOutputFile, cmd.OutOrStdout()) + } + + if refresh { + if requestURL == "" { + cmd.Printf("Error: --refresh requires --request-url\n") + return + } + + // Setup graceful shutdown + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sigCh + cancel() + }() + + cmd.Printf("Starting refresh loop...\n") + + err := tokenData.RefreshLoop( + ctx, + requestURL, + requestValuesMap, + client, + io.ReadAll, + renewThreshold, + tokenOutputFile, + cmd.OutOrStdout(), + ) + if err != nil { + cmd.Printf("Refresh loop exited with error: %s\n", err) + } else { + cmd.Printf("Refresh loop stopped gracefully.\n") + } + } } else { _ = cmd.Help() } @@ -179,6 +232,27 @@ func init() { "Url of the JSON Web Key Set (JWKS) to use for validating the JWT token", ) + jwtinfoCmd.Flags().BoolVar( + &refresh, + flagNameRefresh, + false, + "Run in foreground and automatically refresh the token", + ) + + jwtinfoCmd.Flags().StringVar( + &tokenOutputFile, + flagNameTokenOutputFile, + "", + "File to write the refreshed token to", + ) + + jwtinfoCmd.Flags().Float64Var( + &renewThreshold, + flagNameRenewThreshold, + 80.0, + "Percentage of token lifetime to wait before refreshing", + ) + // Either read a token from a file or request it from an HTTP address jwtinfoCmd.MarkFlagsMutuallyExclusive(flagNameTokenFile, flagNameRequestURL) jwtinfoCmd.MarkFlagsOneRequired(flagNameTokenFile, flagNameRequestURL) diff --git a/internal/jwtinfo/jwtinfo.go b/internal/jwtinfo/jwtinfo.go index 18b338d..30d6bf2 100644 --- a/internal/jwtinfo/jwtinfo.go +++ b/internal/jwtinfo/jwtinfo.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "strconv" "strings" "time" @@ -57,20 +58,20 @@ type AllReader func(io.Reader) ([]byte, error) // response types. // //nolint:revive -func RequestToken(ctx context.Context, reqURL string, reqValues map[string]string, client *http.Client, readAll AllReader) (JwtTokenData, error) { +func RequestToken(ctx context.Context, reqURL string, reqValues map[string]string, client *http.Client, readAll AllReader) (*JwtTokenData, error) { if readAll == nil { - return JwtTokenData{}, errors.New("nil body reader function") + return nil, errors.New("nil body reader function") } if reqURL == emptyString { - return JwtTokenData{}, errors.New("empty string provided as request URL") + return nil, errors.New("empty string provided as request URL") } if len(reqValues) == 0 { - return JwtTokenData{}, errors.New("empty map provided as request values") + return nil, errors.New("empty map provided as request values") } - var t JwtTokenData + t := &JwtTokenData{} urlReqValues := url.Values{} for k, v := range reqValues { @@ -84,7 +85,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin strings.NewReader(urlReqValues.Encode()), ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "HTTP error while defining token data request: %w", err, ) @@ -96,13 +97,13 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin resp, err := client.Do(req) if err != nil { - return JwtTokenData{}, err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "token request returned the following status code: %d", resp.StatusCode, ) @@ -112,7 +113,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) if errBodyRead != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to read body: %w", errBodyRead, ) @@ -125,7 +126,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin if mediaType == "application/json" { if err = json.NewDecoder(resp.Body).Decode(&t); err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "error validating token request data: %w", err, ) @@ -137,7 +138,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin &jwt.RegisteredClaims{}, ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to parse JWT token from HTTP response: %w", err, ) @@ -148,20 +149,20 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin // ReadTokenFromFile reads a JWT token string from the specified file. // It returns a JwtTokenData struct containing the raw token string. -func ReadTokenFromFile(fileName string) (JwtTokenData, error) { +func ReadTokenFromFile(fileName string) (*JwtTokenData, error) { data, err := os.ReadFile(fileName) if err != nil { - return JwtTokenData{}, fmt.Errorf("unable to read token file: %w", err) + return nil, fmt.Errorf("unable to read token file: %w", err) } - td := JwtTokenData{AccessTokenRaw: strings.TrimSpace(string(data))} + td := &JwtTokenData{AccessTokenRaw: strings.TrimSpace(string(data))} _, _, err = jwt.NewParser().ParseUnverified( td.AccessTokenRaw, &jwt.RegisteredClaims{}, ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to parse JWT token from file: %w", err, ) @@ -238,8 +239,6 @@ func isValidJSON(data []byte) bool { // DecodeBase64 decodes the base64-encoded header and claims of the access and // refresh tokens stored in the JwtTokenData struct. -// -//nolint:revive func (jtd *JwtTokenData) DecodeBase64() error { if jtd.AccessTokenRaw != emptyString { header, claims, err := decodeToken("AccessToken", jtd.AccessTokenRaw) @@ -252,13 +251,18 @@ func (jtd *JwtTokenData) DecodeBase64() error { } if jtd.RefreshTokenRaw != emptyString { - header, claims, err := decodeToken("RefreshToken", jtd.RefreshTokenRaw) - if err != nil { - return err - } + // A refresh token is not strictly required to be a JWT in OAuth2. + // If it has 3 parts, we attempt to decode it as a JWT. + // If it doesn't, we treat it as an opaque token and continue. + if strings.Count(jtd.RefreshTokenRaw, ".") == 2 { + header, claims, err := decodeToken("RefreshToken", jtd.RefreshTokenRaw) + if err != nil { + return err + } - jtd.RefreshTokenHeader = header - jtd.RefreshTokenClaims = claims + jtd.RefreshTokenHeader = header + jtd.RefreshTokenClaims = claims + } } return nil @@ -351,7 +355,7 @@ func (jtd *JwtTokenData) ParseWithJWKS(ctx context.Context, jwksURL string, keyf // to the provided writer in a human-readable format. // //nolint:revive -func PrintTokenInfo(jtd JwtTokenData, w io.Writer) error { +func PrintTokenInfo(jtd *JwtTokenData, w io.Writer) error { sl := style.CertKeyP4.Render sv := style.CertValue.Render sTrue := style.BoolTrue.Render @@ -478,3 +482,229 @@ func unmarshalTokenTimeClaims(claims []byte) (map[string]string, error) { return tokenClaims, nil } + +// GetExpiration extracts the expiration time (exp) from the token claims. +func (jtd *JwtTokenData) GetExpiration() (time.Time, error) { + if jtd.AccessTokenClaims == nil { + return time.Time{}, errors.New("access token claims are empty") + } + + var genericClaims map[string]any + if err := json.Unmarshal(jtd.AccessTokenClaims, &genericClaims); err != nil { + return time.Time{}, fmt.Errorf("unable to unmarshal claims: %w", err) + } + + if v, ok := genericClaims["exp"]; ok { + if vf, ok := v.(float64); ok { + return time.Unix(int64(vf), 0), nil + } + + return time.Time{}, errors.New("exp claim is not a numeric timestamp") + } + + return time.Time{}, errors.New("exp claim missing") +} + +// GetIssuedAt extracts the issued at time (iat) from the token claims. +func (jtd *JwtTokenData) GetIssuedAt() (time.Time, error) { + if jtd.AccessTokenClaims == nil { + return time.Time{}, errors.New("access token claims are empty") + } + + var genericClaims map[string]any + if err := json.Unmarshal(jtd.AccessTokenClaims, &genericClaims); err != nil { + return time.Time{}, fmt.Errorf("unable to unmarshal claims: %w", err) + } + + if v, ok := genericClaims["iat"]; ok { + if vf, ok := v.(float64); ok { + return time.Unix(int64(vf), 0), nil + } + + return time.Time{}, errors.New("iat claim is not a numeric timestamp") + } + + return time.Time{}, errors.New("iat claim missing") +} + +// Refresh attempts to acquire a new token either by using the refresh token or the original request values. +func (jtd *JwtTokenData) Refresh( + ctx context.Context, + reqURL string, + reqValues map[string]string, + client *http.Client, + readAll AllReader, +) error { + refreshValues := maps.Clone(reqValues) + if refreshValues == nil { + refreshValues = make(map[string]string) + } + + if jtd.RefreshTokenRaw != "" { + refreshValues["grant_type"] = "refresh_token" + refreshValues["refresh_token"] = jtd.RefreshTokenRaw + } + + newTokenData, err := RequestToken(ctx, reqURL, refreshValues, client, readAll) + if err != nil { + return fmt.Errorf("failed to request refreshed token: %w", err) + } + + if err := newTokenData.DecodeBase64(); err != nil { + return fmt.Errorf("failed to decode refreshed token: %w", err) + } + + jtd.AccessTokenRaw = newTokenData.AccessTokenRaw + jtd.AccessTokenJwt = newTokenData.AccessTokenJwt + jtd.AccessTokenHeader = newTokenData.AccessTokenHeader + jtd.AccessTokenClaims = newTokenData.AccessTokenClaims + + if newTokenData.RefreshTokenRaw != "" { + jtd.RefreshTokenRaw = newTokenData.RefreshTokenRaw + jtd.RefreshTokenJwt = newTokenData.RefreshTokenJwt + jtd.RefreshTokenHeader = newTokenData.RefreshTokenHeader + jtd.RefreshTokenClaims = newTokenData.RefreshTokenClaims + } + + return nil +} + +// RefreshLoop runs a loop that periodically refreshes the JWT token before it +// expires. +func (jtd *JwtTokenData) RefreshLoop( + ctx context.Context, + reqURL string, + reqValues map[string]string, + client *http.Client, + readAll AllReader, + renewThreshold float64, + outFileName string, + outWriter io.Writer, +) error { + for { + sleepFor, err := jtd.calculateWaitDuration(renewThreshold) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(sleepFor): + } + + if err := jtd.Refresh(ctx, reqURL, reqValues, client, readAll); err != nil { + fmt.Fprintf(outWriter, "Failed to refresh token: %v\n", err) + // Sleep before retrying on failure + select { + case <-ctx.Done(): + return nil + case <-time.After(10 * time.Second): + } + + continue + } + + jtd.WriteTokenToFile(outFileName, outWriter) + } +} + +// calculateWaitDuration determines how long to wait before the next token refresh +// based on the expiration time and the renewal threshold. +func (jtd *JwtTokenData) calculateWaitDuration(renewThreshold float64) (time.Duration, error) { + if renewThreshold < 0 || renewThreshold > 100 { + return 0, fmt.Errorf("renewThreshold must be between 0 and 100, got %.2f", renewThreshold) + } + + exp, err := jtd.GetExpiration() + if err != nil { + return 0, fmt.Errorf("unable to determine expiration: %w", err) + } + + iat, err := jtd.GetIssuedAt() + + var lifetime time.Duration + + var wakeTime time.Time + + if err == nil { + lifetime = exp.Sub(iat) + waitDuration := time.Duration(float64(lifetime) * (renewThreshold / 100.0)) + wakeTime = iat.Add(waitDuration) + } else { + // Fallback if iat is missing, use current time + lifetime = time.Until(exp) + waitDuration := time.Duration(float64(lifetime) * (renewThreshold / 100.0)) + wakeTime = time.Now().Add(waitDuration) + } + + if lifetime <= 0 { + return 0, errors.New("token lifetime is zero or negative") + } + + sleepFor := time.Until(wakeTime) + if sleepFor <= 0 { + // If we are already past the wake time, trigger a refresh immediately. + // But avoid a tight spin loop if refresh fails instantly. + sleepFor = 100 * time.Millisecond + } + + return sleepFor, nil +} + +// WriteTokenToFile handles the persistence or display of a newly +// acquired token, either writing it to a file or printing it to the console. +func (jtd *JwtTokenData) WriteTokenToFile(outFileName string, outWriter io.Writer) { + if outFileName == "" { + fmt.Fprintf(outWriter, "\n--- Token Refreshed at %s ---\n", time.Now().Format(time.RFC3339)) + _ = PrintTokenInfo(jtd, outWriter) + + return + } + + dir := filepath.Dir(outFileName) + + tmp, err := os.CreateTemp(dir, ".token-*") + if err != nil { + fmt.Fprintf(outWriter, "Failed to create temp token file for %s: %v\n", outFileName, err) + return + } + + tmpName := tmp.Name() + + if _, err := tmp.WriteString(jtd.AccessTokenRaw); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to write token to temp file for %s: %v\n", outFileName, err) + + return + } + + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to close temp token file for %s: %v\n", outFileName, err) + + return + } + + if err := os.Chmod(tmpName, 0o600); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to set token file permissions for %s: %v\n", outFileName, err) + + return + } + + if err := os.Rename(tmpName, outFileName); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to replace token file %s: %v\n", outFileName, err) + + return + } + + ts := time.Now().Format(time.RFC3339) + fmt.Fprintf(outWriter, "[%s] Token persisted to %s\n", ts, outFileName) +} diff --git a/internal/jwtinfo/jwtinfo_refresh_test.go b/internal/jwtinfo/jwtinfo_refresh_test.go new file mode 100644 index 0000000..b1169fd --- /dev/null +++ b/internal/jwtinfo/jwtinfo_refresh_test.go @@ -0,0 +1,248 @@ +package jwtinfo + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type mockTransport struct { + roundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTripFunc(req) +} + +func TestJwtTokenData_GetTimeClaims(t *testing.T) { + exp := time.Now().Add(1 * time.Hour).Unix() + iat := time.Now().Unix() + + claims := map[string]any{ + "exp": float64(exp), + "iat": float64(iat), + } + + claimsBytes, err := json.Marshal(claims) + require.NoError(t, err) + + jtd := JwtTokenData{ + AccessTokenClaims: claimsBytes, + } + + gotExp, err := jtd.GetExpiration() + require.NoError(t, err) + require.Equal(t, exp, gotExp.Unix()) + + gotIat, err := jtd.GetIssuedAt() + require.NoError(t, err) + require.Equal(t, iat, gotIat.Unix()) +} + +func TestJwtTokenData_RefreshLoop(t *testing.T) { + jtd, dummyToken := setupRefreshLoopTest(t) + reqCount := 0 + client := setupRefreshClient(t, &reqCount, dummyToken) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(300 * time.Millisecond) + cancel() + }() + + var buf bytes.Buffer + + err := jtd.RefreshLoop( + ctx, + "http://dummy.url", + map[string]string{"client_id": "foo"}, + client, + io.ReadAll, + 0.0, // trigger refresh immediately + "", + &buf, + ) + require.NoError(t, err) + + require.NotZero(t, reqCount, "expected refresh request to be made") + require.Equal(t, "new-refresh-token", jtd.RefreshTokenRaw) +} + +func setupRefreshLoopTest(t *testing.T) (*JwtTokenData, string) { + exp := time.Now().Add(5 * time.Second).Unix() + iat := time.Now().Unix() + claims := map[string]any{"exp": float64(exp), "iat": float64(iat)} + + claimsBytes, err := json.Marshal(claims) + require.NoError(t, err) + + jtd := &JwtTokenData{ + AccessTokenClaims: claimsBytes, + RefreshTokenRaw: "initial-refresh-token", + } + + header := `{"alg":"none"}` + b64Header := base64.RawURLEncoding.EncodeToString([]byte(header)) + newExp := time.Now().Add(1 * time.Hour).Unix() + newClaims := fmt.Sprintf(`{"exp":%d,"iat":%d}`, newExp, iat) + b64Claims := base64.RawURLEncoding.EncodeToString([]byte(newClaims)) + dummyToken := fmt.Sprintf("%s.%s.", b64Header, b64Claims) + + return jtd, dummyToken +} + +func setupRefreshClient(t *testing.T, reqCount *int, dummyToken string) *http.Client { + return &http.Client{ + Transport: &mockTransport{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + *reqCount++ + + if err := req.ParseForm(); err != nil { + t.Fatalf("ParseForm error: %v", err) + } + + if req.Form.Get("grant_type") != "refresh_token" { + t.Errorf("expected grant_type refresh_token, got %s", req.Form.Get("grant_type")) + } + + currentRefreshToken := req.Form.Get("refresh_token") + if currentRefreshToken == "" { + require.NotEmpty(t, currentRefreshToken, "expected refresh_token to be provided") + } + + respBody := fmt.Sprintf(`{"access_token": "%s", "refresh_token": "new-refresh-token"}`, dummyToken) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(respBody)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }, + }, + } +} + +func TestJwtTokenData_WriteTokenToFile(t *testing.T) { + jtd := &JwtTokenData{ + AccessTokenRaw: "initial-token", + } + + tempFile, err := os.CreateTemp("", "token-test-*") + require.NoError(t, err) + tempFile.Close() + + defer os.Remove(tempFile.Name()) + + var buf bytes.Buffer + jtd.WriteTokenToFile(tempFile.Name(), &buf) + + data, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + require.Equal(t, "initial-token", string(data)) + require.Contains(t, buf.String(), "Token persisted to") +} + +func TestJwtTokenData_TimingMethods_Errors(t *testing.T) { + tests := []struct { + name string + jtd *JwtTokenData + wantErr string + }{ + { + name: "nil_claims", + jtd: &JwtTokenData{AccessTokenClaims: nil}, + wantErr: "access token claims are empty", + }, + { + name: "invalid_json", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{invalid}`)}, + wantErr: "unable to unmarshal claims", + }, + { + name: "missing_exp", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{"iat":123}`)}, + wantErr: "exp claim missing", + }, + { + name: "non_numeric_exp", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{"exp":"not-a-number"}`)}, + wantErr: "exp claim is not a numeric timestamp", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.jtd.GetExpiration() + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + + _, err = tt.jtd.GetIssuedAt() + if tt.name == "nil_claims" || tt.name == "invalid_json" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} + +func TestJwtTokenData_RefreshLoop_ErrorRetry(t *testing.T) { + // Create token that expires very soon to trigger refresh + jtd := &JwtTokenData{ + AccessTokenRaw: "initial", + AccessTokenClaims: []byte(fmt.Sprintf(`{"exp":%d, "iat":%d}`, + time.Now().Add(100*time.Millisecond).Unix(), + time.Now().Add(-1*time.Hour).Unix())), + } + + // Mock client that fails + client := &http.Client{ + Transport: &mockTransport{ + roundTripFunc: func(_ *http.Request) (*http.Response, error) { + return nil, errors.New("network error") + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + var buf bytes.Buffer + // This will try to refresh, fail, then wait for 10s or context cancel. + // Since context is short, it should return nil when context expires. + err := jtd.RefreshLoop( + ctx, + "http://dummy.url", + map[string]string{"client_id": "foo"}, + client, + io.ReadAll, + 1.0, // trigger immediately + "", + &buf, + ) + require.NoError(t, err, "RefreshLoop should return nil on context cancel") + require.Contains(t, buf.String(), "Failed to refresh token") + require.Contains(t, buf.String(), "network error") +} + +func TestJwtTokenData_CalculateWaitDuration_Validation(t *testing.T) { + jtd := &JwtTokenData{} + + _, err := jtd.calculateWaitDuration(-1.0) + require.Error(t, err) + require.Contains(t, err.Error(), "renewThreshold must be between 0 and 100") + + _, err = jtd.calculateWaitDuration(101.0) + require.Error(t, err) + require.Contains(t, err.Error(), "renewThreshold must be between 0 and 100") +} diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 90e08e9..f4998f7 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -8,6 +8,7 @@ import ( "io" "maps" "os" + "strings" "testing" "time" @@ -695,7 +696,14 @@ func TestDecodeBase64(t *testing.T) { tdRefreshTokenTest := tdR tdRefreshTokenTest.RefreshTokenRaw = tt.tokenString err = tdRefreshTokenTest.DecodeBase64() - require.ErrorContains(t, err, tt.errMsg) + + // Special case: RefreshToken is allowed to be a non-JWT string. + // It only fails if it *looks* like a JWT (3 parts) but is invalid. + if strings.Count(tt.tokenString, ".") != 2 { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tt.errMsg) + } }) } } @@ -981,7 +989,7 @@ func TestPrintTokenInfo_Errors(t *testing.T) { AccessTokenClaims: []byte(claimsJSON), } - err := PrintTokenInfo(jtd, &buffer) + err := PrintTokenInfo(&jtd, &buffer) require.NoError(t, err) // The json.Indent for header failed and it wrote the raw header. @@ -997,7 +1005,7 @@ func TestPrintTokenInfo_Errors(t *testing.T) { AccessTokenClaims: []byte("invalid json"), } - err := PrintTokenInfo(jtd, &buffer) + err := PrintTokenInfo(&jtd, &buffer) require.Error(t, err) require.ErrorContains(t, err, "unable to unmarshal time claims from AccessToken") })