Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ type streamAdapter struct {

trackUsage bool
toolCall bool
toolID string
// toolIDByBlock maps a content block index to its tool_use block ID.
// Anthropic emits each tool_use in its own content block; subsequent
// input_json_delta events carry the block index (not the tool ID), so
// we must remember the ID per block to route partial JSON correctly
// when multiple tool calls stream in parallel.
toolIDByBlock map[int64]string
}

func (c *Client) newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
return &streamAdapter{
retryableStream: retryableStream[anthropic.MessageStreamEventUnion]{stream: stream},
trackUsage: trackUsage,
toolIDByBlock: map[int64]string{},
}
}

Expand Down Expand Up @@ -92,10 +98,13 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
case anthropic.ContentBlockStartEvent:
switch block := eventVariant.ContentBlock.AsAny().(type) {
case anthropic.ToolUseBlock:
a.toolID = block.ID
if a.toolIDByBlock == nil {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] Redundant nil guard on toolIDByBlock inside Recv() — slightly misleading

newStreamAdapter always initialises toolIDByBlock to a non-nil map, so the defensive if a.toolIDByBlock == nil check inside the ContentBlockStartEvent branch (line 101) will never fire for adapters created via the constructor.

The check exists because the test file constructs the adapter as a struct literal without calling the constructor (see parallel_toolcall_test.go line 129: adapter := &streamAdapter{retryableStream: ...}), leaving the map nil. The nil-guard silently rescues that path.

Two minor concerns:

  1. The guard implies the constructor-initialised map can somehow become nil, which is confusing for future readers.
  2. The test is actually exercising the nil-initialised (non-constructor) path, which diverges slightly from production usage.

Neither is a functional bug — Go map reads on a nil map safely return the zero value, and the lazy initialisation via the guard works correctly. Just worth being aware of if the guard is ever removed or the constructor changes.

a.toolIDByBlock = map[int64]string{}
}
a.toolIDByBlock[eventVariant.Index] = block.ID
a.toolCall = true
toolCall := tools.ToolCall{
ID: a.toolID,
ID: block.ID,
Type: "function",
Function: tools.FunctionCall{
Name: block.Name,
Expand All @@ -122,7 +131,7 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
case anthropic.InputJSONDelta:
inputBytes := deltaVariant.PartialJSON
toolCall := tools.ToolCall{
ID: a.toolID,
ID: a.toolIDByBlock[eventVariant.Index],
Type: "function",
Function: tools.FunctionCall{
Arguments: inputBytes,
Expand Down
15 changes: 11 additions & 4 deletions pkg/model/provider/anthropic/beta_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ type betaStreamAdapter struct {

trackUsage bool
toolCall bool
toolID string
// toolIDByBlock maps a content block index to its tool_use block ID.
// See streamAdapter.toolIDByBlock for the same rationale (parallel
// tool calls require per-block routing of input_json deltas).
toolIDByBlock map[int64]string
}

// newBetaStreamAdapter creates a new Beta stream adapter
func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion], trackUsage bool) *betaStreamAdapter {
return &betaStreamAdapter{
retryableStream: retryableStream[anthropic.BetaRawMessageStreamEventUnion]{stream: stream},
trackUsage: trackUsage,
toolIDByBlock: map[int64]string{},
}
}

Expand Down Expand Up @@ -56,10 +60,13 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
case anthropic.BetaRawContentBlockStartEvent:
switch block := eventVariant.ContentBlock.AsAny().(type) {
case anthropic.BetaToolUseBlock:
a.toolID = block.ID
if a.toolIDByBlock == nil {
a.toolIDByBlock = map[int64]string{}
}
a.toolIDByBlock[eventVariant.Index] = block.ID
a.toolCall = true
toolCall := tools.ToolCall{
ID: a.toolID,
ID: block.ID,
Type: "function",
Function: tools.FunctionCall{
Name: block.Name,
Expand All @@ -84,7 +91,7 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
case anthropic.BetaInputJSONDelta:
inputBytes := deltaVariant.PartialJSON
toolCall := tools.ToolCall{
ID: a.toolID,
ID: a.toolIDByBlock[eventVariant.Index],
Type: "function",
Function: tools.FunctionCall{
Arguments: inputBytes,
Expand Down
282 changes: 282 additions & 0 deletions pkg/model/provider/anthropic/parallel_toolcall_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
package anthropic

import (
"encoding/json"
"strings"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
)

// fakeDecoder feeds a fixed sequence of SSE events to ssestream.Stream.
type fakeDecoder struct {
events []ssestream.Event
i int
closed bool
}

func (d *fakeDecoder) Next() bool {
if d.i >= len(d.events) {
return false
}
d.i++
return true
}

func (d *fakeDecoder) Event() ssestream.Event { return d.events[d.i-1] }
func (d *fakeDecoder) Close() error { d.closed = true; return nil }
func (d *fakeDecoder) Err() error { return nil }

func sseEvent(t string, payload any) ssestream.Event {
data, _ := json.Marshal(payload)
return ssestream.Event{Type: t, Data: data}
}

// TestParallelToolCallIDsAreNotCrossWired reproduces a bug where the Anthropic
// stream adapter loses track of which tool_use block a given input_json delta
// belongs to when two or more tool calls stream in parallel.
//
// Anthropic's streaming protocol emits a content_block_start for each
// tool_use block (each with its own block index and its own tool ID), then
// emits content_block_delta events of type input_json_delta carrying partial
// JSON for each block. Every event carries the block's index. The adapter
// must use that index to route partial JSON back to the correct tool call.
//
// The current adapter stores the most recently seen tool ID in a single
// streamAdapter.toolID field. When a second tool_use block starts, that
// field is overwritten. Subsequent input_json deltas for the FIRST block
// then carry the SECOND block's ID, and the runtime accumulator
// (keyed by tool call ID in pkg/runtime/streaming.go) concatenates both
// calls' argument fragments into the same buffer, producing malformed JSON
// that surfaces upstream as Go json errors like
//
// "invalid character 's' looking for beginning of value"
// "invalid character '-' after object key:value pair"
//
// This test demonstrates the bug. With the fix in place (route by block
// index, not by a single shared toolID), both tool calls' arguments end up
// in their own buffers and parse cleanly.
func TestParallelToolCallIDsAreNotCrossWired(t *testing.T) {
// Event sequence: two parallel tool_use blocks with interleaved
// input_json_delta events. This mirrors what Anthropic emits when the
// model issues parallel tool calls.
events := []ssestream.Event{
// message_start (minimal — we only care about content blocks below)
sseEvent("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{"id": "msg_test", "model": "claude-test", "role": "assistant", "type": "message"},
}),

// content_block_start, index 0: tool A (memory_refresh_complete)
sseEvent("content_block_start", map[string]any{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "tool_use",
"id": "toolu_AAA",
"name": "memory_refresh_complete",
"input": map[string]any{},
},
}),

// content_block_start, index 1: tool B (memory_learning_add)
sseEvent("content_block_start", map[string]any{
"type": "content_block_start",
"index": 1,
"content_block": map[string]any{
"type": "tool_use",
"id": "toolu_BBB",
"name": "memory_learning_add",
"input": map[string]any{},
},
}),

// Interleaved input_json_delta events. Each carries its block index.
// Tool A args: {"refresh_id":"abc-def"}
// Tool B args: {"category":"tool_failure","summary":"x"}
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `{"refresh_id":"abc-`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 1,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `{"category":"tool_failure",`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `def"}`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 1,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `"summary":"x"}`},
}),

sseEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": 0}),
sseEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": 1}),
sseEvent("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "tool_use"},
}),
sseEvent("message_stop", map[string]any{"type": "message_stop"}),
}

stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](&fakeDecoder{events: events}, nil)
adapter := &streamAdapter{retryableStream: retryableStream[anthropic.MessageStreamEventUnion]{stream: stream}}

// Replicate the runtime's tool-call accumulator (pkg/runtime/streaming.go).
// It keys by ToolCall.ID and concatenates Arguments fragments. This is the
// downstream layer that turns malformed concatenation into a JSON parse
// error when the tool is invoked.
argsByID := map[string]*strings.Builder{}
nameByID := map[string]string{}

for {
resp, err := adapter.Recv()
if err != nil {
break
}
if len(resp.Choices) == 0 {
continue
}
for _, tc := range resp.Choices[0].Delta.ToolCalls {
if _, ok := argsByID[tc.ID]; !ok {
argsByID[tc.ID] = &strings.Builder{}
}
if tc.Function.Name != "" {
nameByID[tc.ID] = tc.Function.Name
}
argsByID[tc.ID].WriteString(tc.Function.Arguments)
}
}

// Expected behaviour: tool A and tool B each get exactly their own args,
// and both parse as valid JSON.
expectedA := `{"refresh_id":"abc-def"}`
expectedB := `{"category":"tool_failure","summary":"x"}`

gotA := argsByID["toolu_AAA"].String()
gotB := argsByID["toolu_BBB"].String()

t.Logf("toolu_AAA name=%q args=%q", nameByID["toolu_AAA"], gotA)
t.Logf("toolu_BBB name=%q args=%q", nameByID["toolu_BBB"], gotB)

if gotA != expectedA {
t.Errorf("tool A (toolu_AAA, memory_refresh_complete) args wrong\n want: %s\n got: %s", expectedA, gotA)
}
if gotB != expectedB {
t.Errorf("tool B (toolu_BBB, memory_learning_add) args wrong\n want: %s\n got: %s", expectedB, gotB)
}

// As a sharper assertion: both buffers must individually parse as JSON.
// With the bug, one or both fail with the exact Go errors observed in
// production ("invalid character ... looking for beginning of value" or
// "invalid character ... after object key:value pair").
var sink any
if err := json.Unmarshal([]byte(gotA), &sink); err != nil {
t.Errorf("tool A args failed to parse as JSON: %v\n buffer was: %s", err, gotA)
}
if err := json.Unmarshal([]byte(gotB), &sink); err != nil {
t.Errorf("tool B args failed to parse as JSON: %v\n buffer was: %s", err, gotB)
}
}

// TestBetaParallelToolCallIDsAreNotCrossWired is the same scenario but for
// the Beta stream adapter. The bug and fix are identical.
func TestBetaParallelToolCallIDsAreNotCrossWired(t *testing.T) {
events := []ssestream.Event{
sseEvent("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{"id": "msg_test", "model": "claude-test", "role": "assistant", "type": "message"},
}),
sseEvent("content_block_start", map[string]any{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "tool_use",
"id": "toolu_AAA",
"name": "memory_refresh_complete",
"input": map[string]any{},
},
}),
sseEvent("content_block_start", map[string]any{
"type": "content_block_start",
"index": 1,
"content_block": map[string]any{
"type": "tool_use",
"id": "toolu_BBB",
"name": "memory_learning_add",
"input": map[string]any{},
},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `{"refresh_id":"abc-`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 1,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `{"category":"tool_failure",`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `def"}`},
}),
sseEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 1,
"delta": map[string]any{"type": "input_json_delta", "partial_json": `"summary":"x"}`},
}),
sseEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": 0}),
sseEvent("content_block_stop", map[string]any{"type": "content_block_stop", "index": 1}),
sseEvent("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "tool_use"},
}),
sseEvent("message_stop", map[string]any{"type": "message_stop"}),
}

stream := ssestream.NewStream[anthropic.BetaRawMessageStreamEventUnion](&fakeDecoder{events: events}, nil)
adapter := &betaStreamAdapter{retryableStream: retryableStream[anthropic.BetaRawMessageStreamEventUnion]{stream: stream}}

argsByID := map[string]*strings.Builder{}
for {
resp, err := adapter.Recv()
if err != nil {
break
}
if len(resp.Choices) == 0 {
continue
}
for _, tc := range resp.Choices[0].Delta.ToolCalls {
if _, ok := argsByID[tc.ID]; !ok {
argsByID[tc.ID] = &strings.Builder{}
}
argsByID[tc.ID].WriteString(tc.Function.Arguments)
}
}

gotA := argsByID["toolu_AAA"].String()
gotB := argsByID["toolu_BBB"].String()
expectedA := `{"refresh_id":"abc-def"}`
expectedB := `{"category":"tool_failure","summary":"x"}`
if gotA != expectedA {
t.Errorf("[beta] tool A args wrong\n want: %s\n got: %s", expectedA, gotA)
}
if gotB != expectedB {
t.Errorf("[beta] tool B args wrong\n want: %s\n got: %s", expectedB, gotB)
}
var sink any
if err := json.Unmarshal([]byte(gotA), &sink); err != nil {
t.Errorf("[beta] tool A args failed to parse: %v\n buffer was: %s", err, gotA)
}
if err := json.Unmarshal([]byte(gotB), &sink); err != nil {
t.Errorf("[beta] tool B args failed to parse: %v\n buffer was: %s", err, gotB)
}
}
Loading