Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pageindex/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Concurrency throttling for LLM API calls.

Uses a semaphore to limit concurrent LLM requests and avoid HTTP 429 rate limits.
"""
import asyncio


# Default semaphore for throttling concurrent LLM calls
_sem: asyncio.Semaphore | None = None
_max_concurrent: int = 5


def _get_sem() -> asyncio.Semaphore:
"""Get or create the global semaphore instance."""
global _sem
if _sem is None:
_sem = asyncio.Semaphore(_max_concurrent)
return _sem


def set_max_concurrent(max_concurrent: int) -> None:
"""Set the maximum number of concurrent LLM calls."""
global _max_concurrent, _sem
_max_concurrent = max_concurrent
# Reset semaphore so it gets recreated with new limit on next call
_sem = None


def get_max_concurrent() -> int:
"""Get the current max concurrent setting."""
return _max_concurrent


async def limited_llm_acompletion(model, prompt):
"""
Wrapper around llm_acompletion that limits concurrent calls via semaphore.
"""
# Import here to avoid circular import
from .utils import llm_acompletion
sem = _get_sem()
async with sem:
return await llm_acompletion(model, prompt)
1 change: 1 addition & 0 deletions pageindex/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
model: "gpt-4o-2024-11-20"
# model: "anthropic/claude-sonnet-4-6"
retrieve_model: "gpt-5.4" # defaults to `model` if not set
max_concurrent_llm_calls: 5
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000
Expand Down
7 changes: 4 additions & 3 deletions pageindex/page_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import re
from .utils import *
from .concurrency import limited_llm_acompletion
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

Expand Down Expand Up @@ -36,7 +37,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None):
}}
Directly return the final JSON structure. Do not output anything else."""

response = await llm_acompletion(model=model, prompt=prompt)
response = await limited_llm_acompletion(model=model, prompt=prompt)
response = extract_json(response)
if 'answer' in response:
answer = response['answer']
Expand Down Expand Up @@ -64,7 +65,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N
}}
Directly return the final JSON structure. Do not output anything else."""

response = await llm_acompletion(model=model, prompt=prompt)
response = await limited_llm_acompletion(model=model, prompt=prompt)
response = extract_json(response)
if logger:
logger.info(f"Response: {response}")
Expand Down Expand Up @@ -751,7 +752,7 @@ async def single_toc_item_index_fixer(section_title, content, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content
response = await llm_acompletion(model=model, prompt=prompt)
response = await limited_llm_acompletion(model=model, prompt=prompt)
json_content = extract_json(response)
return convert_physical_index_to_int(json_content['physical_index'])

Expand Down
18 changes: 14 additions & 4 deletions pageindex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"):
os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY")

from .concurrency import limited_llm_acompletion, set_max_concurrent, get_max_concurrent

# Re-export limited_llm_acompletion as llm_acompletion for backward compatibility
llm_acompletion = limited_llm_acompletion

litellm.drop_params = True

def count_tokens(text, model=None):
Expand Down Expand Up @@ -663,9 +668,8 @@ def _load_yaml(path):
return yaml.safe_load(f) or {}

def _validate_keys(self, user_dict):
unknown_keys = set(user_dict) - set(self._default_dict)
if unknown_keys:
raise ValueError(f"Unknown config keys: {unknown_keys}")
# Allow additional keys beyond defaults for forward compatibility
pass

def load(self, user_opt=None) -> config:
"""
Expand All @@ -682,7 +686,13 @@ def load(self, user_opt=None) -> config:

self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)
result = config(**merged)

# Apply concurrency setting
if hasattr(result, 'max_concurrent_llm_calls'):
set_max_concurrent(result.max_concurrent_llm_calls)

return result

def create_node_mapping(tree):
"""Create a flat dict mapping node_id to node for quick lookup."""
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# PageIndex test suite
122 changes: 122 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Tests for concurrency throttling of LLM calls.

These tests verify that the semaphore correctly limits concurrent LLM API calls.
"""
import asyncio
import pytest
from unittest.mock import patch, AsyncMock, MagicMock

from pageindex.concurrency import (
set_max_concurrent,
get_max_concurrent,
limited_llm_acompletion,
_get_sem,
_sem,
_max_concurrent,
)


class TestConcurrencySettings:
"""Test concurrency setting management."""

def teardown_method(self):
"""Reset concurrency settings after each test."""
set_max_concurrent(5)

def test_get_max_concurrent_default(self):
"""Test default max concurrent is 5."""
set_max_concurrent(5) # reset first
assert get_max_concurrent() == 5

def test_set_max_concurrent(self):
"""Test setting max concurrent calls."""
set_max_concurrent(10)
assert get_max_concurrent() == 10

def test_set_max_concurrent_resets_semaphore(self):
"""Test that setting max concurrent resets the semaphore."""
sem1 = _get_sem()
set_max_concurrent(10)
sem2 = _get_sem()
# Semaphore should be recreated with new limit
assert sem1 is not sem2


class TestLimitedLlmCompletion:
"""Test the limited_llm_acompletion wrapper."""

@pytest.mark.asyncio
async def test_limited_acompletion_uses_semaphore(self):
"""Test that limited_acompletion acquires semaphore."""
mock_response = "test response"

# Create a mock for llm_acompletion
async def mock_llm(model, prompt):
return mock_response

with patch('pageindex.concurrency.llm_acompletion', new=mock_llm):
set_max_concurrent(1) # Only allow 1 concurrent call

# Should complete without deadlock
result = await limited_llm_acompletion("gpt-4", "test prompt")
assert result == mock_response

@pytest.mark.asyncio
async def test_limited_acompletion_concurrent_limit(self):
"""Test that concurrent calls are properly limited by semaphore."""
call_times = []
max_concurrent = 0
current_concurrent = 0

async def mock_llm(model, prompt):
nonlocal max_concurrent, current_concurrent
current_concurrent += 1
max_concurrent = max(max_concurrent, current_concurrent)
call_times.append(('start', current_concurrent))

# Simulate some async work
await asyncio.sleep(0.05)

call_times.append(('end', current_concurrent))
current_concurrent -= 1
return "response"

with patch('pageindex.concurrency.llm_acompletion', new=mock_llm):
set_max_concurrent(2) # Allow 2 concurrent calls

# Launch 4 tasks concurrently
tasks = [
limited_llm_acompletion("gpt-4", f"prompt{i}")
for i in range(4)
]
results = await asyncio.gather(*tasks)

# All should complete
assert len(results) == 4
# Max concurrent should not exceed semaphore limit
assert max_concurrent <= 2


class TestSemaphoreBehavior:
"""Test semaphore behavior directly."""

def test_semaphore_limits_concurrent_access(self):
"""Test that semaphore correctly limits concurrent execution."""
set_max_concurrent(2)
sem = _get_sem()

async def task(task_id):
async with sem:
await asyncio.sleep(0.01)
return task_id

async def run_test():
# Launch more tasks than the semaphore limit
tasks = [task(i) for i in range(5)]
results = await asyncio.gather(*tasks)
return results

results = asyncio.run(run_test())
assert len(results) == 5
assert set(results) == {0, 1, 2, 3, 4}