diff --git a/README.md b/README.md index a85fbd01d..7eeb37488 100644 --- a/README.md +++ b/README.md @@ -175,7 +175,8 @@ You can customize the processing with additional optional arguments: --max-tokens-per-node Max tokens per node (default: 20000) --if-add-node-id Add node ID (yes/no, default: yes) --if-add-node-summary Add node summary (yes/no, default: yes) ---if-add-doc-description Add doc description (yes/no, default: yes) +--if-add-doc-description Add doc description (yes/no, default: no) +--if-add-node-text Add node text (yes/no, default: no) ``` diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..527a2bef2 --- /dev/null +++ b/config.yaml @@ -0,0 +1,8 @@ +model: "gpt-4o-2024-11-20" +toc_check_page_num: 20 +max_page_num_each_node: 10 +max_token_num_each_node: 20000 +if_add_node_id: true +if_add_node_summary: true +if_add_doc_description: false +if_add_node_text: false diff --git a/pageindex/config.py b/pageindex/config.py new file mode 100644 index 000000000..cc0e3c547 --- /dev/null +++ b/pageindex/config.py @@ -0,0 +1,101 @@ +import os +import yaml +from pathlib import Path +from typing import Any, Dict, Optional, Union +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +class PageIndexConfig(BaseModel): + """ + Configuration schema for PageIndex. + """ + # Keep fallback defaults aligned with the shipped config.yaml profile so + # repo, test, and installed-package contexts behave the same way. + model: str = Field(default="gpt-4o-2024-11-20", description="LLM model to use") + retrieve_model: Optional[str] = Field( + default="gpt-5.4", + description="Model to use for agentic retrieval flows", + ) + + # PDF Processing + toc_check_page_num: int = Field(default=20, description="Number of pages to check for TOC") + max_page_num_each_node: int = Field(default=10, description="Maximum pages per leaf node") + max_token_num_each_node: int = Field(default=20000, description="Max tokens per node") # Approx + + # Enrichment + if_add_node_id: bool = Field(default=True, description="Add unique ID to nodes") + if_add_node_summary: bool = Field(default=True, description="Generate summary for nodes") + if_add_doc_description: bool = Field(default=False, description="Generate doc-level description") + if_add_node_text: bool = Field(default=False, description="Keep raw text in nodes") + + # Tree Optimization + if_thinning: bool = Field(default=True, description="Merge small adjacent nodes") + thinning_threshold: int = Field(default=500, description="Token threshold for merging") + summary_token_threshold: int = Field(default=200, description="Min tokens required to trigger summary generation") + + # Additional + api_key: Optional[str] = Field(default=None, description="OpenAI API Key (optional, prefers env var)") + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + +class ConfigLoader: + def __init__(self, default_path: Optional[Union[str, Path]] = None): + if default_path is None: + env_path = os.getenv("PAGEINDEX_CONFIG") + if env_path: + default_path = Path(env_path) + else: + cwd_path = Path.cwd() / "config.yaml" + repo_path = Path(__file__).resolve().parents[1] / "config.yaml" + default_path = cwd_path if cwd_path.exists() else repo_path + + self.default_path = default_path + self._default_dict = self._load_yaml(default_path) if default_path else {} + + @staticmethod + def _load_yaml(path: Optional[Path]) -> Dict[str, Any]: + if not path or not path.exists(): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except Exception as e: + print(f"Warning: Failed to load config from {path}: {e}") + return {} + + def load(self, user_opt: Optional[Union[Dict[str, Any], Any]] = None) -> PageIndexConfig: + """ + Load configuration, merging defaults with user overrides and validating via Pydantic. + + Args: + user_opt: Dictionary or object with overrides. + + Returns: + PageIndexConfig: Validated configuration object. + """ + user_dict = self._normalize_user_opt(user_opt) + + # Merge defaults and user overrides + # Pydantic accepts kwargs, efficiently merging + merged_data = {**self._default_dict, **user_dict} + + try: + return PageIndexConfig(**merged_data) + except ValidationError as e: + # Re-raise nicely or log + raise ValueError(f"Configuration validation failed: {e}") + + @staticmethod + def _normalize_user_opt(user_opt: Optional[Union[Dict[str, Any], Any]]) -> Dict[str, Any]: + if user_opt is None: + return {} + if isinstance(user_opt, BaseModel): + # Preserve explicit None values while ignoring unset fields. + return user_opt.model_dump(exclude_unset=True) + if isinstance(user_opt, dict): + return dict(user_opt) + if hasattr(user_opt, "__dict__"): + # Generic objects cannot distinguish "unset" from "explicitly None", + # so keep the previous behavior for backwards compatibility. + return {k: v for k, v in vars(user_opt).items() if v is not None} + raise TypeError(f"user_opt must be dict or object, got {type(user_opt)}") diff --git a/pageindex/config.yaml b/pageindex/config.yaml index 591fe9331..2b3c515e1 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,10 +1,9 @@ model: "gpt-4o-2024-11-20" -# model: "anthropic/claude-sonnet-4-6" retrieve_model: "gpt-5.4" # defaults to `model` if not set toc_check_page_num: 20 max_page_num_each_node: 10 max_token_num_each_node: 20000 -if_add_node_id: "yes" -if_add_node_summary: "yes" -if_add_doc_description: "no" -if_add_node_text: "no" \ No newline at end of file +if_add_node_id: true +if_add_node_summary: true +if_add_doc_description: false +if_add_node_text: false diff --git a/pageindex/core/__init__.py b/pageindex/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/core/llm.py b/pageindex/core/llm.py new file mode 100644 index 000000000..264788c76 --- /dev/null +++ b/pageindex/core/llm.py @@ -0,0 +1,245 @@ +import tiktoken +import openai +import logging +import os +import time +import json +import asyncio +from typing import Optional, List, Dict, Any, Union, Tuple +from dotenv import load_dotenv + +load_dotenv() + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("CHATGPT_API_KEY") + +def count_tokens(text: Optional[str], model: str = "gpt-4o") -> int: + """ + Count the number of tokens in a text string using the specified model's encoding. + + Args: + text (Optional[str]): The text to encode. If None, returns 0. + model (str): The model name to use for encoding. Defaults to "gpt-4o". + + Returns: + int: The number of tokens. + """ + if not text: + return 0 + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback for newer or unknown models + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(text) + return len(tokens) + +def ChatGPT_API_with_finish_reason( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY, + chat_history: Optional[List[Dict[str, str]]] = None +) -> Tuple[str, str]: + """ + Call OpenAI Chat Completion API and return content along with finish reason. + + Args: + model (str): The model name (e.g., "gpt-4o"). + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. Defaults to env var. + chat_history (Optional[List[Dict[str, str]]]): Previous messages for context. + + Returns: + Tuple[str, str]: A tuple containing (content, finish_reason). + Returns ("Error", "error") if max retries reached. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error", "missing_api_key" + + client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): + try: + if chat_history: + messages = chat_history.copy() # Avoid modifying original list if passed by ref (shallow copy enough for append) + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + + content = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason + + if finish_reason == "length": + return content, "max_output_reached" + else: + return content, "finished" + + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + time.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error", "error" + return "Error", "max_retries" + +def ChatGPT_API( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY, + chat_history: Optional[List[Dict[str, str]]] = None +) -> str: + """ + Call OpenAI Chat Completion API and return the content string. + + Args: + model (str): The model name. + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. + chat_history (Optional[List[Dict[str, str]]]): Previous messages. + + Returns: + str: The response content, or "Error" if failed. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error" + + client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): + try: + if chat_history: + messages = chat_history.copy() + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + + return response.choices[0].message.content or "" + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + time.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error" + return "Error" + +async def ChatGPT_API_async( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY +) -> str: + """ + Asynchronously call OpenAI Chat Completion API. + + Args: + model (str): The model name. + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. + + Returns: + str: The response content, or "Error" if failed. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error" + + messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): + try: + async with openai.AsyncOpenAI(api_key=api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content or "" + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + await asyncio.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error" + return "Error" + +def get_json_content(response: str) -> str: + """ + Extract content inside markdown JSON code blocks. + + Args: + response (str): The full raw response string. + + Returns: + str: The extracted JSON string stripped of markers. + """ + start_idx = response.find("```json") + if start_idx != -1: + start_idx += 7 + response = response[start_idx:] + + end_idx = response.rfind("```") + if end_idx != -1: + response = response[:end_idx] + + json_content = response.strip() + return json_content + +def extract_json(content: str) -> Union[Dict[str, Any], List[Any]]: + """ + Robustly extract and parse JSON from a string, handling common LLM formatting issues. + + Args: + content (str): The text containing JSON. + + Returns: + Union[Dict, List]: The parsed JSON object or empty dict/list on failure. + """ + try: + # First, try to extract JSON enclosed within ```json and ``` + start_idx = content.find("```json") + if start_idx != -1: + start_idx += 7 # Adjust index to start after the delimiter + end_idx = content.rfind("```") + json_content = content[start_idx:end_idx].strip() + else: + # If no delimiters, assume entire content could be JSON + json_content = content.strip() + + # Clean up common issues that might cause parsing errors + json_content = json_content.replace('None', 'null') # Replace Python None with JSON null + json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines + json_content = ' '.join(json_content.split()) # Normalize whitespace + + # Attempt to parse and return the JSON object + return json.loads(json_content) + except json.JSONDecodeError as e: + logging.error(f"Failed to extract JSON: {e}") + # Try to clean up the content further if initial parsing fails + try: + # Remove any trailing commas before closing brackets/braces + json_content = json_content.replace(',]', ']').replace(',}', '}') + return json.loads(json_content) + except: + logging.error("Failed to parse JSON even after cleanup") + return {} + except Exception as e: + logging.error(f"Unexpected error while extracting JSON: {e}") + return {} diff --git a/pageindex/core/logging.py b/pageindex/core/logging.py new file mode 100644 index 000000000..9e7cd0be3 --- /dev/null +++ b/pageindex/core/logging.py @@ -0,0 +1,65 @@ +import os +import json +from datetime import datetime +from typing import Any, Dict, Optional, Union, List +from .pdf import get_pdf_name + +class JsonLogger: + """ + A simple JSON-based logger that writes distinct log files for each run session. + """ + def __init__(self, file_path: Union[str, Any]): + """ + Initialize the logger. + + Args: + file_path (Union[str, Any]): The source file path (usually PDF) to derive the log filename from. + """ + # Extract PDF name for logger name + pdf_name = get_pdf_name(file_path) + + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + self.filename = f"{pdf_name}_{current_time}.json" + os.makedirs("./logs", exist_ok=True) + # Initialize empty list to store all messages + self.log_data: List[Dict[str, Any]] = [] + + def log(self, level: str, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + """ + Log a message. + + Args: + level (str): Log level (INFO, ERROR, etc.) + message (Union[str, Dict]): The message content. + """ + entry: Dict[str, Any] = {} + if isinstance(message, dict): + entry = message + else: + entry = {'message': message} + + entry['level'] = level + entry['timestamp'] = datetime.now().isoformat() + entry.update(kwargs) + + self.log_data.append(entry) + + # Write entire log data to file (inefficient for large logs, but simple for now) + with open(self._filepath(), "w", encoding='utf-8') as f: + json.dump(self.log_data, f, indent=2, ensure_ascii=False) + + def info(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("INFO", message, **kwargs) + + def error(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("ERROR", message, **kwargs) + + def debug(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("DEBUG", message, **kwargs) + + def exception(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + kwargs["exception"] = True + self.log("ERROR", message, **kwargs) + + def _filepath(self) -> str: + return os.path.join("logs", self.filename) diff --git a/pageindex/core/pdf.py b/pageindex/core/pdf.py new file mode 100644 index 000000000..3f66009a0 --- /dev/null +++ b/pageindex/core/pdf.py @@ -0,0 +1,210 @@ +import PyPDF2 +import pymupdf +import re +import os +import tiktoken +from io import BytesIO +from typing import List, Tuple, Union, Optional +from .llm import count_tokens + +def extract_text_from_pdf(pdf_path: str) -> str: + """ + Extract all text from a PDF file using PyPDF2. + + Args: + pdf_path (str): Path to the PDF file. + + Returns: + str: Concatenated text from all pages. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + text = "" + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + text += page.extract_text() or "" + return text + +def get_pdf_title(pdf_path: Union[str, BytesIO]) -> str: + """ + Extract the title from PDF metadata. + + Args: + pdf_path (Union[str, BytesIO]): Path to PDF or BytesIO object. + + Returns: + str: Title of the PDF or 'Untitled'. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + meta = pdf_reader.metadata + title = meta.title if meta and meta.title else 'Untitled' + return title + +def get_text_of_pages(pdf_path: str, start_page: int, end_page: int, tag: bool = True) -> str: + """ + Get text from a specific range of pages in a PDF. + + Args: + pdf_path (str): Path to the PDF file. + start_page (int): Start page number (1-based). + end_page (int): End page number (1-based). + tag (bool): If True, wraps page text in ... tags. + + Returns: + str: Extracted text. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + text = "" + for page_num in range(start_page-1, end_page): + if page_num < len(pdf_reader.pages): + page = pdf_reader.pages[page_num] + page_text = page.extract_text() + if tag: + text += f"\n{page_text}\n\n" + else: + text += page_text + return text + +def get_first_start_page_from_text(text: str) -> int: + """ + Extract the first page index tag found in text. + + Args: + text (str): Text containing tags. + + Returns: + int: Page number or -1 if not found. + """ + start_page = -1 + start_page_match = re.search(r'', text) + if start_page_match: + start_page = int(start_page_match.group(1)) + return start_page + +def get_last_start_page_from_text(text: str) -> int: + """ + Extract the last page index tag found in text. + + Args: + text (str): Text containing tags. + + Returns: + int: Page number or -1 if not found. + """ + start_page = -1 + start_page_matches = re.finditer(r'', text) + matches_list = list(start_page_matches) + if matches_list: + start_page = int(matches_list[-1].group(1)) + return start_page + + +def sanitize_filename(filename: str, replacement: str = '-') -> str: + """Replace illegal characters in filename.""" + return filename.replace('/', replacement) + +def get_pdf_name(pdf_path: Union[str, BytesIO]) -> str: + """ + Get a sanitized name for the PDF file. + + Args: + pdf_path (Union[str, BytesIO]): Path or file object. + + Returns: + str: Filename or logical title. + """ + pdf_name = "Untitled.pdf" + if isinstance(pdf_path, str): + pdf_name = os.path.basename(pdf_path) + elif isinstance(pdf_path, BytesIO): + pdf_reader = PyPDF2.PdfReader(pdf_path) + meta = pdf_reader.metadata + if meta and meta.title: + pdf_name = meta.title + pdf_name = sanitize_filename(pdf_name) + return pdf_name + + +def get_page_tokens( + pdf_path: Union[str, BytesIO], + model: str = "gpt-4o-2024-11-20", + pdf_parser: str = "PyPDF2" +) -> List[Tuple[str, int]]: + """ + Extract text and token counts for each page. + + Args: + pdf_path (Union[str, BytesIO]): Path to PDF. + model (str): Model name for token counting. + pdf_parser (str): "PyPDF2" or "PyMuPDF". + + Returns: + List[Tuple[str, int]]: List of (page_text, token_count). + """ + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") + if pdf_parser == "PyPDF2": + pdf_reader = PyPDF2.PdfReader(pdf_path) + page_list = [] + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + page_text = page.extract_text() or "" + token_length = len(enc.encode(page_text)) + page_list.append((page_text, token_length)) + return page_list + elif pdf_parser == "PyMuPDF": + if isinstance(pdf_path, BytesIO): + pdf_stream = pdf_path + doc = pymupdf.open(stream=pdf_stream, filetype="pdf") + elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"): + doc = pymupdf.open(pdf_path) + else: + raise ValueError(f"Invalid pdf path for PyMuPDF: {pdf_path}") + + page_list = [] + for page in doc: + page_text = page.get_text() + token_length = len(enc.encode(page_text)) + page_list.append((page_text, token_length)) + return page_list + else: + raise ValueError(f"Unsupported PDF parser: {pdf_parser}") + + + +def get_text_of_pdf_pages(pdf_pages: List[Tuple[str, int]], start_page: int, end_page: int) -> str: + """ + Combine text from a list of page tuples [1-based range]. + + Args: + pdf_pages (List[Tuple[str, int]]): Output from get_page_tokens. + start_page (int): Start page (1-based). + end_page (int): End page (1-based, inclusive). + + Returns: + str: Combined text. + """ + text = "" + # Safe indexing + total_pages = len(pdf_pages) + for page_num in range(start_page-1, end_page): + if 0 <= page_num < total_pages: + text += pdf_pages[page_num][0] + return text + +def get_text_of_pdf_pages_with_labels(pdf_pages: List[Tuple[str, int]], start_page: int, end_page: int) -> str: + """ + Combine text from pages with tags. + """ + text = "" + total_pages = len(pdf_pages) + for page_num in range(start_page-1, end_page): + if 0 <= page_num < total_pages: + text += f"\n{pdf_pages[page_num][0]}\n\n" + return text + +def get_number_of_pages(pdf_path: Union[str, BytesIO]) -> int: + """Get total page count of a PDF.""" + pdf_reader = PyPDF2.PdfReader(pdf_path) + return len(pdf_reader.pages) diff --git a/pageindex/core/tree.py b/pageindex/core/tree.py new file mode 100644 index 000000000..b7babd88a --- /dev/null +++ b/pageindex/core/tree.py @@ -0,0 +1,545 @@ +import copy +import json +import asyncio +from typing import List, Dict, Any, Optional, Union +from .llm import count_tokens, ChatGPT_API, ChatGPT_API_async + +# Type aliases for tree structures +Node = Dict[str, Any] +Tree = List[Node] +Structure = Union[Node, List[Any]] # Recursive definition limitation in MyPy, using Any for nested + +def write_node_id(data: Structure, node_id: int = 0) -> int: + """ + Recursively assign sequential node_ids to a tree structure. + + Args: + data (Structure): The tree or node to process. + node_id (int): The starting ID. + + Returns: + int: The next available node_id. + """ + if isinstance(data, dict): + data['node_id'] = str(node_id).zfill(4) + node_id += 1 + for key in list(data.keys()): + if 'nodes' in key: + node_id = write_node_id(data[key], node_id) + elif isinstance(data, list): + for index in range(len(data)): + node_id = write_node_id(data[index], node_id) + return node_id + +def get_nodes(structure: Structure) -> List[Node]: + """ + Flatten the tree into a list of nodes, excluding their children 'nodes' list from the copy. + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: A flat list of node dictionaries (without 'nodes' key). + """ + if isinstance(structure, dict): + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + nodes = [structure_node] + for key in list(structure.keys()): + if 'nodes' in key: + nodes.extend(get_nodes(structure[key])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(get_nodes(item)) + return nodes + return [] + +def structure_to_list(structure: Structure) -> List[Node]: + """ + Flatten the tree into a list of references to all nodes (including containers). + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: Flat list of all nodes. + """ + if isinstance(structure, dict): + nodes = [] + nodes.append(structure) + if 'nodes' in structure: + nodes.extend(structure_to_list(structure['nodes'])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(structure_to_list(item)) + return nodes + return [] + + +def get_leaf_nodes(structure: Structure) -> List[Node]: + """ + Get all leaf nodes (nodes with no children). + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: List of leaf node copies (without 'nodes' key). + """ + if isinstance(structure, dict): + if not structure.get('nodes'): + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + return [structure_node] + else: + leaf_nodes = [] + for key in list(structure.keys()): + if 'nodes' in key: + leaf_nodes.extend(get_leaf_nodes(structure[key])) + return leaf_nodes + elif isinstance(structure, list): + leaf_nodes = [] + for item in structure: + leaf_nodes.extend(get_leaf_nodes(item)) + return leaf_nodes + return [] + +def is_leaf_node(data: Structure, node_id: str) -> bool: + """ + Check if a node with specific ID is a leaf node. + + Args: + data (Structure): The tree structure. + node_id (str): The ID to check. + + Returns: + bool: True if node exists and has no children. + """ + # Helper function to find the node by its node_id + def find_node(data: Structure, node_id: str) -> Optional[Node]: + if isinstance(data, dict): + if data.get('node_id') == node_id: + return data + for key in data.keys(): + if 'nodes' in key: + result = find_node(data[key], node_id) + if result: + return result + elif isinstance(data, list): + for item in data: + result = find_node(item, node_id) + if result: + return result + return None + + # Find the node with the given node_id + node = find_node(data, node_id) + + # Check if the node is a leaf node + if node and not node.get('nodes'): + return True + return False + +def get_last_node(structure: List[Any]) -> Any: + """Get the last element of a list structure.""" + return structure[-1] + +def list_to_tree(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert a flat list of nodes with dot-notation 'structure' keys (e.g., '1.1') + into a nested tree. + + Args: + data (List[Dict[str, Any]]): List of node dictionaries. + + Returns: + List[Dict[str, Any]]: The nested tree structure. + """ + def get_parent_structure(structure: Optional[str]) -> Optional[str]: + """Helper function to get the parent structure code""" + if not structure: + return None + parts = str(structure).split('.') + return '.'.join(parts[:-1]) if len(parts) > 1 else None + + # First pass: Create nodes and track parent-child relationships + nodes: Dict[str, Dict[str, Any]] = {} + root_nodes: List[Dict[str, Any]] = [] + + for item in data: + structure = str(item.get('structure', '')) + node = { + 'title': item.get('title'), + 'start_index': item.get('start_index'), + 'end_index': item.get('end_index'), + 'nodes': [] + } + + nodes[structure] = node + + # Find parent + parent_structure = get_parent_structure(structure) + + if parent_structure: + # Add as child to parent if parent exists + if parent_structure in nodes: + nodes[parent_structure]['nodes'].append(node) + else: + root_nodes.append(node) + else: + # No parent, this is a root node + root_nodes.append(node) + + # Helper function to clean empty children arrays + def clean_node(node: Dict[str, Any]) -> Dict[str, Any]: + if not node['nodes']: + del node['nodes'] + else: + for child in node['nodes']: + clean_node(child) + return node + + # Clean and return the tree + return [clean_node(node) for node in root_nodes] + +def add_preface_if_needed(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Inject a Preface node if the first node starts after page 1. + """ + if not isinstance(data, list) or not data: + return data + + if data[0].get('physical_index') is not None and data[0]['physical_index'] > 1: + preface_node = { + "structure": "0", + "title": "Preface", + "physical_index": 1, + } + data.insert(0, preface_node) + return data + + +def post_processing(structure: List[Dict[str, Any]], end_physical_index: int) -> Union[List[Dict[str, Any]], List[Any]]: + """ + Calculate start/end indices based on 'physical_index' and convert to tree if possible. + + Args: + structure: List of flat nodes. + end_physical_index: Total pages or end index. + + Returns: + Tree or List. + """ + # First convert page_number to start_index in flat list + for i, item in enumerate(structure): + item['start_index'] = item.get('physical_index') + if i < len(structure) - 1: + if structure[i + 1].get('appear_start') == 'yes': + item['end_index'] = structure[i + 1]['physical_index']-1 + else: + item['end_index'] = structure[i + 1]['physical_index'] + else: + item['end_index'] = end_physical_index + tree = list_to_tree(structure) + if len(tree)!=0: + return tree + else: + ### remove appear_start + for node in structure: + node.pop('appear_start', None) + node.pop('physical_index', None) + return structure + +def clean_structure_post(data: Structure) -> Structure: + """Recursively clean internal processing fields from structure.""" + if isinstance(data, dict): + data.pop('page_number', None) + data.pop('start_index', None) + data.pop('end_index', None) + if 'nodes' in data: + clean_structure_post(data['nodes']) + elif isinstance(data, list): + for section in data: + clean_structure_post(section) + return data + +def remove_fields(data: Structure, fields: List[str] = ['text']) -> Structure: + """Recursively remove specified fields from the structure.""" + if isinstance(data, dict): + return {k: remove_fields(v, fields) + for k, v in data.items() if k not in fields} + elif isinstance(data, list): + return [remove_fields(item, fields) for item in data] + return data + +def print_toc(tree: List[Dict[str, Any]], indent: int = 0) -> None: + """Print Table of Contents to stdout.""" + for node in tree: + print(' ' * indent + str(node.get('title', ''))) + if node.get('nodes'): + print_toc(node['nodes'], indent + 1) + +def print_json(data: Any, max_len: int = 40, indent: int = 2) -> None: + """Pretty print JSON with truncated strings.""" + def simplify_data(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: simplify_data(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [simplify_data(item) for item in obj] + elif isinstance(obj, str) and len(obj) > max_len: + return obj[:max_len] + '...' + else: + return obj + + simplified = simplify_data(data) + print(json.dumps(simplified, indent=indent, ensure_ascii=False)) + + +def print_wrapped(text: Any, width: int = 100) -> None: + """Print text wrapped to specified width.""" + import textwrap + + if text is None: + return + for line in str(text).splitlines(): + if not line.strip(): + print() + continue + for wrapped in textwrap.wrap(line, width=width): + print(wrapped) + + +def print_tree(tree: List[Dict[str, Any]], exclude_fields: Optional[List[str]] = None, indent: int = 0, max_summary_len: int = 120) -> None: + """Print tree structure with node IDs and summaries.""" + if exclude_fields: + # Cast to Any to satisfy mypy since remove_fields returns Structure + tree = remove_fields(tree, fields=exclude_fields) # type: ignore + + for node in tree: + node_id = node.get('node_id', '') + title = node.get('title', '') + start = node.get('start_index') + end = node.get('end_index') + summary = node.get('summary') or node.get('prefix_summary') + page_range = None + if start is not None and end is not None: + page_range = start if start == end else f"{start}-{end}" + line = f"{node_id}\t{page_range}\t{title}" if page_range else f"{node_id}\t{title}" + if summary: + short_summary = summary if len(summary) <= max_summary_len else summary[:max_summary_len] + '...' + line = f"{line} — {short_summary}" + print(' ' * indent + line) + if node.get('nodes'): + print_tree(node['nodes'], exclude_fields=exclude_fields, indent=indent + 1, max_summary_len=max_summary_len) + + +def create_node_mapping(tree: List[Dict[str, Any]], include_page_ranges: bool = False, max_page: Optional[int] = None) -> Dict[str, Any]: + """Create a dictionary mapping node_ids to nodes.""" + mapping = {} + + def clamp_page(value: Optional[int]) -> Optional[int]: + if value is None or max_page is None: + return value + return max(1, min(value, max_page)) + + def visit(node: Dict[str, Any]) -> None: + node_id = node.get('node_id') + if node_id: + if include_page_ranges: + start = clamp_page(node.get('start_index')) + end = clamp_page(node.get('end_index')) + mapping[node_id] = { + 'node': node, + 'start_index': start, + 'end_index': end, + } + else: + mapping[node_id] = node + for child in node.get('nodes') or []: + visit(child) + + for root in tree: + visit(root) + + return mapping + + +def remove_structure_text(data: Structure) -> Structure: + """Recursively remove 'text' field.""" + if isinstance(data, dict): + data.pop('text', None) + if 'nodes' in data: + remove_structure_text(data['nodes']) + elif isinstance(data, list): + for item in data: + remove_structure_text(item) + return data + + +def check_token_limit(structure: Structure, limit: int = 110000) -> None: + """Check if any node exceeds the token limit.""" + flat_list = structure_to_list(structure) + for node in flat_list: + text = node.get('text', '') + num_tokens = count_tokens(text, model='gpt-4o') + if num_tokens > limit: + print(f"Node ID: {node.get('node_id')} has {num_tokens} tokens") + print("Start Index:", node.get('start_index')) + print("End Index:", node.get('end_index')) + print("Title:", node.get('title')) + print("\n") + + +def convert_physical_index_to_int(data: Any) -> Any: + """Convert physical_index strings (e.g., '') to integers inplace.""" + if isinstance(data, list): + for i in range(len(data)): + # Check if item is a dictionary and has 'physical_index' key + if isinstance(data[i], dict) and 'physical_index' in data[i]: + if isinstance(data[i]['physical_index'], str): + if data[i]['physical_index'].startswith('').strip()) + elif data[i]['physical_index'].startswith('physical_index_'): + data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip()) + elif isinstance(data, str): + if data.startswith('').strip()) + elif data.startswith('physical_index_'): + data = int(data.split('_')[-1].strip()) + # Check data is int + if isinstance(data, int): + return data + else: + return None + return data + + +def convert_page_to_int(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert 'page' field to int if possible.""" + for item in data: + if 'page' in item and isinstance(item['page'], str): + try: + item['page'] = int(item['page']) + except ValueError: + # Keep original value if conversion fails + pass + return data + +from .pdf import get_text_of_pdf_pages, get_text_of_pdf_pages_with_labels + +def add_node_text(node: Structure, pdf_pages: List[Any]) -> None: + """Recursively add text to nodes from pdf_pages list based on page range.""" + if isinstance(node, dict): + start_page = node.get('start_index') + end_page = node.get('end_index') + if start_page is not None and end_page is not None: + node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page) + if 'nodes' in node: + add_node_text(node['nodes'], pdf_pages) + elif isinstance(node, list): + for index in range(len(node)): + add_node_text(node[index], pdf_pages) + return + + +def add_node_text_with_labels(node: Structure, pdf_pages: List[Any]) -> None: + """Recursively add text with physical index labels.""" + if isinstance(node, dict): + start_page = node.get('start_index') + end_page = node.get('end_index') + if start_page is not None and end_page is not None: + node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page) + if 'nodes' in node: + add_node_text_with_labels(node['nodes'], pdf_pages) + elif isinstance(node, list): + for index in range(len(node)): + add_node_text_with_labels(node[index], pdf_pages) + return + + +async def generate_node_summary(node: Dict[str, Any], model: Optional[str] = None) -> str: + """Generate summary for a node using LLM.""" + # Ensure text exists + text = node.get('text', '') + prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. + + Partial Document Text: {text} + + Directly return the description, do not include any other text. + """ + # Note: model name should ideally be passed, default handled in API + response = await ChatGPT_API_async(model or "gpt-4o", prompt) + return response + + +async def generate_summaries_for_structure(structure: Structure, model: Optional[str] = None) -> Structure: + """Generate summaries for all nodes in the structure.""" + nodes = structure_to_list(structure) + tasks = [generate_node_summary(node, model=model) for node in nodes] + summaries = await asyncio.gather(*tasks) + + for node, summary in zip(nodes, summaries): + node['summary'] = summary + return structure + + +def create_clean_structure_for_description(structure: Structure) -> Structure: + """ + Create a clean structure for document description generation, + excluding unnecessary fields like 'text'. + """ + if isinstance(structure, dict): + clean_node: Dict[str, Any] = {} + # Only include essential fields for description + for key in ['title', 'node_id', 'summary', 'prefix_summary']: + if key in structure: + clean_node[key] = structure[key] + + # Recursively process child nodes + if 'nodes' in structure and structure['nodes']: + clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) + + return clean_node + elif isinstance(structure, list): + return [create_clean_structure_for_description(item) for item in structure] # type: ignore + else: + return structure + + +def generate_doc_description(structure: Structure, model: str = "gpt-4o") -> str: + """Generate a one-sentence description for the entire document structure.""" + prompt = f"""Your are an expert in generating descriptions for a document. + You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. + + Document Structure: {structure} + + Directly return the description, do not include any other text. + """ + response = ChatGPT_API(model, prompt) + return response + + +def reorder_dict(data: Dict[str, Any], key_order: List[str]) -> Dict[str, Any]: + """Reorder dictionary keys.""" + if not key_order: + return data + return {key: data[key] for key in key_order if key in data} + + +def format_structure(structure: Structure, order: Optional[List[str]] = None) -> Structure: + """Recursively format and reorder keys in the structure.""" + if not order: + return structure + if isinstance(structure, dict): + if 'nodes' in structure: + structure['nodes'] = format_structure(structure['nodes'], order) + if not structure.get('nodes'): + structure.pop('nodes', None) + structure = reorder_dict(structure, order) + elif isinstance(structure, list): + structure = [format_structure(item, order) for item in structure] # type: ignore + return structure diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 9004309fb..aeab58ed3 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -4,7 +4,11 @@ import math import random import re -from .utils import * +from .core.llm import ChatGPT_API, ChatGPT_API_with_finish_reason, ChatGPT_API_async, extract_json, count_tokens, get_json_content +from .core.tree import convert_page_to_int, convert_physical_index_to_int, add_node_text, add_node_text_with_labels +from .core.pdf import get_number_of_pages, get_pdf_title, get_page_tokens, get_text_of_pages, get_first_start_page_from_text, get_last_start_page_from_text +from .core.logging import JsonLogger +from pageindex.config import ConfigLoader import os from concurrent.futures import ThreadPoolExecutor, as_completed @@ -36,7 +40,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await ChatGPT_API_async(model=model, prompt=prompt) response = extract_json(response) if 'answer' in response: answer = response['answer'] @@ -64,7 +68,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await ChatGPT_API_async(model=model, prompt=prompt) response = extract_json(response) if logger: logger.info(f"Response: {response}") @@ -116,7 +120,7 @@ def toc_detector_single_page(content, model=None): Directly return the final JSON structure. Do not output anything else. Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) # print('response', response) json_content = extract_json(response) return json_content['toc_detected'] @@ -135,7 +139,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -153,7 +157,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -165,7 +169,7 @@ def extract_toc_content(content, model=None): Directly return the full table of contents content. Do not output anything else.""" - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if_complete = check_if_toc_transformation_is_complete(content, response, model) if if_complete == "yes" and finish_reason == "finished": @@ -176,24 +180,23 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) - attempt = 0 - max_attempts = 5 - + attempts = 0 + max_attempts = 10 while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: + attempts += 1 + if attempts > max_attempts: raise Exception('Failed to complete table of contents after maximum retries') - + chat_history = [ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": response}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -215,7 +218,7 @@ def detect_page_index(toc_content, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['page_index_given_in_toc'] @@ -264,7 +267,7 @@ def toc_index_extractor(toc, content, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content @@ -292,7 +295,7 @@ def toc_transformer(toc_content, model=None): Directly return the final JSON structure, do not output anything else. """ prompt = init_prompt + '\n Given table of contents\n:' + toc_content - last_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) @@ -300,12 +303,13 @@ def toc_transformer(toc_content, model=None): return cleaned_response last_complete = get_json_content(last_complete) - attempt = 0 - max_attempts = 5 + attempts = 0 + max_attempts = 10 while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: - raise Exception('Failed to complete toc transformation after maximum retries') + attempts += 1 + if attempts > max_attempts: + raise Exception('Failed to complete table of contents after maximum retries') + position = last_complete.rfind('}') if position != -1: last_complete = last_complete[:position+2] @@ -321,11 +325,17 @@ def toc_transformer(toc_content, model=None): Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) - if new_complete.startswith('```json'): - new_complete = get_json_content(new_complete) - last_complete = last_complete+new_complete + new_complete_cleaned = new_complete.strip() + if new_complete_cleaned.startswith("```json"): + new_complete_cleaned = new_complete_cleaned[7:] + if new_complete_cleaned.startswith("```"): + new_complete_cleaned = new_complete_cleaned[3:] + if new_complete_cleaned.endswith("```"): + new_complete_cleaned = new_complete_cleaned[:-3] + + last_complete = last_complete + new_complete_cleaned if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) @@ -482,7 +492,7 @@ def add_page_number_to_toc(part, structure, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" - current_json_raw = llm_completion(model=model, prompt=prompt) + current_json_raw = ChatGPT_API(model=model, prompt=prompt) json_result = extract_json(current_json_raw) for item in json_result: @@ -504,7 +514,7 @@ def remove_first_physical_index_section(text): return text ### add verify completeness -def generate_toc_continue(toc_content, part, model=None): +def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): print('start generate_toc_continue') prompt = """ You are an expert in extracting hierarchical tree structure. @@ -532,7 +542,7 @@ def generate_toc_continue(toc_content, part, model=None): Directly return the additional part of the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': return extract_json(response) else: @@ -566,7 +576,7 @@ def generate_toc_init(part, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': return extract_json(response) @@ -674,9 +684,9 @@ def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): page_contents = [] for page_index in range(prev_physical_index, next_physical_index+1): # Add bounds checking to prevent IndexError - list_index = page_index - start_index - if list_index >= 0 and list_index < len(page_list): - page_text = f"\n{page_list[list_index][0]}\n\n\n" + page_list_idx = page_index - start_index + if page_list_idx >= 0 and page_list_idx < len(page_list): + page_text = f"\n{page_list[page_list_idx][0]}\n\n\n" page_contents.append(page_text) else: continue @@ -737,7 +747,7 @@ def check_toc(page_list, opt=None): ################### fix incorrect toc ######################################################### -async def single_toc_item_index_fixer(section_title, content, model=None): +def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"): toc_extractor_prompt = """ You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. @@ -750,8 +760,27 @@ async def single_toc_item_index_fixer(section_title, content, model=None): } Directly return the final JSON structure. Do not output anything else.""" - prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content - response = await llm_acompletion(model=model, prompt=prompt) + prompt = tob_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content + response = ChatGPT_API(model=model, prompt=prompt) + json_content = extract_json(response) + return convert_physical_index_to_int(json_content['physical_index']) + + +async def single_toc_item_index_fixer_async(section_title, content, model="gpt-4o-2024-11-20"): + tob_extractor_prompt = """ + You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. + + The provided pages contains tags like and to indicate the physical location of the page X. + + Reply in a JSON format: + { + "thinking": , contains the start of this section>, + "physical_index": "" (keep the format) + } + Directly return the final JSON structure. Do not output anything else.""" + + prompt = tob_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content + response = await ChatGPT_API_async(model=model, prompt=prompt) json_content = extract_json(response) return convert_physical_index_to_int(json_content['physical_index']) @@ -820,7 +849,7 @@ async def process_and_check_item(incorrect_item): continue content_range = ''.join(page_contents) - physical_index_int = await single_toc_item_index_fixer(incorrect_item['title'], content_range, model) + physical_index_int = await single_toc_item_index_fixer_async(incorrect_item['title'], content_range, model) # Check if the result is correct check_item = incorrect_item.copy() @@ -1074,6 +1103,8 @@ def page_index_main(doc, opt=None): raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') + if opt is None: + opt = ConfigLoader().load() page_list = get_page_tokens(doc, model=opt.model) logger.info({'total_page_number': len(page_list)}) @@ -1081,17 +1112,17 @@ def page_index_main(doc, opt=None): async def page_index_builder(): structure = await tree_parser(page_list, opt, doc=doc, logger=logger) - if opt.if_add_node_id == 'yes': + if opt.if_add_node_id: write_node_id(structure) - if opt.if_add_node_text == 'yes': + if opt.if_add_node_text: add_node_text(structure, page_list) - if opt.if_add_node_summary == 'yes': - if opt.if_add_node_text == 'no': + if opt.if_add_node_summary: + if not opt.if_add_node_text: add_node_text(structure, page_list) await generate_summaries_for_structure(structure, model=opt.model) - if opt.if_add_node_text == 'no': + if not opt.if_add_node_text: remove_structure_text(structure) - if opt.if_add_doc_description == 'yes': + if opt.if_add_doc_description: # Create a clean structure without unnecessary fields for description generation clean_structure = create_clean_structure_for_description(structure) doc_description = generate_doc_description(clean_structure, model=opt.model) diff --git a/pageindex/utils.py b/pageindex/utils.py index f00ccf3a7..8b05fd114 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,710 +1,49 @@ -import litellm -import logging -import os -import textwrap -from datetime import datetime -import time -import json -import PyPDF2 -import copy -import asyncio -import pymupdf -from io import BytesIO -from dotenv import load_dotenv -load_dotenv() -import logging -import yaml -from pathlib import Path -from types import SimpleNamespace as config - -# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY -if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): - os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") - -litellm.drop_params = True - -def count_tokens(text, model=None): - if not text: - return 0 - return litellm.token_counter(model=model, text=text) - - -def llm_completion(model, prompt, chat_history=None, return_finish_reason=False): - if model: - model = model.removeprefix("litellm/") - max_retries = 10 - messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] - for i in range(max_retries): - try: - response = litellm.completion( - model=model, - messages=messages, - temperature=0, - ) - content = response.choices[0].message.content - if return_finish_reason: - finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished" - return content, finish_reason - return content - except Exception as e: - print('************* Retrying *************') - logging.error(f"Error: {e}") - if i < max_retries - 1: - time.sleep(1) - else: - logging.error('Max retries reached for prompt: ' + prompt) - if return_finish_reason: - return "", "error" - return "" - - - -async def llm_acompletion(model, prompt): - if model: - model = model.removeprefix("litellm/") - max_retries = 10 - messages = [{"role": "user", "content": prompt}] - for i in range(max_retries): - try: - response = await litellm.acompletion( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content - except Exception as e: - print('************* Retrying *************') - logging.error(f"Error: {e}") - if i < max_retries - 1: - await asyncio.sleep(1) - else: - logging.error('Max retries reached for prompt: ' + prompt) - return "" - - -def get_json_content(response): - start_idx = response.find("```json") - if start_idx != -1: - start_idx += 7 - response = response[start_idx:] - - end_idx = response.rfind("```") - if end_idx != -1: - response = response[:end_idx] - - json_content = response.strip() - return json_content - - -def extract_json(content): - try: - # First, try to extract JSON enclosed within ```json and ``` - start_idx = content.find("```json") - if start_idx != -1: - start_idx += 7 # Adjust index to start after the delimiter - end_idx = content.rfind("```") - json_content = content[start_idx:end_idx].strip() - else: - # If no delimiters, assume entire content could be JSON - json_content = content.strip() - - # Clean up common issues that might cause parsing errors - json_content = json_content.replace('None', 'null') # Replace Python None with JSON null - json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines - json_content = ' '.join(json_content.split()) # Normalize whitespace - - # Attempt to parse and return the JSON object - return json.loads(json_content) - except json.JSONDecodeError as e: - logging.error(f"Failed to extract JSON: {e}") - # Try to clean up the content further if initial parsing fails - try: - # Remove any trailing commas before closing brackets/braces - json_content = json_content.replace(',]', ']').replace(',}', '}') - return json.loads(json_content) - except: - logging.error("Failed to parse JSON even after cleanup") - return {} - except Exception as e: - logging.error(f"Unexpected error while extracting JSON: {e}") - return {} - -def write_node_id(data, node_id=0): - if isinstance(data, dict): - data['node_id'] = str(node_id).zfill(4) - node_id += 1 - for key in list(data.keys()): - if 'nodes' in key: - node_id = write_node_id(data[key], node_id) - elif isinstance(data, list): - for index in range(len(data)): - node_id = write_node_id(data[index], node_id) - return node_id - -def get_nodes(structure): - if isinstance(structure, dict): - structure_node = copy.deepcopy(structure) - structure_node.pop('nodes', None) - nodes = [structure_node] - for key in list(structure.keys()): - if 'nodes' in key: - nodes.extend(get_nodes(structure[key])) - return nodes - elif isinstance(structure, list): - nodes = [] - for item in structure: - nodes.extend(get_nodes(item)) - return nodes - -def structure_to_list(structure): - if isinstance(structure, dict): - nodes = [] - nodes.append(structure) - if 'nodes' in structure: - nodes.extend(structure_to_list(structure['nodes'])) - return nodes - elif isinstance(structure, list): - nodes = [] - for item in structure: - nodes.extend(structure_to_list(item)) - return nodes - - -def get_leaf_nodes(structure): - if isinstance(structure, dict): - if not structure['nodes']: - structure_node = copy.deepcopy(structure) - structure_node.pop('nodes', None) - return [structure_node] - else: - leaf_nodes = [] - for key in list(structure.keys()): - if 'nodes' in key: - leaf_nodes.extend(get_leaf_nodes(structure[key])) - return leaf_nodes - elif isinstance(structure, list): - leaf_nodes = [] - for item in structure: - leaf_nodes.extend(get_leaf_nodes(item)) - return leaf_nodes - -def is_leaf_node(data, node_id): - # Helper function to find the node by its node_id - def find_node(data, node_id): - if isinstance(data, dict): - if data.get('node_id') == node_id: - return data - for key in data.keys(): - if 'nodes' in key: - result = find_node(data[key], node_id) - if result: - return result - elif isinstance(data, list): - for item in data: - result = find_node(item, node_id) - if result: - return result - return None - - # Find the node with the given node_id - node = find_node(data, node_id) - - # Check if the node is a leaf node - if node and not node.get('nodes'): - return True - return False - -def get_last_node(structure): - return structure[-1] - - -def extract_text_from_pdf(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - ###return text not list - text="" - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - text+=page.extract_text() - return text - -def get_pdf_title(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - meta = pdf_reader.metadata - title = meta.title if meta and meta.title else 'Untitled' - return title - -def get_text_of_pages(pdf_path, start_page, end_page, tag=True): - pdf_reader = PyPDF2.PdfReader(pdf_path) - text = "" - for page_num in range(start_page-1, end_page): - page = pdf_reader.pages[page_num] - page_text = page.extract_text() - if tag: - text += f"\n{page_text}\n\n" - else: - text += page_text - return text - -def get_first_start_page_from_text(text): - start_page = -1 - start_page_match = re.search(r'', text) - if start_page_match: - start_page = int(start_page_match.group(1)) - return start_page - -def get_last_start_page_from_text(text): - start_page = -1 - # Find all matches of start_index tags - start_page_matches = re.finditer(r'', text) - # Convert iterator to list and get the last match if any exist - matches_list = list(start_page_matches) - if matches_list: - start_page = int(matches_list[-1].group(1)) - return start_page - - -def sanitize_filename(filename, replacement='-'): - # In Linux, only '/' and '\0' (null) are invalid in filenames. - # Null can't be represented in strings, so we only handle '/'. - return filename.replace('/', replacement) - -def get_pdf_name(pdf_path): - # Extract PDF name - if isinstance(pdf_path, str): - pdf_name = os.path.basename(pdf_path) - elif isinstance(pdf_path, BytesIO): - pdf_reader = PyPDF2.PdfReader(pdf_path) - meta = pdf_reader.metadata - pdf_name = meta.title if meta and meta.title else 'Untitled' - pdf_name = sanitize_filename(pdf_name) - return pdf_name - - -class JsonLogger: - def __init__(self, file_path): - # Extract PDF name for logger name - pdf_name = get_pdf_name(file_path) - - current_time = datetime.now().strftime("%Y%m%d_%H%M%S") - self.filename = f"{pdf_name}_{current_time}.json" - os.makedirs("./logs", exist_ok=True) - # Initialize empty list to store all messages - self.log_data = [] - - def log(self, level, message, **kwargs): - if isinstance(message, dict): - self.log_data.append(message) - else: - self.log_data.append({'message': message}) - # Add new message to the log data - - # Write entire log data to file - with open(self._filepath(), "w") as f: - json.dump(self.log_data, f, indent=2) - - def info(self, message, **kwargs): - self.log("INFO", message, **kwargs) - - def error(self, message, **kwargs): - self.log("ERROR", message, **kwargs) - - def debug(self, message, **kwargs): - self.log("DEBUG", message, **kwargs) - - def exception(self, message, **kwargs): - kwargs["exception"] = True - self.log("ERROR", message, **kwargs) - - def _filepath(self): - return os.path.join("logs", self.filename) - - - - -def list_to_tree(data): - def get_parent_structure(structure): - """Helper function to get the parent structure code""" - if not structure: - return None - parts = str(structure).split('.') - return '.'.join(parts[:-1]) if len(parts) > 1 else None - - # First pass: Create nodes and track parent-child relationships - nodes = {} - root_nodes = [] - - for item in data: - structure = item.get('structure') - node = { - 'title': item.get('title'), - 'start_index': item.get('start_index'), - 'end_index': item.get('end_index'), - 'nodes': [] - } - - nodes[structure] = node - - # Find parent - parent_structure = get_parent_structure(structure) - - if parent_structure: - # Add as child to parent if parent exists - if parent_structure in nodes: - nodes[parent_structure]['nodes'].append(node) - else: - root_nodes.append(node) - else: - # No parent, this is a root node - root_nodes.append(node) - - # Helper function to clean empty children arrays - def clean_node(node): - if not node['nodes']: - del node['nodes'] - else: - for child in node['nodes']: - clean_node(child) - return node - - # Clean and return the tree - return [clean_node(node) for node in root_nodes] - -def add_preface_if_needed(data): - if not isinstance(data, list) or not data: - return data - - if data[0]['physical_index'] is not None and data[0]['physical_index'] > 1: - preface_node = { - "structure": "0", - "title": "Preface", - "physical_index": 1, - } - data.insert(0, preface_node) - return data - - - -def get_page_tokens(pdf_path, model=None, pdf_parser="PyPDF2"): - if pdf_parser == "PyPDF2": - pdf_reader = PyPDF2.PdfReader(pdf_path) - page_list = [] - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - page_text = page.extract_text() - token_length = litellm.token_counter(model=model, text=page_text) - page_list.append((page_text, token_length)) - return page_list - elif pdf_parser == "PyMuPDF": - if isinstance(pdf_path, BytesIO): - pdf_stream = pdf_path - doc = pymupdf.open(stream=pdf_stream, filetype="pdf") - elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"): - doc = pymupdf.open(pdf_path) - page_list = [] - for page in doc: - page_text = page.get_text() - token_length = litellm.token_counter(model=model, text=page_text) - page_list.append((page_text, token_length)) - return page_list - else: - raise ValueError(f"Unsupported PDF parser: {pdf_parser}") - - - -def get_text_of_pdf_pages(pdf_pages, start_page, end_page): - text = "" - for page_num in range(start_page-1, end_page): - text += pdf_pages[page_num][0] - return text - -def get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page): - text = "" - for page_num in range(start_page-1, end_page): - text += f"\n{pdf_pages[page_num][0]}\n\n" - return text - -def get_number_of_pages(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - num = len(pdf_reader.pages) - return num - - - -def post_processing(structure, end_physical_index): - # First convert page_number to start_index in flat list - for i, item in enumerate(structure): - item['start_index'] = item.get('physical_index') - if i < len(structure) - 1: - if structure[i + 1].get('appear_start') == 'yes': - item['end_index'] = structure[i + 1]['physical_index']-1 - else: - item['end_index'] = structure[i + 1]['physical_index'] - else: - item['end_index'] = end_physical_index - tree = list_to_tree(structure) - if len(tree)!=0: - return tree - else: - ### remove appear_start - for node in structure: - node.pop('appear_start', None) - node.pop('physical_index', None) - return structure - -def clean_structure_post(data): - if isinstance(data, dict): - data.pop('page_number', None) - data.pop('start_index', None) - data.pop('end_index', None) - if 'nodes' in data: - clean_structure_post(data['nodes']) - elif isinstance(data, list): - for section in data: - clean_structure_post(section) - return data - -def remove_fields(data, fields=['text']): - if isinstance(data, dict): - return {k: remove_fields(v, fields) - for k, v in data.items() if k not in fields} - elif isinstance(data, list): - return [remove_fields(item, fields) for item in data] - return data - -def print_toc(tree, indent=0): - for node in tree: - print(' ' * indent + node['title']) - if node.get('nodes'): - print_toc(node['nodes'], indent + 1) - -def print_json(data, max_len=40, indent=2): - def simplify_data(obj): - if isinstance(obj, dict): - return {k: simplify_data(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [simplify_data(item) for item in obj] - elif isinstance(obj, str) and len(obj) > max_len: - return obj[:max_len] + '...' - else: - return obj - - simplified = simplify_data(data) - print(json.dumps(simplified, indent=indent, ensure_ascii=False)) - - -def remove_structure_text(data): - if isinstance(data, dict): - data.pop('text', None) - if 'nodes' in data: - remove_structure_text(data['nodes']) - elif isinstance(data, list): - for item in data: - remove_structure_text(item) - return data - - -def check_token_limit(structure, limit=110000): - list = structure_to_list(structure) - for node in list: - num_tokens = count_tokens(node['text'], model=None) - if num_tokens > limit: - print(f"Node ID: {node['node_id']} has {num_tokens} tokens") - print("Start Index:", node['start_index']) - print("End Index:", node['end_index']) - print("Title:", node['title']) - print("\n") - - -def convert_physical_index_to_int(data): - if isinstance(data, list): - for i in range(len(data)): - # Check if item is a dictionary and has 'physical_index' key - if isinstance(data[i], dict) and 'physical_index' in data[i]: - if isinstance(data[i]['physical_index'], str): - if data[i]['physical_index'].startswith('').strip()) - elif data[i]['physical_index'].startswith('physical_index_'): - data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip()) - elif isinstance(data, str): - if data.startswith('').strip()) - elif data.startswith('physical_index_'): - data = int(data.split('_')[-1].strip()) - # Check data is int - if isinstance(data, int): - return data - else: - return None - return data - - -def convert_page_to_int(data): - for item in data: - if 'page' in item and isinstance(item['page'], str): - try: - item['page'] = int(item['page']) - except ValueError: - # Keep original value if conversion fails - pass - return data - - -def add_node_text(node, pdf_pages): - if isinstance(node, dict): - start_page = node.get('start_index') - end_page = node.get('end_index') - node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page) - if 'nodes' in node: - add_node_text(node['nodes'], pdf_pages) - elif isinstance(node, list): - for index in range(len(node)): - add_node_text(node[index], pdf_pages) - return - - -def add_node_text_with_labels(node, pdf_pages): - if isinstance(node, dict): - start_page = node.get('start_index') - end_page = node.get('end_index') - node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page) - if 'nodes' in node: - add_node_text_with_labels(node['nodes'], pdf_pages) - elif isinstance(node, list): - for index in range(len(node)): - add_node_text_with_labels(node[index], pdf_pages) - return - - -async def generate_node_summary(node, model=None): - prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. - - Partial Document Text: {node['text']} - - Directly return the description, do not include any other text. - """ - response = await llm_acompletion(model, prompt) - return response - - -async def generate_summaries_for_structure(structure, model=None): - nodes = structure_to_list(structure) - tasks = [generate_node_summary(node, model=model) for node in nodes] - summaries = await asyncio.gather(*tasks) - - for node, summary in zip(nodes, summaries): - node['summary'] = summary - return structure - - -def create_clean_structure_for_description(structure): +from .core.llm import * +from .core.pdf import * +from .core.tree import * +from .core.logging import * +from .config import ConfigLoader, PageIndexConfig + +# --------------------------------------------------------------------------- +# Backward-compatibility wrappers for legacy symbols that were previously +# defined in utils.py but are no longer present in the new core modules. +# These delegates allow existing code that imports from pageindex.utils to +# continue functioning without modification. +# --------------------------------------------------------------------------- +import asyncio as _asyncio +import functools as _functools +from typing import Any as _Any +from .core import llm as _core_llm + + +def llm_completion(*args: _Any, **kwargs: _Any) -> _Any: + """Backward-compatible wrapper: delegates to ChatGPT_API in core.llm. + + If the caller passes return_finish_reason=True (legacy kwarg), we route + to ChatGPT_API_with_finish_reason and return (response, finish_reason). """ - Create a clean structure for document description generation, - excluding unnecessary fields like 'text'. - """ - if isinstance(structure, dict): - clean_node = {} - # Only include essential fields for description - for key in ['title', 'node_id', 'summary', 'prefix_summary']: - if key in structure: - clean_node[key] = structure[key] - - # Recursively process child nodes - if 'nodes' in structure and structure['nodes']: - clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) - - return clean_node - elif isinstance(structure, list): - return [create_clean_structure_for_description(item) for item in structure] - else: - return structure - - -def generate_doc_description(structure, model=None): - prompt = f"""Your are an expert in generating descriptions for a document. - You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. - - Document Structure: {structure} - - Directly return the description, do not include any other text. - """ - response = llm_completion(model, prompt) - return response - - -def reorder_dict(data, key_order): - if not key_order: - return data - return {key: data[key] for key in key_order if key in data} - - -def format_structure(structure, order=None): - if not order: - return structure - if isinstance(structure, dict): - if 'nodes' in structure: - structure['nodes'] = format_structure(structure['nodes'], order) - if not structure.get('nodes'): - structure.pop('nodes', None) - structure = reorder_dict(structure, order) - elif isinstance(structure, list): - structure = [format_structure(item, order) for item in structure] - return structure - - -class ConfigLoader: - def __init__(self, default_path: str = None): - if default_path is None: - default_path = Path(__file__).parent / "config.yaml" - self._default_dict = self._load_yaml(default_path) - - @staticmethod - def _load_yaml(path): - with open(path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) or {} - - def _validate_keys(self, user_dict): - unknown_keys = set(user_dict) - set(self._default_dict) - if unknown_keys: - raise ValueError(f"Unknown config keys: {unknown_keys}") - - def load(self, user_opt=None) -> config: - """ - Load the configuration, merging user options with default values. - """ - if user_opt is None: - user_dict = {} - elif isinstance(user_opt, config): - user_dict = vars(user_opt) - elif isinstance(user_opt, dict): - user_dict = user_opt - else: - raise TypeError("user_opt must be dict, config(SimpleNamespace) or None") - - self._validate_keys(user_dict) - merged = {**self._default_dict, **user_dict} - return config(**merged) - -def create_node_mapping(tree): - """Create a flat dict mapping node_id to node for quick lookup.""" - mapping = {} - def _traverse(nodes): - for node in nodes: - if node.get('node_id'): - mapping[node['node_id']] = node - if node.get('nodes'): - _traverse(node['nodes']) - _traverse(tree) - return mapping - -def print_tree(tree, indent=0): - for node in tree: - summary = node.get('summary') or node.get('prefix_summary', '') - summary_str = f" — {summary[:60]}..." if summary else "" - print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}") - if node.get('nodes'): - print_tree(node['nodes'], indent + 1) - -def print_wrapped(text, width=100): - for line in text.splitlines(): - print(textwrap.fill(line, width=width)) - + return_finish_reason = kwargs.pop("return_finish_reason", False) + if hasattr(_core_llm, "llm_completion"): + return _core_llm.llm_completion(*args, **kwargs) + if return_finish_reason and hasattr(_core_llm, "ChatGPT_API_with_finish_reason"): + return _core_llm.ChatGPT_API_with_finish_reason(*args, **kwargs) + if hasattr(_core_llm, "ChatGPT_API"): + return _core_llm.ChatGPT_API(*args, **kwargs) + raise RuntimeError( + "llm_completion is not available in pageindex.core.llm. " + "Update your code to use ChatGPT_API directly." + ) + + +async def llm_acompletion(*args: _Any, **kwargs: _Any) -> _Any: + """Backward-compatible async wrapper: delegates to ChatGPT_API_async in core.llm.""" + if hasattr(_core_llm, "llm_acompletion"): + return await _core_llm.llm_acompletion(*args, **kwargs) + if hasattr(_core_llm, "ChatGPT_API_async"): + return await _core_llm.ChatGPT_API_async(*args, **kwargs) + # Fallback: run sync completion in a thread executor using functools.partial + # to safely pass keyword arguments (run_in_executor doesn't accept kwargs) + loop = _asyncio.get_running_loop() + return await loop.run_in_executor( + None, _functools.partial(llm_completion, *args, **kwargs) + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..e9752821a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "pageindex" +version = "0.1.0" +description = "Vectorless, reasoning-based RAG indexer" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +dependencies = [ + "openai==1.101.0", + "pymupdf==1.26.4", + "PyPDF2==3.0.1", + "python-dotenv==1.1.0", + "tiktoken==0.11.0", + "pyyaml==6.0.2", + "pydantic>=2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", +] + +[project.scripts] +pageindex = "pageindex.cli:main" + +[tool.setuptools.packages.find] +where = ["."] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c360dfec5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +import os +import sys + + +# Add project root to python path for testing +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..6076f7966 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,91 @@ +import pytest +from types import SimpleNamespace +from pydantic import BaseModel +from pageindex.config import ConfigLoader, PageIndexConfig + + +def test_config_loader_default(tmp_path): + # Mock config file + config_file = tmp_path / "config.yaml" + config_file.write_text('model: "gpt-4-test"\nmax_page_num_each_node: 10', encoding="utf-8") + + loader = ConfigLoader(default_path=config_file) + cfg = loader.load() + + assert isinstance(cfg, PageIndexConfig) + assert cfg.model == "gpt-4-test" + assert cfg.max_page_num_each_node == 10 + # Check default logic + assert cfg.toc_check_page_num == 20 + + +def test_config_loader_override(): + loader = ConfigLoader(default_path=None) + override = {"model": "gpt-override", "if_add_node_id": False} + + cfg = loader.load(user_opt=override) + assert cfg.model == "gpt-override" + assert cfg.if_add_node_id is False + + +def test_config_validation_error(): + loader = ConfigLoader(default_path=None) + # Pass invalid type for integer field + override = {"max_page_num_each_node": "not-an-int"} + + with pytest.raises(ValueError, match="Configuration validation failed"): + loader.load(user_opt=override) + + +def test_partial_override_object(): + args = SimpleNamespace(model="cmd-model", other_arg=None) + loader = ConfigLoader(default_path=None) + cfg = loader.load(user_opt=args) + assert cfg.model == "cmd-model" + + +def test_config_loader_fallback_defaults_match_repo_profile(tmp_path): + loader = ConfigLoader(default_path=tmp_path / "missing.yaml") + cfg = loader.load() + + assert cfg.model == "gpt-4o-2024-11-20" + assert cfg.retrieve_model == "gpt-5.4" + assert cfg.toc_check_page_num == 20 + assert cfg.max_page_num_each_node == 10 + assert cfg.max_token_num_each_node == 20000 + assert cfg.if_add_node_id is True + assert cfg.if_add_node_summary is True + assert cfg.if_add_doc_description is False + assert cfg.if_add_node_text is False + + +def test_dict_override_can_explicitly_clear_yaml_value(tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text('model: "gpt-4-test"\napi_key: "from-config"', encoding="utf-8") + + loader = ConfigLoader(default_path=config_file) + cfg = loader.load(user_opt={"api_key": None}) + + assert cfg.model == "gpt-4-test" + assert cfg.api_key is None + + +class OverrideModel(BaseModel): + api_key: str | None = None + + +def test_pydantic_override_can_explicitly_clear_yaml_value(tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text('model: "gpt-4-test"\napi_key: "from-config"', encoding="utf-8") + + loader = ConfigLoader(default_path=config_file) + cfg = loader.load(user_opt=OverrideModel(api_key=None)) + + assert cfg.model == "gpt-4-test" + assert cfg.api_key is None + + +def test_config_loader_is_reexported_from_utils(): + from pageindex.utils import ConfigLoader as UtilsConfigLoader + + assert UtilsConfigLoader is ConfigLoader diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 000000000..6dcbbab97 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,21 @@ +from pageindex.core.llm import extract_json, count_tokens + + +def test_extract_json_basic(): + text = '{"key": "value"}' + assert extract_json(text) == {"key": "value"} + + +def test_extract_json_with_markdown(): + text = 'Here is the json:\n```json\n{"key": "value"}\n```' + assert extract_json(text) == {"key": "value"} + + +def test_extract_json_with_trailing_commas(): + text = '{"key": "value",}' + assert extract_json(text) == {"key": "value"} + + +def test_count_tokens(): + text = "Hello world" + assert count_tokens(text) > 0 diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 000000000..2c62ec8bc --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,38 @@ +import pytest +from pageindex.core.tree import list_to_tree, structure_to_list, get_nodes, write_node_id + + +@pytest.fixture +def sample_structure(): + return [ + {"structure": "1", "title": "Chapter 1", "start_index": 1, "end_index": 5}, + {"structure": "1.1", "title": "Section 1.1", "start_index": 1, "end_index": 3}, + {"structure": "1.2", "title": "Section 1.2", "start_index": 4, "end_index": 5}, + {"structure": "2", "title": "Chapter 2", "start_index": 6, "end_index": 10}, + ] + + +def test_list_to_tree(sample_structure): + tree = list_to_tree(sample_structure) + assert len(tree) == 2 + assert tree[0]["title"] == "Chapter 1" + assert len(tree[0]["nodes"]) == 2 + assert tree[0]["nodes"][0]["title"] == "Section 1.1" + assert tree[1]["title"] == "Chapter 2" + assert "nodes" not in tree[1] or len(tree[1]["nodes"]) == 0 + + +def test_structure_to_list(sample_structure): + tree = list_to_tree(sample_structure) + flat_list = structure_to_list(tree) + assert len(flat_list) == 4 + titles = [item["title"] for item in flat_list] + assert "Chapter 1" in titles + assert "Section 1.1" in titles + + +def test_write_node_id(sample_structure): + tree = list_to_tree(sample_structure) + write_node_id(tree) + assert tree[0]["node_id"] == "0000" + assert tree[0]["nodes"][0]["node_id"] == "0001"