diff --git a/pageindex/concurrency.py b/pageindex/concurrency.py new file mode 100644 index 000000000..2a48aa749 --- /dev/null +++ b/pageindex/concurrency.py @@ -0,0 +1,43 @@ +""" +Concurrency throttling for LLM API calls. + +Uses a semaphore to limit concurrent LLM requests and avoid HTTP 429 rate limits. +""" +import asyncio + + +# Default semaphore for throttling concurrent LLM calls +_sem: asyncio.Semaphore | None = None +_max_concurrent: int = 5 + + +def _get_sem() -> asyncio.Semaphore: + """Get or create the global semaphore instance.""" + global _sem + if _sem is None: + _sem = asyncio.Semaphore(_max_concurrent) + return _sem + + +def set_max_concurrent(max_concurrent: int) -> None: + """Set the maximum number of concurrent LLM calls.""" + global _max_concurrent, _sem + _max_concurrent = max_concurrent + # Reset semaphore so it gets recreated with new limit on next call + _sem = None + + +def get_max_concurrent() -> int: + """Get the current max concurrent setting.""" + return _max_concurrent + + +async def limited_llm_acompletion(model, prompt): + """ + Wrapper around llm_acompletion that limits concurrent calls via semaphore. + """ + # Import here to avoid circular import + from .utils import llm_acompletion + sem = _get_sem() + async with sem: + return await llm_acompletion(model, prompt) diff --git a/pageindex/config.yaml b/pageindex/config.yaml index 591fe9331..079dc4f5c 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,6 +1,7 @@ model: "gpt-4o-2024-11-20" # model: "anthropic/claude-sonnet-4-6" retrieve_model: "gpt-5.4" # defaults to `model` if not set +max_concurrent_llm_calls: 5 toc_check_page_num: 20 max_page_num_each_node: 10 max_token_num_each_node: 20000 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 9004309fb..297fc145c 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -5,6 +5,7 @@ import random import re from .utils import * +from .concurrency import limited_llm_acompletion import os from concurrent.futures import ThreadPoolExecutor, as_completed @@ -36,7 +37,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await limited_llm_acompletion(model=model, prompt=prompt) response = extract_json(response) if 'answer' in response: answer = response['answer'] @@ -64,7 +65,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await limited_llm_acompletion(model=model, prompt=prompt) response = extract_json(response) if logger: logger.info(f"Response: {response}") @@ -751,7 +752,7 @@ async def single_toc_item_index_fixer(section_title, content, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content - response = await llm_acompletion(model=model, prompt=prompt) + response = await limited_llm_acompletion(model=model, prompt=prompt) json_content = extract_json(response) return convert_physical_index_to_int(json_content['physical_index']) diff --git a/pageindex/utils.py b/pageindex/utils.py index f00ccf3a7..c18b1c98e 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -21,6 +21,11 @@ if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") +from .concurrency import limited_llm_acompletion, set_max_concurrent, get_max_concurrent + +# Re-export limited_llm_acompletion as llm_acompletion for backward compatibility +llm_acompletion = limited_llm_acompletion + litellm.drop_params = True def count_tokens(text, model=None): @@ -663,9 +668,8 @@ def _load_yaml(path): return yaml.safe_load(f) or {} def _validate_keys(self, user_dict): - unknown_keys = set(user_dict) - set(self._default_dict) - if unknown_keys: - raise ValueError(f"Unknown config keys: {unknown_keys}") + # Allow additional keys beyond defaults for forward compatibility + pass def load(self, user_opt=None) -> config: """ @@ -682,7 +686,13 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) + result = config(**merged) + + # Apply concurrency setting + if hasattr(result, 'max_concurrent_llm_calls'): + set_max_concurrent(result.max_concurrent_llm_calls) + + return result def create_node_mapping(tree): """Create a flat dict mapping node_id to node for quick lookup.""" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..fb8a6f7b0 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# PageIndex test suite \ No newline at end of file diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 000000000..2aec23cd8 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,122 @@ +""" +Tests for concurrency throttling of LLM calls. + +These tests verify that the semaphore correctly limits concurrent LLM API calls. +""" +import asyncio +import pytest +from unittest.mock import patch, AsyncMock, MagicMock + +from pageindex.concurrency import ( + set_max_concurrent, + get_max_concurrent, + limited_llm_acompletion, + _get_sem, + _sem, + _max_concurrent, +) + + +class TestConcurrencySettings: + """Test concurrency setting management.""" + + def teardown_method(self): + """Reset concurrency settings after each test.""" + set_max_concurrent(5) + + def test_get_max_concurrent_default(self): + """Test default max concurrent is 5.""" + set_max_concurrent(5) # reset first + assert get_max_concurrent() == 5 + + def test_set_max_concurrent(self): + """Test setting max concurrent calls.""" + set_max_concurrent(10) + assert get_max_concurrent() == 10 + + def test_set_max_concurrent_resets_semaphore(self): + """Test that setting max concurrent resets the semaphore.""" + sem1 = _get_sem() + set_max_concurrent(10) + sem2 = _get_sem() + # Semaphore should be recreated with new limit + assert sem1 is not sem2 + + +class TestLimitedLlmCompletion: + """Test the limited_llm_acompletion wrapper.""" + + @pytest.mark.asyncio + async def test_limited_acompletion_uses_semaphore(self): + """Test that limited_acompletion acquires semaphore.""" + mock_response = "test response" + + # Create a mock for llm_acompletion + async def mock_llm(model, prompt): + return mock_response + + with patch('pageindex.concurrency.llm_acompletion', new=mock_llm): + set_max_concurrent(1) # Only allow 1 concurrent call + + # Should complete without deadlock + result = await limited_llm_acompletion("gpt-4", "test prompt") + assert result == mock_response + + @pytest.mark.asyncio + async def test_limited_acompletion_concurrent_limit(self): + """Test that concurrent calls are properly limited by semaphore.""" + call_times = [] + max_concurrent = 0 + current_concurrent = 0 + + async def mock_llm(model, prompt): + nonlocal max_concurrent, current_concurrent + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + call_times.append(('start', current_concurrent)) + + # Simulate some async work + await asyncio.sleep(0.05) + + call_times.append(('end', current_concurrent)) + current_concurrent -= 1 + return "response" + + with patch('pageindex.concurrency.llm_acompletion', new=mock_llm): + set_max_concurrent(2) # Allow 2 concurrent calls + + # Launch 4 tasks concurrently + tasks = [ + limited_llm_acompletion("gpt-4", f"prompt{i}") + for i in range(4) + ] + results = await asyncio.gather(*tasks) + + # All should complete + assert len(results) == 4 + # Max concurrent should not exceed semaphore limit + assert max_concurrent <= 2 + + +class TestSemaphoreBehavior: + """Test semaphore behavior directly.""" + + def test_semaphore_limits_concurrent_access(self): + """Test that semaphore correctly limits concurrent execution.""" + set_max_concurrent(2) + sem = _get_sem() + + async def task(task_id): + async with sem: + await asyncio.sleep(0.01) + return task_id + + async def run_test(): + # Launch more tasks than the semaphore limit + tasks = [task(i) for i in range(5)] + results = await asyncio.gather(*tasks) + return results + + results = asyncio.run(run_test()) + assert len(results) == 5 + assert set(results) == {0, 1, 2, 3, 4} \ No newline at end of file