Skip to content

Commit d27da98

Browse files
committed
feat: introduce ToolOptions and update tool execution context
1 parent 11b9e71 commit d27da98

File tree

9 files changed

+137
-69
lines changed

9 files changed

+137
-69
lines changed

packages/typescript/ai-client/src/chat-client.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ import type {
1010
ToolCallPart,
1111
UIMessage,
1212
} from './types'
13-
import type { AnyClientTool, ModelMessage, StreamChunk } from '@tanstack/ai'
13+
import type {
14+
AnyClientTool,
15+
ModelMessage,
16+
StreamChunk,
17+
ToolOptions,
18+
} from '@tanstack/ai'
1419
import type { ConnectionAdapter } from './connection-adapters'
1520
import type { ChatClientEventEmitter } from './events'
1621

@@ -29,7 +34,7 @@ export class ChatClient<
2934
private clientToolsRef: { current: Map<string, AnyClientTool> }
3035
private currentStreamId: string | null = null
3136
private currentMessageId: string | null = null
32-
private context?: TContext
37+
private options: Partial<ToolOptions<TContext>>
3338

3439
private callbacksRef: {
3540
current: {
@@ -46,7 +51,7 @@ export class ChatClient<
4651
constructor(options: ChatClientOptions<TTools, TContext>) {
4752
this.uniqueId = options.id || this.generateUniqueId('chat')
4853
this.body = options.body || {}
49-
this.context = options.context
54+
this.options = { context: options.context }
5055
this.connection = options.connection
5156
this.events = new DefaultChatClientEventEmitter(this.uniqueId)
5257

@@ -140,7 +145,9 @@ export class ChatClient<
140145
const clientTool = this.clientToolsRef.current.get(args.toolName)
141146
if (clientTool?.execute) {
142147
try {
143-
const output = await clientTool.execute(args.input, this.context)
148+
const output = await clientTool.execute(args.input, {
149+
context: this.options.context,
150+
})
144151
await this.addToolResult({
145152
toolCallId: args.toolCallId,
146153
tool: args.toolName,
@@ -307,7 +314,9 @@ export class ChatClient<
307314
const bodyWithConversationId = {
308315
...this.body,
309316
conversationId: this.uniqueId,
310-
...(this.context !== undefined && { context: this.context }),
317+
...(this.options.context !== undefined && {
318+
context: this.options.context,
319+
}),
311320
}
312321

313322
// Connect and stream

packages/typescript/ai-client/tests/chat-client.test.ts

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
createThinkingChunks,
99
createToolCallChunks,
1010
} from './test-utils'
11+
import type { ToolOptions } from '@tanstack/ai'
1112
import type { UIMessage } from '../src/types'
1213

1314
describe('ChatClient', () => {
@@ -608,17 +609,19 @@ describe('ChatClient', () => {
608609
localStorage: mockStorage,
609610
}
610611

611-
const executeFn = vi.fn(async (_args: any, context?: unknown) => {
612-
const ctx = context as TestContext | undefined
613-
if (ctx) {
612+
const executeFn = vi.fn(
613+
async <TContext = unknown>(
614+
_args: any,
615+
options: ToolOptions<TContext>,
616+
) => {
617+
const ctx = options.context as TestContext
614618
ctx.localStorage.setItem(
615619
`pref_${ctx.userId}_${_args.key}`,
616620
_args.value,
617621
)
618622
return { success: true }
619-
}
620-
return { success: false }
621-
})
623+
},
624+
)
622625

623626
const toolDef = toolDefinition({
624627
name: 'savePreference',
@@ -658,7 +661,7 @@ describe('ChatClient', () => {
658661
expect(executeFn).toHaveBeenCalled()
659662
const lastCall = executeFn.mock.calls[0]
660663
expect(lastCall?.[0]).toEqual({ key: 'theme', value: 'dark' })
661-
expect(lastCall?.[1]).toEqual(testContext)
664+
expect(lastCall?.[1]).toEqual({ context: testContext })
662665

663666
// localStorage should have been called
664667
expect(mockStorage.setItem).toHaveBeenCalledWith('pref_123_theme', 'dark')
@@ -721,7 +724,10 @@ describe('ChatClient', () => {
721724
await client.sendMessage('Test')
722725

723726
// Tool should have been called without context
724-
expect(executeFn).toHaveBeenCalledWith({ value: 'test' }, undefined)
727+
expect(executeFn).toHaveBeenCalledWith(
728+
{ value: 'test' },
729+
{ context: undefined },
730+
)
725731
})
726732
})
727733
})

packages/typescript/ai/src/core/chat.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import type {
1616
StreamChunk,
1717
Tool,
1818
ToolCall,
19+
ToolOptions,
1920
} from '../types'
2021

2122
interface ChatEngineConfig<
@@ -45,7 +46,7 @@ class ChatEngine<
4546
private readonly streamId: string
4647
private readonly effectiveRequest?: Request | RequestInit
4748
private readonly effectiveSignal?: AbortSignal
48-
private readonly context?: TParams['context']
49+
private readonly options: Partial<ToolOptions<TParams['context']>>
4950

5051
private messages: Array<ModelMessage>
5152
private iterationCount = 0
@@ -76,7 +77,7 @@ class ChatEngine<
7677
? { signal: config.params.abortController.signal }
7778
: undefined
7879
this.effectiveSignal = config.params.abortController?.signal
79-
this.context = config.params.context
80+
this.options = { context: config.params.context }
8081
}
8182

8283
async *chat(): AsyncGenerator<StreamChunk> {
@@ -383,6 +384,7 @@ class ChatEngine<
383384
this.tools,
384385
approvals,
385386
clientToolResults,
387+
this.options,
386388
)
387389

388390
if (
@@ -451,7 +453,7 @@ class ChatEngine<
451453
this.tools,
452454
approvals,
453455
clientToolResults,
454-
this.context,
456+
this.options,
455457
)
456458

457459
if (

packages/typescript/ai/src/tools/tool-calls.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type {
33
ModelMessage,
44
Tool,
55
ToolCall,
6+
ToolOptions,
67
ToolResultStreamChunk,
78
} from '../types'
89

@@ -32,7 +33,7 @@ import type {
3233
*
3334
* // After stream completes, execute tools
3435
* if (manager.hasToolCalls()) {
35-
* const toolResults = yield* manager.executeTools(doneChunk);
36+
* const toolResults = yield* manager.executeTools(doneChunk, { context });
3637
* messages = [...messages, ...toolResults];
3738
* manager.clear();
3839
* }
@@ -110,7 +111,7 @@ export class ToolCallManager {
110111
*/
111112
async *executeTools(
112113
doneChunk: DoneStreamChunk,
113-
context?: unknown,
114+
options: Partial<ToolOptions<unknown>> = {},
114115
): AsyncGenerator<ToolResultStreamChunk, Array<ModelMessage>, void> {
115116
const toolCallsArray = this.getToolCalls()
116117
const toolResults: Array<ModelMessage> = []
@@ -142,8 +143,10 @@ export class ToolCallManager {
142143
}
143144
}
144145

145-
// Execute the tool with context if available
146-
let result = await tool.execute(args, context)
146+
// Execute the tool with options
147+
let result = await tool.execute(args, {
148+
context: options.context as any,
149+
})
147150

148151
// Validate output against outputSchema if provided
149152
if (tool.outputSchema && result !== undefined && result !== null) {
@@ -239,14 +242,14 @@ interface ExecuteToolCallsResult {
239242
* @param tools - Available tools with their configurations
240243
* @param approvals - Map of approval decisions (approval.id -> approved boolean)
241244
* @param clientResults - Map of client-side execution results (toolCallId -> result)
242-
* @param context - Optional context object to pass to tool execute functions
245+
* @param options - Options object containing context to pass to tool execute functions
243246
*/
244247
export async function executeToolCalls(
245248
toolCalls: Array<ToolCall>,
246249
tools: ReadonlyArray<Tool>,
247250
approvals: Map<string, boolean> = new Map(),
248251
clientResults: Map<string, any> = new Map(),
249-
context?: unknown,
252+
options: Partial<ToolOptions<unknown>> = {},
250253
): Promise<ExecuteToolCallsResult> {
251254
const results: Array<ToolResult> = []
252255
const needsApproval: Array<ApprovalRequest> = []
@@ -378,7 +381,9 @@ export async function executeToolCalls(
378381
// Execute after approval
379382
const startTime = Date.now()
380383
try {
381-
let result = await tool.execute(input, context)
384+
let result = await tool.execute(input, {
385+
context: options.context as any,
386+
})
382387
const duration = Date.now() - startTime
383388

384389
// Validate output against outputSchema if provided
@@ -436,7 +441,9 @@ export async function executeToolCalls(
436441
// CASE 3: Normal server tool - execute immediately
437442
const startTime = Date.now()
438443
try {
439-
let result = await tool.execute(input, context)
444+
let result = await tool.execute(input, {
445+
context: options.context as any,
446+
})
440447
const duration = Date.now() - startTime
441448

442449
// Validate output against outputSchema if provided

packages/typescript/ai/src/tools/tool-definition.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { z } from 'zod'
2-
import type { Tool } from '../types'
2+
import type { Tool, ToolOptions } from '../types'
33

44
/**
55
* Marker type for server-side tools
@@ -33,7 +33,7 @@ export interface ClientTool<
3333
metadata?: Record<string, any>
3434
execute?: (
3535
args: z.infer<TInput>,
36-
context?: TContext,
36+
options: ToolOptions<TContext>,
3737
) => Promise<z.infer<TOutput>> | z.infer<TOutput>
3838
}
3939

@@ -108,7 +108,7 @@ export interface ToolDefinition<
108108
server: <TContext = unknown>(
109109
execute: (
110110
args: z.infer<TInput>,
111-
context?: TContext,
111+
options: ToolOptions<TContext>,
112112
) => Promise<z.infer<TOutput>> | z.infer<TOutput>,
113113
) => ServerTool<TInput, TOutput, TName, TContext>
114114

@@ -118,7 +118,7 @@ export interface ToolDefinition<
118118
client: <TContext = unknown>(
119119
execute?: (
120120
args: z.infer<TInput>,
121-
context?: TContext,
121+
options: ToolOptions<TContext>,
122122
) => Promise<z.infer<TOutput>> | z.infer<TOutput>,
123123
) => ClientTool<TInput, TOutput, TName, TContext>
124124
}
@@ -187,7 +187,7 @@ export function toolDefinition<
187187
server<TContext = unknown>(
188188
execute: (
189189
args: z.infer<TInput>,
190-
context?: TContext,
190+
options: ToolOptions<TContext>,
191191
) => Promise<z.infer<TOutput>> | z.infer<TOutput>,
192192
): ServerTool<TInput, TOutput, TName, TContext> {
193193
return {
@@ -201,7 +201,7 @@ export function toolDefinition<
201201
client<TContext = unknown>(
202202
execute?: (
203203
args: z.infer<TInput>,
204-
context?: TContext,
204+
options: ToolOptions<TContext>,
205205
) => Promise<z.infer<TOutput>> | z.infer<TOutput>,
206206
): ClientTool<TInput, TOutput, TName, TContext> {
207207
return {

packages/typescript/ai/src/types.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ export interface ToolCall {
1111
}
1212
}
1313

14+
/**
15+
* Options object passed to tool execute functions
16+
* @template TContext - The type of context object
17+
*/
18+
export interface ToolOptions<TContext = unknown> {
19+
context: TContext
20+
}
21+
1422
// ============================================================================
1523
// Multimodal Content Types
1624
// ============================================================================
@@ -324,7 +332,7 @@ export interface Tool<
324332
* Can return any value - will be automatically stringified if needed.
325333
*
326334
* @param args - The arguments parsed from the model's tool call (validated against inputSchema)
327-
* @param context - Optional context object passed from chat() options (if provided)
335+
* @param options - Optional options object passed from chat() options (if provided)
328336
* @returns Result to send back to the model (validated against outputSchema if provided)
329337
*
330338
* @example
@@ -335,14 +343,14 @@ export interface Tool<
335343
* }
336344
*
337345
* // With context:
338-
* execute: async (args, context) => {
339-
* const user = await context.db.users.find({ id: context.userId });
346+
* execute: async (args, options) => {
347+
* const user = await options.context.db.users.find({ id: options.context.userId });
340348
* return user;
341349
* }
342350
*/
343351
execute?: <TContext = unknown>(
344352
args: any,
345-
context?: TContext,
353+
options: ToolOptions<TContext>,
346354
) => Promise<any> | any
347355

348356
/** If true, tool execution requires user approval before running. Works with both server and client tools. */
@@ -529,9 +537,9 @@ export interface ChatOptions<
529537
* });
530538
*
531539
* // In tool definition:
532-
* const getUserData = getUserDataDef.server(async (args, context) => {
533-
* // context.userId and context.db are available
534-
* return await context.db.users.find({ userId: context.userId });
540+
* const getUserData = getUserDataDef.server(async (args, options) => {
541+
* // options.context.userId and options.context.db are available
542+
* return await options.context.db.users.find({ userId: options.context.userId });
535543
* });
536544
*/
537545
context?: TContext

0 commit comments

Comments
 (0)