Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/function-callable-methods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": minor
---

Add a function-wrapper form for callable methods, allowing class fields like `greet = callable(fn, metadata)` in addition to the existing `@callable()` decorator.
2 changes: 1 addition & 1 deletion docs/agent-class.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class MyAgent extends Agent {

### RPC and Callable Methods

`agents` take Durable Objects RPC one step forward by implementing RPC through WebSockets, so clients can also call methods on the Agent directly. To make a method callable through WS, developers can use the `@callable` decorator. Methods can return a serializable value or a stream (when using `@callable({ stream: true })`).
`agents` take Durable Objects RPC one step forward by implementing RPC through WebSockets, so clients can also call methods on the Agent directly. To make a method callable through WS, developers can use the `@callable` decorator. Methods can return a serializable value or a stream (when using `@callable({ streaming: true })`).

```ts
class MyAgent extends Agent {
Expand Down
24 changes: 24 additions & 0 deletions docs/callable-methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ export class CounterAgent extends Agent<Env, State> {
}
```

You can also wrap a function-valued class field with `callable()` if you do not
want to use decorators:

```typescript
import { Agent, callable, type StreamingResponse } from "agents";

export class CounterAgent extends Agent<Env, State> {
increment = callable((): number => {
this.setState({ ...this.state, count: this.state.count + 1 });
return this.state.count;
});

streamNumbers = callable(
async (stream: StreamingResponse, count: number): Promise<void> => {
for (let i = 0; i < count; i++) {
stream.send(i);
}
stream.end();
},
{ streaming: true, description: "Stream numbers" }
);
}
```

### Calling from the Client

There are two ways to call methods from the client:
Expand Down
140 changes: 117 additions & 23 deletions packages/agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ export type CallableMetadata = {
};

const callableMetadata = new WeakMap<Function, CallableMetadata>();
const agentContextWrappedMethods = new WeakSet<Function>();

/**
* Error class for SQL execution failures, containing the query that failed
Expand Down Expand Up @@ -452,20 +453,52 @@ export type SubAgentStub<T extends Agent> = {
: never;
};

type CallableDecorator = <This, Args extends unknown[], Return>(
target: (this: This, ...args: Args) => Return,
context: ClassMethodDecoratorContext
) => (this: This, ...args: Args) => Return;

function markCallable<T extends Function>(
target: T,
metadata: CallableMetadata
): T {
if (!callableMetadata.has(target)) {
callableMetadata.set(target, metadata);
}

return target;
}

/**
* Decorator that marks a method as callable by clients
* Marks a method as callable by clients.
* Use as a TC39 method decorator (`@callable()`) or as a function wrapper for
* class fields (`method = callable(fn, metadata)`).
* @param metadata Optional metadata about the callable method
*/
export function callable(metadata: CallableMetadata = {}) {
return function callableDecorator<This, Args extends unknown[], Return>(
target: (this: This, ...args: Args) => Return,
export function callable(metadata?: CallableMetadata): CallableDecorator;
export function callable<This, Args extends unknown[], Return>(
target: (this: This, ...args: Args) => Return,
metadata?: CallableMetadata
): (this: This, ...args: Args) => Return;
export function callable<This, Args extends unknown[], Return>(
metadataOrTarget:
| CallableMetadata
| ((this: This, ...args: Args) => Return) = {},
metadata: CallableMetadata = {}
): CallableDecorator | ((this: This, ...args: Args) => Return) {
if (typeof metadataOrTarget === "function") {
return markCallable(metadataOrTarget, metadata);
}

return function callableDecorator<
ThisDecorator,
ArgsDecorator extends unknown[],
ReturnDecorator
>(
target: (this: ThisDecorator, ...args: ArgsDecorator) => ReturnDecorator,
_context: ClassMethodDecoratorContext
) {
if (!callableMetadata.has(target)) {
callableMetadata.set(target, metadata);
}

return target;
return markCallable(target, metadataOrTarget);
};
}

Expand Down Expand Up @@ -1360,7 +1393,14 @@ function withAgentContext<T extends (...args: any[]) => any>(
this: Agent<Cloudflare.Env, unknown>,
...args: Parameters<T>
) => ReturnType<T> {
return function (...args: Parameters<T>): ReturnType<T> {
if (agentContextWrappedMethods.has(method)) {
return method;
}

const wrapped = function (
this: Agent<Cloudflare.Env, unknown>,
...args: Parameters<T>
): ReturnType<T> {
const { agent } = getCurrentAgent();

if (agent === this) {
Expand All @@ -1381,6 +1421,9 @@ function withAgentContext<T extends (...args: any[]) => any>(
}
);
};
agentContextWrappedMethods.add(wrapped);

return wrapped;
}

/**
Expand Down Expand Up @@ -2136,6 +2179,7 @@ export class Agent<
if (isRPCRequest(parsed)) {
try {
const { id, method, args } = parsed;
this._ensureOwnCallableMethodsWrapped();

// Check if method exists and is callable
const methodFn = this[method as keyof this];
Expand Down Expand Up @@ -3126,6 +3170,46 @@ export class Agent<
}
}

private _wrapCallableMethod<T extends Function>(method: T): T {
if (agentContextWrappedMethods.has(method)) {
return method;
}

const metadata = callableMetadata.get(method);

/* oxlint-disable @typescript-eslint/no-explicit-any -- dynamic method wrapping requires any */
const wrappedFunction = withAgentContext(
method as unknown as (...args: any[]) => any
) as unknown as T;
/* oxlint-enable @typescript-eslint/no-explicit-any */

if (metadata) {
callableMetadata.set(wrappedFunction, metadata);
}

return wrappedFunction;
}

private _ensureOwnCallableMethodsWrapped() {
for (const methodName of Object.getOwnPropertyNames(this)) {
const descriptor = Object.getOwnPropertyDescriptor(this, methodName);
if (
!descriptor ||
!!descriptor.get ||
typeof descriptor.value !== "function" ||
!callableMetadata.has(descriptor.value) ||
agentContextWrappedMethods.has(descriptor.value)
) {
continue;
}

Object.defineProperty(this, methodName, {
...descriptor,
value: this._wrapCallableMethod(descriptor.value as Function)
});
}
}

/**
* Automatically wrap custom methods with agent context
* This ensures getCurrentAgent() works in all custom methods without decorators
Expand Down Expand Up @@ -3165,19 +3249,12 @@ export class Agent<

// Now, methodName is confirmed to be a custom method/function
// Wrap the custom method with context
/* oxlint-disable @typescript-eslint/no-explicit-any -- dynamic method wrapping requires any */
const wrappedFunction = withAgentContext(
this[methodName as keyof this] as (...args: any[]) => any
) as any;
/* oxlint-enable @typescript-eslint/no-explicit-any */

// if the method is callable, copy the metadata from the original method
if (this._isCallable(methodName)) {
callableMetadata.set(
wrappedFunction,
callableMetadata.get(this[methodName as keyof this] as Function)!
);
}
const wrappedFunction = this._isCallable(methodName)
? this._wrapCallableMethod(this[methodName as keyof this] as Function)
: withAgentContext(
/* oxlint-disable-next-line @typescript-eslint/no-explicit-any -- dynamic method wrapping requires any */
this[methodName as keyof this] as (...args: any[]) => any
);

// set the wrapped function on the prototype
this.constructor.prototype[methodName as keyof this] = wrappedFunction;
Expand Down Expand Up @@ -9353,8 +9430,25 @@ export class Agent<
* @returns A map of method names to their metadata
*/
getCallableMethods(): Map<string, CallableMetadata> {
this._ensureOwnCallableMethodsWrapped();
const result = new Map<string, CallableMetadata>();

for (const name of Object.getOwnPropertyNames(this)) {
const descriptor = Object.getOwnPropertyDescriptor(this, name);
if (
!descriptor ||
!!descriptor.get ||
typeof descriptor.value !== "function"
) {
continue;
}

const meta = callableMetadata.get(descriptor.value as Function);
if (meta) {
result.set(name, meta);
}
}

// Walk the entire prototype chain to find callable methods from parent classes
let prototype = Object.getPrototypeOf(this);
while (prototype && prototype !== Object.prototype) {
Expand Down
22 changes: 22 additions & 0 deletions packages/agents/src/tests-d/agent-client-stub.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ class MyAgent extends Agent<typeof env, { count: number; name: string }> {
return `Hello, ${name ?? "World"}!`;
}

functionHello = callable((name?: string): string => {
return `Hello, ${name ?? "World"}!`;
});

functionWithThis = callable(function (this: MyAgent, amount: number): number {
return this.state.count + amount;
});

functionStream = callable(
async (_stream: StreamingResponse, _prompt: string): Promise<void> => {},
{ streaming: true }
);

@callable()
async perform(_task: string, _p1?: number): Promise<void> {}

Expand All @@ -124,6 +137,15 @@ const typedClient = new AgentClient<MyAgent>({
typedClient.stub.sayHello() satisfies Promise<string>;
// @ts-expect-error first argument is not a string
await typedClient.stub.sayHello(1);
typedClient.stub.functionHello("Ada") satisfies Promise<string>;
// @ts-expect-error first argument is not a string
await typedClient.stub.functionHello(1);
typedClient.stub.functionWithThis(5) satisfies Promise<number>;
// @ts-expect-error requires parameters
await typedClient.stub.functionWithThis();
typedClient.stub.functionStream("prompt") satisfies Promise<void>;
// @ts-expect-error StreamingResponse is injected server-side
await typedClient.stub.functionStream(undefined, "prompt");
await typedClient.stub.perform("some task", 1);
await typedClient.stub.perform("another task");
// @ts-expect-error requires parameters
Expand Down
42 changes: 41 additions & 1 deletion packages/agents/src/tests/agents/callable.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Agent, callable } from "../../index.ts";
import { Agent, callable, getCurrentAgent } from "../../index.ts";
import type { StreamingResponse } from "../../index.ts";

// Test Agent for @callable decorator tests
Expand All @@ -8,6 +8,46 @@ export class TestCallableAgent extends Agent<
> {
initialState = { value: 0 };

functionAdd = callable(
(a: number, b: number): number => {
return a + b;
},
{ description: "Function property add" }
);

functionAsync = callable(async (value: string): Promise<string> => {
await new Promise((r) => setTimeout(r, 10));
return `async:${value}`;
});

functionUsesThis = callable(function (
this: TestCallableAgent,
value: number
): number {
this.setState({ value });
return this.state.value;
});

functionContext = callable(
(): {
hasAgent: boolean;
hasConnection: boolean;
} => {
const { agent, connection } = getCurrentAgent<TestCallableAgent>();
return { hasAgent: agent === this, hasConnection: !!connection };
}
);

functionStream = callable(
(stream: StreamingResponse, count: number) => {
for (let i = 0; i < count; i++) {
stream.send(`function-${i}`);
}
stream.end("function-complete");
},
{ streaming: true, description: "Function property stream" }
);

// Basic sync method
@callable()
add(a: number, b: number): number {
Expand Down
Loading
Loading