feat: Implement long-term memory integration with memory_tool and Mem…#52
feat: Implement long-term memory integration with memory_tool and Mem…#52prashant4654 wants to merge 4 commits into10xHub:mainfrom
Conversation
…oryWriteTracker fix: Correct token details key in LiteLLMConverter test test: Add comprehensive tests for long-term memory functionality
Signed-off-by: prashant4654 <ee23btech11218@iith.ac.in>
There was a problem hiding this comment.
Pull request overview
This PR adds first-class long-term memory (LTM) support to AgentFlow, including an LLM-callable memory_tool, a preload node to inject retrieved memories into context, and shutdown handling to await pending async memory writes. It also improves LiteLLM response conversion (token defaults + reasoning extraction) and updates tests accordingly.
Changes:
- Introduces
agentflow.store.long_term_memorywithmemory_tool, preload-node factory, system-prompt helpers, and a pending-write tracker. - Updates graph shutdown (
CompiledGraph.aclose) to wait for pending memory writes before shutting down background tasks. - Improves LiteLLM converter robustness and updates LiteLLM converter tests; adds LTM test suite.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
agentflow/store/long_term_memory.py |
New LTM integration module: tool, preload node, prompt helpers, and write tracking. |
agentflow/graph/compiled_graph.py |
Awaits pending memory writes during graceful shutdown. |
agentflow/adapters/llm/litellm_converter.py |
More robust token usage parsing and reasoning extraction from additional provider fields. |
agentflow/store/__init__.py |
Exposes new LTM APIs via store package exports. |
agentflow/graph/tool_node/constants.py |
Adds task_manager to injectable tool params for schema/invocation behavior. |
tests/store/test_long_term_memory.py |
Comprehensive tests for the new LTM module and tool behavior. |
tests/adapters/test_litellm_converter.py |
Updates test fixture usage field to match new reasoning token source. |
Comments suppressed due to low confidence (1)
agentflow/adapters/llm/litellm_converter.py:79
reasoning_tokensis now sourced only fromcompletion_tokens_details.reasoning_tokens. Some LiteLLM/OpenAI-compatible responses still report reasoning tokens underprompt_tokens_details.reasoning_tokens(as this file previously assumed), so this change can silently drop reasoning token counts to 0 for those providers. Consider falling back toprompt_tokens_detailswhencompletion_tokens_detailsis missing/None.
ImportError: If LiteLLM is not installed.
"""
if not HAS_LITELLM:
raise ImportError("litellm is not installed. Please install it to use this converter.")
data = response.model_dump()
usages_data = data.get("usage", {})
usages = TokenUsages(
| task = MagicMock(spec=asyncio.Task) | ||
| task.done.return_value = False | ||
| task.add_done_callback = MagicMock() | ||
| mgr.create_task = MagicMock(return_value=task) |
There was a problem hiding this comment.
The mock_task_manager.create_task fixture returns a MagicMock and does not actually schedule (or close) the coroutine passed in. In the async-write tests, memory_tool() creates a _do_write(...) coroutine object and hands it to create_task; with this mock, that coroutine is never awaited and can trigger RuntimeWarning: coroutine was never awaited during test teardown. Consider adjusting the mock to either close the passed coroutine or wrap it in a real asyncio.create_task(...) (and manage/cancel it in the test).
| task = MagicMock(spec=asyncio.Task) | |
| task.done.return_value = False | |
| task.add_done_callback = MagicMock() | |
| mgr.create_task = MagicMock(return_value=task) | |
| def _create_task(coro, *args, **kwargs): | |
| # Schedule the coroutine so it is not left un-awaited. | |
| # Using a real asyncio.Task keeps behavior close to production code | |
| # while still allowing us to assert calls via the MagicMock wrapper. | |
| return asyncio.create_task(coro) | |
| mgr.create_task = MagicMock(side_effect=_create_task) |
| query, | ||
| memory_type=mem_type, | ||
| limit=limit, | ||
| score_threshold=score_threshold if score_threshold > 0 else None, |
There was a problem hiding this comment.
In the search branch, the category parameter is currently ignored, even though it’s part of the tool’s public signature and BaseStore.asearch(...) supports category. This makes category a no-op for search calls. Consider passing category through to store.asearch(...) (e.g., when non-empty) so search behavior matches the tool API.
| score_threshold=score_threshold if score_threshold > 0 else None, | |
| score_threshold=score_threshold if score_threshold > 0 else None, | |
| category=category if category else None, |
| if memory_types: | ||
| search_kwargs["memory_type"] = memory_types[0] | ||
| if max_tokens is not None: | ||
| search_kwargs["max_tokens"] = max_tokens | ||
|
|
||
| try: | ||
| results = await store.asearch(config, query, **search_kwargs) |
There was a problem hiding this comment.
memory_types is typed as a list but only memory_types[0] is used for memory_type filtering. This is misleading for callers and makes it easy to think multiple types are supported when they aren’t. Either change the parameter to a single memory_type: MemoryType | None, or implement multi-type support (e.g., multiple searches/filters) in a way the store API can handle.
| if memory_types: | |
| search_kwargs["memory_type"] = memory_types[0] | |
| if max_tokens is not None: | |
| search_kwargs["max_tokens"] = max_tokens | |
| try: | |
| results = await store.asearch(config, query, **search_kwargs) | |
| if max_tokens is not None: | |
| search_kwargs["max_tokens"] = max_tokens | |
| try: | |
| # Support searching across multiple memory types by performing | |
| # separate searches per type and aggregating the results, while | |
| # respecting the overall limit. | |
| if memory_types: | |
| results: list[MemorySearchResult] = [] | |
| for mt in memory_types: | |
| remaining = limit - len(results) | |
| if remaining <= 0: | |
| break | |
| typed_kwargs = dict(search_kwargs) | |
| typed_kwargs["limit"] = remaining | |
| typed_kwargs["memory_type"] = mt | |
| partial = await store.asearch(config, query, **typed_kwargs) | |
| if not partial: | |
| continue | |
| results.extend(partial) | |
| else: | |
| results = await store.asearch(config, query, **search_kwargs) |
| tasks = list(self._pending) | ||
| if not tasks: | ||
| return {"status": "completed", "pending_writes": 0} | ||
|
|
||
| count = len(tasks) | ||
| logger.info("Waiting for %d pending memory writes to complete...", count) | ||
| try: | ||
| if timeout: | ||
| await asyncio.wait_for( | ||
| asyncio.gather(*tasks, return_exceptions=True), | ||
| timeout=timeout, | ||
| ) | ||
| else: | ||
| await asyncio.gather(*tasks, return_exceptions=True) | ||
| logger.info("All %d pending memory writes completed.", count) | ||
| return {"status": "completed", "pending_writes": 0, "completed": count} | ||
| except TimeoutError: | ||
| remaining = len(self._pending) | ||
| logger.warning( | ||
| "Timeout waiting for memory writes: %d/%d still pending", remaining, count | ||
| ) | ||
| return { | ||
| "status": "timeout", | ||
| "pending_writes": remaining, | ||
| "completed": count - remaining, | ||
| } | ||
|
|
There was a problem hiding this comment.
MemoryWriteTracker.wait_for_pending() snapshots self._pending once (tasks = list(self._pending)) and only awaits that initial list. If new memory-write tasks are tracked while the wait is in progress (or after the snapshot but before shutdown completes), they won’t be awaited even though the method’s docstring says it waits for “all pending writes”. Consider looping until the pending set is empty (or timeout elapses), or taking the snapshot under the lock and re-checking after each gather.
| tasks = list(self._pending) | |
| if not tasks: | |
| return {"status": "completed", "pending_writes": 0} | |
| count = len(tasks) | |
| logger.info("Waiting for %d pending memory writes to complete...", count) | |
| try: | |
| if timeout: | |
| await asyncio.wait_for( | |
| asyncio.gather(*tasks, return_exceptions=True), | |
| timeout=timeout, | |
| ) | |
| else: | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| logger.info("All %d pending memory writes completed.", count) | |
| return {"status": "completed", "pending_writes": 0, "completed": count} | |
| except TimeoutError: | |
| remaining = len(self._pending) | |
| logger.warning( | |
| "Timeout waiting for memory writes: %d/%d still pending", remaining, count | |
| ) | |
| return { | |
| "status": "timeout", | |
| "pending_writes": remaining, | |
| "completed": count - remaining, | |
| } | |
| # We may have new writes tracked while we are waiting, so we loop | |
| # until the pending set is empty (or the overall timeout elapses), | |
| # taking snapshots under the lock on each iteration. | |
| loop = asyncio.get_running_loop() | |
| deadline: float | None = loop.time() + timeout if timeout is not None else None | |
| initial_count: int | None = None | |
| while True: | |
| async with self._lock: | |
| tasks = list(self._pending) | |
| pending_now = len(tasks) | |
| if pending_now == 0: | |
| # No pending writes; if we never saw any, match the original return. | |
| if initial_count is None: | |
| return {"status": "completed", "pending_writes": 0} | |
| logger.info("All %d pending memory writes completed.", initial_count) | |
| return { | |
| "status": "completed", | |
| "pending_writes": 0, | |
| "completed": initial_count, | |
| } | |
| # First time we see pending tasks, record and log. | |
| if initial_count is None: | |
| initial_count = pending_now | |
| logger.info( | |
| "Waiting for %d pending memory writes to complete...", initial_count | |
| ) | |
| # Compute remaining timeout for this iteration, if any. | |
| per_iter_timeout: float | None | |
| if deadline is not None: | |
| remaining_time = deadline - loop.time() | |
| if remaining_time <= 0: | |
| remaining = self.pending_count | |
| logger.warning( | |
| "Timeout waiting for memory writes: %d/%d still pending", | |
| remaining, | |
| initial_count, | |
| ) | |
| return { | |
| "status": "timeout", | |
| "pending_writes": remaining, | |
| "completed": max(0, initial_count - remaining), | |
| } | |
| per_iter_timeout = remaining_time | |
| else: | |
| per_iter_timeout = None | |
| try: | |
| if per_iter_timeout is not None: | |
| await asyncio.wait_for( | |
| asyncio.gather(*tasks, return_exceptions=True), | |
| timeout=per_iter_timeout, | |
| ) | |
| else: | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| except asyncio.TimeoutError: | |
| remaining = self.pending_count | |
| logger.warning( | |
| "Timeout waiting for memory writes: %d/%d still pending", | |
| remaining, | |
| initial_count, | |
| ) | |
| return { | |
| "status": "timeout", | |
| "pending_writes": remaining, | |
| "completed": max(0, initial_count - remaining), | |
| } |
| # 0. Wait for pending memory writes to complete (no cancellation) | ||
| try: | ||
| from agentflow.store.long_term_memory import get_write_tracker | ||
|
|
||
| tracker = get_write_tracker() | ||
| if tracker.pending_count > 0: | ||
| logger.info( | ||
| "Waiting for %d pending memory writes before shutdown...", | ||
| tracker.pending_count, | ||
| ) | ||
| write_stats = await tracker.wait_for_pending(timeout=self._shutdown_timeout) | ||
| stats["memory_writes"] = write_stats | ||
| else: | ||
| stats["memory_writes"] = {"status": "completed", "pending_writes": 0} | ||
| except Exception as e: | ||
| stats["memory_writes"] = {"status": "error", "error": str(e)} | ||
| logger.exception("Error waiting for pending memory writes: %s", e) | ||
|
|
There was a problem hiding this comment.
aclose() waits on a global memory-write tracker (get_write_tracker() returns a module-level singleton). In a process running multiple graphs concurrently, closing one graph could block on memory writes scheduled by other graphs (and vice versa), making shutdown timing unpredictable. Consider scoping the tracker to the graph/task_manager instance (e.g., inject a per-graph MemoryWriteTracker, or have BackgroundTaskManager expose a way to wait for a subset of tasks) so aclose() only waits for its own writes.
| # 0. Wait for pending memory writes to complete (no cancellation) | |
| try: | |
| from agentflow.store.long_term_memory import get_write_tracker | |
| tracker = get_write_tracker() | |
| if tracker.pending_count > 0: | |
| logger.info( | |
| "Waiting for %d pending memory writes before shutdown...", | |
| tracker.pending_count, | |
| ) | |
| write_stats = await tracker.wait_for_pending(timeout=self._shutdown_timeout) | |
| stats["memory_writes"] = write_stats | |
| else: | |
| stats["memory_writes"] = {"status": "completed", "pending_writes": 0} | |
| except Exception as e: | |
| stats["memory_writes"] = {"status": "error", "error": str(e)} | |
| logger.exception("Error waiting for pending memory writes: %s", e) | |
| # 0. Memory writes: avoid waiting on a global tracker to prevent cross-graph coupling. | |
| # Any long-term memory writes should be managed by per-graph components (e.g., the | |
| # background task manager or graph-specific stores), so we only record a placeholder | |
| # status here without blocking on a process-wide singleton. | |
| stats["memory_writes"] = { | |
| "status": "not_tracked", | |
| "detail": "CompiledGraph.aclose() does not wait on the global memory-write tracker.", | |
| } |
This pull request introduces a comprehensive long-term memory integration for AgentFlow, enabling LLMs to interact with persistent memory through a new tool and supporting infrastructure. The changes include a new
memory_toolfor search/store/update/delete operations, a system for tracking pending memory writes to guarantee graceful shutdown, and enhancements to memory retrieval and prompt handling.LiteLLM resoning nlock and token issue fixedLong-term memory integration:
Added
agentflow/store/long_term_memory.py, providing:memory_tool: an LLM-callable tool for searching, storing, updating, and deleting long-term memories.create_memory_preload_node: a factory for injecting retrieved memories into the agent state before LLM calls.get_memory_system_prompt: helper for system prompt fragments tailored to different retrieval modes.Updated
agentflow/store/__init__.pyto export new long-term memory components, making them available for import throughout the codebase.Graceful shutdown improvements:
Modified
agentflow/graph/compiled_graph.pyto await pending memory writes before shutting down, using the newMemoryWriteTrackerfor robust resource management.LLM response and content block handling:
agentflow/adapters/llm/litellm_converter.pyby:thinking_blocksif not directly present, supporting more provider formats.Test and schema updates:
tests/adapters/test_litellm_converter.pyto use the correct field (completion_tokens_details) for reasoning tokens, reflecting upstream API changes.Constants and configuration:
"task_manager"to the set of injectable node names inagentflow/graph/tool_node/constants.py, supporting new memory tool dependencies.