Skip to content
66 changes: 55 additions & 11 deletions agent_memory_server/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Any

import tiktoken
Expand Down Expand Up @@ -54,6 +55,9 @@


logger = get_logger(__name__)
_tiktoken_encoding: Any | None = None
_tiktoken_encoding_last_failed_at: float | None = None
_TIKTOKEN_ENCODING_RETRY_INTERVAL_SECONDS = 300

router = APIRouter()

Expand Down Expand Up @@ -102,15 +106,56 @@ def _get_effective_token_limit(

def _calculate_messages_token_count(messages: list[MemoryMessage]) -> int:
"""Calculate total token count for a list of messages."""
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
return sum(_count_message_tokens(msg) for msg in messages)


def _get_tiktoken_encoding() -> Any | None:
"""Load the tokenizer encoding once and fall back safely if unavailable."""
global _tiktoken_encoding, _tiktoken_encoding_last_failed_at

if _tiktoken_encoding is not None:
return _tiktoken_encoding

now = time.monotonic()
last_failed_at = _tiktoken_encoding_last_failed_at
retry_after = (
None
if last_failed_at is None
else last_failed_at + _TIKTOKEN_ENCODING_RETRY_INTERVAL_SECONDS
)
if retry_after is not None and now < retry_after:
return None

try:
_tiktoken_encoding = tiktoken.get_encoding("cl100k_base")
_tiktoken_encoding_last_failed_at = None
except Exception as exc:
_tiktoken_encoding_last_failed_at = now
logger.warning(
"tiktoken encoding unavailable, using character-based token estimate",
error=str(exc),
)
return None

return _tiktoken_encoding


def _estimate_text_token_count(text: str) -> int:
"""Estimate token count when tiktoken is unavailable."""
return max(1, (len(text) + 3) // 4)


def _count_text_tokens(text: str) -> int:
"""Count tokens accurately when possible, otherwise fall back to estimation."""
encoding = _get_tiktoken_encoding()
if encoding is None:
return _estimate_text_token_count(text)
return len(encoding.encode(text))

for msg in messages:
msg_str = f"{msg.role}: {msg.content}"
msg_tokens = len(encoding.encode(msg_str))
total_tokens += msg_tokens

return total_tokens
def _count_message_tokens(message: MemoryMessage) -> int:
"""Count tokens for a single working-memory message."""
return _count_text_tokens(f"{message.role}: {message.content}")


def _calculate_context_usage_percentages(
Expand Down Expand Up @@ -250,7 +295,6 @@ async def _summarize_working_memory(
buffer_tokens = min(max(230, summarization_max_tokens // 100), 1000)
max_message_tokens = summarization_max_tokens - summary_max_tokens - buffer_tokens

encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = 0
messages_to_summarize = []

Expand All @@ -266,7 +310,7 @@ async def _summarize_working_memory(
for i in range(len(memory.messages) - 1, -1, -1):
msg = memory.messages[i]
msg_str = f"{msg.role}: {msg.content}"
msg_tokens = len(encoding.encode(msg_str))
msg_tokens = _count_text_tokens(msg_str)

if recent_messages_tokens + msg_tokens <= target_remaining_tokens:
recent_messages_tokens += msg_tokens
Expand All @@ -281,12 +325,12 @@ async def _summarize_working_memory(

for msg in messages_to_check:
msg_str = f"{msg.role}: {msg.content}"
msg_tokens = len(encoding.encode(msg_str))
msg_tokens = _count_text_tokens(msg_str)

# Handle oversized messages
if msg_tokens > max_message_tokens:
msg_str = msg_str[: max_message_tokens // 2]
msg_tokens = len(encoding.encode(msg_str))
msg_tokens = _count_text_tokens(msg_str)

if total_tokens + msg_tokens <= max_message_tokens:
total_tokens += msg_tokens
Expand Down
102 changes: 102 additions & 0 deletions tests/test_issue_237.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Tests for GitHub issue #237: safe token counting when tiktoken is unavailable."""

from unittest.mock import Mock, patch

import pytest

from agent_memory_server.api import (
_calculate_messages_token_count,
_get_tiktoken_encoding,
)
from agent_memory_server.models import MemoryMessage


class TestIssue237TiktokenFallback:
def test_calculate_messages_token_count_falls_back_when_tiktoken_unavailable(
self,
):
"""Token counting should degrade gracefully when the encoding cannot load."""
messages = [MemoryMessage(role="user", content="Hello world")]

with (
patch("agent_memory_server.api._tiktoken_encoding", None),
patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None),
patch(
"agent_memory_server.api.tiktoken.get_encoding",
side_effect=Exception("Could not download encoding data"),
),
):
token_count = _calculate_messages_token_count(messages)

assert token_count > 0

@pytest.mark.asyncio
async def test_get_working_memory_uses_fallback_when_tiktoken_unavailable(
self, client
):
"""GET should return session data instead of a 500 when tokenization fails."""
session_id = "issue-237-api"

put_response = await client.put(
f"/v1/working-memory/{session_id}",
json={
"messages": [{"role": "user", "content": "Hello from issue 237"}],
"user_id": "alice",
"namespace": "demo",
},
)
assert put_response.status_code == 200

with (
patch("agent_memory_server.api._tiktoken_encoding", None),
patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None),
patch(
"agent_memory_server.api.tiktoken.get_encoding",
side_effect=Exception("Could not download encoding data"),
),
):
get_response = await client.get(
f"/v1/working-memory/{session_id}?model_name=gpt-4o"
)

assert get_response.status_code == 200, get_response.text
data = get_response.json()
assert data["session_id"] == session_id
assert len(data["messages"]) == 1

def test_get_tiktoken_encoding_skips_retries_within_backoff_window(self):
"""Repeated calls should not re-attempt loading within the retry interval."""
mock_get_encoding = Mock(side_effect=Exception("Could not download encoding"))

with (
patch("agent_memory_server.api._tiktoken_encoding", None),
patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None),
patch("agent_memory_server.api.time.monotonic", side_effect=[100.0, 101.0]),
patch("agent_memory_server.api.tiktoken.get_encoding", mock_get_encoding),
):
assert _get_tiktoken_encoding() is None
assert _get_tiktoken_encoding() is None

assert mock_get_encoding.call_count == 1

def test_get_tiktoken_encoding_retries_after_backoff_window(self):
"""A later call should retry loading once the backoff window has passed."""

class FakeEncoding:
def encode(self, text: str) -> list[int]:
return [1] * len(text)

mock_get_encoding = Mock(
side_effect=[Exception("temporary failure"), FakeEncoding()]
)

with (
patch("agent_memory_server.api._tiktoken_encoding", None),
patch("agent_memory_server.api._tiktoken_encoding_last_failed_at", None),
patch("agent_memory_server.api.time.monotonic", side_effect=[100.0, 401.0]),
patch("agent_memory_server.api.tiktoken.get_encoding", mock_get_encoding),
):
assert _get_tiktoken_encoding() is None
assert _get_tiktoken_encoding() is not None

assert mock_get_encoding.call_count == 2
Loading