Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/streaming-tool-call-args.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@electric-ax/agents-runtime": minor
---

Add runtime support for streaming tool call arguments from Pi model events.
43 changes: 42 additions & 1 deletion packages/agents-runtime/src/entity-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,28 @@ type ToolCallValue = {
run_id?: string
tool_call_id?: string
tool_name: string
status: `started` | `args_complete` | `executing` | `completed` | `failed`
status:
| `started`
| `args_streaming`
| `args_complete`
| `executing`
| `completed`
| `failed`
args?: unknown
args_preview?: unknown
result?: unknown
error?: string
duration_ms?: number
}
type ToolArgDeltaValue = {
key?: string
tool_call_key: string
tool_call_id?: string
run_id?: string
seq: number
delta: string
content_index?: number
}
type ReasoningValue = {
key?: string
status: `streaming` | `completed`
Expand Down Expand Up @@ -502,18 +518,33 @@ function createToolCallSchema(): Schema<ToolCallValue> {
tool_name: z.string(),
status: z.enum([
`started`,
`args_streaming`,
`args_complete`,
`executing`,
`completed`,
`failed`,
]),
args: z.unknown().optional(),
args_preview: z.unknown().optional(),
result: z.unknown().optional(),
error: z.string().optional(),
duration_ms: z.number().int().optional(),
})
}

function createToolArgDeltaSchema(): Schema<ToolArgDeltaValue> {
return z.object({
key: z.string().optional(),
...timelineOrderField,
tool_call_key: z.string(),
tool_call_id: z.string().optional(),
run_id: z.string().optional(),
seq: z.number().int(),
delta: z.string(),
content_index: z.number().int().optional(),
})
}

function createReasoningSchema(): Schema<ReasoningValue> {
return z.object({
key: z.string().optional(),
Expand Down Expand Up @@ -848,6 +879,7 @@ export type Step = SequencedPersistedRow<StepValue>
export type Text = SequencedPersistedRow<TextValue>
export type TextDelta = SequencedPersistedRow<TextDeltaValue>
export type ToolCall = SequencedPersistedRow<ToolCallValue>
export type ToolArgDelta = SequencedPersistedRow<ToolArgDeltaValue>
export type Reasoning = SequencedPersistedRow<ReasoningValue>
export type ErrorEvent = SequencedPersistedRow<ErrorEventValue>
export type MessageReceived = SequencedPersistedRow<MessageReceivedValue>
Expand Down Expand Up @@ -961,6 +993,8 @@ export const BUILT_IN_EVENT_SCHEMAS = {
text_delta:
createTextDeltaSchema() as unknown as BuiltInEntitySchema<TextDelta>,
tool_call: createToolCallSchema() as unknown as BuiltInEntitySchema<ToolCall>,
tool_arg_delta:
createToolArgDeltaSchema() as unknown as BuiltInEntitySchema<ToolArgDelta>,
reasoning:
createReasoningSchema() as unknown as BuiltInEntitySchema<Reasoning>,
error: createErrorEventSchema() as unknown as BuiltInEntitySchema<ErrorEvent>,
Expand Down Expand Up @@ -997,6 +1031,7 @@ type EntityCollectionsDefinition = {
texts: CollectionDefinition<Text>
textDeltas: CollectionDefinition<TextDelta>
toolCalls: CollectionDefinition<ToolCall>
toolArgDeltas: CollectionDefinition<ToolArgDelta>
reasoning: CollectionDefinition<Reasoning>
errors: CollectionDefinition<ErrorEvent>
inbox: CollectionDefinition<MessageReceived>
Expand Down Expand Up @@ -1045,6 +1080,12 @@ export const builtInCollections: EntityCollectionsDefinition = {
type: `tool_call`,
primaryKey: `key`,
},
toolArgDeltas: {
schema:
BUILT_IN_EVENT_SCHEMAS.tool_arg_delta as StandardSchemaV1<ToolArgDelta>,
type: `tool_arg_delta`,
primaryKey: `key`,
},
reasoning: {
schema: BUILT_IN_EVENT_SCHEMAS.reasoning as StandardSchemaV1<Reasoning>,
type: `reasoning`,
Expand Down
24 changes: 21 additions & 3 deletions packages/agents-runtime/src/entity-timeline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ export type EntityTimelineContentItem =
toolCallId: string
toolName: string
args: Record<string, unknown>
status: `started` | `args_complete` | `executing` | `completed` | `failed`
status:
| `started`
| `args_streaming`
| `args_complete`
| `executing`
| `completed`
| `failed`
result?: string
error?: string
isError: boolean
Expand Down Expand Up @@ -89,7 +95,13 @@ export interface IncludesToolCall {
run_id: string
order: TimelineOrder
tool_name: string
status: `started` | `args_complete` | `executing` | `completed` | `failed`
status:
| `started`
| `args_streaming`
| `args_complete`
| `executing`
| `completed`
| `failed`
args?: unknown
result?: unknown
error?: string
Expand Down Expand Up @@ -202,7 +214,13 @@ export interface EntityTimelineToolCallItem {
order: TimelineOrder
tool_call_id?: string
tool_name: string
status: `started` | `args_complete` | `executing` | `completed` | `failed`
status:
| `started`
| `args_streaming`
| `args_complete`
| `executing`
| `completed`
| `failed`
args?: unknown
result?: unknown
error?: string
Expand Down
144 changes: 126 additions & 18 deletions packages/agents-runtime/src/outbound-bridge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ export interface OutboundBridge {
onTextStart: () => void
onTextDelta: (delta: string) => void
onTextEnd: () => void
onToolCallArgsStart(
toolCallId: string,
name: string,
argsPreview?: unknown
): void
onToolCallArgsDelta(
toolCallId: string,
name: string,
delta: string,
opts?: { contentIndex?: number; argsPreview?: unknown }
): void
onToolCallArgsEnd(toolCallId: string, name: string, args: unknown): void
onToolCallStart(toolCallId: string, name: string, args: unknown): void
onToolCallStart(name: string, args: unknown): void
onToolCallEnd(
Expand Down Expand Up @@ -154,7 +166,7 @@ export function createOutboundBridge(
let currentTextRunKey: string | null = null
const toolCallsById = new Map<
string,
{ key: string; runKey: string; args: unknown }
{ key: string; runKey: string; args: unknown; argSeq: number }
>()
const legacyToolCallIdsByName = new Map<string, Array<string>>()
const requireActiveRun = (action: string): string => {
Expand All @@ -165,6 +177,65 @@ export function createOutboundBridge(
}
return currentRunKey
}
const ensureToolCall = (
toolCallId: string,
name: string,
opts?: {
args?: unknown
argsPreview?: unknown
status?: `started` | `args_streaming` | `args_complete` | `executing`
}
): { key: string; runKey: string; args: unknown; argSeq: number } => {
const runKey = requireActiveRun(`ensureToolCall`)
const existing = toolCallsById.get(toolCallId)
if (existing) {
if (opts && (`args` in opts || `argsPreview` in opts || opts.status)) {
const nextArgs = `args` in opts ? opts.args : existing.args
if (`args` in opts) existing.args = opts.args
writeEvent(
entityStateSchema.toolCalls.update({
key: existing.key,
value: {
tool_call_id: toolCallId,
tool_name: name,
status: opts.status ?? `args_streaming`,
args: nextArgs,
...(opts.argsPreview !== undefined && {
args_preview: opts.argsPreview,
}),
run_id: existing.runKey,
} as never,
}) as ChangeEvent
)
}
return existing
}
const key = `tc-${counters.tc++}`
persistSeed()
const created = {
key,
runKey,
args: opts && `args` in opts ? opts.args : undefined,
argSeq: 0,
}
toolCallsById.set(toolCallId, created)
writeEvent(
entityStateSchema.toolCalls.insert({
key,
value: {
tool_call_id: toolCallId,
tool_name: name,
status: opts?.status ?? `started`,
args: created.args,
...(opts?.argsPreview !== undefined && {
args_preview: opts.argsPreview,
}),
run_id: runKey,
} as never,
}) as ChangeEvent
)
return created
}

return {
onRunStart() {
Expand Down Expand Up @@ -277,36 +348,73 @@ export function createOutboundBridge(
)
},

onToolCallArgsStart(
toolCallId: string,
name: string,
argsPreview?: unknown
) {
ensureToolCall(toolCallId, name, {
status: `args_streaming`,
argsPreview,
})
},

onToolCallArgsDelta(
toolCallId: string,
name: string,
delta: string,
opts?: { contentIndex?: number; argsPreview?: unknown }
) {
const toolCall =
toolCallsById.get(toolCallId) ??
ensureToolCall(toolCallId, name, {
status: `args_streaming`,
argsPreview: opts?.argsPreview,
})
const seq = toolCall.argSeq++
writeEvent(
entityStateSchema.toolArgDeltas.insert({
key: `${toolCall.key}:args-${seq}`,
value: {
tool_call_key: toolCall.key,
tool_call_id: toolCallId,
run_id: toolCall.runKey,
seq,
delta,
...(opts?.contentIndex !== undefined && {
content_index: opts.contentIndex,
}),
} as never,
}) as ChangeEvent
)
},

onToolCallArgsEnd(toolCallId: string, name: string, args: unknown) {
ensureToolCall(toolCallId, name, {
status: `args_complete`,
args,
})
},

onToolCallStart(
toolCallIdOrName: string,
nameOrArgs: string | unknown,
maybeArgs?: unknown
) {
const runKey = requireActiveRun(`onToolCallStart`)
const key = `tc-${counters.tc++}`
const legacyCall = maybeArgs === undefined
const toolCallId = legacyCall ? key : toolCallIdOrName
const toolCallId = legacyCall ? `tc-${counters.tc}` : toolCallIdOrName
const name = legacyCall ? toolCallIdOrName : (nameOrArgs as string)
const args = legacyCall ? nameOrArgs : maybeArgs
if (legacyCall) {
const ids = legacyToolCallIdsByName.get(name) ?? []
ids.push(toolCallId)
legacyToolCallIdsByName.set(name, ids)
}
persistSeed()
toolCallsById.set(toolCallId, { key, runKey, args })
writeEvent(
entityStateSchema.toolCalls.insert({
key,
value: {
tool_call_id: toolCallId,
tool_name: name,
status: `started`,
args,
run_id: runKey,
} as never,
}) as ChangeEvent
)
const existing = toolCallsById.has(toolCallId)
ensureToolCall(toolCallId, name, {
status: existing ? `executing` : `started`,
args,
})
},

onToolCallEnd(
Expand Down
Loading
Loading