From de968269e0f37ddd30299fb76343365997874b1c Mon Sep 17 00:00:00 2001 From: Sam Willis Date: Wed, 10 Jun 2026 17:59:47 +0100 Subject: [PATCH 1/2] Add streaming tool call argument support --- .changeset/streaming-tool-call-args.md | 5 + packages/agents-runtime/src/entity-schema.ts | 43 +++++- .../agents-runtime/src/entity-timeline.ts | 24 ++- .../agents-runtime/src/outbound-bridge.ts | 143 +++++++++++++++--- packages/agents-runtime/src/pi-adapter.ts | 82 +++++++++- packages/agents-runtime/src/types.ts | 16 +- .../test/outbound-bridge.test.ts | 56 ++++++- .../agents-runtime/test/pi-adapter.test.ts | 135 +++++++++++++++++ 8 files changed, 477 insertions(+), 27 deletions(-) create mode 100644 .changeset/streaming-tool-call-args.md diff --git a/.changeset/streaming-tool-call-args.md b/.changeset/streaming-tool-call-args.md new file mode 100644 index 0000000000..88d56049b4 --- /dev/null +++ b/.changeset/streaming-tool-call-args.md @@ -0,0 +1,5 @@ +--- +"@electric-ax/agents-runtime": minor +--- + +Add runtime support for streaming tool call arguments from Pi model events. diff --git a/packages/agents-runtime/src/entity-schema.ts b/packages/agents-runtime/src/entity-schema.ts index 7d70d3cef2..0e072659a4 100644 --- a/packages/agents-runtime/src/entity-schema.ts +++ b/packages/agents-runtime/src/entity-schema.ts @@ -171,12 +171,28 @@ type ToolCallValue = { run_id?: string tool_call_id?: string tool_name: string - status: `started` | `args_complete` | `executing` | `completed` | `failed` + status: + | `started` + | `args_streaming` + | `args_complete` + | `executing` + | `completed` + | `failed` args?: unknown + args_preview?: unknown result?: unknown error?: string duration_ms?: number } +type ToolArgDeltaValue = { + key?: string + tool_call_key: string + tool_call_id?: string + run_id?: string + seq: number + delta: string + content_index?: number +} type ReasoningValue = { key?: string status: `streaming` | `completed` @@ -502,18 +518,33 @@ function createToolCallSchema(): Schema { tool_name: z.string(), status: z.enum([ `started`, + `args_streaming`, `args_complete`, `executing`, `completed`, `failed`, ]), args: z.unknown().optional(), + args_preview: z.unknown().optional(), result: z.unknown().optional(), error: z.string().optional(), duration_ms: z.number().int().optional(), }) } +function createToolArgDeltaSchema(): Schema { + return z.object({ + key: z.string().optional(), + ...timelineOrderField, + tool_call_key: z.string(), + tool_call_id: z.string().optional(), + run_id: z.string().optional(), + seq: z.number().int(), + delta: z.string(), + content_index: z.number().int().optional(), + }) +} + function createReasoningSchema(): Schema { return z.object({ key: z.string().optional(), @@ -848,6 +879,7 @@ export type Step = SequencedPersistedRow export type Text = SequencedPersistedRow export type TextDelta = SequencedPersistedRow export type ToolCall = SequencedPersistedRow +export type ToolArgDelta = SequencedPersistedRow export type Reasoning = SequencedPersistedRow export type ErrorEvent = SequencedPersistedRow export type MessageReceived = SequencedPersistedRow @@ -961,6 +993,8 @@ export const BUILT_IN_EVENT_SCHEMAS = { text_delta: createTextDeltaSchema() as unknown as BuiltInEntitySchema, tool_call: createToolCallSchema() as unknown as BuiltInEntitySchema, + tool_arg_delta: + createToolArgDeltaSchema() as unknown as BuiltInEntitySchema, reasoning: createReasoningSchema() as unknown as BuiltInEntitySchema, error: createErrorEventSchema() as unknown as BuiltInEntitySchema, @@ -997,6 +1031,7 @@ type EntityCollectionsDefinition = { texts: CollectionDefinition textDeltas: CollectionDefinition toolCalls: CollectionDefinition + toolArgDeltas: CollectionDefinition reasoning: CollectionDefinition errors: CollectionDefinition inbox: CollectionDefinition @@ -1045,6 +1080,12 @@ export const builtInCollections: EntityCollectionsDefinition = { type: `tool_call`, primaryKey: `key`, }, + toolArgDeltas: { + schema: + BUILT_IN_EVENT_SCHEMAS.tool_arg_delta as StandardSchemaV1, + type: `tool_arg_delta`, + primaryKey: `key`, + }, reasoning: { schema: BUILT_IN_EVENT_SCHEMAS.reasoning as StandardSchemaV1, type: `reasoning`, diff --git a/packages/agents-runtime/src/entity-timeline.ts b/packages/agents-runtime/src/entity-timeline.ts index 0520982298..295ec24c73 100644 --- a/packages/agents-runtime/src/entity-timeline.ts +++ b/packages/agents-runtime/src/entity-timeline.ts @@ -38,7 +38,13 @@ export type EntityTimelineContentItem = toolCallId: string toolName: string args: Record - status: `started` | `args_complete` | `executing` | `completed` | `failed` + status: + | `started` + | `args_streaming` + | `args_complete` + | `executing` + | `completed` + | `failed` result?: string error?: string isError: boolean @@ -89,7 +95,13 @@ export interface IncludesToolCall { run_id: string order: TimelineOrder tool_name: string - status: `started` | `args_complete` | `executing` | `completed` | `failed` + status: + | `started` + | `args_streaming` + | `args_complete` + | `executing` + | `completed` + | `failed` args?: unknown result?: unknown error?: string @@ -202,7 +214,13 @@ export interface EntityTimelineToolCallItem { order: TimelineOrder tool_call_id?: string tool_name: string - status: `started` | `args_complete` | `executing` | `completed` | `failed` + status: + | `started` + | `args_streaming` + | `args_complete` + | `executing` + | `completed` + | `failed` args?: unknown result?: unknown error?: string diff --git a/packages/agents-runtime/src/outbound-bridge.ts b/packages/agents-runtime/src/outbound-bridge.ts index 2c81851df1..94995d95de 100644 --- a/packages/agents-runtime/src/outbound-bridge.ts +++ b/packages/agents-runtime/src/outbound-bridge.ts @@ -110,6 +110,18 @@ export interface OutboundBridge { onTextStart: () => void onTextDelta: (delta: string) => void onTextEnd: () => void + onToolCallArgsStart( + toolCallId: string, + name: string, + argsPreview?: unknown + ): void + onToolCallArgsDelta( + toolCallId: string, + name: string, + delta: string, + opts?: { contentIndex?: number; argsPreview?: unknown } + ): void + onToolCallArgsEnd(toolCallId: string, name: string, args: unknown): void onToolCallStart(toolCallId: string, name: string, args: unknown): void onToolCallStart(name: string, args: unknown): void onToolCallEnd( @@ -154,7 +166,7 @@ export function createOutboundBridge( let currentTextRunKey: string | null = null const toolCallsById = new Map< string, - { key: string; runKey: string; args: unknown } + { key: string; runKey: string; args: unknown; argSeq: number } >() const legacyToolCallIdsByName = new Map>() const requireActiveRun = (action: string): string => { @@ -165,6 +177,65 @@ export function createOutboundBridge( } return currentRunKey } + const ensureToolCall = ( + toolCallId: string, + name: string, + opts?: { + args?: unknown + argsPreview?: unknown + status?: `started` | `args_streaming` | `args_complete` | `executing` + } + ): { key: string; runKey: string; args: unknown; argSeq: number } => { + const runKey = requireActiveRun(`ensureToolCall`) + const existing = toolCallsById.get(toolCallId) + if (existing) { + if (opts && (`args` in opts || `argsPreview` in opts || opts.status)) { + const nextArgs = `args` in opts ? opts.args : existing.args + if (`args` in opts) existing.args = opts.args + writeEvent( + entityStateSchema.toolCalls.update({ + key: existing.key, + value: { + tool_call_id: toolCallId, + tool_name: name, + status: opts.status ?? `args_streaming`, + args: nextArgs, + ...(opts.argsPreview !== undefined && { + args_preview: opts.argsPreview, + }), + run_id: existing.runKey, + } as never, + }) as ChangeEvent + ) + } + return existing + } + const key = `tc-${counters.tc++}` + persistSeed() + const created = { + key, + runKey, + args: opts && `args` in opts ? opts.args : undefined, + argSeq: 0, + } + toolCallsById.set(toolCallId, created) + writeEvent( + entityStateSchema.toolCalls.insert({ + key, + value: { + tool_call_id: toolCallId, + tool_name: name, + status: opts?.status ?? `started`, + args: created.args, + ...(opts?.argsPreview !== undefined && { + args_preview: opts.argsPreview, + }), + run_id: runKey, + } as never, + }) as ChangeEvent + ) + return created + } return { onRunStart() { @@ -277,15 +348,61 @@ export function createOutboundBridge( ) }, + onToolCallArgsStart( + toolCallId: string, + name: string, + argsPreview?: unknown + ) { + ensureToolCall(toolCallId, name, { + status: `args_streaming`, + argsPreview, + }) + }, + + onToolCallArgsDelta( + toolCallId: string, + name: string, + delta: string, + opts?: { contentIndex?: number; argsPreview?: unknown } + ) { + const toolCall = + toolCallsById.get(toolCallId) ?? + ensureToolCall(toolCallId, name, { + status: `args_streaming`, + argsPreview: opts?.argsPreview, + }) + const seq = toolCall.argSeq++ + writeEvent( + entityStateSchema.toolArgDeltas.insert({ + key: `${toolCall.key}:args-${seq}`, + value: { + tool_call_key: toolCall.key, + tool_call_id: toolCallId, + run_id: toolCall.runKey, + seq, + delta, + ...(opts?.contentIndex !== undefined && { + content_index: opts.contentIndex, + }), + } as never, + }) as ChangeEvent + ) + }, + + onToolCallArgsEnd(toolCallId: string, name: string, args: unknown) { + ensureToolCall(toolCallId, name, { + status: `args_complete`, + args, + }) + }, + onToolCallStart( toolCallIdOrName: string, nameOrArgs: string | unknown, maybeArgs?: unknown ) { - const runKey = requireActiveRun(`onToolCallStart`) - const key = `tc-${counters.tc++}` const legacyCall = maybeArgs === undefined - const toolCallId = legacyCall ? key : toolCallIdOrName + const toolCallId = legacyCall ? `tc-${counters.tc}` : toolCallIdOrName const name = legacyCall ? toolCallIdOrName : (nameOrArgs as string) const args = legacyCall ? nameOrArgs : maybeArgs if (legacyCall) { @@ -293,20 +410,10 @@ export function createOutboundBridge( ids.push(toolCallId) legacyToolCallIdsByName.set(name, ids) } - persistSeed() - toolCallsById.set(toolCallId, { key, runKey, args }) - writeEvent( - entityStateSchema.toolCalls.insert({ - key, - value: { - tool_call_id: toolCallId, - tool_name: name, - status: `started`, - args, - run_id: runKey, - } as never, - }) as ChangeEvent - ) + ensureToolCall(toolCallId, name, { + status: `executing`, + args, + }) }, onToolCallEnd( diff --git a/packages/agents-runtime/src/pi-adapter.ts b/packages/agents-runtime/src/pi-adapter.ts index 71c4d0f99d..02918e995e 100644 --- a/packages/agents-runtime/src/pi-adapter.ts +++ b/packages/agents-runtime/src/pi-adapter.ts @@ -17,7 +17,6 @@ import type { ChangeEvent } from '@durable-streams/state' import type { AgentEvent, AgentMessage, - AgentTool, StreamFn, } from '@mariozechner/pi-agent-core' import type { @@ -26,7 +25,12 @@ import type { Provider, SimpleStreamOptions, } from '@mariozechner/pi-ai' -import type { LLMContentBlock, LLMMessage, LLMMessageContent } from './types' +import type { + AgentTool, + LLMContentBlock, + LLMMessage, + LLMMessageContent, +} from './types' // ============================================================================ // Options @@ -284,7 +288,24 @@ export function createPiAgentAdapter( case `message_update`: { const assistantEvent = (event as Record) .assistantMessageEvent as - | { type: string; delta?: string } + | { + type: string + contentIndex?: number + delta?: string + toolCall?: { + id?: string + name?: string + arguments?: Record + } + partial?: { + content?: Array<{ + type?: string + id?: string + name?: string + arguments?: Record + }> + } + } | undefined if (assistantEvent?.type === `text_delta`) { if (!textStarted) { @@ -293,6 +314,61 @@ export function createPiAgentAdapter( } bridge.onTextDelta(assistantEvent.delta ?? ``) textDeltaCount++ + } else if ( + assistantEvent?.type === `toolcall_start` || + assistantEvent?.type === `toolcall_delta` || + assistantEvent?.type === `toolcall_end` + ) { + const contentIndex = assistantEvent.contentIndex + const partialToolCall = + typeof contentIndex === `number` + ? assistantEvent.partial?.content?.[contentIndex] + : undefined + const toolCall = assistantEvent.toolCall ?? partialToolCall + const toolCallId = toolCall?.id + const toolName = toolCall?.name + const argsPreview = toolCall?.arguments + if (toolCallId && toolName) { + if (assistantEvent.type === `toolcall_start`) { + bridge.onToolCallArgsStart( + toolCallId, + toolName, + argsPreview + ) + } else if (assistantEvent.type === `toolcall_delta`) { + const delta = assistantEvent.delta ?? `` + bridge.onToolCallArgsDelta(toolCallId, toolName, delta, { + contentIndex, + argsPreview, + }) + const tool = opts.tools.find( + (candidate) => candidate.name === toolName + ) + if (tool?.onArgsDelta) { + void Promise.resolve( + tool.onArgsDelta({ + toolCallId, + toolName, + contentIndex, + delta, + argsPreview, + }) + ).catch((error) => { + runtimeLog.warn( + logPrefix, + `streaming tool arg hook failed for ${toolName}:`, + error + ) + }) + } + } else { + bridge.onToolCallArgsEnd( + toolCallId, + toolName, + argsPreview + ) + } + } } else { runtimeLog.debug( logPrefix, diff --git a/packages/agents-runtime/src/types.ts b/packages/agents-runtime/src/types.ts index 6ab10402f5..8fec5fe62e 100644 --- a/packages/agents-runtime/src/types.ts +++ b/packages/agents-runtime/src/types.ts @@ -385,6 +385,7 @@ export type TimelineItem = error: string | null status: | `started` + | `args_streaming` | `args_complete` | `executing` | `completed` @@ -877,7 +878,20 @@ export type AgentRunResult = { usage: { tokens: number; duration: number } } -export type AgentTool = PiAgentTool +export interface ToolArgumentDeltaContext { + toolCallId: string + toolName: string + contentIndex?: number + delta: string + argsPreview?: unknown +} + +export type AgentTool = PiAgentTool & { + onArgsDelta?: ( + context: ToolArgumentDeltaContext, + signal?: AbortSignal + ) => Promise | void +} export type AgentModel = string | Model export interface AgentConfig { diff --git a/packages/agents-runtime/test/outbound-bridge.test.ts b/packages/agents-runtime/test/outbound-bridge.test.ts index 0b8094b0ca..e7344d5c27 100644 --- a/packages/agents-runtime/test/outbound-bridge.test.ts +++ b/packages/agents-runtime/test/outbound-bridge.test.ts @@ -97,10 +97,64 @@ describe(`createOutboundBridge`, () => { expect((writes[1]!.value as Record).tool_name).toBe( `search` ) - expect((writes[1]!.value as Record).status).toBe(`started`) + expect((writes[1]!.value as Record).status).toBe( + `executing` + ) expect((writes[1]!.value as Record).run_id).toBe(`run-0`) }) + it(`persists streaming tool call argument deltas`, () => { + const writes: Array = [] + const bridge = createOutboundBridge([], (e) => { + writes.push(e) + }) + + bridge.onRunStart() + bridge.onToolCallArgsStart(`call-draft`, `draft`, { text: `He` }) + bridge.onToolCallArgsDelta(`call-draft`, `draft`, `llo`, { + contentIndex: 1, + argsPreview: { text: `Hello` }, + }) + bridge.onToolCallArgsEnd(`call-draft`, `draft`, { text: `Hello` }) + + expect(writes[1]).toMatchObject({ + type: `tool_call`, + key: `tc-0`, + headers: { operation: `insert` }, + value: { + tool_call_id: `call-draft`, + tool_name: `draft`, + status: `args_streaming`, + args_preview: { text: `He` }, + run_id: `run-0`, + }, + }) + expect(writes[2]).toMatchObject({ + type: `tool_arg_delta`, + key: `tc-0:args-0`, + value: { + tool_call_key: `tc-0`, + tool_call_id: `call-draft`, + run_id: `run-0`, + seq: 0, + delta: `llo`, + content_index: 1, + }, + }) + expect(writes[3]).toMatchObject({ + type: `tool_call`, + key: `tc-0`, + headers: { operation: `update` }, + value: { + tool_call_id: `call-draft`, + tool_name: `draft`, + status: `args_complete`, + args: { text: `Hello` }, + run_id: `run-0`, + }, + }) + }) + it(`maps tool_call_end to tool_call update with result`, () => { const writes: Array = [] const bridge = createOutboundBridge([], (e) => { diff --git a/packages/agents-runtime/test/pi-adapter.test.ts b/packages/agents-runtime/test/pi-adapter.test.ts index 3c33b6cb71..75b94d42c5 100644 --- a/packages/agents-runtime/test/pi-adapter.test.ts +++ b/packages/agents-runtime/test/pi-adapter.test.ts @@ -241,6 +241,141 @@ describe(`createPiAgentAdapter`, () => { ) }) + it(`dispatches streamed tool call arguments to the bridge and tool hook`, async () => { + let streamReadyResolve: + | ((stream: ReturnType) => void) + | null = null + const streamReady = new Promise< + ReturnType + >((resolve) => { + streamReadyResolve = resolve + }) + const partialMessage: AssistantMessage = { + role: `assistant`, + content: [ + { + type: `toolCall`, + id: `call-draft`, + name: `draft`, + arguments: { text: `Hello` }, + }, + ], + api: `anthropic-messages`, + provider: `anthropic`, + model: `claude-sonnet-4-5-20250929`, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + total: 0, + }, + }, + stopReason: `toolUse`, + timestamp: Date.now(), + } + const completedMessage: AssistantMessage = { + ...partialMessage, + content: [{ type: `text`, text: `` }], + stopReason: `stop`, + } + const argDeltas: Array = [] + const events: Array = [] + const factory = createPiAgentAdapter({ + systemPrompt: `Test system prompt`, + model: `claude-sonnet-4-5-20250929`, + tools: [ + { + name: `draft`, + label: `Draft`, + description: `Draft text`, + parameters: { + type: `object`, + properties: { text: { type: `string` } }, + required: [`text`], + } as never, + onArgsDelta: (context) => { + argDeltas.push(context) + }, + execute: async () => ({ + content: [{ type: `text`, text: `ok` }], + details: null, + }), + }, + ], + streamFn: () => { + const stream = createAssistantMessageEventStream() + streamReadyResolve?.(stream) + return stream + }, + }) + const handle = factory({ + entityUrl: `test/entity-1`, + epoch: 1, + messages: [], + outboundIdSeed: { run: 0, step: 0, msg: 0, tc: 0 }, + writeEvent: (event: ChangeEvent) => { + events.push(event) + }, + }) + + const runPromise = handle.run(`hello`) + const stream = await streamReady + stream.push({ + type: `start`, + partial: partialMessage, + }) + stream.push({ + type: `toolcall_start`, + contentIndex: 0, + partial: partialMessage, + }) + stream.push({ + type: `toolcall_delta`, + contentIndex: 0, + delta: `"Hello"`, + partial: partialMessage, + }) + stream.push({ + type: `toolcall_end`, + contentIndex: 0, + toolCall: partialMessage.content[0] as never, + partial: partialMessage, + }) + stream.push({ + type: `done`, + reason: `stop`, + message: completedMessage, + }) + await runPromise + + expect(argDeltas).toEqual([ + { + toolCallId: `call-draft`, + toolName: `draft`, + contentIndex: 0, + delta: `"Hello"`, + argsPreview: { text: `Hello` }, + }, + ]) + expect(events).toContainEqual( + expect.objectContaining({ + type: `tool_arg_delta`, + value: expect.objectContaining({ + tool_call_id: `call-draft`, + delta: `"Hello"`, + content_index: 0, + }), + }) + ) + }) + it(`isRunning returns false initially`, () => { const factory = createPiAgentAdapter({ systemPrompt: `Test system prompt`, From 636d7169f734c8dfe0bffcf395917c16523d5861 Mon Sep 17 00:00:00 2001 From: Sam Willis Date: Wed, 10 Jun 2026 19:00:01 +0100 Subject: [PATCH 2/2] Preserve legacy tool call start status --- packages/agents-runtime/src/outbound-bridge.ts | 3 ++- packages/agents-runtime/test/outbound-bridge.test.ts | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/agents-runtime/src/outbound-bridge.ts b/packages/agents-runtime/src/outbound-bridge.ts index 94995d95de..58b9be83be 100644 --- a/packages/agents-runtime/src/outbound-bridge.ts +++ b/packages/agents-runtime/src/outbound-bridge.ts @@ -410,8 +410,9 @@ export function createOutboundBridge( ids.push(toolCallId) legacyToolCallIdsByName.set(name, ids) } + const existing = toolCallsById.has(toolCallId) ensureToolCall(toolCallId, name, { - status: `executing`, + status: existing ? `executing` : `started`, args, }) }, diff --git a/packages/agents-runtime/test/outbound-bridge.test.ts b/packages/agents-runtime/test/outbound-bridge.test.ts index e7344d5c27..954f196e97 100644 --- a/packages/agents-runtime/test/outbound-bridge.test.ts +++ b/packages/agents-runtime/test/outbound-bridge.test.ts @@ -97,9 +97,7 @@ describe(`createOutboundBridge`, () => { expect((writes[1]!.value as Record).tool_name).toBe( `search` ) - expect((writes[1]!.value as Record).status).toBe( - `executing` - ) + expect((writes[1]!.value as Record).status).toBe(`started`) expect((writes[1]!.value as Record).run_id).toBe(`run-0`) })