diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index e41f531f..084589bb 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -31,6 +31,9 @@ class AgentCoreMemoryConfig(BaseModel): retrieval_config: Optional dictionary mapping namespaces to retrieval configurations batch_size: Number of messages to batch before sending to AgentCore Memory. Default of 1 means immediate sending (no batching). Max 100. + flush_interval_seconds: Optional interval in seconds for automatic buffer flushing. + Useful for long-running agents to ensure messages are persisted regularly. + Default is None (disabled). context_tag: XML tag name used to wrap retrieved memory context injected into messages. Default is "user_context". """ @@ -40,4 +43,5 @@ class AgentCoreMemoryConfig(BaseModel): actor_id: str = Field(min_length=1) retrieval_config: Optional[Dict[str, RetrievalConfig]] = None batch_size: int = Field(default=1, ge=1, le=100) + flush_interval_seconds: Optional[float] = Field(default=None, gt=0) context_tag: str = Field(default="user_context", min_length=1) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 2783a498..d279ba98 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -10,7 +10,7 @@ import boto3 from botocore.config import Config as BotocoreConfig -from strands.hooks import MessageAddedEvent +from strands.hooks import AfterInvocationEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -125,6 +125,11 @@ def __init__( # Cache for agent created_at timestamps to avoid fetching on every update self._agent_created_at_cache: dict[str, datetime] = {} + # Interval-based flushing support + self._flush_timer: Optional[threading.Timer] = None + self._timer_lock = threading.Lock() + self._shutdown = False + # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -145,6 +150,10 @@ def __init__( ) super().__init__(session_id=self.config.session_id, session_repository=self) + # Start interval-based flush timer if configured + if self.config.flush_interval_seconds: + self._start_flush_timer() + # region SessionRepository interface implementation def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session in AgentCore Memory. @@ -672,6 +681,10 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None: RepositorySessionManager.register_hooks(self, registry, **kwargs) registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + # Only register AfterInvocationEvent hook when batching is enabled + if self.config.batch_size > 1: + registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: if self.has_existing_agent: @@ -784,6 +797,7 @@ def close(self) -> None: messages are sent to AgentCore Memory. Alternatively, use the context manager protocol (with statement) for automatic cleanup. """ + self._stop_flush_timer() self._flush_messages() def __enter__(self) -> "AgentCoreMemorySessionManager": @@ -803,6 +817,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: exc_tb: Exception traceback if an exception occurred. """ try: + self._stop_flush_timer() self._flush_messages() except Exception as e: if exc_type is not None: @@ -811,3 +826,70 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: raise # endregion Batching support + + # region Interval-based flushing support + + def _start_flush_timer(self) -> None: + """Start the interval-based flush timer. + + This method schedules a recurring timer that flushes the message buffer + at regular intervals if flush_interval_seconds is configured. + """ + with self._timer_lock: + if self._shutdown: + return + + # Cancel existing timer if any + if self._flush_timer is not None: + self._flush_timer.cancel() + + # Schedule next flush + self._flush_timer = threading.Timer( + self.config.flush_interval_seconds, + self._interval_flush_callback, + ) + self._flush_timer.daemon = True + self._flush_timer.start() + logger.debug( + "Scheduled interval flush in %.1f seconds", + self.config.flush_interval_seconds, + ) + + def _interval_flush_callback(self) -> None: + """Callback executed by the flush timer. + + Flushes the buffer if it contains messages, then reschedules the timer. + """ + try: + # Only flush if there are messages in the buffer + pending = self.pending_message_count() + if pending > 0: + logger.debug("Interval flush triggered: %d message(s) pending", pending) + self._flush_messages() + else: + logger.debug("Interval flush skipped: buffer is empty") + + # Reschedule the timer (unless shutdown) + if not self._shutdown and self.config.flush_interval_seconds: + self._start_flush_timer() + + except Exception as e: + logger.error("Error during interval flush: %s", e) + # Attempt to reschedule even after error + if not self._shutdown and self.config.flush_interval_seconds: + self._start_flush_timer() + + def _stop_flush_timer(self) -> None: + """Stop the interval-based flush timer. + + This method cancels the timer and prevents it from rescheduling. + Should be called during cleanup (close() or __exit__). + """ + with self._timer_lock: + self._shutdown = True + if self._flush_timer is not None: + self._flush_timer.cancel() + self._flush_timer = None + logger.debug("Stopped interval flush timer") + + # endregion Interval-based flushing support diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index dd131063..11c964d8 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1,13 +1,15 @@ """Tests for AgentCoreMemorySessionManager.""" import logging +import time from unittest.mock import Mock, patch import pytest from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError from strands.agent.agent import Agent -from strands.hooks import MessageAddedEvent +from strands.hooks import AfterInvocationEvent, MessageAddedEvent +from strands.hooks.registry import HookRegistry from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -2082,3 +2084,350 @@ def test_retrieve_customer_context_default_context_tag(self, mock_memory_client) content = mock_agent.messages[0]["content"] assert "" in content[0]["text"] assert "" in content[0]["text"] + + +class TestAfterInvocationHook: + """Test AfterInvocationEvent hook integration.""" + + def test_after_invocation_hook_registered(self, batching_session_manager): + """Test that AfterInvocationEvent hook is registered when batching is enabled.""" + registry = HookRegistry() + batching_session_manager.register_hooks(registry) + + # Verify AfterInvocationEvent callback is registered (batching is enabled) + assert AfterInvocationEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterInvocationEvent]) > 0 + + def test_after_invocation_hook_flushes_buffer(self, batching_session_manager, mock_memory_client): + """Test that AfterInvocationEvent hook triggers flush.""" + # Mock session_repository to avoid parent class hook issues + batching_session_manager.session_repository = Mock() + + # Add messages to buffer + with batching_session_manager._buffer_lock: + batching_session_manager._message_buffer.append( + ("test-session", [("user", "test message")], False, batching_session_manager._get_monotonic_timestamp()) + ) + + assert batching_session_manager.pending_message_count() == 1 + + # Register hooks and trigger AfterInvocationEvent + registry = HookRegistry() + batching_session_manager.register_hooks(registry) + + # Create mock event with mock agent + mock_agent = Mock() + mock_event = AfterInvocationEvent(agent=mock_agent) + registry.invoke_callbacks(mock_event) + + # Verify buffer was flushed + assert batching_session_manager.pending_message_count() == 0 + + def test_after_invocation_hook_not_registered_when_batching_disabled(self, session_manager): + """Test that AfterInvocationEvent flush hook is NOT registered when batching is disabled.""" + # Spy on the registry to track what gets added + registry = HookRegistry() + original_add = registry.add_callback + added_callbacks = [] + + def spy_add_callback(event_type, callback): + added_callbacks.append((event_type, callback)) + return original_add(event_type, callback) + + registry.add_callback = spy_add_callback + session_manager.register_hooks(registry) + + # Check that no AfterInvocationEvent callback referencing _flush_messages was added + flush_callbacks = [ + cb + for event_type, cb in added_callbacks + if event_type == AfterInvocationEvent + and hasattr(cb, "__code__") + and "_flush_messages" in str(cb.__code__.co_names) + ] + assert len(flush_callbacks) == 0 + + +class TestIntervalFlush: + """Test interval-based flush mechanism for long-running agents.""" + + def test_interval_flush_timer_starts_when_configured(self): + """Test that interval flush timer starts when flush_interval_seconds is set.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + + # Verify timer was started + assert manager._flush_timer is not None + assert manager._flush_timer.is_alive() + assert not manager._shutdown + + # Cleanup + manager.close() + + def test_interval_flush_timer_not_started_when_disabled(self): + """Test that interval flush timer is not started when flush_interval_seconds is None.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + # flush_interval_seconds not set (defaults to None) + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + + # Verify timer was not started + assert manager._flush_timer is None + assert not manager._shutdown + + def test_interval_flush_timer_stops_on_close(self): + """Test that interval flush timer stops when close() is called.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + + # Verify timer is running + assert manager._flush_timer is not None + timer_ref = manager._flush_timer # Keep reference + assert timer_ref.is_alive() + assert not manager._shutdown + + # Close manager + manager.close() + + # Verify timer is stopped + assert manager._shutdown + assert manager._flush_timer is None # Should be set to None + # Give timer a moment to actually stop + time.sleep(0.1) + assert not timer_ref.is_alive() # Verify thread actually stopped + + def test_interval_flush_timer_stops_on_context_exit(self): + """Test that interval flush timer stops when exiting context manager.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + with AgentCoreMemorySessionManager(config) as manager: + # Verify timer is running inside context + assert manager._flush_timer is not None + assert not manager._shutdown + + # Verify timer is stopped after context exit + assert manager._shutdown + + def test_interval_flush_callback_flushes_when_buffer_has_messages(self): + """Test that interval flush callback flushes buffer when it has messages.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + mock_client.create_event.return_value = {"eventId": "event_123"} + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + + # Add messages to buffer + with manager._buffer_lock: + manager._message_buffer.append( + ("test-session", [("user", "test message")], False, manager._get_monotonic_timestamp()) + ) + + assert manager.pending_message_count() == 1 + + # Manually trigger interval flush callback + manager._interval_flush_callback() + + # Verify buffer was flushed + assert manager.pending_message_count() == 0 + + # Cleanup + manager.close() + + def test_interval_flush_callback_skips_when_buffer_empty(self): + """Test that interval flush callback skips flush when buffer is empty.""" + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + batch_size=10, + flush_interval_seconds=5.0, + ) + + mock_client = Mock() + mock_client.list_events.return_value = [] + + with ( + patch( + "bedrock_agentcore.memory.integrations.strands.session_manager.MemoryClient", + return_value=mock_client, + ), + patch("boto3.Session") as mock_boto_session, + patch("strands.session.repository_session_manager.RepositorySessionManager.__init__", return_value=None), + ): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_boto_session.return_value = mock_session + + manager = AgentCoreMemorySessionManager(config) + + # Verify buffer is empty + assert manager.pending_message_count() == 0 + + # Track flush calls + original_flush = manager._flush_messages + flush_called = {"count": 0} + + def tracked_flush(): + flush_called["count"] += 1 + return original_flush() + + manager._flush_messages = tracked_flush + + # Manually trigger interval flush callback + manager._interval_flush_callback() + + # Verify flush was not called (buffer was empty) + assert flush_called["count"] == 0 + + # Cleanup + manager.close() + + def test_config_flush_interval_validation(self): + """Test that flush_interval_seconds must be positive.""" + # Valid: positive value + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + flush_interval_seconds=30.0, + ) + assert config.flush_interval_seconds == 30.0 + + # Valid: None (disabled) + config = AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + flush_interval_seconds=None, + ) + assert config.flush_interval_seconds is None + + # Invalid: zero or negative should raise validation error + from pydantic import ValidationError + + with pytest.raises(ValidationError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + flush_interval_seconds=0.0, + ) + + with pytest.raises(ValidationError): + AgentCoreMemoryConfig( + memory_id="test-memory", + session_id="test-session", + actor_id="test-actor", + flush_interval_seconds=-5.0, + )