Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Strands integration for Bedrock AgentCore Memory."""

from .converters import MemoryConverter, OpenAIConverseConverter

__all__ = ["MemoryConverter", "OpenAIConverseConverter"]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

logger = logging.getLogger(__name__)

CONVERSATIONAL_MAX_SIZE = 9000
# Bedrock AgentCore Data Plane conversational payload text max is 100000 chars.
# Ref: https://docs.aws.amazon.com/cli/latest/reference/bedrock-agentcore/create-event.html
CONVERSATIONAL_MAX_SIZE = 100000


class AgentCoreMemoryConverter:
Expand Down
3 changes: 3 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class AgentCoreMemoryConfig(BaseModel):
Default of 1 means immediate sending (no batching). Max 100.
context_tag: XML tag name used to wrap retrieved memory context injected into messages.
Default is "user_context".
filter_restored_tool_context: When True, strip historical toolUse/toolResult blocks from
restored messages before loading them into Strands runtime memory. Default is False.
"""

memory_id: str = Field(min_length=1)
Expand All @@ -41,3 +43,4 @@ class AgentCoreMemoryConfig(BaseModel):
retrieval_config: Optional[Dict[str, RetrievalConfig]] = None
batch_size: int = Field(default=1, ge=1, le=100)
context_tag: str = Field(default="user_context", min_length=1)
filter_restored_tool_context: bool = Field(default=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Converters for Strands <-> STM message formats."""

from .openai import OpenAIConverseConverter
from .protocol import MemoryConverter

__all__ = [
"OpenAIConverseConverter",
"MemoryConverter",
]
190 changes: 190 additions & 0 deletions src/bedrock_agentcore/memory/integrations/strands/converters/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""OpenAI-format converter for AgentCore Memory.

Converts between Strands SessionMessages (Strands-native message shape) and OpenAI message format
stored in AgentCore Memory STM events.
"""

import json
import logging
from typing import Any, Tuple

from strands.types.session import SessionMessage

from .protocol import exceeds_conversational_limit

logger = logging.getLogger(__name__)


def _bedrock_to_openai(message: dict) -> dict:
"""Convert a Strands-native message dict to OpenAI message format."""
role = message.get("role", "user")
content = message.get("content", [])

if content and "toolResult" in content[0]:
tool_result = content[0]["toolResult"]
text_parts = [c.get("text", "") for c in tool_result.get("content", []) if "text" in c]
result = {
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": "\n".join(text_parts),
}
if "status" in tool_result:
result["status"] = tool_result["status"]
return result

text_parts = []
tool_calls = []
reasoning_blocks: list[dict[str, Any]] = []
for item in content:
if "text" in item:
text_value = item.get("text")
if isinstance(text_value, str):
text = text_value.strip()
if text:
text_parts.append(text)
elif "reasoningContent" in item:
# OpenAI message shape does not have a stable multi-turn reasoning block field.
# Preserve original block(s) in storage-only extension field for lossless restore.
reasoning_blocks.append(item)
elif "toolUse" in item:
tu = item["toolUse"]
tool_calls.append(
{
"id": tu["toolUseId"],
"type": "function",
"function": {
"name": tu["name"],
"arguments": json.dumps(tu.get("input", {})),
},
}
)

result: dict[str, Any] = {"role": role}

if tool_calls:
result["content"] = "\n".join(text_parts) if text_parts else None
result["tool_calls"] = tool_calls
else:
result["content"] = "\n".join(text_parts) if text_parts else ""

if reasoning_blocks:
result["_strands_reasoning_content"] = reasoning_blocks

return result


def _openai_to_bedrock(openai_msg: dict) -> dict:
"""Convert an OpenAI message dict to Strands-native message shape."""
role = openai_msg.get("role", "user")
content_items: list[dict[str, Any]] = []

if role == "tool":
tool_result: dict[str, Any] = {
"toolUseId": openai_msg["tool_call_id"],
"content": [{"text": openai_msg.get("content", "")}],
}
if "status" in openai_msg:
tool_result["status"] = openai_msg["status"]
return {
"role": "user",
"content": [{"toolResult": tool_result}],
}

if role == "system":
return {
"role": "user",
"content": [{"text": openai_msg.get("content", "")}],
}

text_content = openai_msg.get("content")
if text_content and isinstance(text_content, str):
content_items.append({"text": text_content})

for tc in openai_msg.get("tool_calls", []):
fn = tc.get("function", {})
args_str = fn.get("arguments", "{}")
try:
args = json.loads(args_str)
except (json.JSONDecodeError, ValueError):
args = {}
content_items.append(
{
"toolUse": {
"toolUseId": tc["id"],
"name": fn["name"],
"input": args,
}
}
)

for rc in openai_msg.get("_strands_reasoning_content", []):
if isinstance(rc, dict) and "reasoningContent" in rc:
content_items.append(rc)

bedrock_role = "assistant" if role == "assistant" else "user"

return {"role": bedrock_role, "content": content_items}


class OpenAIConverseConverter:
"""Converts between Strands SessionMessages and OpenAI message format in STM."""

@staticmethod
def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]:
"""Convert a SessionMessage (Strands-native shape) to OpenAI-format STM payload."""
message = session_message.message
content = message.get("content", [])
if not content:
return []

has_non_empty = any(
(isinstance(item.get("text"), str) and item["text"].strip())
or "toolUse" in item
or "toolResult" in item
for item in content
)
if not has_non_empty:
return []

openai_msg = _bedrock_to_openai(message)
role = openai_msg.get("role", "user")
return [(json.dumps(openai_msg), role)]

@staticmethod
def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
"""Convert STM events containing OpenAI-format messages to SessionMessages."""
messages: list[SessionMessage] = []

for event in reversed(events):
for payload_item in event.get("payload", []):
openai_msg = None

if "conversational" in payload_item:
conv = payload_item["conversational"]
try:
openai_msg = json.loads(conv["content"]["text"])
except (json.JSONDecodeError, KeyError, ValueError):
logger.error("Failed to parse conversational payload as OpenAI message")
continue

elif "blob" in payload_item:
try:
blob_data = json.loads(payload_item["blob"])
if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2:
openai_msg = json.loads(blob_data[0])
except (json.JSONDecodeError, ValueError):
logger.error("Failed to parse blob payload: %s", payload_item)
continue

if openai_msg and isinstance(openai_msg, dict):
bedrock_msg = _openai_to_bedrock(openai_msg)
if bedrock_msg.get("content"):
session_msg = SessionMessage(message=bedrock_msg, message_id=0)
messages.append(session_msg)

return messages

@staticmethod
def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds conversational payload size limit."""
return exceeds_conversational_limit(message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Shared protocol and utilities for memory converters."""

from typing import Any, Protocol, Tuple

from strands.types.session import SessionMessage

CONVERSATIONAL_MAX_SIZE = 100000


class MemoryConverter(Protocol):
"""Protocol for converting between Strands messages and STM event payloads."""

@staticmethod
def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]]:
"""Convert SessionMessage to STM event payload format."""

@staticmethod
def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]:
"""Convert STM events to SessionMessages."""

@staticmethod
def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds conversational payload size limit."""


def exceeds_conversational_limit(message: tuple[str, str]) -> bool:
"""Check if message exceeds the conversational payload size limit."""
return sum(len(text) for text in message) >= CONVERSATIONAL_MAX_SIZE
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@
from typing_extensions import override

from bedrock_agentcore.memory.client import MemoryClient
from bedrock_agentcore.memory.models.filters import EventMetadataFilter, LeftExpression, OperatorType, RightExpression
from bedrock_agentcore.memory.models.filters import (
EventMetadataFilter,
LeftExpression,
OperatorType,
RightExpression,
)

from .bedrock_converter import AgentCoreMemoryConverter
from .config import AgentCoreMemoryConfig, RetrievalConfig
from .converters import MemoryConverter

if TYPE_CHECKING:
from strands.agent.agent import Agent
Expand Down Expand Up @@ -98,6 +104,7 @@ def _get_monotonic_timestamp(cls, desired_timestamp: Optional[datetime] = None)
def __init__(
self,
agentcore_memory_config: AgentCoreMemoryConfig,
converter: Optional[type[MemoryConverter]] = None,
region_name: Optional[str] = None,
boto_session: Optional[boto3.Session] = None,
boto_client_config: Optional[BotocoreConfig] = None,
Expand All @@ -107,12 +114,15 @@ def __init__(

Args:
agentcore_memory_config (AgentCoreMemoryConfig): Configuration for AgentCore Memory integration.
converter (Optional[type[MemoryConverter]], optional): Optional custom converter.
If None, native Bedrock/Strands converter is used.
region_name (Optional[str], optional): AWS region for Bedrock AgentCore Memory. Defaults to None.
boto_session (Optional[boto3.Session], optional): Optional boto3 session. Defaults to None.
boto_client_config (Optional[BotocoreConfig], optional): Optional boto3 client configuration.
Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
self.converter = converter
self.config = agentcore_memory_config
self.memory_client = MemoryClient(region_name=region_name)
session = boto_session or boto3.Session(region_name=region_name)
Expand Down Expand Up @@ -417,11 +427,12 @@ def create_message(
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")

# Convert and check size ONCE (not again at flush)
messages = AgentCoreMemoryConverter.message_to_payload(session_message)
converter = self.converter or AgentCoreMemoryConverter
messages = converter.message_to_payload(session_message)
if not messages:
return None

is_blob = AgentCoreMemoryConverter.exceeds_conversational_limit(messages[0])
is_blob = converter.exceeds_conversational_limit(messages[0])

# Parse the original timestamp and use it as desired timestamp
original_timestamp = datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00"))
Expand Down Expand Up @@ -545,7 +556,10 @@ def list_messages(
session_id=session_id,
max_results=max_results,
)
messages = AgentCoreMemoryConverter.events_to_messages(events)
converter = self.converter or AgentCoreMemoryConverter
messages = converter.events_to_messages(events)
if self.config.filter_restored_tool_context:
messages = self._filter_restored_tool_context(messages)
if limit is not None:
return messages[offset : offset + limit]
else:
Expand All @@ -555,6 +569,33 @@ def list_messages(
logger.error("Failed to list messages from AgentCore Memory: %s", e)
return []

def _filter_restored_tool_context(self, messages: list[SessionMessage]) -> list[SessionMessage]:
"""Strip historical toolUse/toolResult context from restored messages."""
filtered_messages: list[SessionMessage] = []
for session_message in messages:
message = session_message.to_message()
filtered_content = [
content
for content in message.get("content", [])
if "toolUse" not in content and "toolResult" not in content
]

if not filtered_content:
continue

filtered_message: Message = {"role": message["role"], "content": filtered_content}
filtered_messages.append(
SessionMessage(
message=filtered_message,
message_id=session_message.message_id,
redact_message=session_message.redact_message,
created_at=session_message.created_at,
updated_at=session_message.updated_at,
)
)

return filtered_messages

# endregion SessionRepository interface implementation

# region RepositorySessionManager overrides
Expand Down Expand Up @@ -682,7 +723,8 @@ def _flush_messages(self) -> list[dict[str, Any]]:

Messages are batched by session_id - all conversational messages for the same
session are combined into a single create_event() call to reduce API calls.
Blob messages (>9KB) are sent individually as they require a different API path.
Messages that exceed the conversational payload limit are sent as blob events individually
as they require a different API path.

Returns:
list[dict[str, Any]]: List of created event responses from AgentCore Memory.
Expand Down
Loading