From 7c22ff16f912c0b25b19f1053d3dd1ff89cbe871 Mon Sep 17 00:00:00 2001 From: Go Expert Coding agent Date: Fri, 13 Mar 2026 10:17:35 +0000 Subject: [PATCH 1/3] fix: make HTTP 429 retryable when no fallback model, respect Retry-After header MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a model returns HTTP 429 (Too Many Requests): - If fallback models are configured: skip to next fallback immediately (existing behavior preserved) - If no fallbacks are configured: retry the same model using the Retry-After header duration (capped at 60s), falling back to exponential backoff if the header is absent HTTP 500 was already handled as retryable via IsRetryableStatusCode — no change needed. New additions to pkg/modelerrors: - ExtractRetryAfter(err): extracts Retry-After duration from *anthropic.Error and *openai.Error responses - parseRetryAfterHeader(value): parses integer-seconds or HTTP-date Retry-After values - ClassifyModelError(err): unified classifier returning (retryable, rateLimited bool, retryAfter time.Duration) - MaxRetryAfterWait constant (60s cap) pkg/runtime/fallback.go changes: - Replace both IsRetryableModelError call sites with ClassifyModelError - Add handleModelError helper to avoid duplicating classification logic - 429 with fallbacks → break to next model (unchanged behavior) - 429 without fallbacks → sleep(min(retryAfter, 60s) or backoff), then continue Fixes #2095 Assisted-By: docker-agent --- pkg/modelerrors/modelerrors.go | 97 ++++++++++++++ pkg/modelerrors/modelerrors_test.go | 193 ++++++++++++++++++++++++++++ pkg/runtime/fallback.go | 137 ++++++++++++++------ pkg/runtime/fallback_test.go | 145 +++++++++++++++++++++ 4 files changed, 532 insertions(+), 40 deletions(-) diff --git a/pkg/modelerrors/modelerrors.go b/pkg/modelerrors/modelerrors.go index 3efd4d705..340b31c8a 100644 --- a/pkg/modelerrors/modelerrors.go +++ b/pkg/modelerrors/modelerrors.go @@ -11,11 +11,14 @@ import ( "log/slog" "math/rand" "net" + "net/http" "regexp" + "strconv" "strings" "time" "github.com/anthropics/anthropic-sdk-go" + openai "github.com/openai/openai-go/v3" "google.golang.org/genai" ) @@ -27,6 +30,14 @@ const ( backoffJitter = 0.1 ) +// maxRetryAfterWait caps how long we'll honor a Retry-After header to prevent +// a misbehaving server from blocking the agent for an unreasonable amount of time. +const maxRetryAfterWait = 60 * time.Second + +// MaxRetryAfterWait is the exported cap for Retry-After header values. +// See maxRetryAfterWait. +const MaxRetryAfterWait = maxRetryAfterWait + // Default fallback configuration. const ( // DefaultRetries is the default number of retries per model with exponential @@ -296,6 +307,92 @@ func IsRetryableModelError(err error) bool { return false } +// ExtractRetryAfter extracts the Retry-After duration from an HTTP error response. +// Works with Anthropic and OpenAI SDK error types that expose *http.Response. +// Returns 0 if no Retry-After header is present or the error type is unsupported. +func ExtractRetryAfter(err error) time.Duration { + var resp *http.Response + + if anthropicErr, ok := errors.AsType[*anthropic.Error](err); ok { + resp = anthropicErr.Response + } else if openaiErr, ok := errors.AsType[*openai.Error](err); ok { + resp = openaiErr.Response + } + + if resp == nil { + return 0 + } + + return parseRetryAfterHeader(resp.Header.Get("Retry-After")) +} + +// parseRetryAfterHeader parses the Retry-After header value. +// Supports both seconds (integer) and HTTP-date formats per RFC 7231 §7.1.3. +// Returns 0 if the value is empty, invalid, or results in a non-positive duration. +func parseRetryAfterHeader(value string) time.Duration { + if value == "" { + return 0 + } + // Try integer seconds first (most common for rate limits) + if seconds, err := strconv.Atoi(value); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + // Try HTTP-date format + if t, err := http.ParseTime(value); err == nil { + d := time.Until(t) + if d > 0 { + return d + } + } + return 0 +} + +// ClassifyModelError classifies an error for the retry/fallback decision. +// +// Returns: +// - retryable=true: retry the SAME model with backoff (5xx, timeouts) +// - rateLimited=true: it's a 429 error; caller decides retry vs fallback based on config +// - retryAfter: suggested wait from Retry-After header (only set when rateLimited=true) +// +// When rateLimited=true, retryable is always false — the caller is responsible for +// deciding whether to retry (when no fallback is configured) or skip to the next +// model (when fallbacks are available). +// +// IsRetryableModelError and IsRetryableStatusCode are kept unchanged for backward +// compatibility. This function is the authoritative classifier used by the retry loop. +func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time.Duration) { + if err == nil { + return false, false, 0 + } + + // Context cancellation and deadline are never retryable. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false, false, 0 + } + + // Context overflow errors are never retryable — retrying the same oversized + // payload will always fail. + if IsContextOverflowError(err) { + return false, false, 0 + } + + statusCode := ExtractHTTPStatusCode(err) + + // 429: rate limited — caller decides retry-vs-fallback based on config. + if statusCode == http.StatusTooManyRequests { + return false, true, ExtractRetryAfter(err) + } + + // Known retryable status codes (5xx, 408, 529). + if statusCode != 0 { + return IsRetryableStatusCode(statusCode), false, 0 + } + + // No structured status code — fall back to IsRetryableModelError for net.Error + // and message-pattern matching. + return IsRetryableModelError(err), false, 0 +} + // CalculateBackoff returns the backoff duration for a given attempt (0-indexed). // Uses exponential backoff with jitter. func CalculateBackoff(attempt int) time.Duration { diff --git a/pkg/modelerrors/modelerrors_test.go b/pkg/modelerrors/modelerrors_test.go index 056a30755..acb0ee2db 100644 --- a/pkg/modelerrors/modelerrors_test.go +++ b/pkg/modelerrors/modelerrors_test.go @@ -5,9 +5,13 @@ import ( "errors" "fmt" "net" + "net/http" + "net/http/httptest" "testing" "time" + "github.com/anthropics/anthropic-sdk-go" + openai "github.com/openai/openai-go/v3" "github.com/stretchr/testify/assert" ) @@ -278,3 +282,192 @@ func TestFormatError(t *testing.T) { assert.Equal(t, "authentication failed", FormatError(err)) }) } + +// makeAnthropicError creates an *anthropic.Error with the given status code and +// optional Retry-After header value. Used for testing ExtractRetryAfter. +func makeAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // anthropic.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody) + return &anthropic.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +// makeOpenAIError creates an *openai.Error with the given status code and +// optional Retry-After header value. Used for testing ExtractRetryAfter. +func makeOpenAIError(statusCode int, retryAfterValue string) *openai.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // openai.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody) + return &openai.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +func TestParseRetryAfterHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + expected time.Duration + }{ + {name: "empty", value: "", expected: 0}, + {name: "zero seconds", value: "0", expected: 0}, + {name: "negative seconds", value: "-1", expected: 0}, + {name: "invalid string", value: "foo", expected: 0}, + {name: "5 seconds", value: "5", expected: 5 * time.Second}, + {name: "30 seconds", value: "30", expected: 30 * time.Second}, + {name: "120 seconds", value: "120", expected: 120 * time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseRetryAfterHeader(tt.value) + assert.Equal(t, tt.expected, got, "parseRetryAfterHeader(%q)", tt.value) + }) + } + + t.Run("HTTP-date in the future", func(t *testing.T) { + t.Parallel() + // Use a time 10 seconds in the future + future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) + got := parseRetryAfterHeader(future) + assert.Greater(t, got, 0*time.Second, "should return positive duration for future HTTP-date") + assert.LessOrEqual(t, got, 11*time.Second, "should not exceed ~10s for near-future date") + }) + + t.Run("HTTP-date in the past", func(t *testing.T) { + t.Parallel() + past := time.Now().Add(-10 * time.Second).UTC().Format(http.TimeFormat) + got := parseRetryAfterHeader(past) + assert.Equal(t, 0*time.Second, got, "should return 0 for past HTTP-date") + }) +} + +func TestExtractRetryAfter(t *testing.T) { + t.Parallel() + + t.Run("nil error returns 0", func(t *testing.T) { + t.Parallel() + assert.Equal(t, time.Duration(0), ExtractRetryAfter(nil)) + }) + + t.Run("plain error returns 0", func(t *testing.T) { + t.Parallel() + assert.Equal(t, time.Duration(0), ExtractRetryAfter(errors.New("some error"))) + }) + + t.Run("anthropic error with Retry-After seconds", func(t *testing.T) { + t.Parallel() + err := makeAnthropicError(429, "15") + assert.Equal(t, 15*time.Second, ExtractRetryAfter(err)) + }) + + t.Run("anthropic error without Retry-After header", func(t *testing.T) { + t.Parallel() + err := makeAnthropicError(429, "") + assert.Equal(t, time.Duration(0), ExtractRetryAfter(err)) + }) + + t.Run("openai error with Retry-After seconds", func(t *testing.T) { + t.Parallel() + err := makeOpenAIError(429, "30") + assert.Equal(t, 30*time.Second, ExtractRetryAfter(err)) + }) + + t.Run("openai error without Retry-After header", func(t *testing.T) { + t.Parallel() + err := makeOpenAIError(429, "") + assert.Equal(t, time.Duration(0), ExtractRetryAfter(err)) + }) + + t.Run("wrapped anthropic error", func(t *testing.T) { + t.Parallel() + anthropicErr := makeAnthropicError(429, "5") + wrapped := fmt.Errorf("model failed: %w", anthropicErr) + assert.Equal(t, 5*time.Second, ExtractRetryAfter(wrapped)) + }) +} + +func TestClassifyModelError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantRetryable bool + wantRateLimited bool + wantRetryAfterGT time.Duration // retryAfter should be > this (0 means just checking it's >=0) + }{ + {name: "nil", err: nil, wantRetryable: false, wantRateLimited: false}, + {name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false}, + {name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false}, + {name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false}, + // 429 rate limit cases + {name: "429 message only", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true}, + {name: "429 anthropic error no header", err: makeAnthropicError(429, ""), wantRetryable: false, wantRateLimited: true}, + {name: "429 openai error no header", err: makeOpenAIError(429, ""), wantRetryable: false, wantRateLimited: true}, + {name: "500 openai error", err: makeOpenAIError(500, ""), wantRetryable: true, wantRateLimited: false}, + // Retryable server errors + {name: "500 message", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false}, + {name: "500 anthropic error", err: makeAnthropicError(500, ""), wantRetryable: true, wantRateLimited: false}, + {name: "502 bad gateway", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false}, + {name: "503 service unavailable", err: errors.New("503 service unavailable"), wantRetryable: true, wantRateLimited: false}, + {name: "504 gateway timeout", err: errors.New("504 gateway timeout"), wantRetryable: true, wantRateLimited: false}, + {name: "529 overloaded", err: makeAnthropicError(529, ""), wantRetryable: true, wantRateLimited: false}, + {name: "408 timeout", err: makeAnthropicError(408, ""), wantRetryable: true, wantRateLimited: false}, + // Non-retryable client errors + {name: "400 bad request", err: makeAnthropicError(400, ""), wantRetryable: false, wantRateLimited: false}, + {name: "401 unauthorized", err: makeAnthropicError(401, ""), wantRetryable: false, wantRateLimited: false}, + {name: "403 forbidden", err: makeAnthropicError(403, ""), wantRetryable: false, wantRateLimited: false}, + // Network errors + {name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + retryable, rateLimited, retryAfter := ClassifyModelError(tt.err) + assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch") + assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch") + assert.GreaterOrEqual(t, retryAfter, time.Duration(0), "retryAfter should never be negative") + }) + } + + t.Run("429 with Retry-After header propagated", func(t *testing.T) { + t.Parallel() + err := makeAnthropicError(429, "20") + retryable, rateLimited, retryAfter := ClassifyModelError(err) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 20*time.Second, retryAfter) + }) + + t.Run("429 openai with Retry-After header", func(t *testing.T) { + t.Parallel() + err := makeOpenAIError(429, "10") + retryable, rateLimited, retryAfter := ClassifyModelError(err) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 10*time.Second, retryAfter) + }) +} diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index b60c6deef..7d1cd86f7 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -212,12 +212,14 @@ func (r *LocalRuntime) tryModelWithFallback( var lastErr error primaryFailedWithNonRetryable := false + hasFallbacks := len(fallbackModels) > 0 for chainIdx := startIndex; chainIdx < len(modelChain); chainIdx++ { modelEntry := modelChain[chainIdx] // Each model in the chain gets (1 + retries) attempts for retryable errors. - // Non-retryable errors (429, 4xx) skip immediately to the next model. + // Non-retryable errors (429 with fallbacks, 4xx) skip immediately to the next model. + // 429 without fallbacks is retried directly on the same model. maxAttempts := 1 + fallbackRetries for attempt := range maxAttempts { @@ -270,28 +272,12 @@ func (r *LocalRuntime) tryModelWithFallback( return streamResult{}, nil, err } - // Check if error is retryable - if !modelerrors.IsRetryableModelError(err) { - slog.Error("Non-retryable error creating stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "error", err) - - // Track if primary failed with non-retryable error - if !modelEntry.isFallback { - primaryFailedWithNonRetryable = true - } - - // Skip to next model in chain + decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) + if decision == retryDecisionReturn { + return streamResult{}, nil, ctx.Err() + } else if decision == retryDecisionBreak { break } - - slog.Warn("Retryable error creating stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "max_attempts", maxAttempts, - "error", err) continue } @@ -306,27 +292,12 @@ func (r *LocalRuntime) tryModelWithFallback( return streamResult{}, nil, err } - // Check if stream error is retryable - if !modelerrors.IsRetryableModelError(err) { - slog.Error("Non-retryable error handling stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "error", err) - - // Track if primary failed with non-retryable error - if !modelEntry.isFallback { - primaryFailedWithNonRetryable = true - } - + decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) + if decision == retryDecisionReturn { + return streamResult{}, nil, ctx.Err() + } else if decision == retryDecisionBreak { break } - - slog.Warn("Retryable error handling stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "max_attempts", maxAttempts, - "error", err) continue } @@ -359,3 +330,89 @@ func (r *LocalRuntime) tryModelWithFallback( } return streamResult{}, nil, errors.New("all models failed with unknown error") } + +// retryDecision is the outcome of handleModelError. +type retryDecision int + +const ( + // retryDecisionContinue means retry the same model (backoff already applied). + retryDecisionContinue retryDecision = iota + // retryDecisionBreak means skip to the next model in the fallback chain. + retryDecisionBreak + // retryDecisionReturn means context was cancelled; return immediately. + retryDecisionReturn +) + +// handleModelError classifies err and decides what to do next: +// - retryDecisionReturn — context cancelled while sleeping; caller returns ctx.Err() +// - retryDecisionBreak — non-retryable error or 429 with fallbacks; skip to next model +// - retryDecisionContinue — retryable error or 429 without fallbacks; retry same model +// +// Side-effect: sets *primaryFailedWithNonRetryable when the primary model fails with a +// non-retryable (or rate-limited-with-fallbacks) error. +func (r *LocalRuntime) handleModelError( + ctx context.Context, + err error, + a *agent.Agent, + modelEntry modelWithFallback, + attempt int, + hasFallbacks bool, + primaryFailedWithNonRetryable *bool, +) retryDecision { + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(err) + + if rateLimited { + if hasFallbacks { + // Fallbacks available → skip to next model immediately (existing behaviour). + slog.Warn("Rate limited with fallbacks available, skipping to next model", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "retry_after", retryAfter) + if !modelEntry.isFallback { + *primaryFailedWithNonRetryable = true + } + return retryDecisionBreak + } + + // No fallbacks → retry same model after honouring Retry-After (or backoff). + waitDuration := retryAfter + if waitDuration <= 0 { + waitDuration = modelerrors.CalculateBackoff(attempt) + } else if waitDuration > modelerrors.MaxRetryAfterWait { + slog.Warn("Retry-After exceeds maximum, capping", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "retry_after", retryAfter, + "max", modelerrors.MaxRetryAfterWait) + waitDuration = modelerrors.MaxRetryAfterWait + } + slog.Warn("Rate limited without fallbacks, retrying with wait", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "attempt", attempt+1, + "wait", waitDuration, + "retry_after_from_header", retryAfter > 0) + if !modelerrors.SleepWithContext(ctx, waitDuration) { + return retryDecisionReturn + } + return retryDecisionContinue + } + + if !retryable { + slog.Error("Non-retryable error from model", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "error", err) + if !modelEntry.isFallback { + *primaryFailedWithNonRetryable = true + } + return retryDecisionBreak + } + + slog.Warn("Retryable error from model", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "attempt", attempt+1, + "error", err) + return retryDecisionContinue +} diff --git a/pkg/runtime/fallback_test.go b/pkg/runtime/fallback_test.go index 342f215e2..d6688bb06 100644 --- a/pkg/runtime/fallback_test.go +++ b/pkg/runtime/fallback_test.go @@ -468,3 +468,148 @@ func TestFallbackModelsClonedWithThinkingEnabled(t *testing.T) { "BaseConfig() should be called on fallback provider when thinking is enabled") }) } + +func TestFallback429WithFallbacksSkipsToNextModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary gets rate limited; with fallbacks configured it should skip immediately. + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + successStream := newStreamBuilder(). + AddContent("Success from fallback"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(5), // many retries — 429 should NOT use them + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 With Fallback Skip Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success from fallback" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content from fallback") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — 429 with fallbacks should skip immediately") + }) +} + +func TestFallback429WithoutFallbacksRetriesSameModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary gets rate limited; with no fallbacks configured it should retry. + successStream := newStreamBuilder(). + AddContent("Success after rate limit"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 2, // fail twice with 429, then succeed + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models configured + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 No Fallback Retry Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after rate limit" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after rate limit retries") + assert.Equal(t, 3, primary.callCount, "primary should be called 3 times: 2 failures + 1 success") + }) +} + +func TestFallback429WithoutFallbacksExhaustsRetries(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary always returns 429, no fallbacks — should fail after all retries. + primary := &failingProvider{ + id: "primary/always-rate-limited", + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(2), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 No Fallback Exhaust Test" + + var gotError bool + for ev := range rt.RunStream(t.Context(), sess) { + if _, ok := ev.(*ErrorEvent); ok { + gotError = true + } + } + assert.True(t, gotError, "should receive an error when all retries exhausted") + }) +} + +func TestFallback500RetryableWithBackoff(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary returns 500 (retryable), no fallbacks — should retry with backoff. + successStream := newStreamBuilder(). + AddContent("Success after 500"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/server-error", + failCount: 1, + err: errors.New("500 internal server error"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(2), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "500 Retry Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after 500" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after 500 retry") + assert.Equal(t, 2, primary.callCount, "primary should be called twice: 1 failure + 1 success") + }) +} From 3d44b7cd409fe4cff72407d3db8b47829dc07e45 Mon Sep 17 00:00:00 2001 From: Go Expert Coding agent Date: Fri, 13 Mar 2026 11:07:50 +0000 Subject: [PATCH 2/3] refactor: use typed StatusError for retry metadata, providers wrap errors at Recv() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback from rumpl and simonferquel on PR #2096. The previous approach extracted Retry-After metadata in a dispatcher function that type-switched on provider SDK error types (`*anthropic.Error`, `*openai.Error`). This kept all providers coupled to the modelerrors package and duplicated parsing logic. New design: 1. `modelerrors.StatusError` — a thin typed wrapper carrying: - StatusCode int - RetryAfter time.Duration (parsed from Retry-After header) - Err error (original SDK error, accessible via Unwrap) 2. `modelerrors.WrapHTTPError(statusCode, resp, err)` — wraps any provider error in StatusError; parses Retry-After from resp.Header if non-nil. 3. `modelerrors.ParseRetryAfterHeader(value)` — shared RFC 7231 §7.1.3 parser (integer seconds + HTTP-date). No SDK imports. 4. `modelerrors.ClassifyModelError(err)` — single-arg; uses errors.As to find *StatusError in the chain; falls back to message/pattern matching for providers not yet wrapped (e.g. Bedrock). 5. Three new wrap.go files (one per provider family), each a single function: - pkg/model/provider/anthropic/wrap.go — wrapAnthropicError(*anthropic.Error) - pkg/model/provider/oaistream/wrap.go — WrapOpenAIError(*openai.Error) (exported) - pkg/model/provider/gemini/wrap.go — wrapGeminiError(*genai.APIError) 6. Wrap sites — one-line change at Recv() in each adapter: - anthropic/adapter.go, beta_adapter.go - oaistream/adapter.go - openai/response_stream.go (reuses oaistream.WrapOpenAIError) - gemini/adapter.go No changes to fallback.go logic — ClassifyModelError signature stays `(err error)` and retry/rate-limit decisions are unchanged. Assisted-By: docker-agent --- pkg/model/provider/anthropic/adapter.go | 2 +- pkg/model/provider/anthropic/beta_adapter.go | 2 +- pkg/model/provider/anthropic/wrap.go | 23 +++ pkg/model/provider/anthropic/wrap_test.go | 105 ++++++++++ pkg/model/provider/gemini/adapter.go | 2 +- pkg/model/provider/gemini/wrap.go | 26 +++ pkg/model/provider/oaistream/adapter.go | 2 +- pkg/model/provider/oaistream/wrap.go | 24 +++ pkg/model/provider/oaistream/wrap_test.go | 103 ++++++++++ pkg/model/provider/openai/response_stream.go | 3 +- pkg/modelerrors/modelerrors.go | 101 ++++++---- pkg/modelerrors/modelerrors_test.go | 202 +++++++++---------- 12 files changed, 445 insertions(+), 150 deletions(-) create mode 100644 pkg/model/provider/anthropic/wrap.go create mode 100644 pkg/model/provider/anthropic/wrap_test.go create mode 100644 pkg/model/provider/gemini/wrap.go create mode 100644 pkg/model/provider/oaistream/wrap.go create mode 100644 pkg/model/provider/oaistream/wrap_test.go diff --git a/pkg/model/provider/anthropic/adapter.go b/pkg/model/provider/anthropic/adapter.go index 7133fdee5..d508366f1 100644 --- a/pkg/model/provider/anthropic/adapter.go +++ b/pkg/model/provider/anthropic/adapter.go @@ -70,7 +70,7 @@ func isContextLengthError(err error) bool { func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { ok, err := a.next() if !ok { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, wrapAnthropicError(err) } event := a.stream.Current() diff --git a/pkg/model/provider/anthropic/beta_adapter.go b/pkg/model/provider/anthropic/beta_adapter.go index ca884b72b..1f9b5780f 100644 --- a/pkg/model/provider/anthropic/beta_adapter.go +++ b/pkg/model/provider/anthropic/beta_adapter.go @@ -34,7 +34,7 @@ func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRaw func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) { ok, err := a.next() if !ok { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, wrapAnthropicError(err) } event := a.stream.Current() diff --git a/pkg/model/provider/anthropic/wrap.go b/pkg/model/provider/anthropic/wrap.go new file mode 100644 index 000000000..fc74eba5d --- /dev/null +++ b/pkg/model/provider/anthropic/wrap.go @@ -0,0 +1,23 @@ +package anthropic + +import ( + "errors" + + "github.com/anthropics/anthropic-sdk-go" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// wrapAnthropicError wraps an Anthropic SDK error in a *modelerrors.StatusError +// to carry HTTP status code and Retry-After metadata for the retry loop. +// Non-Anthropic errors (e.g. io.EOF, network errors) pass through unchanged. +func wrapAnthropicError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*anthropic.Error](err) + if !ok { + return err + } + return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err) +} diff --git a/pkg/model/provider/anthropic/wrap_test.go b/pkg/model/provider/anthropic/wrap_test.go new file mode 100644 index 000000000..25f27d4eb --- /dev/null +++ b/pkg/model/provider/anthropic/wrap_test.go @@ -0,0 +1,105 @@ +package anthropic + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// makeTestAnthropicError creates an *anthropic.Error with the given status code and +// optional Retry-After header value for testing. +func makeTestAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // anthropic.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody) + return &anthropic.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +func TestWrapAnthropicError(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + assert.NoError(t, wrapAnthropicError(nil)) + }) + + t.Run("non-anthropic error passes through unchanged", func(t *testing.T) { + t.Parallel() + orig := errors.New("some network error") + result := wrapAnthropicError(orig) + assert.Equal(t, orig, result) + var se *modelerrors.StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + // Original error still accessible + assert.ErrorIs(t, result, apiErr) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "20") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 20*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps with correct status code", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(500, "") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + }) + + t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "15") + result := wrapAnthropicError(apiErr) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 15*time.Second, retryAfter) + }) + + t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "5") + wrapped := fmt.Errorf("stream error: %w", wrapAnthropicError(apiErr)) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(wrapped) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 5*time.Second, retryAfter) + }) +} diff --git a/pkg/model/provider/gemini/adapter.go b/pkg/model/provider/gemini/adapter.go index e459b8e9d..f2cb64f0d 100644 --- a/pkg/model/provider/gemini/adapter.go +++ b/pkg/model/provider/gemini/adapter.go @@ -143,7 +143,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { } if res.err != nil { - return chat.MessageStreamResponse{}, res.err + return chat.MessageStreamResponse{}, wrapGeminiError(res.err) } // Build response diff --git a/pkg/model/provider/gemini/wrap.go b/pkg/model/provider/gemini/wrap.go new file mode 100644 index 000000000..ea23ee796 --- /dev/null +++ b/pkg/model/provider/gemini/wrap.go @@ -0,0 +1,26 @@ +package gemini + +import ( + "errors" + + "google.golang.org/genai" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// wrapGeminiError wraps a Gemini SDK error in a *modelerrors.StatusError +// to carry HTTP status code metadata for the retry loop. +// Gemini's *genai.APIError does not expose *http.Response, so no Retry-After +// header extraction is possible; the RetryAfter field will be zero. +// Non-Gemini errors (e.g. io.EOF, network errors) pass through unchanged. +func wrapGeminiError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*genai.APIError](err) + if !ok { + return err + } + // Pass nil for resp — Gemini doesn't expose *http.Response. + return modelerrors.WrapHTTPError(apiErr.Code, nil, err) +} diff --git a/pkg/model/provider/oaistream/adapter.go b/pkg/model/provider/oaistream/adapter.go index 3716c79f1..c3e12a03d 100644 --- a/pkg/model/provider/oaistream/adapter.go +++ b/pkg/model/provider/oaistream/adapter.go @@ -35,7 +35,7 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { if !a.stream.Next() { err := a.stream.Err() if err != nil { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, WrapOpenAIError(err) } return chat.MessageStreamResponse{}, io.EOF } diff --git a/pkg/model/provider/oaistream/wrap.go b/pkg/model/provider/oaistream/wrap.go new file mode 100644 index 000000000..de917233a --- /dev/null +++ b/pkg/model/provider/oaistream/wrap.go @@ -0,0 +1,24 @@ +package oaistream + +import ( + "errors" + + openaisdk "github.com/openai/openai-go/v3" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// WrapOpenAIError wraps an OpenAI SDK error in a *modelerrors.StatusError +// to carry HTTP status code and Retry-After metadata for the retry loop. +// Non-OpenAI errors (e.g. io.EOF, network errors) pass through unchanged. +// Exported so openai/response_stream.go can reuse it without duplication. +func WrapOpenAIError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*openaisdk.Error](err) + if !ok { + return err + } + return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err) +} diff --git a/pkg/model/provider/oaistream/wrap_test.go b/pkg/model/provider/oaistream/wrap_test.go new file mode 100644 index 000000000..eeb43608d --- /dev/null +++ b/pkg/model/provider/oaistream/wrap_test.go @@ -0,0 +1,103 @@ +package oaistream + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + openaisdk "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// makeTestOpenAIError creates an *openai.Error with the given status code and +// optional Retry-After header value for testing. +func makeTestOpenAIError(statusCode int, retryAfterValue string) *openaisdk.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // openai.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody) + return &openaisdk.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +func TestWrapOpenAIError(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + assert.NoError(t, WrapOpenAIError(nil)) + }) + + t.Run("non-openai error passes through unchanged", func(t *testing.T) { + t.Parallel() + orig := errors.New("some network error") + result := WrapOpenAIError(orig) + assert.Equal(t, orig, result) + var se *modelerrors.StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + // Original error still accessible + assert.ErrorIs(t, result, apiErr) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "30") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 30*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps with correct status code", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(500, "") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + }) + + t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "10") + result := WrapOpenAIError(apiErr) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 10*time.Second, retryAfter) + }) + + t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(500, "") + wrapped := fmt.Errorf("stream error: %w", WrapOpenAIError(apiErr)) + retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped) + assert.True(t, retryable) + assert.False(t, rateLimited) + }) +} diff --git a/pkg/model/provider/openai/response_stream.go b/pkg/model/provider/openai/response_stream.go index 7e9040024..065cba32c 100644 --- a/pkg/model/provider/openai/response_stream.go +++ b/pkg/model/provider/openai/response_stream.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3/responses" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider/oaistream" "github.com/docker/docker-agent/pkg/tools" ) @@ -33,7 +34,7 @@ func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamE func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) { if !a.stream.Next() { if err := a.stream.Err(); err != nil { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, oaistream.WrapOpenAIError(err) } return chat.MessageStreamResponse{}, io.EOF } diff --git a/pkg/modelerrors/modelerrors.go b/pkg/modelerrors/modelerrors.go index 340b31c8a..4283262e7 100644 --- a/pkg/modelerrors/modelerrors.go +++ b/pkg/modelerrors/modelerrors.go @@ -18,25 +18,58 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" - openai "github.com/openai/openai-go/v3" "google.golang.org/genai" ) -// Backoff configuration constants. +// Backoff and retry-after configuration constants. const ( backoffBaseDelay = 200 * time.Millisecond backoffMaxDelay = 2 * time.Second backoffFactor = 2.0 backoffJitter = 0.1 + + // MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent + // a misbehaving server from blocking the agent for an unreasonable amount of time. + MaxRetryAfterWait = 60 * time.Second ) -// maxRetryAfterWait caps how long we'll honor a Retry-After header to prevent -// a misbehaving server from blocking the agent for an unreasonable amount of time. -const maxRetryAfterWait = 60 * time.Second +// StatusError wraps an HTTP API error with structured metadata for retry decisions. +// Providers wrap SDK errors in this type so the retry loop can use errors.As +// to extract status code and Retry-After without importing provider-specific SDKs. +type StatusError struct { + // StatusCode is the HTTP status code from the provider's API response. + StatusCode int + // RetryAfter is the parsed Retry-After header duration. Zero if absent. + RetryAfter time.Duration + // Err is the original error from the provider SDK. + Err error +} + +func (e *StatusError) Error() string { + return e.Err.Error() +} -// MaxRetryAfterWait is the exported cap for Retry-After header values. -// See maxRetryAfterWait. -const MaxRetryAfterWait = maxRetryAfterWait +func (e *StatusError) Unwrap() error { + return e.Err +} + +// WrapHTTPError wraps err in a *StatusError carrying the HTTP status code and +// parsed Retry-After header from resp. Returns err unchanged if statusCode < 400 +// or err is nil. Pass resp=nil when no *http.Response is available. +func WrapHTTPError(statusCode int, resp *http.Response, err error) error { + if err == nil || statusCode < 400 { + return err + } + var retryAfter time.Duration + if resp != nil { + retryAfter = ParseRetryAfterHeader(resp.Header.Get("Retry-After")) + } + return &StatusError{ + StatusCode: statusCode, + RetryAfter: retryAfter, + Err: err, + } +} // Default fallback configuration. const ( @@ -307,29 +340,10 @@ func IsRetryableModelError(err error) bool { return false } -// ExtractRetryAfter extracts the Retry-After duration from an HTTP error response. -// Works with Anthropic and OpenAI SDK error types that expose *http.Response. -// Returns 0 if no Retry-After header is present or the error type is unsupported. -func ExtractRetryAfter(err error) time.Duration { - var resp *http.Response - - if anthropicErr, ok := errors.AsType[*anthropic.Error](err); ok { - resp = anthropicErr.Response - } else if openaiErr, ok := errors.AsType[*openai.Error](err); ok { - resp = openaiErr.Response - } - - if resp == nil { - return 0 - } - - return parseRetryAfterHeader(resp.Header.Get("Retry-After")) -} - -// parseRetryAfterHeader parses the Retry-After header value. +// ParseRetryAfterHeader parses a Retry-After header value. // Supports both seconds (integer) and HTTP-date formats per RFC 7231 §7.1.3. // Returns 0 if the value is empty, invalid, or results in a non-positive duration. -func parseRetryAfterHeader(value string) time.Duration { +func ParseRetryAfterHeader(value string) time.Duration { if value == "" { return 0 } @@ -349,17 +363,18 @@ func parseRetryAfterHeader(value string) time.Duration { // ClassifyModelError classifies an error for the retry/fallback decision. // +// If the error chain contains a *StatusError (wrapped by provider adapters), +// its StatusCode and RetryAfter fields are used directly — no provider-specific +// imports needed in the caller. +// // Returns: // - retryable=true: retry the SAME model with backoff (5xx, timeouts) // - rateLimited=true: it's a 429 error; caller decides retry vs fallback based on config -// - retryAfter: suggested wait from Retry-After header (only set when rateLimited=true) +// - retryAfter: Retry-After duration from the provider (only set for 429) // // When rateLimited=true, retryable is always false — the caller is responsible for // deciding whether to retry (when no fallback is configured) or skip to the next // model (when fallbacks are available). -// -// IsRetryableModelError and IsRetryableStatusCode are kept unchanged for backward -// compatibility. This function is the authoritative classifier used by the retry loop. func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time.Duration) { if err == nil { return false, false, 0 @@ -376,20 +391,24 @@ func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time return false, false, 0 } - statusCode := ExtractHTTPStatusCode(err) + // Primary path: typed StatusError wrapped by provider adapters. + var statusErr *StatusError + if errors.As(err, &statusErr) { + if statusErr.StatusCode == http.StatusTooManyRequests { + return false, true, statusErr.RetryAfter + } + return IsRetryableStatusCode(statusErr.StatusCode), false, 0 + } - // 429: rate limited — caller decides retry-vs-fallback based on config. + // Fallback: providers that don't yet wrap (e.g. Bedrock), or non-provider + // errors (network, pattern matching). + statusCode := ExtractHTTPStatusCode(err) if statusCode == http.StatusTooManyRequests { - return false, true, ExtractRetryAfter(err) + return false, true, 0 // No Retry-After without StatusError } - - // Known retryable status codes (5xx, 408, 529). if statusCode != 0 { return IsRetryableStatusCode(statusCode), false, 0 } - - // No structured status code — fall back to IsRetryableModelError for net.Error - // and message-pattern matching. return IsRetryableModelError(err), false, 0 } diff --git a/pkg/modelerrors/modelerrors_test.go b/pkg/modelerrors/modelerrors_test.go index acb0ee2db..361c89f34 100644 --- a/pkg/modelerrors/modelerrors_test.go +++ b/pkg/modelerrors/modelerrors_test.go @@ -6,13 +6,11 @@ import ( "fmt" "net" "net/http" - "net/http/httptest" "testing" "time" - "github.com/anthropics/anthropic-sdk-go" - openai "github.com/openai/openai-go/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // mockTimeoutError implements net.Error with Timeout() = true @@ -283,44 +281,6 @@ func TestFormatError(t *testing.T) { }) } -// makeAnthropicError creates an *anthropic.Error with the given status code and -// optional Retry-After header value. Used for testing ExtractRetryAfter. -func makeAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error { - header := http.Header{} - if retryAfterValue != "" { - header.Set("Retry-After", retryAfterValue) - } - resp := httptest.NewRecorder().Result() - resp.StatusCode = statusCode - resp.Header = header - // anthropic.Error.Error() dereferences Request, so we must provide a non-nil one. - req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody) - return &anthropic.Error{ - StatusCode: statusCode, - Response: resp, - Request: req, - } -} - -// makeOpenAIError creates an *openai.Error with the given status code and -// optional Retry-After header value. Used for testing ExtractRetryAfter. -func makeOpenAIError(statusCode int, retryAfterValue string) *openai.Error { - header := http.Header{} - if retryAfterValue != "" { - header.Set("Retry-After", retryAfterValue) - } - resp := httptest.NewRecorder().Result() - resp.StatusCode = statusCode - resp.Header = header - // openai.Error.Error() dereferences Request, so we must provide a non-nil one. - req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody) - return &openai.Error{ - StatusCode: statusCode, - Response: resp, - Request: req, - } -} - func TestParseRetryAfterHeader(t *testing.T) { t.Parallel() @@ -341,16 +301,15 @@ func TestParseRetryAfterHeader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := parseRetryAfterHeader(tt.value) - assert.Equal(t, tt.expected, got, "parseRetryAfterHeader(%q)", tt.value) + got := ParseRetryAfterHeader(tt.value) + assert.Equal(t, tt.expected, got, "ParseRetryAfterHeader(%q)", tt.value) }) } t.Run("HTTP-date in the future", func(t *testing.T) { t.Parallel() - // Use a time 10 seconds in the future future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) - got := parseRetryAfterHeader(future) + got := ParseRetryAfterHeader(future) assert.Greater(t, got, 0*time.Second, "should return positive duration for future HTTP-date") assert.LessOrEqual(t, got, 11*time.Second, "should not exceed ~10s for near-future date") }) @@ -358,53 +317,95 @@ func TestParseRetryAfterHeader(t *testing.T) { t.Run("HTTP-date in the past", func(t *testing.T) { t.Parallel() past := time.Now().Add(-10 * time.Second).UTC().Format(http.TimeFormat) - got := parseRetryAfterHeader(past) + got := ParseRetryAfterHeader(past) assert.Equal(t, 0*time.Second, got, "should return 0 for past HTTP-date") }) } -func TestExtractRetryAfter(t *testing.T) { +func TestStatusError(t *testing.T) { t.Parallel() - t.Run("nil error returns 0", func(t *testing.T) { + t.Run("Error() delegates to wrapped error", func(t *testing.T) { + t.Parallel() + underlying := errors.New("rate limit exceeded") + se := &StatusError{StatusCode: 429, Err: underlying} + assert.Equal(t, underlying.Error(), se.Error()) + }) + + t.Run("Unwrap() allows errors.Is traversal", func(t *testing.T) { t.Parallel() - assert.Equal(t, time.Duration(0), ExtractRetryAfter(nil)) + sentinel := errors.New("sentinel") + se := &StatusError{StatusCode: 500, Err: sentinel} + assert.ErrorIs(t, se, sentinel) }) - t.Run("plain error returns 0", func(t *testing.T) { + t.Run("errors.As finds StatusError in chain", func(t *testing.T) { t.Parallel() - assert.Equal(t, time.Duration(0), ExtractRetryAfter(errors.New("some error"))) + se := &StatusError{StatusCode: 429, RetryAfter: 10 * time.Second, Err: errors.New("rate limited")} + wrapped := fmt.Errorf("outer: %w", se) + var found *StatusError + require.ErrorAs(t, wrapped, &found) + assert.Equal(t, 429, found.StatusCode) + assert.Equal(t, 10*time.Second, found.RetryAfter) + }) +} + +func TestWrapHTTPError(t *testing.T) { + t.Parallel() + + t.Run("nil error returns nil", func(t *testing.T) { + t.Parallel() + require.NoError(t, WrapHTTPError(429, nil, nil)) }) - t.Run("anthropic error with Retry-After seconds", func(t *testing.T) { + t.Run("status < 400 passes through unchanged", func(t *testing.T) { t.Parallel() - err := makeAnthropicError(429, "15") - assert.Equal(t, 15*time.Second, ExtractRetryAfter(err)) + origErr := errors.New("original") + result := WrapHTTPError(200, nil, origErr) + assert.Equal(t, origErr, result) + var se *StatusError + assert.NotErrorAs(t, result, &se) }) - t.Run("anthropic error without Retry-After header", func(t *testing.T) { + t.Run("429 without response has zero RetryAfter", func(t *testing.T) { t.Parallel() - err := makeAnthropicError(429, "") - assert.Equal(t, time.Duration(0), ExtractRetryAfter(err)) + origErr := errors.New("rate limited") + result := WrapHTTPError(429, nil, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + assert.Equal(t, origErr.Error(), se.Error()) }) - t.Run("openai error with Retry-After seconds", func(t *testing.T) { + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { t.Parallel() - err := makeOpenAIError(429, "30") - assert.Equal(t, 30*time.Second, ExtractRetryAfter(err)) + origErr := errors.New("rate limited") + respHeader := http.Header{} + respHeader.Set("Retry-After", "30") + resp := &http.Response{Header: respHeader} + result := WrapHTTPError(429, resp, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 30*time.Second, se.RetryAfter) }) - t.Run("openai error without Retry-After header", func(t *testing.T) { + t.Run("500 wraps correctly", func(t *testing.T) { t.Parallel() - err := makeOpenAIError(429, "") - assert.Equal(t, time.Duration(0), ExtractRetryAfter(err)) + origErr := errors.New("internal server error") + result := WrapHTTPError(500, nil, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) }) - t.Run("wrapped anthropic error", func(t *testing.T) { + t.Run("original error still accessible via Unwrap", func(t *testing.T) { t.Parallel() - anthropicErr := makeAnthropicError(429, "5") - wrapped := fmt.Errorf("model failed: %w", anthropicErr) - assert.Equal(t, 5*time.Second, ExtractRetryAfter(wrapped)) + sentinel := errors.New("sentinel") + result := WrapHTTPError(429, nil, sentinel) + assert.ErrorIs(t, result, sentinel) }) } @@ -412,33 +413,34 @@ func TestClassifyModelError(t *testing.T) { t.Parallel() tests := []struct { - name string - err error - wantRetryable bool - wantRateLimited bool - wantRetryAfterGT time.Duration // retryAfter should be > this (0 means just checking it's >=0) + name string + err error + wantRetryable bool + wantRateLimited bool + wantRetryAfter time.Duration }{ {name: "nil", err: nil, wantRetryable: false, wantRateLimited: false}, {name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false}, {name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false}, {name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false}, - // 429 rate limit cases - {name: "429 message only", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true}, - {name: "429 anthropic error no header", err: makeAnthropicError(429, ""), wantRetryable: false, wantRateLimited: true}, - {name: "429 openai error no header", err: makeOpenAIError(429, ""), wantRetryable: false, wantRateLimited: true}, - {name: "500 openai error", err: makeOpenAIError(500, ""), wantRetryable: true, wantRateLimited: false}, - // Retryable server errors - {name: "500 message", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false}, - {name: "500 anthropic error", err: makeAnthropicError(500, ""), wantRetryable: true, wantRateLimited: false}, - {name: "502 bad gateway", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false}, - {name: "503 service unavailable", err: errors.New("503 service unavailable"), wantRetryable: true, wantRateLimited: false}, - {name: "504 gateway timeout", err: errors.New("504 gateway timeout"), wantRetryable: true, wantRateLimited: false}, - {name: "529 overloaded", err: makeAnthropicError(529, ""), wantRetryable: true, wantRateLimited: false}, - {name: "408 timeout", err: makeAnthropicError(408, ""), wantRetryable: true, wantRateLimited: false}, - // Non-retryable client errors - {name: "400 bad request", err: makeAnthropicError(400, ""), wantRetryable: false, wantRateLimited: false}, - {name: "401 unauthorized", err: makeAnthropicError(401, ""), wantRetryable: false, wantRateLimited: false}, - {name: "403 forbidden", err: makeAnthropicError(403, ""), wantRetryable: false, wantRateLimited: false}, + // 429 without StatusError (fallback message-pattern path) + {name: "429 message fallback, no RetryAfter", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, + // 429 via StatusError (primary path) — no Retry-After + {name: "429 StatusError no retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 0, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, + // 429 via StatusError with Retry-After from response header + {name: "429 StatusError with retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 20 * time.Second, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 20 * time.Second}, + // Retryable status codes via StatusError + {name: "500 StatusError", err: &StatusError{StatusCode: 500, Err: errors.New("internal server error")}, wantRetryable: true, wantRateLimited: false}, + {name: "529 StatusError", err: &StatusError{StatusCode: 529, Err: errors.New("overloaded")}, wantRetryable: true, wantRateLimited: false}, + {name: "408 StatusError", err: &StatusError{StatusCode: 408, Err: errors.New("timeout")}, wantRetryable: true, wantRateLimited: false}, + // Retryable fallback path (message-based) + {name: "500 message fallback", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false}, + {name: "502 message fallback", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false}, + // Non-retryable via StatusError + {name: "401 StatusError", err: &StatusError{StatusCode: 401, Err: errors.New("unauthorized")}, wantRetryable: false, wantRateLimited: false}, + {name: "403 StatusError", err: &StatusError{StatusCode: 403, Err: errors.New("forbidden")}, wantRetryable: false, wantRateLimited: false}, + // Non-retryable fallback + {name: "401 message fallback", err: errors.New("401 unauthorized"), wantRetryable: false, wantRateLimited: false}, // Network errors {name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false}, } @@ -446,28 +448,20 @@ func TestClassifyModelError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - retryable, rateLimited, retryAfter := ClassifyModelError(tt.err) + retryable, rateLimited, retryAfterOut := ClassifyModelError(tt.err) assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch") assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch") - assert.GreaterOrEqual(t, retryAfter, time.Duration(0), "retryAfter should never be negative") + assert.Equal(t, tt.wantRetryAfter, retryAfterOut, "retryAfter mismatch") }) } - t.Run("429 with Retry-After header propagated", func(t *testing.T) { - t.Parallel() - err := makeAnthropicError(429, "20") - retryable, rateLimited, retryAfter := ClassifyModelError(err) - assert.False(t, retryable) - assert.True(t, rateLimited) - assert.Equal(t, 20*time.Second, retryAfter) - }) - - t.Run("429 openai with Retry-After header", func(t *testing.T) { + t.Run("wrapped StatusError is found by errors.As", func(t *testing.T) { t.Parallel() - err := makeOpenAIError(429, "10") - retryable, rateLimited, retryAfter := ClassifyModelError(err) + statusErr := &StatusError{StatusCode: 429, RetryAfter: 15 * time.Second, Err: errors.New("rate limited")} + wrapped := fmt.Errorf("model failed: %w", statusErr) + retryable, rateLimited, retryAfterOut := ClassifyModelError(wrapped) assert.False(t, retryable) assert.True(t, rateLimited) - assert.Equal(t, 10*time.Second, retryAfter) + assert.Equal(t, 15*time.Second, retryAfterOut) }) } From e529f3dfaed079700332d697f7704c4328a23898 Mon Sep 17 00:00:00 2001 From: Go Expert Coding agent Date: Fri, 13 Mar 2026 13:36:00 +0000 Subject: [PATCH 3/3] fix: gate 429 retry behavior behind WithRetryOnRateLimit() opt-in option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses simonferquel's review on PR #2096 — the sleep-and-retry path for HTTP 429 is now opt-in via runtime.WithRetryOnRateLimit(), disabled by default. Existing behavior (fall through to next model or fail) is preserved for all current callers (CLI, API server, A2A, MCP). Decision table for rateLimited=true in handleModelError: retryOnRateLimit=false (default), any fallbacks → break (today's behavior) retryOnRateLimit=true, has fallbacks → break (fallbacks take priority) retryOnRateLimit=true, no fallbacks → sleep(Retry-After or backoff), retry Changes: pkg/runtime/runtime.go: add retryOnRateLimit bool field to LocalRuntime; add WithRetryOnRateLimit() Opt following existing patterns pkg/runtime/fallback.go: gate the 429 sleep-and-retry path behind r.retryOnRateLimit pkg/runtime/fallback_test.go: update existing 429-no-fallback tests to opt in; add 3 gate tests covering all branches of the decision table Assisted-By: docker-agent --- pkg/runtime/fallback.go | 19 ++++-- pkg/runtime/fallback_test.go | 125 +++++++++++++++++++++++++++++++++-- pkg/runtime/runtime.go | 23 +++++++ 3 files changed, 156 insertions(+), 11 deletions(-) diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index 7d1cd86f7..a9caa0456 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -362,19 +362,23 @@ func (r *LocalRuntime) handleModelError( retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(err) if rateLimited { - if hasFallbacks { - // Fallbacks available → skip to next model immediately (existing behaviour). - slog.Warn("Rate limited with fallbacks available, skipping to next model", + // Gate: only retry on 429 if opt-in is enabled AND no fallbacks exist. + // Default behavior (retryOnRateLimit=false) treats 429 as non-retryable, + // identical to today's behavior before this feature was added. + if !r.retryOnRateLimit || hasFallbacks { + slog.Warn("Rate limited, treating as non-retryable", "agent", a.Name(), "model", modelEntry.provider.ID(), - "retry_after", retryAfter) + "retry_on_rate_limit_enabled", r.retryOnRateLimit, + "has_fallbacks", hasFallbacks, + "error", err) if !modelEntry.isFallback { *primaryFailedWithNonRetryable = true } return retryDecisionBreak } - // No fallbacks → retry same model after honouring Retry-After (or backoff). + // Opt-in enabled, no fallbacks → retry same model after honouring Retry-After (or backoff). waitDuration := retryAfter if waitDuration <= 0 { waitDuration = modelerrors.CalculateBackoff(attempt) @@ -386,12 +390,13 @@ func (r *LocalRuntime) handleModelError( "max", modelerrors.MaxRetryAfterWait) waitDuration = modelerrors.MaxRetryAfterWait } - slog.Warn("Rate limited without fallbacks, retrying with wait", + slog.Warn("Rate limited, retrying (opt-in enabled)", "agent", a.Name(), "model", modelEntry.provider.ID(), "attempt", attempt+1, "wait", waitDuration, - "retry_after_from_header", retryAfter > 0) + "retry_after_from_header", retryAfter > 0, + "error", err) if !modelerrors.SleepWithContext(ctx, waitDuration) { return retryDecisionReturn } diff --git a/pkg/runtime/fallback_test.go b/pkg/runtime/fallback_test.go index d6688bb06..614642ca4 100644 --- a/pkg/runtime/fallback_test.go +++ b/pkg/runtime/fallback_test.go @@ -509,7 +509,8 @@ func TestFallback429WithFallbacksSkipsToNextModel(t *testing.T) { func TestFallback429WithoutFallbacksRetriesSameModel(t *testing.T) { synctest.Test(t, func(t *testing.T) { - // Primary gets rate limited; with no fallbacks configured it should retry. + // Primary gets rate limited; with no fallbacks configured it should retry + // when the opt-in is enabled. successStream := newStreamBuilder(). AddContent("Success after rate limit"). AddStopWithUsage(10, 5). @@ -528,7 +529,7 @@ func TestFallback429WithoutFallbacksRetriesSameModel(t *testing.T) { ) tm := team.New(team.WithAgents(root)) - rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) require.NoError(t, err) sess := session.New(session.WithUserMessage("test")) @@ -547,7 +548,7 @@ func TestFallback429WithoutFallbacksRetriesSameModel(t *testing.T) { func TestFallback429WithoutFallbacksExhaustsRetries(t *testing.T) { synctest.Test(t, func(t *testing.T) { - // Primary always returns 429, no fallbacks — should fail after all retries. + // Primary always returns 429, no fallbacks, opt-in enabled — should fail after all retries. primary := &failingProvider{ id: "primary/always-rate-limited", err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), @@ -560,7 +561,7 @@ func TestFallback429WithoutFallbacksExhaustsRetries(t *testing.T) { ) tm := team.New(team.WithAgents(root)) - rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) require.NoError(t, err) sess := session.New(session.WithUserMessage("test")) @@ -613,3 +614,119 @@ func TestFallback500RetryableWithBackoff(t *testing.T) { assert.Equal(t, 2, primary.callCount, "primary should be called twice: 1 failure + 1 success") }) } + +// --- WithRetryOnRateLimit gate tests --- + +func TestRateLimitGate_DisabledNoFallbacks_FailsImmediately(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // With retryOnRateLimit=false (default) and no fallbacks, a 429 should + // be treated as non-retryable and fail immediately without any retry. + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models, no WithRetryOnRateLimit opt-in + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + // Note: WithRetryOnRateLimit() is NOT passed — default off + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Disabled Test" + + var gotError bool + for ev := range rt.RunStream(t.Context(), sess) { + if _, ok := ev.(*ErrorEvent); ok { + gotError = true + } + } + assert.True(t, gotError, "should fail immediately with an error") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — no retry without opt-in") + }) +} + +func TestRateLimitGate_EnabledNoFallbacks_RetriesSameModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // With retryOnRateLimit=true and no fallbacks, a 429 should retry + // the same model until it succeeds or retries are exhausted. + successStream := newStreamBuilder(). + AddContent("Success after rate limit"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 2, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Enabled No Fallbacks Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after rate limit" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after retrying") + assert.Equal(t, 3, primary.callCount, "primary should be called 3 times: 2 failures + 1 success") + }) +} + +func TestRateLimitGate_EnabledWithFallbacks_SkipsToFallback(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Even with retryOnRateLimit=true, when fallbacks are configured + // a 429 should skip to the fallback immediately (fallbacks take priority). + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + successStream := newStreamBuilder(). + AddContent("Success from fallback"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(5), + ) + + tm := team.New(team.WithAgents(root)) + // opt-in is enabled, but fallbacks are present → should still skip to fallback + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Enabled With Fallbacks Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success from fallback" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content from fallback") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — fallbacks take priority over retry") + }) +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 82487b202..fd7e35e7a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -199,6 +199,12 @@ type LocalRuntime struct { env []string // Environment variables for hooks execution modelSwitcherCfg *ModelSwitcherConfig + // retryOnRateLimit enables retry-with-backoff for HTTP 429 (rate limit) errors + // when no fallback models are configured. When false (default), 429 errors are + // treated as non-retryable and immediately fail or skip to the next model. + // Library consumers can enable this via WithRetryOnRateLimit(). + retryOnRateLimit bool + // fallbackCooldowns tracks per-agent cooldown state for sticky fallback behavior fallbackCooldowns map[string]*fallbackCooldownState fallbackCooldownsMux sync.RWMutex @@ -264,6 +270,23 @@ func WithEnv(env []string) Opt { } } +// WithRetryOnRateLimit enables automatic retry with backoff for HTTP 429 (rate limit) +// errors when no fallback models are available. When enabled, the runtime will honor +// the Retry-After header from the provider's response to determine wait time before +// retrying, falling back to exponential backoff if the header is absent. +// +// This is off by default. It is intended for library consumers that run agents +// programmatically and prefer to wait for rate limits to clear rather than fail +// immediately. +// +// When fallback models are configured, 429 errors always skip to the next model +// regardless of this setting. +func WithRetryOnRateLimit() Opt { + return func(r *LocalRuntime) { + r.retryOnRateLimit = true + } +} + // NewLocalRuntime creates a new LocalRuntime without the persistence wrapper. // This is useful for testing or when persistence is handled externally. func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {