Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
GeneratedImageUpdated,
HiddenContextItem,
SDKHiddenContextItem,
StructuredInputItem,
Task,
TaskItem,
ThoughtTask,
Expand Down Expand Up @@ -552,14 +553,19 @@ def end_workflow(item: WorkflowItem):

if event.type == "run_item_stream_event":
event = event.item
if (
event.type == "tool_call_item"
and event.raw_item.type == "function_call"
):
current_tool_call = event.raw_item.call_id
current_item_id = event.raw_item.id
assert current_item_id
produced_items.add(current_item_id)
if event.type == "tool_call_item":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why these changes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newer versions of openai-agents allow the raw item to be a dict, so handling that here (pyright caught it)!

raw_item = event.raw_item
if isinstance(raw_item, dict):
if raw_item.get("type") == "function_call":
current_tool_call = event.call_id
current_item_id = raw_item.get("id")
assert current_item_id
produced_items.add(current_item_id)
elif raw_item.type == "function_call":
current_tool_call = event.call_id
current_item_id = raw_item.id
assert current_item_id
produced_items.add(current_item_id)
Comment on lines +556 to +568
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newer versions of openai-agents allow the raw item to be a dict, so handling that here (pyright caught it).

continue

if event.type != "raw_response_event":
Expand Down Expand Up @@ -964,6 +970,39 @@ async def task_to_input(
role="user",
)

async def structured_input_to_input(
self, item: StructuredInputItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
"""
Convert a StructuredInputItem into input item(s) to send to the model.
"""
lines = []
for question in item.inputs:
answer = question.answer
if answer is None:
lines.append(f"- {question.question}: unanswered")
elif answer.skipped:
lines.append(f"- {question.question}: skipped")
else:
lines.append(f"- {question.question}: {', '.join(answer.values)}")

text = (
"A structured input request was displayed to the user with the following "
f"status: {item.status}\n<StructuredInput>\n"
+ "\n".join(lines)
+ "\n</StructuredInput>"
)
return Message(
type="message",
content=[
ResponseInputTextParam(
type="input_text",
text=text,
),
],
role="user",
)

async def workflow_to_input(
self, item: WorkflowItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
Expand Down Expand Up @@ -1172,6 +1211,9 @@ async def _thread_item_to_input_item(
case TaskItem():
out = await self.task_to_input(item) or []
return out if isinstance(out, list) else [out]
case StructuredInputItem():
out = await self.structured_input_to_input(item) or []
return out if isinstance(out, list) else [out]
case HiddenContextItem():
out = await self.hidden_context_to_input(item) or []
return out if isinstance(out, list) else [out]
Expand Down
82 changes: 82 additions & 0 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
StreamingReq,
StreamOptions,
StreamOptionsEvent,
StructuredInputAnswer,
StructuredInputItem,
StructuredInputMultipleChoice,
StructuredInputSubmission,
SyncCustomActionResponse,
Thread,
ThreadCreatedEvent,
Expand All @@ -63,6 +67,7 @@
ThreadItemUpdatedEvent,
ThreadMetadata,
ThreadsAddClientToolOutputReq,
ThreadsAddStructuredInputReq,
ThreadsAddUserMessageReq,
ThreadsCreateReq,
ThreadsCustomActionReq,
Expand Down Expand Up @@ -616,6 +621,41 @@ async def _process_streaming_impl(
):
yield event

case ThreadsAddStructuredInputReq():
thread = await self.store.load_thread(
request.params.thread_id, context=context
)
item = await self.store.load_item(
request.params.thread_id,
request.params.item_id,
context=context,
)
if not isinstance(item, StructuredInputItem):
raise ValueError(
f"Item {request.params.item_id} is not a StructuredInputItem"
)

updated_item = await self._apply_structured_input_submission(
item,
request.params.input,
)

async def stream_structured_input_response() -> AsyncIterator[
ThreadStreamEvent
]:
# Keep this replacement inside _process_events so it is persisted
# through the same event pipeline as other thread item replacements.
yield ThreadItemReplacedEvent(item=updated_item)
async for event in self.respond(thread, None, context):
yield event

async for event in self._process_events(
thread,
context,
stream_structured_input_response,
):
yield event

case ThreadsRetryAfterItemReq():
thread_metadata = await self.store.load_thread(
request.params.thread_id, context=context
Expand Down Expand Up @@ -715,6 +755,48 @@ async def _process_sync_custom_action(
)
)

async def _apply_structured_input_submission(
self,
item: StructuredInputItem,
submission: StructuredInputSubmission,
) -> StructuredInputItem:
if item.status != "pending":
raise ValueError(f"Structured input item {item.id} is not pending")

updated_item = item.model_copy(deep=True)
questions_by_id = {
question.id: question for question in updated_item.inputs
}
submitted_question_ids = set(submission.answers)
unknown_question_ids = submitted_question_ids.difference(questions_by_id)
if unknown_question_ids:
unknown = ", ".join(sorted(unknown_question_ids))
logger.warning(
f"Structured input item {item.id} received unknown question id(s), ignoring: {unknown}"
)

for question in updated_item.inputs:
answer = submission.answers.get(question.id)

if (
submission.status == "skipped" or
answer is None or
answer.skipped or
not answer.values
):
question.answer = StructuredInputAnswer(skipped=True)
continue

values = answer.values
if isinstance(question, StructuredInputMultipleChoice):
if not question.multiple:
values = values[:1]

question.answer = StructuredInputAnswer(values=values)

updated_item.status = submission.status
return updated_item

async def _cleanup_pending_client_tool_call(
self, thread: ThreadMetadata, context: TContext
) -> None:
Expand Down
96 changes: 96 additions & 0 deletions chatkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,39 @@ class ThreadAddClientToolOutputParams(BaseModel):
result: Any


class StructuredInputAnswerSubmission(BaseModel):
"""Client-submitted answer for a structured input."""

values: list[str] = Field(default_factory=list)
"""Text answer values submitted for the structured input."""
skipped: bool = False
"""Whether this structured input answer was skipped."""


class StructuredInputSubmission(BaseModel):
"""Client-submitted answers for a pending structured input item."""

status: Literal["answered", "skipped"] = "answered"
"""Overall status for the structured input submission."""
answers: dict[str, StructuredInputAnswerSubmission] = Field(default_factory=dict)
"""Answers keyed by structured input id."""


class ThreadsAddStructuredInputReq(BaseReq):
"""Request to submit answers for a pending structured input item."""

type: Literal["threads.add_structured_input"] = "threads.add_structured_input"
params: ThreadAddStructuredInputParams


class ThreadAddStructuredInputParams(BaseModel):
"""Parameters for adding structured input to a thread."""

thread_id: str
item_id: str
input: StructuredInputSubmission


class ThreadsCustomActionReq(BaseReq):
"""Request to execute a custom action within a thread."""

Expand Down Expand Up @@ -272,6 +305,7 @@ class ThreadDeleteParams(BaseModel):
ThreadsCreateReq
| ThreadsAddUserMessageReq
| ThreadsAddClientToolOutputReq
| ThreadsAddStructuredInputReq
| ThreadsRetryAfterItemReq
| ThreadsCustomActionReq
)
Expand Down Expand Up @@ -308,6 +342,7 @@ def is_streaming_req(request: ChatKitReq) -> TypeIs[StreamingReq]:
ThreadsAddUserMessageReq,
ThreadsRetryAfterItemReq,
ThreadsAddClientToolOutputReq,
ThreadsAddStructuredInputReq,
ThreadsCustomActionReq,
),
)
Expand Down Expand Up @@ -660,6 +695,66 @@ class GeneratedImageItem(ThreadItemBase):
image: GeneratedImage | None = None


class StructuredInputAnswer(BaseModel):
"""Answer recorded for a structured input."""

values: list[str] = Field(default_factory=list)
"""Text answer values recorded for the structured input."""
skipped: bool = False
"""Whether this structured input answer was skipped."""


class StructuredInputBase(BaseModel):
"""Base fields shared by structured inputs."""

id: str
"""Stable id for this structured input."""
question: str
"""Question shown to the user."""
answer: StructuredInputAnswer | None = None
"""Answer recorded for this structured input, if available."""


class StructuredInputMultipleChoiceOption(BaseModel):
"""Option shown for a multiple-choice structured input."""

value: str
"""Text value submitted when this option is selected."""


class StructuredInputMultipleChoice(StructuredInputBase):
"""Structured input answered by choosing one or more options."""

type: Literal["multiple_choice"] = "multiple_choice"
options: list[StructuredInputMultipleChoiceOption]
"""Suggested choices to display before the freeform custom answer affordance."""
multiple: bool = False
"""Whether the user may submit more than one text value."""


class StructuredInputFreeform(StructuredInputBase):
"""Structured input answered with freeform text."""

type: Literal["freeform"] = "freeform"
description: str | None = None
"""Supporting text shown with this input."""


StructuredInput = Annotated[
StructuredInputMultipleChoice | StructuredInputFreeform,
Field(discriminator="type"),
]
"""Structured input variants supported by a structured input item."""


class StructuredInputItem(ThreadItemBase):
"""Thread item requesting structured input from the user."""

type: Literal["structured_input"] = "structured_input"
status: Literal["pending", "answered", "skipped"] = "pending"
inputs: list[StructuredInput]


class TaskItem(ThreadItemBase):
"""Thread item containing a task."""

Expand Down Expand Up @@ -706,6 +801,7 @@ class SDKHiddenContextItem(ThreadItemBase):
| ClientToolCallItem
| WidgetItem
| GeneratedImageItem
| StructuredInputItem
| WorkflowItem
| TaskItem
| HiddenContextItem
Expand Down
Loading
Loading