From 1f4681d9e59df07ec233f1daa57947ca3d619233 Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Sat, 2 May 2026 19:50:20 +0200 Subject: [PATCH 1/2] feat(jwtinfo): add options to write the JWT token to a file and refresh it before it expires --- README.md | 18 ++ internal/cmd/jwtinfo.go | 94 ++++++++- internal/jwtinfo/jwtinfo.go | 232 +++++++++++++++++++--- internal/jwtinfo/jwtinfo_refresh_test.go | 236 +++++++++++++++++++++++ internal/jwtinfo/jwtinfo_test.go | 14 +- 5 files changed, 565 insertions(+), 29 deletions(-) create mode 100644 internal/jwtinfo/jwtinfo_refresh_test.go diff --git a/README.md b/README.md index 771b04b..cc66072 100644 --- a/README.md +++ b/README.md @@ -234,15 +234,21 @@ Examples: # Request and validate a JWT token https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --validation-url $VALIDATION_URL + # Request a JWT token, write it to a file and refresh it before expiration + https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --token-output-file /tmp/token --refresh + Usage: https-wrench jwtinfo [flags] Flags: -h, --help help for jwtinfo + --refresh Enable background token refresh before expiration + --renew-threshold float Token renewal threshold as a percentage of lifetime (default 80) --request-url string HTTP address to use for the JWT token request --request-values-file string File containing the JSON encoded values to use for the JWT token request --request-values-json string JSON encoded values to use for the JWT token request --token-file string File containing the JWT token + --token-output-file string File where the acquired/refreshed token will be written --validation-url string Url of the JSON Web Key Set (JWKS) to use for validating the JWT token Global Flags: @@ -258,6 +264,18 @@ Decode a token from a file: ❯ https-wrench jwtinfo --token-file mytoken.jwt ``` +Request a token and save it to a file: + +```shell +❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt +``` + +Request a token, save it to a file, and keep it refreshed in the background: + +```shell +❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt --refresh --renew-threshold 90 +``` + ### HTTPS Wrench jwks `jwks` generates a public JSON Web Key Set from a PEM-encoded public key. This is useful for exposing your public keys at a `.well-known/jwks.json` endpoint. diff --git a/internal/cmd/jwtinfo.go b/internal/cmd/jwtinfo.go index c9e26f6..b9675da 100644 --- a/internal/cmd/jwtinfo.go +++ b/internal/cmd/jwtinfo.go @@ -5,8 +5,12 @@ Copyright © 2026 Zeno Belli package cmd import ( + "context" "io" "net/http" + "os" + "os/signal" + "syscall" "github.com/MicahParks/keyfunc/v3" "github.com/spf13/cobra" @@ -19,11 +23,17 @@ var ( flagNameRequestURL = "request-url" flagNameTokenFile = "token-file" flagNameJwksURL = "validation-url" + flagNameRefresh = "refresh" + flagNameTokenOutputFile = "token-output-file" + flagNameRenewThreshold = "renew-threshold" requestJSONValues string requestValuesFile string requestURL string tokenFile string jwksURL string + refresh bool + tokenOutputFile string + renewThreshold float64 keyfuncDefOverride keyfunc.Override ) @@ -48,11 +58,14 @@ Examples: # Request and validate a JWT token https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --validation-url $VALIDATION_URL + + # Request a JWT token, write it to a file and refresh it before expiration + https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --token-output-file /tmp/token --refresh `, Run: func(cmd *cobra.Command, _ []string) { var ( err error - tokenData jwtinfo.JwtTokenData + tokenData *jwtinfo.JwtTokenData ) // TODO: remove global --config option @@ -115,7 +128,7 @@ Examples: } } - if tokenData.AccessTokenRaw != "" { + if tokenData != nil && tokenData.AccessTokenRaw != "" { err = tokenData.DecodeBase64() if err != nil { cmd.Printf("DecodeBase64 error: %s\n", err) @@ -135,6 +148,62 @@ Examples: cmd.Printf("error while printing token data: %s\n", err) return } + + if tokenOutputFile != "" { + tokenData.WriteTokenToFile(tokenOutputFile, cmd.OutOrStdout()) + } + + if refresh { + if requestURL == "" { + cmd.Printf("Error: --refresh requires --request-url\n") + return + } + + // Setup graceful shutdown + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sigCh + cancel() + }() + + // Note: RequestValuesMap and client are recreated here if needed or reused + // Since they were declared inside the if block, we reconstruct them or declare them outside + // But wait, requestValuesMap and client aren't in scope here. + // Let's redefine them for the refresh loop since they are just configured from flags + refreshClient := &http.Client{} + refreshValuesMap := make(map[string]string) + + if requestValuesFile != "" { + refreshValuesMap, _ = jwtinfo.ReadRequestValuesFile(requestValuesFile, refreshValuesMap) + } + + if requestJSONValues != "" { + refreshValuesMap, _ = jwtinfo.ParseRequestJSONValues(requestJSONValues, refreshValuesMap) + } + + cmd.Printf("Starting refresh loop...\n") + + err := tokenData.RefreshLoop( + ctx, + requestURL, + refreshValuesMap, + refreshClient, + io.ReadAll, + renewThreshold, + tokenOutputFile, + cmd.OutOrStdout(), + ) + if err != nil { + cmd.Printf("Refresh loop exited with error: %s\n", err) + } else { + cmd.Printf("Refresh loop stopped gracefully.\n") + } + } } else { _ = cmd.Help() } @@ -179,6 +248,27 @@ func init() { "Url of the JSON Web Key Set (JWKS) to use for validating the JWT token", ) + jwtinfoCmd.Flags().BoolVar( + &refresh, + flagNameRefresh, + false, + "Run in foreground and automatically refresh the token", + ) + + jwtinfoCmd.Flags().StringVar( + &tokenOutputFile, + flagNameTokenOutputFile, + "", + "File to write the refreshed token to", + ) + + jwtinfoCmd.Flags().Float64Var( + &renewThreshold, + flagNameRenewThreshold, + 80.0, + "Percentage of token lifetime to wait before refreshing", + ) + // Either read a token from a file or request it from an HTTP address jwtinfoCmd.MarkFlagsMutuallyExclusive(flagNameTokenFile, flagNameRequestURL) jwtinfoCmd.MarkFlagsOneRequired(flagNameTokenFile, flagNameRequestURL) diff --git a/internal/jwtinfo/jwtinfo.go b/internal/jwtinfo/jwtinfo.go index 18b338d..381cb03 100644 --- a/internal/jwtinfo/jwtinfo.go +++ b/internal/jwtinfo/jwtinfo.go @@ -57,20 +57,20 @@ type AllReader func(io.Reader) ([]byte, error) // response types. // //nolint:revive -func RequestToken(ctx context.Context, reqURL string, reqValues map[string]string, client *http.Client, readAll AllReader) (JwtTokenData, error) { +func RequestToken(ctx context.Context, reqURL string, reqValues map[string]string, client *http.Client, readAll AllReader) (*JwtTokenData, error) { if readAll == nil { - return JwtTokenData{}, errors.New("nil body reader function") + return nil, errors.New("nil body reader function") } if reqURL == emptyString { - return JwtTokenData{}, errors.New("empty string provided as request URL") + return nil, errors.New("empty string provided as request URL") } if len(reqValues) == 0 { - return JwtTokenData{}, errors.New("empty map provided as request values") + return nil, errors.New("empty map provided as request values") } - var t JwtTokenData + t := &JwtTokenData{} urlReqValues := url.Values{} for k, v := range reqValues { @@ -84,7 +84,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin strings.NewReader(urlReqValues.Encode()), ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "HTTP error while defining token data request: %w", err, ) @@ -96,13 +96,13 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin resp, err := client.Do(req) if err != nil { - return JwtTokenData{}, err + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "token request returned the following status code: %d", resp.StatusCode, ) @@ -112,7 +112,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) if errBodyRead != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to read body: %w", errBodyRead, ) @@ -125,7 +125,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin if mediaType == "application/json" { if err = json.NewDecoder(resp.Body).Decode(&t); err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "error validating token request data: %w", err, ) @@ -137,7 +137,7 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin &jwt.RegisteredClaims{}, ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to parse JWT token from HTTP response: %w", err, ) @@ -148,20 +148,20 @@ func RequestToken(ctx context.Context, reqURL string, reqValues map[string]strin // ReadTokenFromFile reads a JWT token string from the specified file. // It returns a JwtTokenData struct containing the raw token string. -func ReadTokenFromFile(fileName string) (JwtTokenData, error) { +func ReadTokenFromFile(fileName string) (*JwtTokenData, error) { data, err := os.ReadFile(fileName) if err != nil { - return JwtTokenData{}, fmt.Errorf("unable to read token file: %w", err) + return nil, fmt.Errorf("unable to read token file: %w", err) } - td := JwtTokenData{AccessTokenRaw: strings.TrimSpace(string(data))} + td := &JwtTokenData{AccessTokenRaw: strings.TrimSpace(string(data))} _, _, err = jwt.NewParser().ParseUnverified( td.AccessTokenRaw, &jwt.RegisteredClaims{}, ) if err != nil { - return JwtTokenData{}, fmt.Errorf( + return nil, fmt.Errorf( "unable to parse JWT token from file: %w", err, ) @@ -238,8 +238,6 @@ func isValidJSON(data []byte) bool { // DecodeBase64 decodes the base64-encoded header and claims of the access and // refresh tokens stored in the JwtTokenData struct. -// -//nolint:revive func (jtd *JwtTokenData) DecodeBase64() error { if jtd.AccessTokenRaw != emptyString { header, claims, err := decodeToken("AccessToken", jtd.AccessTokenRaw) @@ -252,13 +250,18 @@ func (jtd *JwtTokenData) DecodeBase64() error { } if jtd.RefreshTokenRaw != emptyString { - header, claims, err := decodeToken("RefreshToken", jtd.RefreshTokenRaw) - if err != nil { - return err - } + // A refresh token is not strictly required to be a JWT in OAuth2. + // If it has 3 parts, we attempt to decode it as a JWT. + // If it doesn't, we treat it as an opaque token and continue. + if strings.Count(jtd.RefreshTokenRaw, ".") == 2 { + header, claims, err := decodeToken("RefreshToken", jtd.RefreshTokenRaw) + if err != nil { + return err + } - jtd.RefreshTokenHeader = header - jtd.RefreshTokenClaims = claims + jtd.RefreshTokenHeader = header + jtd.RefreshTokenClaims = claims + } } return nil @@ -351,7 +354,7 @@ func (jtd *JwtTokenData) ParseWithJWKS(ctx context.Context, jwksURL string, keyf // to the provided writer in a human-readable format. // //nolint:revive -func PrintTokenInfo(jtd JwtTokenData, w io.Writer) error { +func PrintTokenInfo(jtd *JwtTokenData, w io.Writer) error { sl := style.CertKeyP4.Render sv := style.CertValue.Render sTrue := style.BoolTrue.Render @@ -478,3 +481,184 @@ func unmarshalTokenTimeClaims(claims []byte) (map[string]string, error) { return tokenClaims, nil } + +// GetExpiration extracts the expiration time (exp) from the token claims. +func (jtd *JwtTokenData) GetExpiration() (time.Time, error) { + if jtd.AccessTokenClaims == nil { + return time.Time{}, errors.New("access token claims are empty") + } + + var genericClaims map[string]any + if err := json.Unmarshal(jtd.AccessTokenClaims, &genericClaims); err != nil { + return time.Time{}, fmt.Errorf("unable to unmarshal claims: %w", err) + } + + if v, ok := genericClaims["exp"]; ok { + if vf, ok := v.(float64); ok { + return time.Unix(int64(vf), 0), nil + } + + return time.Time{}, errors.New("exp claim is not a numeric timestamp") + } + + return time.Time{}, errors.New("exp claim missing") +} + +// GetIssuedAt extracts the issued at time (iat) from the token claims. +func (jtd *JwtTokenData) GetIssuedAt() (time.Time, error) { + if jtd.AccessTokenClaims == nil { + return time.Time{}, errors.New("access token claims are empty") + } + + var genericClaims map[string]any + if err := json.Unmarshal(jtd.AccessTokenClaims, &genericClaims); err != nil { + return time.Time{}, fmt.Errorf("unable to unmarshal claims: %w", err) + } + + if v, ok := genericClaims["iat"]; ok { + if vf, ok := v.(float64); ok { + return time.Unix(int64(vf), 0), nil + } + + return time.Time{}, errors.New("iat claim is not a numeric timestamp") + } + + return time.Time{}, errors.New("iat claim missing") +} + +// Refresh attempts to acquire a new token either by using the refresh token or the original request values. +func (jtd *JwtTokenData) Refresh( + ctx context.Context, + reqURL string, + reqValues map[string]string, + client *http.Client, + readAll AllReader, +) error { + refreshValues := maps.Clone(reqValues) + if refreshValues == nil { + refreshValues = make(map[string]string) + } + + if jtd.RefreshTokenRaw != "" { + refreshValues["grant_type"] = "refresh_token" + refreshValues["refresh_token"] = jtd.RefreshTokenRaw + } + + newTokenData, err := RequestToken(ctx, reqURL, refreshValues, client, readAll) + if err != nil { + return fmt.Errorf("failed to request refreshed token: %w", err) + } + + if err := newTokenData.DecodeBase64(); err != nil { + return fmt.Errorf("failed to decode refreshed token: %w", err) + } + + jtd.AccessTokenRaw = newTokenData.AccessTokenRaw + jtd.AccessTokenJwt = newTokenData.AccessTokenJwt + jtd.AccessTokenHeader = newTokenData.AccessTokenHeader + jtd.AccessTokenClaims = newTokenData.AccessTokenClaims + + if newTokenData.RefreshTokenRaw != "" { + jtd.RefreshTokenRaw = newTokenData.RefreshTokenRaw + jtd.RefreshTokenJwt = newTokenData.RefreshTokenJwt + jtd.RefreshTokenHeader = newTokenData.RefreshTokenHeader + jtd.RefreshTokenClaims = newTokenData.RefreshTokenClaims + } + + return nil +} + +// RefreshLoop runs a loop that periodically refreshes the JWT token before it +// expires. +func (jtd *JwtTokenData) RefreshLoop( + ctx context.Context, + reqURL string, + reqValues map[string]string, + client *http.Client, + readAll AllReader, + renewThreshold float64, + outFileName string, + outWriter io.Writer, +) error { + for { + sleepFor, err := jtd.calculateWaitDuration(renewThreshold) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(sleepFor): + } + + if err := jtd.Refresh(ctx, reqURL, reqValues, client, readAll); err != nil { + fmt.Fprintf(outWriter, "Failed to refresh token: %v\n", err) + // Sleep before retrying on failure + select { + case <-ctx.Done(): + return nil + case <-time.After(10 * time.Second): + } + + continue + } + + jtd.WriteTokenToFile(outFileName, outWriter) + } +} + +// calculateWaitDuration determines how long to wait before the next token refresh +// based on the expiration time and the renewal threshold. +func (jtd *JwtTokenData) calculateWaitDuration(renewThreshold float64) (time.Duration, error) { + exp, err := jtd.GetExpiration() + if err != nil { + return 0, fmt.Errorf("unable to determine expiration: %w", err) + } + + iat, err := jtd.GetIssuedAt() + + var lifetime time.Duration + + var wakeTime time.Time + + if err == nil { + lifetime = exp.Sub(iat) + waitDuration := time.Duration(float64(lifetime) * (renewThreshold / 100.0)) + wakeTime = iat.Add(waitDuration) + } else { + // Fallback if iat is missing, use current time + lifetime = time.Until(exp) + waitDuration := time.Duration(float64(lifetime) * (renewThreshold / 100.0)) + wakeTime = time.Now().Add(waitDuration) + } + + if lifetime <= 0 { + return 0, errors.New("token lifetime is zero or negative") + } + + sleepFor := time.Until(wakeTime) + if sleepFor <= 0 { + // If we are already past the wake time, trigger a refresh immediately. + // But avoid a tight spin loop if refresh fails instantly. + sleepFor = 100 * time.Millisecond + } + + return sleepFor, nil +} + +// WriteTokenToFile handles the persistence or display of a newly +// acquired token, either writing it to a file or printing it to the console. +func (jtd *JwtTokenData) WriteTokenToFile(outFileName string, outWriter io.Writer) { + if outFileName != "" { + if err := os.WriteFile(outFileName, []byte(jtd.AccessTokenRaw), 0o600); err != nil { + fmt.Fprintf(outWriter, "Failed to write token to file %s: %v\n", outFileName, err) + } else { + ts := time.Now().Format(time.RFC3339) + fmt.Fprintf(outWriter, "[%s] Token persisted to %s\n", ts, outFileName) + } + } else { + fmt.Fprintf(outWriter, "\n--- Token Refreshed at %s ---\n", time.Now().Format(time.RFC3339)) + _ = PrintTokenInfo(jtd, outWriter) + } +} diff --git a/internal/jwtinfo/jwtinfo_refresh_test.go b/internal/jwtinfo/jwtinfo_refresh_test.go new file mode 100644 index 0000000..830caf5 --- /dev/null +++ b/internal/jwtinfo/jwtinfo_refresh_test.go @@ -0,0 +1,236 @@ +package jwtinfo + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type mockTransport struct { + roundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return m.roundTripFunc(req) +} + +func TestJwtTokenData_GetTimeClaims(t *testing.T) { + exp := time.Now().Add(1 * time.Hour).Unix() + iat := time.Now().Unix() + + claims := map[string]any{ + "exp": float64(exp), + "iat": float64(iat), + } + + claimsBytes, err := json.Marshal(claims) + require.NoError(t, err) + + jtd := JwtTokenData{ + AccessTokenClaims: claimsBytes, + } + + gotExp, err := jtd.GetExpiration() + require.NoError(t, err) + require.Equal(t, exp, gotExp.Unix()) + + gotIat, err := jtd.GetIssuedAt() + require.NoError(t, err) + require.Equal(t, iat, gotIat.Unix()) +} + +func TestJwtTokenData_RefreshLoop(t *testing.T) { + jtd, dummyToken := setupRefreshLoopTest(t) + reqCount := 0 + client := setupRefreshClient(t, &reqCount, dummyToken) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(300 * time.Millisecond) + cancel() + }() + + var buf bytes.Buffer + + err := jtd.RefreshLoop( + ctx, + "http://dummy.url", + map[string]string{"client_id": "foo"}, + client, + io.ReadAll, + 0.0, // trigger refresh immediately + "", + &buf, + ) + require.NoError(t, err) + + require.NotZero(t, reqCount, "expected refresh request to be made") + require.Equal(t, "new-refresh-token", jtd.RefreshTokenRaw) +} + +func setupRefreshLoopTest(t *testing.T) (*JwtTokenData, string) { + exp := time.Now().Add(5 * time.Second).Unix() + iat := time.Now().Unix() + claims := map[string]any{"exp": float64(exp), "iat": float64(iat)} + + claimsBytes, err := json.Marshal(claims) + require.NoError(t, err) + + jtd := &JwtTokenData{ + AccessTokenClaims: claimsBytes, + RefreshTokenRaw: "initial-refresh-token", + } + + header := `{"alg":"none"}` + b64Header := base64.RawURLEncoding.EncodeToString([]byte(header)) + newExp := time.Now().Add(1 * time.Hour).Unix() + newClaims := fmt.Sprintf(`{"exp":%d,"iat":%d}`, newExp, iat) + b64Claims := base64.RawURLEncoding.EncodeToString([]byte(newClaims)) + dummyToken := fmt.Sprintf("%s.%s.", b64Header, b64Claims) + + return jtd, dummyToken +} + +func setupRefreshClient(t *testing.T, reqCount *int, dummyToken string) *http.Client { + return &http.Client{ + Transport: &mockTransport{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + *reqCount++ + + if err := req.ParseForm(); err != nil { + t.Fatalf("ParseForm error: %v", err) + } + + if req.Form.Get("grant_type") != "refresh_token" { + t.Errorf("expected grant_type refresh_token, got %s", req.Form.Get("grant_type")) + } + + currentRefreshToken := req.Form.Get("refresh_token") + if currentRefreshToken == "" { + require.NotEmpty(t, currentRefreshToken, "expected refresh_token to be provided") + } + + respBody := fmt.Sprintf(`{"access_token": "%s", "refresh_token": "new-refresh-token"}`, dummyToken) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(respBody)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }, + }, + } +} + +func TestJwtTokenData_WriteTokenToFile(t *testing.T) { + jtd := &JwtTokenData{ + AccessTokenRaw: "initial-token", + } + + tempFile, err := os.CreateTemp("", "token-test-*") + require.NoError(t, err) + tempFile.Close() + + defer os.Remove(tempFile.Name()) + + var buf bytes.Buffer + jtd.WriteTokenToFile(tempFile.Name(), &buf) + + data, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + require.Equal(t, "initial-token", string(data)) + require.Contains(t, buf.String(), "Token persisted to") +} + +func TestJwtTokenData_TimingMethods_Errors(t *testing.T) { + tests := []struct { + name string + jtd *JwtTokenData + wantErr string + }{ + { + name: "nil_claims", + jtd: &JwtTokenData{AccessTokenClaims: nil}, + wantErr: "access token claims are empty", + }, + { + name: "invalid_json", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{invalid}`)}, + wantErr: "unable to unmarshal claims", + }, + { + name: "missing_exp", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{"iat":123}`)}, + wantErr: "exp claim missing", + }, + { + name: "non_numeric_exp", + jtd: &JwtTokenData{AccessTokenClaims: []byte(`{"exp":"not-a-number"}`)}, + wantErr: "exp claim is not a numeric timestamp", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.jtd.GetExpiration() + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + + _, err = tt.jtd.GetIssuedAt() + if tt.name == "nil_claims" || tt.name == "invalid_json" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} + +func TestJwtTokenData_RefreshLoop_ErrorRetry(t *testing.T) { + // Create token that expires very soon to trigger refresh + jtd := &JwtTokenData{ + AccessTokenRaw: "initial", + AccessTokenClaims: []byte(fmt.Sprintf(`{"exp":%d, "iat":%d}`, + time.Now().Add(100*time.Millisecond).Unix(), + time.Now().Add(-1*time.Hour).Unix())), + } + + // Mock client that fails + client := &http.Client{ + Transport: &mockTransport{ + roundTripFunc: func(_ *http.Request) (*http.Response, error) { + return nil, errors.New("network error") + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + var buf bytes.Buffer + // This will try to refresh, fail, then wait for 10s or context cancel. + // Since context is short, it should return nil when context expires. + err := jtd.RefreshLoop( + ctx, + "http://dummy.url", + map[string]string{"client_id": "foo"}, + client, + io.ReadAll, + 1.0, // trigger immediately + "", + &buf, + ) + require.NoError(t, err, "RefreshLoop should return nil on context cancel") + require.Contains(t, buf.String(), "Failed to refresh token") + require.Contains(t, buf.String(), "network error") +} diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 90e08e9..f4998f7 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -8,6 +8,7 @@ import ( "io" "maps" "os" + "strings" "testing" "time" @@ -695,7 +696,14 @@ func TestDecodeBase64(t *testing.T) { tdRefreshTokenTest := tdR tdRefreshTokenTest.RefreshTokenRaw = tt.tokenString err = tdRefreshTokenTest.DecodeBase64() - require.ErrorContains(t, err, tt.errMsg) + + // Special case: RefreshToken is allowed to be a non-JWT string. + // It only fails if it *looks* like a JWT (3 parts) but is invalid. + if strings.Count(tt.tokenString, ".") != 2 { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tt.errMsg) + } }) } } @@ -981,7 +989,7 @@ func TestPrintTokenInfo_Errors(t *testing.T) { AccessTokenClaims: []byte(claimsJSON), } - err := PrintTokenInfo(jtd, &buffer) + err := PrintTokenInfo(&jtd, &buffer) require.NoError(t, err) // The json.Indent for header failed and it wrote the raw header. @@ -997,7 +1005,7 @@ func TestPrintTokenInfo_Errors(t *testing.T) { AccessTokenClaims: []byte("invalid json"), } - err := PrintTokenInfo(jtd, &buffer) + err := PrintTokenInfo(&jtd, &buffer) require.Error(t, err) require.ErrorContains(t, err, "unable to unmarshal time claims from AccessToken") }) From 9a25ae045dc8cfdb747725c5128480199c03f49d Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Sat, 2 May 2026 20:29:27 +0200 Subject: [PATCH 2/2] fix: threshold validation, redundant input values, partial output file writes --- README.md | 4 +- internal/cmd/jwtinfo.go | 28 +++-------- internal/jwtinfo/jwtinfo.go | 62 +++++++++++++++++++++--- internal/jwtinfo/jwtinfo_refresh_test.go | 12 +++++ 4 files changed, 74 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index cc66072..29181a5 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ Usage: Flags: -h, --help help for jwtinfo - --refresh Enable background token refresh before expiration + --refresh Run in foreground and automatically refresh the token --renew-threshold float Token renewal threshold as a percentage of lifetime (default 80) --request-url string HTTP address to use for the JWT token request --request-values-file string File containing the JSON encoded values to use for the JWT token request @@ -270,7 +270,7 @@ Request a token and save it to a file: ❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt ``` -Request a token, save it to a file, and keep it refreshed in the background: +Request a token, save it to a file, and keep it refreshed until interrupted: ```shell ❯ https-wrench jwtinfo --request-url https://auth.example.com/token --request-values-json '{"client_id":"foo"}' --token-output-file ./token.jwt --refresh --renew-threshold 90 diff --git a/internal/cmd/jwtinfo.go b/internal/cmd/jwtinfo.go index b9675da..fe61b00 100644 --- a/internal/cmd/jwtinfo.go +++ b/internal/cmd/jwtinfo.go @@ -64,8 +64,10 @@ Examples: `, Run: func(cmd *cobra.Command, _ []string) { var ( - err error - tokenData *jwtinfo.JwtTokenData + err error + tokenData *jwtinfo.JwtTokenData + client = &http.Client{} + requestValuesMap = make(map[string]string) ) // TODO: remove global --config option @@ -82,9 +84,6 @@ Examples: } if requestURL != "" { - client := &http.Client{} - requestValuesMap := make(map[string]string) - if requestValuesFile != "" { requestValuesMap, err = jwtinfo.ReadRequestValuesFile( requestValuesFile, @@ -171,28 +170,13 @@ Examples: cancel() }() - // Note: RequestValuesMap and client are recreated here if needed or reused - // Since they were declared inside the if block, we reconstruct them or declare them outside - // But wait, requestValuesMap and client aren't in scope here. - // Let's redefine them for the refresh loop since they are just configured from flags - refreshClient := &http.Client{} - refreshValuesMap := make(map[string]string) - - if requestValuesFile != "" { - refreshValuesMap, _ = jwtinfo.ReadRequestValuesFile(requestValuesFile, refreshValuesMap) - } - - if requestJSONValues != "" { - refreshValuesMap, _ = jwtinfo.ParseRequestJSONValues(requestJSONValues, refreshValuesMap) - } - cmd.Printf("Starting refresh loop...\n") err := tokenData.RefreshLoop( ctx, requestURL, - refreshValuesMap, - refreshClient, + requestValuesMap, + client, io.ReadAll, renewThreshold, tokenOutputFile, diff --git a/internal/jwtinfo/jwtinfo.go b/internal/jwtinfo/jwtinfo.go index 381cb03..30d6bf2 100644 --- a/internal/jwtinfo/jwtinfo.go +++ b/internal/jwtinfo/jwtinfo.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "strconv" "strings" "time" @@ -611,6 +612,10 @@ func (jtd *JwtTokenData) RefreshLoop( // calculateWaitDuration determines how long to wait before the next token refresh // based on the expiration time and the renewal threshold. func (jtd *JwtTokenData) calculateWaitDuration(renewThreshold float64) (time.Duration, error) { + if renewThreshold < 0 || renewThreshold > 100 { + return 0, fmt.Errorf("renewThreshold must be between 0 and 100, got %.2f", renewThreshold) + } + exp, err := jtd.GetExpiration() if err != nil { return 0, fmt.Errorf("unable to determine expiration: %w", err) @@ -650,15 +655,56 @@ func (jtd *JwtTokenData) calculateWaitDuration(renewThreshold float64) (time.Dur // WriteTokenToFile handles the persistence or display of a newly // acquired token, either writing it to a file or printing it to the console. func (jtd *JwtTokenData) WriteTokenToFile(outFileName string, outWriter io.Writer) { - if outFileName != "" { - if err := os.WriteFile(outFileName, []byte(jtd.AccessTokenRaw), 0o600); err != nil { - fmt.Fprintf(outWriter, "Failed to write token to file %s: %v\n", outFileName, err) - } else { - ts := time.Now().Format(time.RFC3339) - fmt.Fprintf(outWriter, "[%s] Token persisted to %s\n", ts, outFileName) - } - } else { + if outFileName == "" { fmt.Fprintf(outWriter, "\n--- Token Refreshed at %s ---\n", time.Now().Format(time.RFC3339)) _ = PrintTokenInfo(jtd, outWriter) + + return + } + + dir := filepath.Dir(outFileName) + + tmp, err := os.CreateTemp(dir, ".token-*") + if err != nil { + fmt.Fprintf(outWriter, "Failed to create temp token file for %s: %v\n", outFileName, err) + return + } + + tmpName := tmp.Name() + + if _, err := tmp.WriteString(jtd.AccessTokenRaw); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to write token to temp file for %s: %v\n", outFileName, err) + + return } + + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to close temp token file for %s: %v\n", outFileName, err) + + return + } + + if err := os.Chmod(tmpName, 0o600); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to set token file permissions for %s: %v\n", outFileName, err) + + return + } + + if err := os.Rename(tmpName, outFileName); err != nil { + _ = os.Remove(tmpName) + + fmt.Fprintf(outWriter, "Failed to replace token file %s: %v\n", outFileName, err) + + return + } + + ts := time.Now().Format(time.RFC3339) + fmt.Fprintf(outWriter, "[%s] Token persisted to %s\n", ts, outFileName) } diff --git a/internal/jwtinfo/jwtinfo_refresh_test.go b/internal/jwtinfo/jwtinfo_refresh_test.go index 830caf5..b1169fd 100644 --- a/internal/jwtinfo/jwtinfo_refresh_test.go +++ b/internal/jwtinfo/jwtinfo_refresh_test.go @@ -234,3 +234,15 @@ func TestJwtTokenData_RefreshLoop_ErrorRetry(t *testing.T) { require.Contains(t, buf.String(), "Failed to refresh token") require.Contains(t, buf.String(), "network error") } + +func TestJwtTokenData_CalculateWaitDuration_Validation(t *testing.T) { + jtd := &JwtTokenData{} + + _, err := jtd.calculateWaitDuration(-1.0) + require.Error(t, err) + require.Contains(t, err.Error(), "renewThreshold must be between 0 and 100") + + _, err = jtd.calculateWaitDuration(101.0) + require.Error(t, err) + require.Contains(t, err.Error(), "renewThreshold must be between 0 and 100") +}