diff --git a/cuopt-agent/cuopt_agent/configs/config-deepagent.yml b/cuopt-agent/cuopt_agent/configs/config-deepagent.yml index 2de65c9..910b620 100755 --- a/cuopt-agent/cuopt_agent/configs/config-deepagent.yml +++ b/cuopt-agent/cuopt_agent/configs/config-deepagent.yml @@ -32,6 +32,11 @@ general: # api_key: ${LANGSMITH_API_KEY} front_end: _type: fastapi + # Suppress intermediate_data: SSE lines on /v1/chat/completions (API Catalog expects data: chunks only). + # Workflow/tool traces remain available via Phoenix tracing. Use step_adaptor mode custom/default + # only for interactive UIs that consume intermediate_data. + step_adaptor: + mode: off endpoints: - path: /health method: GET diff --git a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py index 7f879c7..eda9bfd 100755 --- a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py +++ b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py @@ -13,24 +13,267 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import json import logging import os import re +import uuid +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager from pathlib import Path from tempfile import TemporaryDirectory +from typing import Literal from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.builder.function_info import FunctionInfo from nat.cli.register_workflow import register_function -from nat.data_models.api_server import ChatRequest, ChatRequestOrMessage, ChatResponse, Usage +from nat.data_models.api_server import ( + ChatRequest, + ChatRequestOrMessage, + ChatResponse, + ChatResponseChunk, + Usage, + UserMessageContentRoleType, +) from nat.data_models.component_ref import FunctionRef, LLMRef from nat.data_models.function import FunctionBaseConfig from nat.utils.type_converter import GlobalTypeConverter -from pydantic import Field +from pydantic import Field, PrivateAttr logger = logging.getLogger(__name__) +# Built via concat so file tooling does not strip XML-like tag literals. +_THINKING_OPEN_TAG = "<" + "redacted_thinking" + ">" +_THINKING_CLOSE_TAG = "" +_StreamKind = Literal["content", "reasoning"] + +_DEFAULT_STRIP_REASONING_PATTERN = ( + rf"{_THINKING_OPEN_TAG}.*?{_THINKING_CLOSE_TAG}\s*|{_THINKING_OPEN_TAG}.*" +) + +# Streaming tuning (not NAT workflow YAML keys — adjust here, not in config-deepagent.yml). +_STREAM_MAX_SEGMENT_CHARS = 48 +_STREAM_PROGRESS_UPDATES = True + + +def _split_partial_marker_suffix(text: str, marker: str) -> tuple[str, str]: + """Split *text* so a trailing prefix of *marker* is held back (incomplete tag).""" + if not text or not marker: + return text, "" + for k in range(min(len(text), len(marker) - 1), 0, -1): + if marker.startswith(text[-k:]): + return text[:-k], text[-k:] + return text, "" + + +class _ThinkingTagParser: + """Split an LLM token stream into visible content vs minimax thinking blocks.""" + + def __init__( + self, + open_tag: str = _THINKING_OPEN_TAG, + close_tag: str = _THINKING_CLOSE_TAG, + ) -> None: + self._open_tag = open_tag + self._close_tag = close_tag + self._in_thinking = False + self._carry = "" + + def feed(self, text: str) -> list[tuple[_StreamKind, str]]: + if not text: + return [] + stream = self._carry + text + self._carry = "" + out: list[tuple[_StreamKind, str]] = [] + i = 0 + while i < len(stream): + if not self._in_thinking: + open_at = stream.find(self._open_tag, i) + if open_at == -1: + tail = stream[i:] + emit, self._carry = _split_partial_marker_suffix(tail, self._open_tag) + if emit: + out.append(("content", emit)) + break + if open_at > i: + out.append(("content", stream[i:open_at])) + i = open_at + len(self._open_tag) + self._in_thinking = True + else: + close_at = stream.find(self._close_tag, i) + if close_at == -1: + tail = stream[i:] + emit, self._carry = _split_partial_marker_suffix(tail, self._close_tag) + if emit: + out.append(("reasoning", emit)) + break + if close_at > i: + out.append(("reasoning", stream[i:close_at])) + i = close_at + len(self._close_tag) + self._in_thinking = False + return out + + def flush(self) -> list[tuple[_StreamKind, str]]: + out: list[tuple[_StreamKind, str]] = [] + if self._carry: + kind: _StreamKind = "reasoning" if self._in_thinking else "content" + out.append((kind, self._carry)) + self._carry = "" + self._in_thinking = False + return out + + +class _SegmentBuffer: + """Rolling buffer that emits fixed-size segments as data arrives.""" + + def __init__(self, max_chars: int) -> None: + self._max_chars = max(1, max_chars) + self._pending = "" + + def push(self, text: str) -> list[str]: + if not text: + return [] + self._pending += text + return self._take(full_only=True) + + def finish(self) -> list[str]: + segments = self._take(full_only=False) + if self._pending: + segments.append(self._pending) + self._pending = "" + return segments + + def _take(self, *, full_only: bool) -> list[str]: + segments: list[str] = [] + while self._pending: + if full_only and len(self._pending) <= self._max_chars: + break + end = min(self._max_chars, len(self._pending)) + if end < len(self._pending) and self._pending[end - 1] not in " \n\t": + boundary = self._pending.rfind(" ", 0, end) + if boundary > 0: + end = boundary + 1 + segment = self._pending[:end] + if not segment: + segment = self._pending[:1] + end = 1 + segments.append(segment) + self._pending = self._pending[end:] + return segments + + +class _StreamSegmentEmitter: + """Parse thinking tags and emit segmented content / reasoning text pieces.""" + + def __init__(self, max_chars: int) -> None: + self._parser = _ThinkingTagParser() + self._content_buf = _SegmentBuffer(max_chars) + self._reasoning_buf = _SegmentBuffer(max_chars) + + def feed(self, text: str) -> list[tuple[_StreamKind, str]]: + segments: list[tuple[_StreamKind, str]] = [] + for kind, piece in self._parser.feed(text): + segments.extend(self._push_piece(kind, piece)) + return segments + + def finish(self) -> list[tuple[_StreamKind, str]]: + segments: list[tuple[_StreamKind, str]] = [] + for kind, piece in self._parser.flush(): + segments.extend(self._push_piece(kind, piece)) + for segment in self._content_buf.finish(): + segments.append(("content", segment)) + for segment in self._reasoning_buf.finish(): + segments.append(("reasoning", segment)) + return segments + + def _push_piece(self, kind: _StreamKind, piece: str) -> list[tuple[_StreamKind, str]]: + buf = self._reasoning_buf if kind == "reasoning" else self._content_buf + return [(kind, segment) for segment in buf.push(piece)] + + +class _CatalogStreamChunk(ChatResponseChunk): + """``ChatResponseChunk`` that serializes ``delta.reasoning_content`` for API Catalog.""" + + _reasoning_delta: str | None = PrivateAttr(default=None) + + def get_stream_data(self) -> str: + payload = json.loads(super().get_stream_data().removeprefix("data:").strip()) + if self._reasoning_delta is not None: + choice = payload["choices"][0] + delta = dict(choice.get("delta") or {}) + delta["reasoning_content"] = self._reasoning_delta + if not delta.get("content"): + delta.pop("content", None) + choice["delta"] = delta + return f"data: {json.dumps(payload, separators=(',', ':'))}\n\n" + + +def _catalog_stream_chunk( + *, + stream_id: str, + created: datetime.datetime, + model: str, + content: str | None = None, + reasoning_content: str | None = None, + role: str | None = None, + finish_reason: str | None = None, + usage: Usage | None = None, +) -> ChatResponseChunk: + base = ChatResponseChunk.create_streaming_chunk( + content if content is not None else "", + id_=stream_id, + created=created, + model=model, + role=role, + finish_reason=finish_reason, + usage=usage, + ) + if not reasoning_content: + return base + chunk = _CatalogStreamChunk.model_validate(base.model_dump()) + chunk._reasoning_delta = reasoning_content + return chunk + + +class _StreamDeltaWriter: + """Map segmented stream pieces to API Catalog ``ChatResponseChunk`` objects.""" + + def __init__( + self, + *, + stream_id: str, + created: datetime.datetime, + model: str, + max_chars: int, + ) -> None: + self._stream_id = stream_id + self._created = created + self._model = model + self._emitter = _StreamSegmentEmitter(max_chars) + + def feed(self, text: str) -> list[ChatResponseChunk]: + return [self._to_chunk(kind, segment) for kind, segment in self._emitter.feed(text)] + + def finish(self) -> list[ChatResponseChunk]: + return [self._to_chunk(kind, segment) for kind, segment in self._emitter.finish()] + + def _to_chunk(self, kind: _StreamKind, segment: str) -> ChatResponseChunk: + if kind == "reasoning": + return _catalog_stream_chunk( + stream_id=self._stream_id, + created=self._created, + model=self._model, + reasoning_content=segment, + ) + return _catalog_stream_chunk( + stream_id=self._stream_id, + created=self._created, + model=self._model, + content=segment, + ) + class DeepAgentConfig(FunctionBaseConfig, name="deepagent_fn"): """Langchain DeepAgents agent that delegates to subagents via create_deep_agent. @@ -121,15 +364,13 @@ class DeepAgentConfig(FunctionBaseConfig, name="deepagent_fn"): description="Maximum delay cap in seconds between retries.", ) strip_reasoning_pattern: str = Field( - default=r".*?\s*|.*", + default=_DEFAULT_STRIP_REASONING_PATTERN, description=( "Regex pattern (re.DOTALL) to strip from the final response. " "Matches are removed before returning to the caller. " "Set to empty string to disable stripping." ), ) - - @register_function(config_type=DeepAgentConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def deep_agent(config: DeepAgentConfig, builder: Builder): import psutil @@ -183,33 +424,21 @@ async def deep_agent(config: DeepAgentConfig, builder: Builder): # Workaround to strip reasoning patterns from the final response with minimax model strip_re = re.compile(config.strip_reasoning_pattern, re.DOTALL) if config.strip_reasoning_pattern else None - # Inner function that handles the agent invocation and response processing - async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str: - """Inner function that handles the agent invocation and response processing. - Args: - chat_request_or_message: The chat request or message to process. - Returns: - A chat response or string. - """ - chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + @asynccontextmanager + async def _agent_session( + chat_request: ChatRequest, + ) -> AsyncIterator[tuple[object, list]]: + """Yield (agent, messages_dict_list) inside a sandbox; cleans up child processes on exit.""" messages = [m.model_dump() for m in chat_request.messages] - - # Create a temporary sandbox directory for the agent - # Note execute tool will create files on host, a more robust sandbox should be used for production. with TemporaryDirectory() as sandbox_dir: sandbox = Path(sandbox_dir) - populate_sandbox(sandbox, skills_src_dirs, agents_md_src, config.workspace_dirs) - - # Create a local shell backend for the agent backend = LocalShellBackend( root_dir=sandbox, virtual_mode=True, inherit_env=True, env=env, ) - - # create subagent dictionaries sub_agent_dicts: list[dict] = [] for ref in config.subagents: fn = await builder.get_function(ref) @@ -222,7 +451,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse "Resolved %d subagent(s): %s", len(sub_agent_dicts), [sa.get("name", "?") for sa in sub_agent_dicts] ) - # Create a middleware chain for the agent to improve reliability and performance middleware = [ FixToolNamesMiddleware(), ToolRetryMiddleware(), @@ -236,7 +464,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse ), ] - # Create a dictionary of agent configuration arguments, including subagents if configured agent_kwargs: dict = dict( tools=config.tools, model=llm, @@ -252,29 +479,208 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse agent_kwargs["memory"] = effective_memory agent = create_deep_agent(**agent_kwargs) - - # Ensure child/orphaned processes are cleaned up pre_children = {c.pid for c in psutil.Process().children(recursive=True)} try: - agent_result = await agent.ainvoke({"messages": messages}) - - result_messages = agent_result["messages"] - content = result_messages[-1].content if result_messages else "" - content = strip_pattern(content, strip_re) + yield agent, messages finally: kill_orphaned_children(pre_children) - # Calculate usage metrics + def _usage_for_content(chat_request: ChatRequest, content: str) -> Usage: prompt_tokens = sum(len(str(m.content).split()) for m in chat_request.messages) completion_tokens = len(content.split()) if content else 0 - usage = Usage( + return Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - response = ChatResponse.from_string(content, usage=usage) - if chat_request_or_message.is_string: - return GlobalTypeConverter.get().convert(response, to_type=str) - return response - yield FunctionInfo.from_fn(_inner, description=config.description) + def _response_model(chat_request: ChatRequest) -> str: + return (chat_request.model or "").strip() or "unknown-model" + + async def _single(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse: + """Non-streaming OpenAI chat completion (root JSON object, no ``value`` wrapper).""" + chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + async with _agent_session(chat_request) as (agent, messages): + agent_result = await agent.ainvoke({"messages": messages}) + result_messages = agent_result["messages"] + content = result_messages[-1].content if result_messages else "" + content = strip_pattern(content, strip_re) + usage = _usage_for_content(chat_request, content) + return ChatResponse.from_string(content, usage=usage, model=_response_model(chat_request)) + + def _extract_text_content(content: object) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict) and block.get("type") == "text": + parts.append(str(block.get("text", ""))) + elif hasattr(block, "text"): + parts.append(str(block.text)) + return "".join(parts) + return str(content) + + def _namespace_tuple(ns: object) -> tuple: + if isinstance(ns, str): + return (ns,) + if ns is None: + return () + return tuple(ns) + + def _is_subagent_namespace(ns: tuple) -> bool: + return any(isinstance(s, str) and s.startswith("tools:") for s in ns) + + def _message_token_text(token: object) -> str: + if getattr(token, "type", None) not in ("ai", None): + return "" + if getattr(token, "tool_call_chunks", None): + return "" + return _extract_text_content(getattr(token, "content", None)) + + def _progress_from_update(chunk: dict) -> str | None: + ns = _namespace_tuple(chunk.get("ns")) + if _is_subagent_namespace(ns): + return None + data = chunk.get("data") + if not isinstance(data, dict): + return None + if "tools" in data: + return "Running tools…\n" + if "model_request" in data and not ns: + return None + return None + + async def _stream_llm_chunks(agent: object, messages: list) -> AsyncGenerator[str, None]: + """Yield main-agent assistant text (LLM tokens and optional progress lines).""" + + async def _yield_from_astream_events() -> AsyncGenerator[str, None]: + astream_events = getattr(agent, "astream_events", None) + if astream_events is None: + return + async for event in astream_events({"messages": messages}, version="v2"): + if not isinstance(event, dict) or event.get("event") != "on_chat_model_stream": + continue + data = event.get("data") or {} + llm_chunk = data.get("chunk") + text = _message_token_text(llm_chunk) + if text: + yield text + + emitted = False + try: + async for text in _yield_from_astream_events(): + emitted = True + yield text + except Exception: + logger.debug("astream_events token stream unavailable", exc_info=True) + + if emitted: + return + + stream_modes: list[str] = ["messages"] + if _STREAM_PROGRESS_UPDATES: + stream_modes.append("updates") + + try: + astream = agent.astream( + {"messages": messages}, + stream_mode=stream_modes, + subgraphs=True, + version="v2", + ) + except TypeError: + astream = agent.astream( + {"messages": messages}, + stream_mode="messages", + subgraphs=True, + ) + + async for chunk in astream: + if not isinstance(chunk, dict): + continue + chunk_type = chunk.get("type") + ns = _namespace_tuple(chunk.get("ns")) + + if chunk_type == "updates" and _STREAM_PROGRESS_UPDATES: + if _is_subagent_namespace(ns): + continue + progress = _progress_from_update(chunk) + if progress: + yield progress + continue + + if chunk_type != "messages": + continue + if _is_subagent_namespace(ns): + continue + payload = chunk.get("data") + if not isinstance(payload, (list, tuple)) or len(payload) < 1: + continue + token = payload[0] + text = _message_token_text(token) + if text: + yield text + + async def _stream(chat_request_or_message: ChatRequestOrMessage) -> AsyncGenerator[ChatResponseChunk, None]: + """OpenAI-style SSE chunks via NAT ``ChatResponseChunk`` (``data:`` lines when framed by NAT).""" + chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + response_model = _response_model(chat_request) + stream_id = str(uuid.uuid4()) + created = datetime.datetime.now(datetime.UTC) + assembled_raw: list[str] = [] + writer = _StreamDeltaWriter( + stream_id=stream_id, + created=created, + model=response_model, + max_chars=_STREAM_MAX_SEGMENT_CHARS, + ) + + async with _agent_session(chat_request) as (agent, messages): + yield ChatResponseChunk.create_streaming_chunk( + "", + id_=stream_id, + created=created, + model=response_model, + role=UserMessageContentRoleType.ASSISTANT, + ) + try: + async for text in _stream_llm_chunks(agent, messages): + assembled_raw.append(text) + for chunk in writer.feed(text): + yield chunk + except Exception: + logger.exception("Token streaming failed; falling back to buffered completion") + agent_result = await agent.ainvoke({"messages": messages}) + result_messages = agent_result["messages"] + content = result_messages[-1].content if result_messages else "" + content = _extract_text_content(content) + assembled_raw.clear() + assembled_raw.append(content) + for chunk in writer.feed(content): + yield chunk + + for chunk in writer.finish(): + yield chunk + + content = strip_pattern("".join(assembled_raw), strip_re) + + usage = _usage_for_content(chat_request, content) + yield ChatResponseChunk.create_streaming_chunk( + "", + id_=stream_id, + created=created, + model=response_model, + finish_reason="stop", + usage=usage, + ) + + yield FunctionInfo.create( + single_fn=_single, + stream_fn=_stream, + description=config.description, + )