diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index ef1beba3..fc7eac3e 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -120,7 +120,11 @@ def __init__( # Batching support - stores pre-processed messages: (session_id, messages, is_blob, timestamp) self._message_buffer: list[tuple[str, list[tuple[str, str]], bool, datetime]] = [] - self._buffer_lock = threading.Lock() + self._message_lock = threading.Lock() + + # Agent state buffering - stores all agent state updates: (session_id, agent) + self._agent_state_buffer: list[tuple[str, SessionAgent]] = [] + self._agent_state_lock = threading.Lock() # Cache for agent created_at timestamps to avoid fetching on every update self._agent_created_at_cache: dict[str, datetime] = {} @@ -397,8 +401,14 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") session_agent.created_at = self._agent_created_at_cache[agent_id] - # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` - self.create_agent(session_id, session_agent) + if self.config.batch_size > 1: + # Buffer the agent state update + with self._agent_state_lock: + self._agent_state_buffer.append((session_id, session_agent)) + else: + # Immediate send create_event without buffering + # Create a new agent as AgentCore Memory is immutable. We always get the latest one in `read_agent` + self.create_agent(session_id, session_agent) def create_message( self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any @@ -452,7 +462,7 @@ def create_message( if self.config.batch_size > 1: # Buffer the pre-processed message should_flush = False - with self._buffer_lock: + with self._message_lock: self._message_buffer.append((session_id, messages, is_blob, monotonic_timestamp)) should_flush = len(self._message_buffer) >= self.config.batch_size @@ -702,27 +712,31 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # region Batching support def _flush_messages(self) -> list[dict[str, Any]]: - """Flush all buffered messages to AgentCore Memory. + """Flush all buffered messages and agent state to AgentCore Memory. - Call this method to send any remaining buffered messages when batch_size > 1. + Call this method to send any remaining buffered messages and agent state when batch_size > 1. This is automatically called when the buffer reaches batch_size, but should also be called explicitly when the session is complete (via close() or context manager). Messages are batched by session_id - all conversational messages for the same session are combined into a single create_event() call to reduce API calls. Blob messages (>9KB) are sent individually as they require a different API path. + Agent state updates are sent after messages. Returns: list[dict[str, Any]]: List of created event responses from AgentCore Memory. Raises: - SessionException: If any message creation fails. On failure, all messages - remain in the buffer to prevent data loss. + SessionException: If any message or agent state creation fails. On failure, all messages + and agent state remain in the buffer to prevent data loss. """ - with self._buffer_lock: + with self._message_lock: messages_to_send = list(self._message_buffer) - if not messages_to_send: + with self._agent_state_lock: + agent_states_to_send = list(self._agent_state_buffer) + + if not messages_to_send and not agent_states_to_send: return [] # Group conversational messages by session_id, preserve order @@ -772,13 +786,39 @@ def _flush_messages(self) -> list[dict[str, Any]]: results.append(event) logger.debug("Flushed blob event for session %s: %s", session_id, event.get("eventId")) - # Clear buffer only after ALL messages succeed - with self._buffer_lock: + # Flush agent state updates after messages - batch all agent states into a single API call + if agent_states_to_send: + # Convert all agent states to payload format + agent_state_payloads = [] + for _session_id, session_agent in agent_states_to_send: + agent_state_payloads.append({"blob": json.dumps(session_agent.to_dict())}) + + # Send all agent states in a single batched create_event call + event = self.memory_client.gmdp_client.create_event( + memoryId=self.config.memory_id, + actorId=self.config.actor_id, + sessionId=self.config.session_id, + payload=agent_state_payloads, + eventTimestamp=self._get_monotonic_timestamp(), + metadata={ + STATE_TYPE_KEY: {"stringValue": StateType.AGENT.value}, + }, + ) + results.append(event) + logger.debug( + "Flushed %d agent states in batched event: %s", len(agent_states_to_send), event.get("eventId") + ) + + # Clear buffers only after ALL messages and agent state succeed + with self._message_lock: self._message_buffer.clear() + with self._agent_state_lock: + self._agent_state_buffer.clear() + except Exception as e: - logger.error("Failed to flush messages to AgentCore Memory for session: %s", e) - raise SessionException(f"Failed to flush messages: {e}") from e + logger.error("Failed to flush messages and agent state to AgentCore Memory: %s", e) + raise SessionException(f"Failed to flush messages and agent state: {e}") from e logger.info("Flushed %d events to AgentCore Memory", len(results)) return results @@ -789,9 +829,18 @@ def pending_message_count(self) -> int: Returns: int: Number of buffered messages waiting to be sent. """ - with self._buffer_lock: + with self._message_lock: return len(self._message_buffer) + def pending_agent_state_count(self) -> int: + """Return the number of agent states pending in the buffer. + + Returns: + int: Number of buffered agent states waiting to be sent. + """ + with self._agent_state_lock: + return len(self._agent_state_buffer) + def close(self) -> None: """Explicitly flush pending messages and close the session manager. @@ -860,16 +909,21 @@ def _start_flush_timer(self) -> None: def _interval_flush_callback(self) -> None: """Callback executed by the flush timer. - Flushes the buffer if it contains messages, then reschedules the timer. + Flushes the buffer if it contains messages or agent states, then reschedules the timer. """ try: - # Only flush if there are messages in the buffer - pending = self.pending_message_count() - if pending > 0: - logger.debug("Interval flush triggered: %d message(s) pending", pending) + # Only flush if there are messages or agent states in the buffer + pending_messages = self.pending_message_count() + pending_agent_states = self.pending_agent_state_count() + if pending_messages > 0 or pending_agent_states > 0: + logger.debug( + "Interval flush triggered: %d message(s) and %d agent state(s) pending", + pending_messages, + pending_agent_states, + ) self._flush_messages() else: - logger.debug("Interval flush skipped: buffer is empty") + logger.debug("Interval flush skipped: buffers are empty") # Reschedule the timer (unless shutdown) if not self._shutdown and self.config.flush_interval_seconds: diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index b793c373..c36cd7ff 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1389,6 +1389,124 @@ def test_create_message_returns_empty_dict_when_buffered(self, batching_session_ assert result == {} + def test_pending_agent_state_count_empty_buffer(self, batching_session_manager): + """Test pending_agent_state_count returns 0 for empty buffer.""" + assert batching_session_manager.pending_agent_state_count() == 0 + + def test_pending_agent_state_count_with_buffered_states(self, batching_session_manager, mock_memory_client): + """Test pending_agent_state_count returns correct count.""" + # First create the agent so update_agent doesn't fail + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # Update agent state multiple times (should buffer with batch_size=5) + for i in range(3): + agent.state["description"] = f"Updated description {i}" + batching_session_manager.update_agent("test-session-456", agent) + + # Should have 3 agent states in buffer (all updates preserved) + assert batching_session_manager.pending_agent_state_count() == 3 + # Verify no additional create_agent calls were made (still buffered) + assert mock_memory_client.gmdp_client.create_event.call_count == 1 # Only the initial create_agent + + def test_agent_state_buffer_keeps_latest_per_agent(self, batching_session_manager, mock_memory_client): + """Test that agent state buffer preserves all agent state updates.""" + # Create two agents + agent1 = SessionAgent( + agent_id="agent-1", + state={"description": "Description 1"}, + conversation_manager_state={}, + ) + agent2 = SessionAgent( + agent_id="agent-2", + state={"description": "Description 2"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent1) + batching_session_manager.create_agent("test-session-456", agent2) + + # Update agent1 multiple times + for i in range(3): + agent1.state["description"] = f"Agent 1 update {i}" + batching_session_manager.update_agent("test-session-456", agent1) + + # Update agent2 once + agent2.state["description"] = "Agent 2 updated" + batching_session_manager.update_agent("test-session-456", agent2) + + # Should have 4 agent states in buffer (all updates preserved: 3 for agent1 + 1 for agent2) + assert batching_session_manager.pending_agent_state_count() == 4 + + def test_agent_state_flushed_with_messages(self, batching_session_manager, mock_memory_client): + """Test that agent states are flushed along with messages.""" + mock_memory_client.create_event.return_value = {"eventId": "event_123"} + + # Create agent + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # Add messages and update agent state + for i in range(3): + message = SessionMessage( + message={"role": "user", "content": [{"text": f"Message {i}"}]}, + message_id=i, + created_at="2024-01-01T12:00:00Z", + ) + batching_session_manager.create_message("test-session-456", "test-agent", message) + + agent.state["description"] = "Updated" + batching_session_manager.update_agent("test-session-456", agent) + + # Verify both are buffered + assert batching_session_manager.pending_message_count() == 3 + assert batching_session_manager.pending_agent_state_count() == 1 + + # Flush + batching_session_manager._flush_messages() + + # Both buffers should be cleared + assert batching_session_manager.pending_message_count() == 0 + assert batching_session_manager.pending_agent_state_count() == 0 + + # Verify create_event was called for messages and agent state + # 1 initial create_agent + 1 batched message call + 1 agent state update + assert mock_memory_client.create_event.call_count == 1 # batched messages + assert mock_memory_client.gmdp_client.create_event.call_count == 2 # initial + update + + def test_agent_state_preserved_on_flush_failure(self, batching_session_manager, mock_memory_client): + """Test that agent states remain in buffer if flush fails.""" + # Create agent + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", agent) + + # Update agent state + agent.state["description"] = "Updated" + batching_session_manager.update_agent("test-session-456", agent) + + assert batching_session_manager.pending_agent_state_count() == 1 + + # Make flush fail + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") + + # Flush should fail + with pytest.raises(SessionException): + batching_session_manager._flush_messages() + + # Agent state should still be in buffer + assert batching_session_manager.pending_agent_state_count() == 1 + class TestBatchingFlush: """Test _flush_messages behavior.""" @@ -2156,7 +2274,7 @@ def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mo batching_session_manager.session_repository = Mock() # Add messages to buffer - with batching_session_manager._buffer_lock: + with batching_session_manager._message_lock: batching_session_manager._message_buffer.append( ("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) ) @@ -2379,7 +2497,7 @@ def test_interval_flush_callback_flushes_when_buffer_has_messages(self): manager = AgentCoreMemorySessionManager(config) # Add messages to buffer - with manager._buffer_lock: + with manager._message_lock: manager._message_buffer.append( ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) ) @@ -2445,6 +2563,124 @@ def tracked_flush(): # Cleanup manager.close() + def test_interval_flush_callback_flushes_when_agent_state_pending(self): + """Test that interval flush callback flushes when agent state is pending.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + mock_client.create_event.return_value = {"eventId": "event_123"} + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_gmdp_client = Mock() + mock_gmdp_client.create_event.return_value = {"eventId": "event_456"} + mock_session.client.return_value = mock_gmdp_client + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + manager.session_id = "test-session" # Set session_id since parent __init__ is mocked + + # Add agent state to buffer (no messages) + from strands.types.session import SessionAgent + + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test"}, + conversation_manager_state={}, + ) + with manager._agent_state_lock: + manager._agent_state_buffer.append(("test-session", agent)) + manager._agent_created_at_cache["test-agent"] = agent.created_at + + assert manager.pending_message_count() == 0 + assert manager.pending_agent_state_count() == 1 + + # Manually trigger interval flush callback + manager._interval_flush_callback() + + # Verify buffer was flushed + assert manager.pending_agent_state_count() == 0 + + # Cleanup + manager.close() + + def test_interval_flush_callback_flushes_when_both_buffers_have_data(self): + """Test that interval flush callback flushes when both messages and agent states are pending.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + mock_client.create_event.return_value = {"eventId": "event_123"} + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_gmdp_client = Mock() + mock_gmdp_client.create_event.return_value = {"eventId": "event_456"} + mock_session.client.return_value = mock_gmdp_client + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + manager.session_id = "test-session" # Set session_id since parent __init__ is mocked + + # Add both messages and agent state to buffers + with manager._message_lock: + manager._message_buffer.append( + ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + ) + + from strands.types.session import SessionAgent + + agent = SessionAgent( + agent_id="test-agent", + state={"description": "Test"}, + conversation_manager_state={}, + ) + with manager._agent_state_lock: + manager._agent_state_buffer.append(("test-session", agent)) + manager._agent_created_at_cache["test-agent"] = agent.created_at + + assert manager.pending_message_count() == 1 + assert manager.pending_agent_state_count() == 1 + + # Manually trigger interval flush callback + manager._interval_flush_callback() + + # Verify both buffers were flushed + assert manager.pending_message_count() == 0 + assert manager.pending_agent_state_count() == 0 + + # Cleanup + manager.close() + def test_config_flush_interval_validation(self): """Test that flush_interval_seconds must be positive.""" # Valid: positive value