diff --git a/cuopt-agent/cuopt_agent/configs/config-deepagent.yml b/cuopt-agent/cuopt_agent/configs/config-deepagent.yml
index 2de65c9..910b620 100755
--- a/cuopt-agent/cuopt_agent/configs/config-deepagent.yml
+++ b/cuopt-agent/cuopt_agent/configs/config-deepagent.yml
@@ -32,6 +32,11 @@ general:
# api_key: ${LANGSMITH_API_KEY}
front_end:
_type: fastapi
+ # Suppress intermediate_data: SSE lines on /v1/chat/completions (API Catalog expects data: chunks only).
+ # Workflow/tool traces remain available via Phoenix tracing. Use step_adaptor mode custom/default
+ # only for interactive UIs that consume intermediate_data.
+ step_adaptor:
+ mode: off
endpoints:
- path: /health
method: GET
diff --git a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py
index 7f879c7..eda9bfd 100755
--- a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py
+++ b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py
@@ -13,24 +13,267 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
+import json
import logging
import os
import re
+import uuid
+from collections.abc import AsyncGenerator, AsyncIterator
+from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
+from typing import Literal
from nat.builder.builder import Builder
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.builder.function_info import FunctionInfo
from nat.cli.register_workflow import register_function
-from nat.data_models.api_server import ChatRequest, ChatRequestOrMessage, ChatResponse, Usage
+from nat.data_models.api_server import (
+ ChatRequest,
+ ChatRequestOrMessage,
+ ChatResponse,
+ ChatResponseChunk,
+ Usage,
+ UserMessageContentRoleType,
+)
from nat.data_models.component_ref import FunctionRef, LLMRef
from nat.data_models.function import FunctionBaseConfig
from nat.utils.type_converter import GlobalTypeConverter
-from pydantic import Field
+from pydantic import Field, PrivateAttr
logger = logging.getLogger(__name__)
+# Built via concat so file tooling does not strip XML-like tag literals.
+_THINKING_OPEN_TAG = "<" + "redacted_thinking" + ">"
+_THINKING_CLOSE_TAG = "" + "redacted_thinking" + ">"
+_StreamKind = Literal["content", "reasoning"]
+
+_DEFAULT_STRIP_REASONING_PATTERN = (
+ rf"{_THINKING_OPEN_TAG}.*?{_THINKING_CLOSE_TAG}\s*|{_THINKING_OPEN_TAG}.*"
+)
+
+# Streaming tuning (not NAT workflow YAML keys — adjust here, not in config-deepagent.yml).
+_STREAM_MAX_SEGMENT_CHARS = 48
+_STREAM_PROGRESS_UPDATES = True
+
+
+def _split_partial_marker_suffix(text: str, marker: str) -> tuple[str, str]:
+ """Split *text* so a trailing prefix of *marker* is held back (incomplete tag)."""
+ if not text or not marker:
+ return text, ""
+ for k in range(min(len(text), len(marker) - 1), 0, -1):
+ if marker.startswith(text[-k:]):
+ return text[:-k], text[-k:]
+ return text, ""
+
+
+class _ThinkingTagParser:
+ """Split an LLM token stream into visible content vs minimax thinking blocks."""
+
+ def __init__(
+ self,
+ open_tag: str = _THINKING_OPEN_TAG,
+ close_tag: str = _THINKING_CLOSE_TAG,
+ ) -> None:
+ self._open_tag = open_tag
+ self._close_tag = close_tag
+ self._in_thinking = False
+ self._carry = ""
+
+ def feed(self, text: str) -> list[tuple[_StreamKind, str]]:
+ if not text:
+ return []
+ stream = self._carry + text
+ self._carry = ""
+ out: list[tuple[_StreamKind, str]] = []
+ i = 0
+ while i < len(stream):
+ if not self._in_thinking:
+ open_at = stream.find(self._open_tag, i)
+ if open_at == -1:
+ tail = stream[i:]
+ emit, self._carry = _split_partial_marker_suffix(tail, self._open_tag)
+ if emit:
+ out.append(("content", emit))
+ break
+ if open_at > i:
+ out.append(("content", stream[i:open_at]))
+ i = open_at + len(self._open_tag)
+ self._in_thinking = True
+ else:
+ close_at = stream.find(self._close_tag, i)
+ if close_at == -1:
+ tail = stream[i:]
+ emit, self._carry = _split_partial_marker_suffix(tail, self._close_tag)
+ if emit:
+ out.append(("reasoning", emit))
+ break
+ if close_at > i:
+ out.append(("reasoning", stream[i:close_at]))
+ i = close_at + len(self._close_tag)
+ self._in_thinking = False
+ return out
+
+ def flush(self) -> list[tuple[_StreamKind, str]]:
+ out: list[tuple[_StreamKind, str]] = []
+ if self._carry:
+ kind: _StreamKind = "reasoning" if self._in_thinking else "content"
+ out.append((kind, self._carry))
+ self._carry = ""
+ self._in_thinking = False
+ return out
+
+
+class _SegmentBuffer:
+ """Rolling buffer that emits fixed-size segments as data arrives."""
+
+ def __init__(self, max_chars: int) -> None:
+ self._max_chars = max(1, max_chars)
+ self._pending = ""
+
+ def push(self, text: str) -> list[str]:
+ if not text:
+ return []
+ self._pending += text
+ return self._take(full_only=True)
+
+ def finish(self) -> list[str]:
+ segments = self._take(full_only=False)
+ if self._pending:
+ segments.append(self._pending)
+ self._pending = ""
+ return segments
+
+ def _take(self, *, full_only: bool) -> list[str]:
+ segments: list[str] = []
+ while self._pending:
+ if full_only and len(self._pending) <= self._max_chars:
+ break
+ end = min(self._max_chars, len(self._pending))
+ if end < len(self._pending) and self._pending[end - 1] not in " \n\t":
+ boundary = self._pending.rfind(" ", 0, end)
+ if boundary > 0:
+ end = boundary + 1
+ segment = self._pending[:end]
+ if not segment:
+ segment = self._pending[:1]
+ end = 1
+ segments.append(segment)
+ self._pending = self._pending[end:]
+ return segments
+
+
+class _StreamSegmentEmitter:
+ """Parse thinking tags and emit segmented content / reasoning text pieces."""
+
+ def __init__(self, max_chars: int) -> None:
+ self._parser = _ThinkingTagParser()
+ self._content_buf = _SegmentBuffer(max_chars)
+ self._reasoning_buf = _SegmentBuffer(max_chars)
+
+ def feed(self, text: str) -> list[tuple[_StreamKind, str]]:
+ segments: list[tuple[_StreamKind, str]] = []
+ for kind, piece in self._parser.feed(text):
+ segments.extend(self._push_piece(kind, piece))
+ return segments
+
+ def finish(self) -> list[tuple[_StreamKind, str]]:
+ segments: list[tuple[_StreamKind, str]] = []
+ for kind, piece in self._parser.flush():
+ segments.extend(self._push_piece(kind, piece))
+ for segment in self._content_buf.finish():
+ segments.append(("content", segment))
+ for segment in self._reasoning_buf.finish():
+ segments.append(("reasoning", segment))
+ return segments
+
+ def _push_piece(self, kind: _StreamKind, piece: str) -> list[tuple[_StreamKind, str]]:
+ buf = self._reasoning_buf if kind == "reasoning" else self._content_buf
+ return [(kind, segment) for segment in buf.push(piece)]
+
+
+class _CatalogStreamChunk(ChatResponseChunk):
+ """``ChatResponseChunk`` that serializes ``delta.reasoning_content`` for API Catalog."""
+
+ _reasoning_delta: str | None = PrivateAttr(default=None)
+
+ def get_stream_data(self) -> str:
+ payload = json.loads(super().get_stream_data().removeprefix("data:").strip())
+ if self._reasoning_delta is not None:
+ choice = payload["choices"][0]
+ delta = dict(choice.get("delta") or {})
+ delta["reasoning_content"] = self._reasoning_delta
+ if not delta.get("content"):
+ delta.pop("content", None)
+ choice["delta"] = delta
+ return f"data: {json.dumps(payload, separators=(',', ':'))}\n\n"
+
+
+def _catalog_stream_chunk(
+ *,
+ stream_id: str,
+ created: datetime.datetime,
+ model: str,
+ content: str | None = None,
+ reasoning_content: str | None = None,
+ role: str | None = None,
+ finish_reason: str | None = None,
+ usage: Usage | None = None,
+) -> ChatResponseChunk:
+ base = ChatResponseChunk.create_streaming_chunk(
+ content if content is not None else "",
+ id_=stream_id,
+ created=created,
+ model=model,
+ role=role,
+ finish_reason=finish_reason,
+ usage=usage,
+ )
+ if not reasoning_content:
+ return base
+ chunk = _CatalogStreamChunk.model_validate(base.model_dump())
+ chunk._reasoning_delta = reasoning_content
+ return chunk
+
+
+class _StreamDeltaWriter:
+ """Map segmented stream pieces to API Catalog ``ChatResponseChunk`` objects."""
+
+ def __init__(
+ self,
+ *,
+ stream_id: str,
+ created: datetime.datetime,
+ model: str,
+ max_chars: int,
+ ) -> None:
+ self._stream_id = stream_id
+ self._created = created
+ self._model = model
+ self._emitter = _StreamSegmentEmitter(max_chars)
+
+ def feed(self, text: str) -> list[ChatResponseChunk]:
+ return [self._to_chunk(kind, segment) for kind, segment in self._emitter.feed(text)]
+
+ def finish(self) -> list[ChatResponseChunk]:
+ return [self._to_chunk(kind, segment) for kind, segment in self._emitter.finish()]
+
+ def _to_chunk(self, kind: _StreamKind, segment: str) -> ChatResponseChunk:
+ if kind == "reasoning":
+ return _catalog_stream_chunk(
+ stream_id=self._stream_id,
+ created=self._created,
+ model=self._model,
+ reasoning_content=segment,
+ )
+ return _catalog_stream_chunk(
+ stream_id=self._stream_id,
+ created=self._created,
+ model=self._model,
+ content=segment,
+ )
+
class DeepAgentConfig(FunctionBaseConfig, name="deepagent_fn"):
"""Langchain DeepAgents agent that delegates to subagents via create_deep_agent.
@@ -121,15 +364,13 @@ class DeepAgentConfig(FunctionBaseConfig, name="deepagent_fn"):
description="Maximum delay cap in seconds between retries.",
)
strip_reasoning_pattern: str = Field(
- default=r".*?\s*|.*",
+ default=_DEFAULT_STRIP_REASONING_PATTERN,
description=(
"Regex pattern (re.DOTALL) to strip from the final response. "
"Matches are removed before returning to the caller. "
"Set to empty string to disable stripping."
),
)
-
-
@register_function(config_type=DeepAgentConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
async def deep_agent(config: DeepAgentConfig, builder: Builder):
import psutil
@@ -183,33 +424,21 @@ async def deep_agent(config: DeepAgentConfig, builder: Builder):
# Workaround to strip reasoning patterns from the final response with minimax model
strip_re = re.compile(config.strip_reasoning_pattern, re.DOTALL) if config.strip_reasoning_pattern else None
- # Inner function that handles the agent invocation and response processing
- async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
- """Inner function that handles the agent invocation and response processing.
- Args:
- chat_request_or_message: The chat request or message to process.
- Returns:
- A chat response or string.
- """
- chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
+ @asynccontextmanager
+ async def _agent_session(
+ chat_request: ChatRequest,
+ ) -> AsyncIterator[tuple[object, list]]:
+ """Yield (agent, messages_dict_list) inside a sandbox; cleans up child processes on exit."""
messages = [m.model_dump() for m in chat_request.messages]
-
- # Create a temporary sandbox directory for the agent
- # Note execute tool will create files on host, a more robust sandbox should be used for production.
with TemporaryDirectory() as sandbox_dir:
sandbox = Path(sandbox_dir)
-
populate_sandbox(sandbox, skills_src_dirs, agents_md_src, config.workspace_dirs)
-
- # Create a local shell backend for the agent
backend = LocalShellBackend(
root_dir=sandbox,
virtual_mode=True,
inherit_env=True,
env=env,
)
-
- # create subagent dictionaries
sub_agent_dicts: list[dict] = []
for ref in config.subagents:
fn = await builder.get_function(ref)
@@ -222,7 +451,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
"Resolved %d subagent(s): %s", len(sub_agent_dicts), [sa.get("name", "?") for sa in sub_agent_dicts]
)
- # Create a middleware chain for the agent to improve reliability and performance
middleware = [
FixToolNamesMiddleware(),
ToolRetryMiddleware(),
@@ -236,7 +464,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
),
]
- # Create a dictionary of agent configuration arguments, including subagents if configured
agent_kwargs: dict = dict(
tools=config.tools,
model=llm,
@@ -252,29 +479,208 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
agent_kwargs["memory"] = effective_memory
agent = create_deep_agent(**agent_kwargs)
-
- # Ensure child/orphaned processes are cleaned up
pre_children = {c.pid for c in psutil.Process().children(recursive=True)}
try:
- agent_result = await agent.ainvoke({"messages": messages})
-
- result_messages = agent_result["messages"]
- content = result_messages[-1].content if result_messages else ""
- content = strip_pattern(content, strip_re)
+ yield agent, messages
finally:
kill_orphaned_children(pre_children)
- # Calculate usage metrics
+ def _usage_for_content(chat_request: ChatRequest, content: str) -> Usage:
prompt_tokens = sum(len(str(m.content).split()) for m in chat_request.messages)
completion_tokens = len(content.split()) if content else 0
- usage = Usage(
+ return Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
- response = ChatResponse.from_string(content, usage=usage)
- if chat_request_or_message.is_string:
- return GlobalTypeConverter.get().convert(response, to_type=str)
- return response
- yield FunctionInfo.from_fn(_inner, description=config.description)
+ def _response_model(chat_request: ChatRequest) -> str:
+ return (chat_request.model or "").strip() or "unknown-model"
+
+ async def _single(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse:
+ """Non-streaming OpenAI chat completion (root JSON object, no ``value`` wrapper)."""
+ chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
+ async with _agent_session(chat_request) as (agent, messages):
+ agent_result = await agent.ainvoke({"messages": messages})
+ result_messages = agent_result["messages"]
+ content = result_messages[-1].content if result_messages else ""
+ content = strip_pattern(content, strip_re)
+ usage = _usage_for_content(chat_request, content)
+ return ChatResponse.from_string(content, usage=usage, model=_response_model(chat_request))
+
+ def _extract_text_content(content: object) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: list[str] = []
+ for block in content:
+ if isinstance(block, str):
+ parts.append(block)
+ elif isinstance(block, dict) and block.get("type") == "text":
+ parts.append(str(block.get("text", "")))
+ elif hasattr(block, "text"):
+ parts.append(str(block.text))
+ return "".join(parts)
+ return str(content)
+
+ def _namespace_tuple(ns: object) -> tuple:
+ if isinstance(ns, str):
+ return (ns,)
+ if ns is None:
+ return ()
+ return tuple(ns)
+
+ def _is_subagent_namespace(ns: tuple) -> bool:
+ return any(isinstance(s, str) and s.startswith("tools:") for s in ns)
+
+ def _message_token_text(token: object) -> str:
+ if getattr(token, "type", None) not in ("ai", None):
+ return ""
+ if getattr(token, "tool_call_chunks", None):
+ return ""
+ return _extract_text_content(getattr(token, "content", None))
+
+ def _progress_from_update(chunk: dict) -> str | None:
+ ns = _namespace_tuple(chunk.get("ns"))
+ if _is_subagent_namespace(ns):
+ return None
+ data = chunk.get("data")
+ if not isinstance(data, dict):
+ return None
+ if "tools" in data:
+ return "Running tools…\n"
+ if "model_request" in data and not ns:
+ return None
+ return None
+
+ async def _stream_llm_chunks(agent: object, messages: list) -> AsyncGenerator[str, None]:
+ """Yield main-agent assistant text (LLM tokens and optional progress lines)."""
+
+ async def _yield_from_astream_events() -> AsyncGenerator[str, None]:
+ astream_events = getattr(agent, "astream_events", None)
+ if astream_events is None:
+ return
+ async for event in astream_events({"messages": messages}, version="v2"):
+ if not isinstance(event, dict) or event.get("event") != "on_chat_model_stream":
+ continue
+ data = event.get("data") or {}
+ llm_chunk = data.get("chunk")
+ text = _message_token_text(llm_chunk)
+ if text:
+ yield text
+
+ emitted = False
+ try:
+ async for text in _yield_from_astream_events():
+ emitted = True
+ yield text
+ except Exception:
+ logger.debug("astream_events token stream unavailable", exc_info=True)
+
+ if emitted:
+ return
+
+ stream_modes: list[str] = ["messages"]
+ if _STREAM_PROGRESS_UPDATES:
+ stream_modes.append("updates")
+
+ try:
+ astream = agent.astream(
+ {"messages": messages},
+ stream_mode=stream_modes,
+ subgraphs=True,
+ version="v2",
+ )
+ except TypeError:
+ astream = agent.astream(
+ {"messages": messages},
+ stream_mode="messages",
+ subgraphs=True,
+ )
+
+ async for chunk in astream:
+ if not isinstance(chunk, dict):
+ continue
+ chunk_type = chunk.get("type")
+ ns = _namespace_tuple(chunk.get("ns"))
+
+ if chunk_type == "updates" and _STREAM_PROGRESS_UPDATES:
+ if _is_subagent_namespace(ns):
+ continue
+ progress = _progress_from_update(chunk)
+ if progress:
+ yield progress
+ continue
+
+ if chunk_type != "messages":
+ continue
+ if _is_subagent_namespace(ns):
+ continue
+ payload = chunk.get("data")
+ if not isinstance(payload, (list, tuple)) or len(payload) < 1:
+ continue
+ token = payload[0]
+ text = _message_token_text(token)
+ if text:
+ yield text
+
+ async def _stream(chat_request_or_message: ChatRequestOrMessage) -> AsyncGenerator[ChatResponseChunk, None]:
+ """OpenAI-style SSE chunks via NAT ``ChatResponseChunk`` (``data:`` lines when framed by NAT)."""
+ chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
+ response_model = _response_model(chat_request)
+ stream_id = str(uuid.uuid4())
+ created = datetime.datetime.now(datetime.UTC)
+ assembled_raw: list[str] = []
+ writer = _StreamDeltaWriter(
+ stream_id=stream_id,
+ created=created,
+ model=response_model,
+ max_chars=_STREAM_MAX_SEGMENT_CHARS,
+ )
+
+ async with _agent_session(chat_request) as (agent, messages):
+ yield ChatResponseChunk.create_streaming_chunk(
+ "",
+ id_=stream_id,
+ created=created,
+ model=response_model,
+ role=UserMessageContentRoleType.ASSISTANT,
+ )
+ try:
+ async for text in _stream_llm_chunks(agent, messages):
+ assembled_raw.append(text)
+ for chunk in writer.feed(text):
+ yield chunk
+ except Exception:
+ logger.exception("Token streaming failed; falling back to buffered completion")
+ agent_result = await agent.ainvoke({"messages": messages})
+ result_messages = agent_result["messages"]
+ content = result_messages[-1].content if result_messages else ""
+ content = _extract_text_content(content)
+ assembled_raw.clear()
+ assembled_raw.append(content)
+ for chunk in writer.feed(content):
+ yield chunk
+
+ for chunk in writer.finish():
+ yield chunk
+
+ content = strip_pattern("".join(assembled_raw), strip_re)
+
+ usage = _usage_for_content(chat_request, content)
+ yield ChatResponseChunk.create_streaming_chunk(
+ "",
+ id_=stream_id,
+ created=created,
+ model=response_model,
+ finish_reason="stop",
+ usage=usage,
+ )
+
+ yield FunctionInfo.create(
+ single_fn=_single,
+ stream_fn=_stream,
+ description=config.description,
+ )