Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def __init__(

# Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp)
self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = []
self._buffer_lock = threading.Lock()
self._message_lock = threading.Lock()

# Agent state buffering - stores all agent state updates: (session_id, agent)
self._agent_state_buffer: list[tuple[str, SessionAgent]] = []
self._agent_state_lock = threading.Lock()

# Cache for agent created_at timestamps to avoid fetching on every update
self._agent_created_at_cache: dict[str, datetime] = {}
Expand Down Expand Up @@ -397,8 +401,14 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
session_agent.created_at = self._agent_created_at_cache[agent_id]

# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
self.create_agent(session_id, session_agent)
if self.config.batch_size > 1:
# Buffer the agent state update
with self._agent_state_lock:
self._agent_state_buffer.append((session_id, session_agent))
else:
# Immediate send create_event without buffering
# Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent`
self.create_agent(session_id, session_agent)

def create_message(
self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any
Expand Down Expand Up @@ -452,7 +462,7 @@ def create_message(
if self.config.batch_size > 1:
# Buffer the pre-processed message
should_flush = False
with self._buffer_lock:
with self._message_lock:
self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp))
should_flush = len(self._message_buffer) >= self.config.batch_size

Expand Down Expand Up @@ -702,27 +712,31 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
# region Batching support

def _flush_messages(self) -> list[dict[str, Any]]:
"""Flush all buffered messages to AgentCore Memory.
"""Flush all buffered messages and agent state to AgentCore Memory.

Call this method to send any remaining buffered messages when batch_size > 1.
Call this method to send any remaining buffered messages and agent state when batch_size > 1.
This is automatically called when the buffer reaches batch_size, but should
also be called explicitly when the session is complete (via close() or context manager).

Messages are batched by session_id - all conversational messages for the same
session are combined into a single create_event() call to reduce API calls.
Blob messages (>9KB) are sent individually as they require a different API path.
Agent state updates are sent after messages.

Returns:
list[dict[str, Any]]: List of created event responses from AgentCore Memory.

Raises:
SessionException: If any message creation fails. On failure, all messages
remain in the buffer to prevent data loss.
SessionException: If any message or agent state creation fails. On failure, all messages
and agent state remain in the buffer to prevent data loss.
"""
with self._buffer_lock:
with self._message_lock:
messages_to_send = list(self._message_buffer)

if not messages_to_send:
with self._agent_state_lock:
agent_states_to_send = list(self._agent_state_buffer)

if not messages_to_send and not agent_states_to_send:
return []

# Group conversational messages by session_id, preserve order
Expand Down Expand Up @@ -772,13 +786,39 @@ def _flush_messages(self) -> list[dict[str, Any]]:
results.append(event)
logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId"))

# Clear buffer only after ALL messages succeed
with self._buffer_lock:
# Flush agent state updates after messages - batch all agent states into a single API call
if agent_states_to_send:
# Convert all agent states to payload format
agent_state_payloads = []
for _session_id, session_agent in agent_states_to_send:
agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())})

# Send all agent states in a single batched create_event call
event = self.memory_client.gmdp_client.create_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=self.config.session_id,
payload=agent_state_payloads,
eventTimestamp=self._get_monotonic_timestamp(),
metadata={
STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value},
},
)
results.append(event)
logger.debug(
"Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId")
)

# Clear buffers only after ALL messages and agent state succeed
with self._message_lock:
self._message_buffer.clear()

with self._agent_state_lock:
self._agent_state_buffer.clear()

except Exception as e:
logger.error("Failed to flush messages to AgentCore Memory for session: %s", e)
raise SessionException(f"Failed to flush messages: {e}") from e
logger.error("Failed to flush messages and agent state to AgentCore Memory: %s", e)
raise SessionException(f"Failed to flush messages and agent state: {e}") from e

logger.info("Flushed %d events to AgentCore Memory", len(results))
return results
Expand All @@ -789,9 +829,18 @@ def pending_message_count(self) -> int:
Returns:
int: Number of buffered messages waiting to be sent.
"""
with self._buffer_lock:
with self._message_lock:
return len(self._message_buffer)

def pending_agent_state_count(self) -> int:
"""Return the number of agent states pending in the buffer.

Returns:
int: Number of buffered agent states waiting to be sent.
"""
with self._agent_state_lock:
return len(self._agent_state_buffer)

def close(self) -> None:
"""Explicitly flush pending messages and close the session manager.

Expand Down Expand Up @@ -860,16 +909,21 @@ def _start_flush_timer(self) -> None:
def _interval_flush_callback(self) -> None:
"""Callback executed by the flush timer.

Flushes the buffer if it contains messages, then reschedules the timer.
Flushes the buffer if it contains messages or agent states, then reschedules the timer.
"""
try:
# Only flush if there are messages in the buffer
pending = self.pending_message_count()
if pending > 0:
logger.debug("Interval flush triggered: %d message(s) pending", pending)
# Only flush if there are messages or agent states in the buffer
pending_messages = self.pending_message_count()
pending_agent_states = self.pending_agent_state_count()
if pending_messages > 0 or pending_agent_states > 0:
logger.debug(
"Interval flush triggered: %d message(s) and %d agent state(s) pending",
pending_messages,
pending_agent_states,
)
self._flush_messages()
else:
logger.debug("Interval flush skipped: buffer is empty")
logger.debug("Interval flush skipped: buffers are empty")

# Reschedule the timer (unless shutdown)
if not self._shutdown and self.config.flush_interval_seconds:
Expand Down
Loading
Loading