diff --git a/sonar/retry.go b/sonar/retry.go index 9b14db8..1dc4fc3 100644 --- a/sonar/retry.go +++ b/sonar/retry.go @@ -28,6 +28,13 @@ type RetryOptions struct { MaxDelay time.Duration // RetryableStatusCodes lists HTTP status codes that should trigger a retry. RetryableStatusCodes []int + // RetryNonIdempotent enables automatic retries for non-idempotent methods + // (POST, PATCH). It is disabled by default: only idempotent methods (GET, + // HEAD, PUT, DELETE, OPTIONS, TRACE) are retried, because resending a + // non-idempotent request after a network error or 5xx can duplicate a + // server-side side effect (for example creating a resource twice). Enable + // this only for endpoints you know are safe to repeat. + RetryNonIdempotent bool } // WithRetry is a ClientOptionFunc that enables opt-in retry with exponential @@ -52,6 +59,10 @@ type retryRoundTripper struct { // errors with exponential backoff and full jitter. Retries stop immediately when // the request context is cancelled or its deadline is exceeded. // +// Only idempotent methods are retried by default; non-idempotent methods (POST, +// PATCH) are passed straight through unless RetryNonIdempotent is set, so a +// resend cannot accidentally duplicate a server-side side effect. +// // When retries occur, the final response carries an X-Retry-Attempts header with // the total number of attempts made. // @@ -61,6 +72,12 @@ func (r *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return r.base.RoundTrip(req) } + // Non-idempotent methods are not retried unless explicitly opted in, to avoid + // duplicating server-side side effects on a resend. + if !r.opts.RetryNonIdempotent && !isIdempotentMethod(req.Method) { + return r.base.RoundTrip(req) + } + hasBody := req.Body != nil && req.Body != http.NoBody if hasBody && req.GetBody == nil { // Non-replayable body: cannot retry safely. @@ -70,6 +87,19 @@ func (r *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return r.retryLoop(req, hasBody) } +// isIdempotentMethod reports whether an HTTP method is safe to retry +// automatically. Per RFC 7231, GET, HEAD, PUT, DELETE, OPTIONS and TRACE are +// idempotent; POST and PATCH are not. +func isIdempotentMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodPut, + http.MethodDelete, http.MethodOptions, http.MethodTrace: + return true + default: + return false + } +} + // retryLoop runs up to MaxAttempts, sleeping between retries. // On the final response, X-Retry-Attempts is set to the total attempt count // when more than one attempt was made. diff --git a/sonar/retry_test.go b/sonar/retry_test.go index c96c8c6..91601a1 100644 --- a/sonar/retry_test.go +++ b/sonar/retry_test.go @@ -151,7 +151,9 @@ func TestRetryRoundTripper_NonReplayableBodyNotRetried(t *testing.T) { opts: RetryOptions{MaxAttempts: 4, RetryableStatusCodes: []int{503}}, } - req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", io.NopCloser(bytes.NewReader([]byte("body")))) + // Use an idempotent method so the non-replayable body is the reason the + // request is not retried, not the method. + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPut, "http://example.com", io.NopCloser(bytes.NewReader([]byte("body")))) // GetBody is nil because we used io.NopCloser directly, not bytes.NewReader via http.NewRequest. require.Nil(t, req.GetBody) @@ -179,7 +181,8 @@ func TestRetryRoundTripper_ReplayableBodyRetried(t *testing.T) { } body := bytes.NewReader([]byte(`{"key":"value"}`)) - req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", body) + // PUT is idempotent and has a replayable body, so it is retried by default. + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPut, "http://example.com", body) // http.NewRequest with bytes.Reader sets GetBody automatically. require.NotNil(t, req.GetBody) @@ -189,6 +192,82 @@ func TestRetryRoundTripper_ReplayableBodyRetried(t *testing.T) { assert.Equal(t, 2, transport.calls) } +func TestRetryRoundTripper_NonIdempotentNotRetriedByDefault(t *testing.T) { + t.Parallel() + + transport := &countingTransport{ + responses: []*http.Response{ + makeResponse(http.StatusServiceUnavailable), + makeResponse(http.StatusOK), + }, + } + rt := &retryRoundTripper{ + base: transport, + opts: RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Millisecond, + MaxDelay: 5 * time.Millisecond, + RetryableStatusCodes: []int{503}, + }, + } + + body := bytes.NewReader([]byte(`{"key":"value"}`)) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", body) + require.NotNil(t, req.GetBody) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + // POST is not retried by default even with a replayable body and a retryable status. + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, 1, transport.calls) +} + +func TestRetryRoundTripper_NonIdempotentRetriedWhenOptedIn(t *testing.T) { + t.Parallel() + + transport := &countingTransport{ + responses: []*http.Response{ + makeResponse(http.StatusServiceUnavailable), + makeResponse(http.StatusOK), + }, + } + rt := &retryRoundTripper{ + base: transport, + opts: RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Millisecond, + MaxDelay: 5 * time.Millisecond, + RetryableStatusCodes: []int{503}, + RetryNonIdempotent: true, + }, + } + + body := bytes.NewReader([]byte(`{"key":"value"}`)) + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", body) + require.NotNil(t, req.GetBody) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 2, transport.calls) +} + +func TestIsIdempotentMethod(t *testing.T) { + t.Parallel() + + idempotent := []string{ + http.MethodGet, http.MethodHead, http.MethodPut, + http.MethodDelete, http.MethodOptions, http.MethodTrace, + } + for _, method := range idempotent { + assert.True(t, isIdempotentMethod(method), "%s should be idempotent", method) + } + + for _, method := range []string{http.MethodPost, http.MethodPatch} { + assert.False(t, isIdempotentMethod(method), "%s should not be idempotent", method) + } +} + func TestWithRetry_ClientIntegration(t *testing.T) { callCount := 0 ts := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) {