Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func isContextLengthError(err error) bool {
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
ok, err := a.next()
if !ok {
return chat.MessageStreamResponse{}, err
return chat.MessageStreamResponse{}, wrapAnthropicError(err)
}

event := a.stream.Current()
Expand Down
2 changes: 1 addition & 1 deletion pkg/model/provider/anthropic/beta_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRaw
func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
ok, err := a.next()
if !ok {
return chat.MessageStreamResponse{}, err
return chat.MessageStreamResponse{}, wrapAnthropicError(err)
}

event := a.stream.Current()
Expand Down
23 changes: 23 additions & 0 deletions pkg/model/provider/anthropic/wrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package anthropic

import (
"errors"

"github.com/anthropics/anthropic-sdk-go"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// wrapAnthropicError wraps an Anthropic SDK error in a *modelerrors.StatusError
// to carry HTTP status code and Retry-After metadata for the retry loop.
// Non-Anthropic errors (e.g. io.EOF, network errors) pass through unchanged.
func wrapAnthropicError(err error) error {
if err == nil {
return nil
}
apiErr, ok := errors.AsType[*anthropic.Error](err)
if !ok {
return err
}
return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err)
}
105 changes: 105 additions & 0 deletions pkg/model/provider/anthropic/wrap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package anthropic

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/anthropics/anthropic-sdk-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// makeTestAnthropicError creates an *anthropic.Error with the given status code and
// optional Retry-After header value for testing.
func makeTestAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error {
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
}
resp := httptest.NewRecorder().Result()
resp.StatusCode = statusCode
resp.Header = header
// anthropic.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody)
return &anthropic.Error{
StatusCode: statusCode,
Response: resp,
Request: req,
}
}

func TestWrapAnthropicError(t *testing.T) {
t.Parallel()

t.Run("nil returns nil", func(t *testing.T) {
t.Parallel()
assert.NoError(t, wrapAnthropicError(nil))
})

t.Run("non-anthropic error passes through unchanged", func(t *testing.T) {
t.Parallel()
orig := errors.New("some network error")
result := wrapAnthropicError(orig)
assert.Equal(t, orig, result)
var se *modelerrors.StatusError
assert.NotErrorAs(t, result, &se)
})

t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, time.Duration(0), se.RetryAfter)
// Original error still accessible
assert.ErrorIs(t, result, apiErr)
})

t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "20")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, 20*time.Second, se.RetryAfter)
})

t.Run("500 wraps with correct status code", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(500, "")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 500, se.StatusCode)
assert.Equal(t, time.Duration(0), se.RetryAfter)
})

t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "15")
result := wrapAnthropicError(apiErr)
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 15*time.Second, retryAfter)
})

t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "5")
wrapped := fmt.Errorf("stream error: %w", wrapAnthropicError(apiErr))
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(wrapped)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 5*time.Second, retryAfter)
})
}
2 changes: 1 addition & 1 deletion pkg/model/provider/gemini/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
}

if res.err != nil {
return chat.MessageStreamResponse{}, res.err
return chat.MessageStreamResponse{}, wrapGeminiError(res.err)
}

// Build response
Expand Down
26 changes: 26 additions & 0 deletions pkg/model/provider/gemini/wrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package gemini

import (
"errors"

"google.golang.org/genai"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// wrapGeminiError wraps a Gemini SDK error in a *modelerrors.StatusError
// to carry HTTP status code metadata for the retry loop.
// Gemini's *genai.APIError does not expose *http.Response, so no Retry-After
// header extraction is possible; the RetryAfter field will be zero.
// Non-Gemini errors (e.g. io.EOF, network errors) pass through unchanged.
func wrapGeminiError(err error) error {
if err == nil {
return nil
}
apiErr, ok := errors.AsType[*genai.APIError](err)
if !ok {
return err
}
// Pass nil for resp — Gemini doesn't expose *http.Response.
return modelerrors.WrapHTTPError(apiErr.Code, nil, err)
}
2 changes: 1 addition & 1 deletion pkg/model/provider/oaistream/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
if !a.stream.Next() {
err := a.stream.Err()
if err != nil {
return chat.MessageStreamResponse{}, err
return chat.MessageStreamResponse{}, WrapOpenAIError(err)
}
return chat.MessageStreamResponse{}, io.EOF
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/model/provider/oaistream/wrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package oaistream

import (
"errors"

openaisdk "github.com/openai/openai-go/v3"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// WrapOpenAIError wraps an OpenAI SDK error in a *modelerrors.StatusError
// to carry HTTP status code and Retry-After metadata for the retry loop.
// Non-OpenAI errors (e.g. io.EOF, network errors) pass through unchanged.
// Exported so openai/response_stream.go can reuse it without duplication.
func WrapOpenAIError(err error) error {
if err == nil {
return nil
}
apiErr, ok := errors.AsType[*openaisdk.Error](err)
if !ok {
return err
}
return modelerrors.WrapHTTPError(apiErr.StatusCode, apiErr.Response, err)
}
103 changes: 103 additions & 0 deletions pkg/model/provider/oaistream/wrap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package oaistream

import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

openaisdk "github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/modelerrors"
)

// makeTestOpenAIError creates an *openai.Error with the given status code and
// optional Retry-After header value for testing.
func makeTestOpenAIError(statusCode int, retryAfterValue string) *openaisdk.Error {
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
}
resp := httptest.NewRecorder().Result()
resp.StatusCode = statusCode
resp.Header = header
// openai.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody)
return &openaisdk.Error{
StatusCode: statusCode,
Response: resp,
Request: req,
}
}

func TestWrapOpenAIError(t *testing.T) {
t.Parallel()

t.Run("nil returns nil", func(t *testing.T) {
t.Parallel()
assert.NoError(t, WrapOpenAIError(nil))
})

t.Run("non-openai error passes through unchanged", func(t *testing.T) {
t.Parallel()
orig := errors.New("some network error")
result := WrapOpenAIError(orig)
assert.Equal(t, orig, result)
var se *modelerrors.StatusError
assert.NotErrorAs(t, result, &se)
})

t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, time.Duration(0), se.RetryAfter)
// Original error still accessible
assert.ErrorIs(t, result, apiErr)
})

t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "30")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 429, se.StatusCode)
assert.Equal(t, 30*time.Second, se.RetryAfter)
})

t.Run("500 wraps with correct status code", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(500, "")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
assert.Equal(t, 500, se.StatusCode)
})

t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "10")
result := WrapOpenAIError(apiErr)
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
assert.False(t, retryable)
assert.True(t, rateLimited)
assert.Equal(t, 10*time.Second, retryAfter)
})

t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(500, "")
wrapped := fmt.Errorf("stream error: %w", WrapOpenAIError(apiErr))
retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped)
assert.True(t, retryable)
assert.False(t, rateLimited)
})
}
3 changes: 2 additions & 1 deletion pkg/model/provider/openai/response_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/openai/openai-go/v3/responses"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/model/provider/oaistream"
"github.com/docker/docker-agent/pkg/tools"
)

Expand All @@ -33,7 +34,7 @@ func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamE
func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
if !a.stream.Next() {
if err := a.stream.Err(); err != nil {
return chat.MessageStreamResponse{}, err
return chat.MessageStreamResponse{}, oaistream.WrapOpenAIError(err)
}
return chat.MessageStreamResponse{}, io.EOF
}
Expand Down
Loading
Loading