diff --git a/pkg/model/provider/anthropic/adapter.go b/pkg/model/provider/anthropic/adapter.go index 7133fdee5..d508366f1 100644 --- a/pkg/model/provider/anthropic/adapter.go +++ b/pkg/model/provider/anthropic/adapter.go @@ -70,7 +70,7 @@ func isContextLengthError(err error) bool { func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { ok, err := a.next() if !ok { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, wrapAnthropicError(err) } event := a.stream.Current() diff --git a/pkg/model/provider/anthropic/beta_adapter.go b/pkg/model/provider/anthropic/beta_adapter.go index ca884b72b..1f9b5780f 100644 --- a/pkg/model/provider/anthropic/beta_adapter.go +++ b/pkg/model/provider/anthropic/beta_adapter.go @@ -34,7 +34,7 @@ func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRaw func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) { ok, err := a.next() if !ok { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, wrapAnthropicError(err) } event := a.stream.Current() diff --git a/pkg/model/provider/anthropic/wrap.go b/pkg/model/provider/anthropic/wrap.go new file mode 100644 index 000000000..fc74eba5d --- /dev/null +++ b/pkg/model/provider/anthropic/wrap.go @@ -0,0 +1,23 @@ +package anthropic + +import ( + "errors" + + "github.com/anthropics/anthropic-sdk-go" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// wrapAnthropicError wraps an Anthropic SDK error in a *modelerrors.StatusError +// to carry HTTP status code and Retry-After metadata for the retry loop. +// Non-Anthropic errors (e.g. io.EOF, network errors) pass through unchanged. +func wrapAnthropicError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*anthropic.Error](err) + if !ok { + return err + } + return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err) +} diff --git a/pkg/model/provider/anthropic/wrap_test.go b/pkg/model/provider/anthropic/wrap_test.go new file mode 100644 index 000000000..25f27d4eb --- /dev/null +++ b/pkg/model/provider/anthropic/wrap_test.go @@ -0,0 +1,105 @@ +package anthropic + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// makeTestAnthropicError creates an *anthropic.Error with the given status code and +// optional Retry-After header value for testing. +func makeTestAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // anthropic.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody) + return &anthropic.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +func TestWrapAnthropicError(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + assert.NoError(t, wrapAnthropicError(nil)) + }) + + t.Run("non-anthropic error passes through unchanged", func(t *testing.T) { + t.Parallel() + orig := errors.New("some network error") + result := wrapAnthropicError(orig) + assert.Equal(t, orig, result) + var se *modelerrors.StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + // Original error still accessible + assert.ErrorIs(t, result, apiErr) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "20") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 20*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps with correct status code", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(500, "") + result := wrapAnthropicError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + }) + + t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "15") + result := wrapAnthropicError(apiErr) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 15*time.Second, retryAfter) + }) + + t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) { + t.Parallel() + apiErr := makeTestAnthropicError(429, "5") + wrapped := fmt.Errorf("stream error: %w", wrapAnthropicError(apiErr)) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(wrapped) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 5*time.Second, retryAfter) + }) +} diff --git a/pkg/model/provider/gemini/adapter.go b/pkg/model/provider/gemini/adapter.go index e459b8e9d..f2cb64f0d 100644 --- a/pkg/model/provider/gemini/adapter.go +++ b/pkg/model/provider/gemini/adapter.go @@ -143,7 +143,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { } if res.err != nil { - return chat.MessageStreamResponse{}, res.err + return chat.MessageStreamResponse{}, wrapGeminiError(res.err) } // Build response diff --git a/pkg/model/provider/gemini/wrap.go b/pkg/model/provider/gemini/wrap.go new file mode 100644 index 000000000..ea23ee796 --- /dev/null +++ b/pkg/model/provider/gemini/wrap.go @@ -0,0 +1,26 @@ +package gemini + +import ( + "errors" + + "google.golang.org/genai" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// wrapGeminiError wraps a Gemini SDK error in a *modelerrors.StatusError +// to carry HTTP status code metadata for the retry loop. +// Gemini's *genai.APIError does not expose *http.Response, so no Retry-After +// header extraction is possible; the RetryAfter field will be zero. +// Non-Gemini errors (e.g. io.EOF, network errors) pass through unchanged. +func wrapGeminiError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*genai.APIError](err) + if !ok { + return err + } + // Pass nil for resp — Gemini doesn't expose *http.Response. + return modelerrors.WrapHTTPError(apiErr.Code, nil, err) +} diff --git a/pkg/model/provider/oaistream/adapter.go b/pkg/model/provider/oaistream/adapter.go index 3716c79f1..c3e12a03d 100644 --- a/pkg/model/provider/oaistream/adapter.go +++ b/pkg/model/provider/oaistream/adapter.go @@ -35,7 +35,7 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) { if !a.stream.Next() { err := a.stream.Err() if err != nil { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, WrapOpenAIError(err) } return chat.MessageStreamResponse{}, io.EOF } diff --git a/pkg/model/provider/oaistream/wrap.go b/pkg/model/provider/oaistream/wrap.go new file mode 100644 index 000000000..de917233a --- /dev/null +++ b/pkg/model/provider/oaistream/wrap.go @@ -0,0 +1,24 @@ +package oaistream + +import ( + "errors" + + openaisdk "github.com/openai/openai-go/v3" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// WrapOpenAIError wraps an OpenAI SDK error in a *modelerrors.StatusError +// to carry HTTP status code and Retry-After metadata for the retry loop. +// Non-OpenAI errors (e.g. io.EOF, network errors) pass through unchanged. +// Exported so openai/response_stream.go can reuse it without duplication. +func WrapOpenAIError(err error) error { + if err == nil { + return nil + } + apiErr, ok := errors.AsType[*openaisdk.Error](err) + if !ok { + return err + } + return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err) +} diff --git a/pkg/model/provider/oaistream/wrap_test.go b/pkg/model/provider/oaistream/wrap_test.go new file mode 100644 index 000000000..eeb43608d --- /dev/null +++ b/pkg/model/provider/oaistream/wrap_test.go @@ -0,0 +1,103 @@ +package oaistream + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + openaisdk "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/modelerrors" +) + +// makeTestOpenAIError creates an *openai.Error with the given status code and +// optional Retry-After header value for testing. +func makeTestOpenAIError(statusCode int, retryAfterValue string) *openaisdk.Error { + header := http.Header{} + if retryAfterValue != "" { + header.Set("Retry-After", retryAfterValue) + } + resp := httptest.NewRecorder().Result() + resp.StatusCode = statusCode + resp.Header = header + // openai.Error.Error() dereferences Request, so we must provide a non-nil one. + req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody) + return &openaisdk.Error{ + StatusCode: statusCode, + Response: resp, + Request: req, + } +} + +func TestWrapOpenAIError(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + assert.NoError(t, WrapOpenAIError(nil)) + }) + + t.Run("non-openai error passes through unchanged", func(t *testing.T) { + t.Parallel() + orig := errors.New("some network error") + result := WrapOpenAIError(orig) + assert.Equal(t, orig, result) + var se *modelerrors.StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + // Original error still accessible + assert.ErrorIs(t, result, apiErr) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "30") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 30*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps with correct status code", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(500, "") + result := WrapOpenAIError(apiErr) + var se *modelerrors.StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + }) + + t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(429, "10") + result := WrapOpenAIError(apiErr) + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 10*time.Second, retryAfter) + }) + + t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) { + t.Parallel() + apiErr := makeTestOpenAIError(500, "") + wrapped := fmt.Errorf("stream error: %w", WrapOpenAIError(apiErr)) + retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped) + assert.True(t, retryable) + assert.False(t, rateLimited) + }) +} diff --git a/pkg/model/provider/openai/response_stream.go b/pkg/model/provider/openai/response_stream.go index 7e9040024..065cba32c 100644 --- a/pkg/model/provider/openai/response_stream.go +++ b/pkg/model/provider/openai/response_stream.go @@ -9,6 +9,7 @@ import ( "github.com/openai/openai-go/v3/responses" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider/oaistream" "github.com/docker/docker-agent/pkg/tools" ) @@ -33,7 +34,7 @@ func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamE func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) { if !a.stream.Next() { if err := a.stream.Err(); err != nil { - return chat.MessageStreamResponse{}, err + return chat.MessageStreamResponse{}, oaistream.WrapOpenAIError(err) } return chat.MessageStreamResponse{}, io.EOF } diff --git a/pkg/modelerrors/modelerrors.go b/pkg/modelerrors/modelerrors.go index 3efd4d705..4283262e7 100644 --- a/pkg/modelerrors/modelerrors.go +++ b/pkg/modelerrors/modelerrors.go @@ -11,7 +11,9 @@ import ( "log/slog" "math/rand" "net" + "net/http" "regexp" + "strconv" "strings" "time" @@ -19,14 +21,56 @@ import ( "google.golang.org/genai" ) -// Backoff configuration constants. +// Backoff and retry-after configuration constants. const ( backoffBaseDelay = 200 * time.Millisecond backoffMaxDelay = 2 * time.Second backoffFactor = 2.0 backoffJitter = 0.1 + + // MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent + // a misbehaving server from blocking the agent for an unreasonable amount of time. + MaxRetryAfterWait = 60 * time.Second ) +// StatusError wraps an HTTP API error with structured metadata for retry decisions. +// Providers wrap SDK errors in this type so the retry loop can use errors.As +// to extract status code and Retry-After without importing provider-specific SDKs. +type StatusError struct { + // StatusCode is the HTTP status code from the provider's API response. + StatusCode int + // RetryAfter is the parsed Retry-After header duration. Zero if absent. + RetryAfter time.Duration + // Err is the original error from the provider SDK. + Err error +} + +func (e *StatusError) Error() string { + return e.Err.Error() +} + +func (e *StatusError) Unwrap() error { + return e.Err +} + +// WrapHTTPError wraps err in a *StatusError carrying the HTTP status code and +// parsed Retry-After header from resp. Returns err unchanged if statusCode < 400 +// or err is nil. Pass resp=nil when no *http.Response is available. +func WrapHTTPError(statusCode int, resp *http.Response, err error) error { + if err == nil || statusCode < 400 { + return err + } + var retryAfter time.Duration + if resp != nil { + retryAfter = ParseRetryAfterHeader(resp.Header.Get("Retry-After")) + } + return &StatusError{ + StatusCode: statusCode, + RetryAfter: retryAfter, + Err: err, + } +} + // Default fallback configuration. const ( // DefaultRetries is the default number of retries per model with exponential @@ -296,6 +340,78 @@ func IsRetryableModelError(err error) bool { return false } +// ParseRetryAfterHeader parses a Retry-After header value. +// Supports both seconds (integer) and HTTP-date formats per RFC 7231 §7.1.3. +// Returns 0 if the value is empty, invalid, or results in a non-positive duration. +func ParseRetryAfterHeader(value string) time.Duration { + if value == "" { + return 0 + } + // Try integer seconds first (most common for rate limits) + if seconds, err := strconv.Atoi(value); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + // Try HTTP-date format + if t, err := http.ParseTime(value); err == nil { + d := time.Until(t) + if d > 0 { + return d + } + } + return 0 +} + +// ClassifyModelError classifies an error for the retry/fallback decision. +// +// If the error chain contains a *StatusError (wrapped by provider adapters), +// its StatusCode and RetryAfter fields are used directly — no provider-specific +// imports needed in the caller. +// +// Returns: +// - retryable=true: retry the SAME model with backoff (5xx, timeouts) +// - rateLimited=true: it's a 429 error; caller decides retry vs fallback based on config +// - retryAfter: Retry-After duration from the provider (only set for 429) +// +// When rateLimited=true, retryable is always false — the caller is responsible for +// deciding whether to retry (when no fallback is configured) or skip to the next +// model (when fallbacks are available). +func ClassifyModelError(err error) (retryable, rateLimited bool, retryAfter time.Duration) { + if err == nil { + return false, false, 0 + } + + // Context cancellation and deadline are never retryable. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false, false, 0 + } + + // Context overflow errors are never retryable — retrying the same oversized + // payload will always fail. + if IsContextOverflowError(err) { + return false, false, 0 + } + + // Primary path: typed StatusError wrapped by provider adapters. + var statusErr *StatusError + if errors.As(err, &statusErr) { + if statusErr.StatusCode == http.StatusTooManyRequests { + return false, true, statusErr.RetryAfter + } + return IsRetryableStatusCode(statusErr.StatusCode), false, 0 + } + + // Fallback: providers that don't yet wrap (e.g. Bedrock), or non-provider + // errors (network, pattern matching). + statusCode := ExtractHTTPStatusCode(err) + if statusCode == http.StatusTooManyRequests { + return false, true, 0 // No Retry-After without StatusError + } + if statusCode != 0 { + return IsRetryableStatusCode(statusCode), false, 0 + } + return IsRetryableModelError(err), false, 0 +} + // CalculateBackoff returns the backoff duration for a given attempt (0-indexed). // Uses exponential backoff with jitter. func CalculateBackoff(attempt int) time.Duration { diff --git a/pkg/modelerrors/modelerrors_test.go b/pkg/modelerrors/modelerrors_test.go index 056a30755..361c89f34 100644 --- a/pkg/modelerrors/modelerrors_test.go +++ b/pkg/modelerrors/modelerrors_test.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net" + "net/http" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // mockTimeoutError implements net.Error with Timeout() = true @@ -278,3 +280,188 @@ func TestFormatError(t *testing.T) { assert.Equal(t, "authentication failed", FormatError(err)) }) } + +func TestParseRetryAfterHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + expected time.Duration + }{ + {name: "empty", value: "", expected: 0}, + {name: "zero seconds", value: "0", expected: 0}, + {name: "negative seconds", value: "-1", expected: 0}, + {name: "invalid string", value: "foo", expected: 0}, + {name: "5 seconds", value: "5", expected: 5 * time.Second}, + {name: "30 seconds", value: "30", expected: 30 * time.Second}, + {name: "120 seconds", value: "120", expected: 120 * time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ParseRetryAfterHeader(tt.value) + assert.Equal(t, tt.expected, got, "ParseRetryAfterHeader(%q)", tt.value) + }) + } + + t.Run("HTTP-date in the future", func(t *testing.T) { + t.Parallel() + future := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) + got := ParseRetryAfterHeader(future) + assert.Greater(t, got, 0*time.Second, "should return positive duration for future HTTP-date") + assert.LessOrEqual(t, got, 11*time.Second, "should not exceed ~10s for near-future date") + }) + + t.Run("HTTP-date in the past", func(t *testing.T) { + t.Parallel() + past := time.Now().Add(-10 * time.Second).UTC().Format(http.TimeFormat) + got := ParseRetryAfterHeader(past) + assert.Equal(t, 0*time.Second, got, "should return 0 for past HTTP-date") + }) +} + +func TestStatusError(t *testing.T) { + t.Parallel() + + t.Run("Error() delegates to wrapped error", func(t *testing.T) { + t.Parallel() + underlying := errors.New("rate limit exceeded") + se := &StatusError{StatusCode: 429, Err: underlying} + assert.Equal(t, underlying.Error(), se.Error()) + }) + + t.Run("Unwrap() allows errors.Is traversal", func(t *testing.T) { + t.Parallel() + sentinel := errors.New("sentinel") + se := &StatusError{StatusCode: 500, Err: sentinel} + assert.ErrorIs(t, se, sentinel) + }) + + t.Run("errors.As finds StatusError in chain", func(t *testing.T) { + t.Parallel() + se := &StatusError{StatusCode: 429, RetryAfter: 10 * time.Second, Err: errors.New("rate limited")} + wrapped := fmt.Errorf("outer: %w", se) + var found *StatusError + require.ErrorAs(t, wrapped, &found) + assert.Equal(t, 429, found.StatusCode) + assert.Equal(t, 10*time.Second, found.RetryAfter) + }) +} + +func TestWrapHTTPError(t *testing.T) { + t.Parallel() + + t.Run("nil error returns nil", func(t *testing.T) { + t.Parallel() + require.NoError(t, WrapHTTPError(429, nil, nil)) + }) + + t.Run("status < 400 passes through unchanged", func(t *testing.T) { + t.Parallel() + origErr := errors.New("original") + result := WrapHTTPError(200, nil, origErr) + assert.Equal(t, origErr, result) + var se *StatusError + assert.NotErrorAs(t, result, &se) + }) + + t.Run("429 without response has zero RetryAfter", func(t *testing.T) { + t.Parallel() + origErr := errors.New("rate limited") + result := WrapHTTPError(429, nil, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + assert.Equal(t, origErr.Error(), se.Error()) + }) + + t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) { + t.Parallel() + origErr := errors.New("rate limited") + respHeader := http.Header{} + respHeader.Set("Retry-After", "30") + resp := &http.Response{Header: respHeader} + result := WrapHTTPError(429, resp, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 429, se.StatusCode) + assert.Equal(t, 30*time.Second, se.RetryAfter) + }) + + t.Run("500 wraps correctly", func(t *testing.T) { + t.Parallel() + origErr := errors.New("internal server error") + result := WrapHTTPError(500, nil, origErr) + var se *StatusError + require.ErrorAs(t, result, &se) + assert.Equal(t, 500, se.StatusCode) + assert.Equal(t, time.Duration(0), se.RetryAfter) + }) + + t.Run("original error still accessible via Unwrap", func(t *testing.T) { + t.Parallel() + sentinel := errors.New("sentinel") + result := WrapHTTPError(429, nil, sentinel) + assert.ErrorIs(t, result, sentinel) + }) +} + +func TestClassifyModelError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantRetryable bool + wantRateLimited bool + wantRetryAfter time.Duration + }{ + {name: "nil", err: nil, wantRetryable: false, wantRateLimited: false}, + {name: "context canceled", err: context.Canceled, wantRetryable: false, wantRateLimited: false}, + {name: "context deadline exceeded", err: context.DeadlineExceeded, wantRetryable: false, wantRateLimited: false}, + {name: "context overflow", err: errors.New("prompt is too long: 200000 tokens > 100000 maximum"), wantRetryable: false, wantRateLimited: false}, + // 429 without StatusError (fallback message-pattern path) + {name: "429 message fallback, no RetryAfter", err: errors.New("POST /v1/chat: 429 Too Many Requests"), wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, + // 429 via StatusError (primary path) — no Retry-After + {name: "429 StatusError no retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 0, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 0}, + // 429 via StatusError with Retry-After from response header + {name: "429 StatusError with retry-after", err: &StatusError{StatusCode: 429, RetryAfter: 20 * time.Second, Err: errors.New("rate limited")}, wantRetryable: false, wantRateLimited: true, wantRetryAfter: 20 * time.Second}, + // Retryable status codes via StatusError + {name: "500 StatusError", err: &StatusError{StatusCode: 500, Err: errors.New("internal server error")}, wantRetryable: true, wantRateLimited: false}, + {name: "529 StatusError", err: &StatusError{StatusCode: 529, Err: errors.New("overloaded")}, wantRetryable: true, wantRateLimited: false}, + {name: "408 StatusError", err: &StatusError{StatusCode: 408, Err: errors.New("timeout")}, wantRetryable: true, wantRateLimited: false}, + // Retryable fallback path (message-based) + {name: "500 message fallback", err: errors.New("500 internal server error"), wantRetryable: true, wantRateLimited: false}, + {name: "502 message fallback", err: errors.New("502 bad gateway"), wantRetryable: true, wantRateLimited: false}, + // Non-retryable via StatusError + {name: "401 StatusError", err: &StatusError{StatusCode: 401, Err: errors.New("unauthorized")}, wantRetryable: false, wantRateLimited: false}, + {name: "403 StatusError", err: &StatusError{StatusCode: 403, Err: errors.New("forbidden")}, wantRetryable: false, wantRateLimited: false}, + // Non-retryable fallback + {name: "401 message fallback", err: errors.New("401 unauthorized"), wantRetryable: false, wantRateLimited: false}, + // Network errors + {name: "network timeout", err: &mockTimeoutError{}, wantRetryable: true, wantRateLimited: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + retryable, rateLimited, retryAfterOut := ClassifyModelError(tt.err) + assert.Equal(t, tt.wantRetryable, retryable, "retryable mismatch") + assert.Equal(t, tt.wantRateLimited, rateLimited, "rateLimited mismatch") + assert.Equal(t, tt.wantRetryAfter, retryAfterOut, "retryAfter mismatch") + }) + } + + t.Run("wrapped StatusError is found by errors.As", func(t *testing.T) { + t.Parallel() + statusErr := &StatusError{StatusCode: 429, RetryAfter: 15 * time.Second, Err: errors.New("rate limited")} + wrapped := fmt.Errorf("model failed: %w", statusErr) + retryable, rateLimited, retryAfterOut := ClassifyModelError(wrapped) + assert.False(t, retryable) + assert.True(t, rateLimited) + assert.Equal(t, 15*time.Second, retryAfterOut) + }) +} diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index b60c6deef..a9caa0456 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -212,12 +212,14 @@ func (r *LocalRuntime) tryModelWithFallback( var lastErr error primaryFailedWithNonRetryable := false + hasFallbacks := len(fallbackModels) > 0 for chainIdx := startIndex; chainIdx < len(modelChain); chainIdx++ { modelEntry := modelChain[chainIdx] // Each model in the chain gets (1 + retries) attempts for retryable errors. - // Non-retryable errors (429, 4xx) skip immediately to the next model. + // Non-retryable errors (429 with fallbacks, 4xx) skip immediately to the next model. + // 429 without fallbacks is retried directly on the same model. maxAttempts := 1 + fallbackRetries for attempt := range maxAttempts { @@ -270,28 +272,12 @@ func (r *LocalRuntime) tryModelWithFallback( return streamResult{}, nil, err } - // Check if error is retryable - if !modelerrors.IsRetryableModelError(err) { - slog.Error("Non-retryable error creating stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "error", err) - - // Track if primary failed with non-retryable error - if !modelEntry.isFallback { - primaryFailedWithNonRetryable = true - } - - // Skip to next model in chain + decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) + if decision == retryDecisionReturn { + return streamResult{}, nil, ctx.Err() + } else if decision == retryDecisionBreak { break } - - slog.Warn("Retryable error creating stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "max_attempts", maxAttempts, - "error", err) continue } @@ -306,27 +292,12 @@ func (r *LocalRuntime) tryModelWithFallback( return streamResult{}, nil, err } - // Check if stream error is retryable - if !modelerrors.IsRetryableModelError(err) { - slog.Error("Non-retryable error handling stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "error", err) - - // Track if primary failed with non-retryable error - if !modelEntry.isFallback { - primaryFailedWithNonRetryable = true - } - + decision := r.handleModelError(ctx, err, a, modelEntry, attempt, hasFallbacks, &primaryFailedWithNonRetryable) + if decision == retryDecisionReturn { + return streamResult{}, nil, ctx.Err() + } else if decision == retryDecisionBreak { break } - - slog.Warn("Retryable error handling stream", - "agent", a.Name(), - "model", modelEntry.provider.ID(), - "attempt", attempt+1, - "max_attempts", maxAttempts, - "error", err) continue } @@ -359,3 +330,94 @@ func (r *LocalRuntime) tryModelWithFallback( } return streamResult{}, nil, errors.New("all models failed with unknown error") } + +// retryDecision is the outcome of handleModelError. +type retryDecision int + +const ( + // retryDecisionContinue means retry the same model (backoff already applied). + retryDecisionContinue retryDecision = iota + // retryDecisionBreak means skip to the next model in the fallback chain. + retryDecisionBreak + // retryDecisionReturn means context was cancelled; return immediately. + retryDecisionReturn +) + +// handleModelError classifies err and decides what to do next: +// - retryDecisionReturn — context cancelled while sleeping; caller returns ctx.Err() +// - retryDecisionBreak — non-retryable error or 429 with fallbacks; skip to next model +// - retryDecisionContinue — retryable error or 429 without fallbacks; retry same model +// +// Side-effect: sets *primaryFailedWithNonRetryable when the primary model fails with a +// non-retryable (or rate-limited-with-fallbacks) error. +func (r *LocalRuntime) handleModelError( + ctx context.Context, + err error, + a *agent.Agent, + modelEntry modelWithFallback, + attempt int, + hasFallbacks bool, + primaryFailedWithNonRetryable *bool, +) retryDecision { + retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(err) + + if rateLimited { + // Gate: only retry on 429 if opt-in is enabled AND no fallbacks exist. + // Default behavior (retryOnRateLimit=false) treats 429 as non-retryable, + // identical to today's behavior before this feature was added. + if !r.retryOnRateLimit || hasFallbacks { + slog.Warn("Rate limited, treating as non-retryable", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "retry_on_rate_limit_enabled", r.retryOnRateLimit, + "has_fallbacks", hasFallbacks, + "error", err) + if !modelEntry.isFallback { + *primaryFailedWithNonRetryable = true + } + return retryDecisionBreak + } + + // Opt-in enabled, no fallbacks → retry same model after honouring Retry-After (or backoff). + waitDuration := retryAfter + if waitDuration <= 0 { + waitDuration = modelerrors.CalculateBackoff(attempt) + } else if waitDuration > modelerrors.MaxRetryAfterWait { + slog.Warn("Retry-After exceeds maximum, capping", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "retry_after", retryAfter, + "max", modelerrors.MaxRetryAfterWait) + waitDuration = modelerrors.MaxRetryAfterWait + } + slog.Warn("Rate limited, retrying (opt-in enabled)", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "attempt", attempt+1, + "wait", waitDuration, + "retry_after_from_header", retryAfter > 0, + "error", err) + if !modelerrors.SleepWithContext(ctx, waitDuration) { + return retryDecisionReturn + } + return retryDecisionContinue + } + + if !retryable { + slog.Error("Non-retryable error from model", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "error", err) + if !modelEntry.isFallback { + *primaryFailedWithNonRetryable = true + } + return retryDecisionBreak + } + + slog.Warn("Retryable error from model", + "agent", a.Name(), + "model", modelEntry.provider.ID(), + "attempt", attempt+1, + "error", err) + return retryDecisionContinue +} diff --git a/pkg/runtime/fallback_test.go b/pkg/runtime/fallback_test.go index 342f215e2..614642ca4 100644 --- a/pkg/runtime/fallback_test.go +++ b/pkg/runtime/fallback_test.go @@ -468,3 +468,265 @@ func TestFallbackModelsClonedWithThinkingEnabled(t *testing.T) { "BaseConfig() should be called on fallback provider when thinking is enabled") }) } + +func TestFallback429WithFallbacksSkipsToNextModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary gets rate limited; with fallbacks configured it should skip immediately. + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + successStream := newStreamBuilder(). + AddContent("Success from fallback"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(5), // many retries — 429 should NOT use them + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 With Fallback Skip Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success from fallback" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content from fallback") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — 429 with fallbacks should skip immediately") + }) +} + +func TestFallback429WithoutFallbacksRetriesSameModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary gets rate limited; with no fallbacks configured it should retry + // when the opt-in is enabled. + successStream := newStreamBuilder(). + AddContent("Success after rate limit"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 2, // fail twice with 429, then succeed + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models configured + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 No Fallback Retry Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after rate limit" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after rate limit retries") + assert.Equal(t, 3, primary.callCount, "primary should be called 3 times: 2 failures + 1 success") + }) +} + +func TestFallback429WithoutFallbacksExhaustsRetries(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary always returns 429, no fallbacks, opt-in enabled — should fail after all retries. + primary := &failingProvider{ + id: "primary/always-rate-limited", + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(2), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 No Fallback Exhaust Test" + + var gotError bool + for ev := range rt.RunStream(t.Context(), sess) { + if _, ok := ev.(*ErrorEvent); ok { + gotError = true + } + } + assert.True(t, gotError, "should receive an error when all retries exhausted") + }) +} + +func TestFallback500RetryableWithBackoff(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Primary returns 500 (retryable), no fallbacks — should retry with backoff. + successStream := newStreamBuilder(). + AddContent("Success after 500"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/server-error", + failCount: 1, + err: errors.New("500 internal server error"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(2), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "500 Retry Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after 500" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after 500 retry") + assert.Equal(t, 2, primary.callCount, "primary should be called twice: 1 failure + 1 success") + }) +} + +// --- WithRetryOnRateLimit gate tests --- + +func TestRateLimitGate_DisabledNoFallbacks_FailsImmediately(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // With retryOnRateLimit=false (default) and no fallbacks, a 429 should + // be treated as non-retryable and fail immediately without any retry. + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models, no WithRetryOnRateLimit opt-in + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + // Note: WithRetryOnRateLimit() is NOT passed — default off + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Disabled Test" + + var gotError bool + for ev := range rt.RunStream(t.Context(), sess) { + if _, ok := ev.(*ErrorEvent); ok { + gotError = true + } + } + assert.True(t, gotError, "should fail immediately with an error") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — no retry without opt-in") + }) +} + +func TestRateLimitGate_EnabledNoFallbacks_RetriesSameModel(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // With retryOnRateLimit=true and no fallbacks, a 429 should retry + // the same model until it succeeds or retries are exhausted. + successStream := newStreamBuilder(). + AddContent("Success after rate limit"). + AddStopWithUsage(10, 5). + Build() + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 2, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + stream: successStream, + } + + root := agent.New("root", "test", + agent.WithModel(primary), + // No fallback models + agent.WithFallbackRetries(3), + ) + + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Enabled No Fallbacks Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success after rate limit" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content after retrying") + assert.Equal(t, 3, primary.callCount, "primary should be called 3 times: 2 failures + 1 success") + }) +} + +func TestRateLimitGate_EnabledWithFallbacks_SkipsToFallback(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Even with retryOnRateLimit=true, when fallbacks are configured + // a 429 should skip to the fallback immediately (fallbacks take priority). + primary := &countingProvider{ + id: "primary/rate-limited", + failCount: 100, + err: errors.New("POST /v1/chat/completions: 429 Too Many Requests"), + } + successStream := newStreamBuilder(). + AddContent("Success from fallback"). + AddStopWithUsage(10, 5). + Build() + fallback := &mockProvider{id: "fallback/success", stream: successStream} + + root := agent.New("root", "test", + agent.WithModel(primary), + agent.WithFallbackModel(fallback), + agent.WithFallbackRetries(5), + ) + + tm := team.New(team.WithAgents(root)) + // opt-in is enabled, but fallbacks are present → should still skip to fallback + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}), WithRetryOnRateLimit()) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("test")) + sess.Title = "429 Gate Enabled With Fallbacks Test" + + var gotContent bool + for ev := range rt.RunStream(t.Context(), sess) { + if choice, ok := ev.(*AgentChoiceEvent); ok && choice.Content == "Success from fallback" { + gotContent = true + } + } + assert.True(t, gotContent, "should receive content from fallback") + assert.Equal(t, 1, primary.callCount, "primary should only be called once — fallbacks take priority over retry") + }) +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 82487b202..fd7e35e7a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -199,6 +199,12 @@ type LocalRuntime struct { env []string // Environment variables for hooks execution modelSwitcherCfg *ModelSwitcherConfig + // retryOnRateLimit enables retry-with-backoff for HTTP 429 (rate limit) errors + // when no fallback models are configured. When false (default), 429 errors are + // treated as non-retryable and immediately fail or skip to the next model. + // Library consumers can enable this via WithRetryOnRateLimit(). + retryOnRateLimit bool + // fallbackCooldowns tracks per-agent cooldown state for sticky fallback behavior fallbackCooldowns map[string]*fallbackCooldownState fallbackCooldownsMux sync.RWMutex @@ -264,6 +270,23 @@ func WithEnv(env []string) Opt { } } +// WithRetryOnRateLimit enables automatic retry with backoff for HTTP 429 (rate limit) +// errors when no fallback models are available. When enabled, the runtime will honor +// the Retry-After header from the provider's response to determine wait time before +// retrying, falling back to exponential backoff if the header is absent. +// +// This is off by default. It is intended for library consumers that run agents +// programmatically and prefer to wait for rate limits to clear rather than fail +// immediately. +// +// When fallback models are configured, 429 errors always skip to the next model +// regardless of this setting. +func WithRetryOnRateLimit() Opt { + return func(r *LocalRuntime) { + r.retryOnRateLimit = true + } +} + // NewLocalRuntime creates a new LocalRuntime without the persistence wrapper. // This is useful for testing or when persistence is handled externally. func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {