diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index b5be13dc..6f84fe08 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -1,3 +1,4 @@ +import time from typing import Any import tiktoken @@ -54,6 +55,9 @@ logger = get_logger(__name__) +_tiktoken_encoding: Any | None = None +_tiktoken_encoding_last_failed_at: float | None = None +_TIKTOKEN_ENCODING_RETRY_INTERVAL_SECONDS = 300 router = APIRouter() @@ -102,15 +106,56 @@ def _get_effective_token_limit( def _calculate_messages_token_count(messages: list[MemoryMessage]) -> int: """Calculate total token count for a list of messages.""" - encoding = tiktoken.get_encoding("cl100k_base") - total_tokens = 0 + return sum(_count_message_tokens(msg) for msg in messages) + + +def _get_tiktoken_encoding() -> Any | None: + """Load the tokenizer encoding once and fall back safely if unavailable.""" + global _tiktoken_encoding, _tiktoken_encoding_last_failed_at + + if _tiktoken_encoding is not None: + return _tiktoken_encoding + + now = time.monotonic() + last_failed_at = _tiktoken_encoding_last_failed_at + retry_after = ( + None + if last_failed_at is None + else last_failed_at + _TIKTOKEN_ENCODING_RETRY_INTERVAL_SECONDS + ) + if retry_after is not None and now < retry_after: + return None + + try: + _tiktoken_encoding = tiktoken.get_encoding("cl100k_base") + _tiktoken_encoding_last_failed_at = None + except Exception as exc: + _tiktoken_encoding_last_failed_at = now + logger.warning( + "tiktoken encoding unavailable, using character-based token estimate", + error=str(exc), + ) + return None + + return _tiktoken_encoding + + +def _estimate_text_token_count(text: str) -> int: + """Estimate token count when tiktoken is unavailable.""" + return max(1, (len(text) + 3) // 4) + + +def _count_text_tokens(text: str) -> int: + """Count tokens accurately when possible, otherwise fall back to estimation.""" + encoding = _get_tiktoken_encoding() + if encoding is None: + return _estimate_text_token_count(text) + return len(encoding.encode(text)) - for msg in messages: - msg_str = f"{msg.role}: {msg.content}" - msg_tokens = len(encoding.encode(msg_str)) - total_tokens += msg_tokens - return total_tokens +def _count_message_tokens(message: MemoryMessage) -> int: + """Count tokens for a single working-memory message.""" + return _count_text_tokens(f"{message.role}: {message.content}") def _calculate_context_usage_percentages( @@ -250,7 +295,6 @@ async def _summarize_working_memory( buffer_tokens = min(max(230, summarization_max_tokens // 100), 1000) max_message_tokens = summarization_max_tokens - summary_max_tokens - buffer_tokens - encoding = tiktoken.get_encoding("cl100k_base") total_tokens = 0 messages_to_summarize = [] @@ -266,7 +310,7 @@ async def _summarize_working_memory( for i in range(len(memory.messages) - 1, -1, -1): msg = memory.messages[i] msg_str = f"{msg.role}: {msg.content}" - msg_tokens = len(encoding.encode(msg_str)) + msg_tokens = _count_text_tokens(msg_str) if recent_messages_tokens + msg_tokens <= target_remaining_tokens: recent_messages_tokens += msg_tokens @@ -281,12 +325,12 @@ async def _summarize_working_memory( for msg in messages_to_check: msg_str = f"{msg.role}: {msg.content}" - msg_tokens = len(encoding.encode(msg_str)) + msg_tokens = _count_text_tokens(msg_str) # Handle oversized messages if msg_tokens > max_message_tokens: msg_str = msg_str[: max_message_tokens // 2] - msg_tokens = len(encoding.encode(msg_str)) + msg_tokens = _count_text_tokens(msg_str) if total_tokens + msg_tokens <= max_message_tokens: total_tokens += msg_tokens diff --git a/tests/test_issue_237.py b/tests/test_issue_237.py new file mode 100644 index 00000000..067bd5e3 --- /dev/null +++ b/tests/test_issue_237.py @@ -0,0 +1,102 @@ +"""Tests for GitHub issue #237: safe token counting when tiktoken is unavailable.""" + +from unittest.mock import Mock, patch + +import pytest + +from agent_memory_server.api import ( + _calculate_messages_token_count, + _get_tiktoken_encoding, +) +from agent_memory_server.models import MemoryMessage + + +class TestIssue237TiktokenFallback: + def test_calculate_messages_token_count_falls_back_when_tiktoken_unavailable( + self, + ): + """Token counting should degrade gracefully when the encoding cannot load.""" + messages = [MemoryMessage(role="user", content="Hello world")] + + with ( + patch("agent_memory_server.api._tiktoken_encoding", None), + patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None), + patch( + "agent_memory_server.api.tiktoken.get_encoding", + side_effect=Exception("Could not download encoding data"), + ), + ): + token_count = _calculate_messages_token_count(messages) + + assert token_count > 0 + + @pytest.mark.asyncio + async def test_get_working_memory_uses_fallback_when_tiktoken_unavailable( + self, client + ): + """GET should return session data instead of a 500 when tokenization fails.""" + session_id = "issue-237-api" + + put_response = await client.put( + f"/v1/working-memory/{session_id}", + json={ + "messages": [{"role": "user", "content": "Hello from issue 237"}], + "user_id": "alice", + "namespace": "demo", + }, + ) + assert put_response.status_code == 200 + + with ( + patch("agent_memory_server.api._tiktoken_encoding", None), + patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None), + patch( + "agent_memory_server.api.tiktoken.get_encoding", + side_effect=Exception("Could not download encoding data"), + ), + ): + get_response = await client.get( + f"/v1/working-memory/{session_id}?model_name=gpt-4o" + ) + + assert get_response.status_code == 200, get_response.text + data = get_response.json() + assert data["session_id"] == session_id + assert len(data["messages"]) == 1 + + def test_get_tiktoken_encoding_skips_retries_within_backoff_window(self): + """Repeated calls should not re-attempt loading within the retry interval.""" + mock_get_encoding = Mock(side_effect=Exception("Could not download encoding")) + + with ( + patch("agent_memory_server.api._tiktoken_encoding", None), + patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None), + patch("agent_memory_server.api.time.monotonic", side_effect=[100.0, 101.0]), + patch("agent_memory_server.api.tiktoken.get_encoding", mock_get_encoding), + ): + assert _get_tiktoken_encoding() is None + assert _get_tiktoken_encoding() is None + + assert mock_get_encoding.call_count == 1 + + def test_get_tiktoken_encoding_retries_after_backoff_window(self): + """A later call should retry loading once the backoff window has passed.""" + + class FakeEncoding: + def encode(self, text: str) -> list[int]: + return [1] * len(text) + + mock_get_encoding = Mock( + side_effect=[Exception("temporary failure"), FakeEncoding()] + ) + + with ( + patch("agent_memory_server.api._tiktoken_encoding", None), + patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None), + patch("agent_memory_server.api.time.monotonic", side_effect=[100.0, 401.0]), + patch("agent_memory_server.api.tiktoken.get_encoding", mock_get_encoding), + ): + assert _get_tiktoken_encoding() is None + assert _get_tiktoken_encoding() is not None + + assert mock_get_encoding.call_count == 2