diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..bb84bc8 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,43 @@ +name: Deploy Documentation + +on: + push: + branches: [main] + paths: + - 'docs/**' + - 'mkdocs.yml' + workflow_dispatch: + +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: pages + cancel-in-progress: false + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install mkdocs-material "mkdocstrings[python]>=0.24" + - run: pip install -e . + - run: mkdocs build + - uses: actions/upload-pages-artifact@v3 + with: + path: site/ + + deploy: + needs: build + runs-on: ubuntu-latest + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index a1f1c70..9b8dbff 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -1,8 +1,10 @@ name: Python Tests on: + push: + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: test: @@ -12,24 +14,24 @@ jobs: python-version: ["3.12"] steps: - - uses: actions/checkout@v3 - + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' - + - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e ".[dev]" - + - name: Lint with ruff run: | - ruff format --check --diff . - ruff check --select I . - + ruff format --check --diff lettucedetect/ tests/ + ruff check lettucedetect/ tests/ --extend-exclude lettucedetect/integrations/ + - name: Test with pytest run: | - pytest tests/test_inference_pytest.py -v \ No newline at end of file + pytest tests/test_inference_pytest.py -v diff --git a/.gitignore b/.gitignore index b81c6d0..3c2f5ce 100644 --- a/.gitignore +++ b/.gitignore @@ -170,6 +170,9 @@ cython_debug/ # PyPI configuration file .pypirc +# macOS +.DS_Store + # data/ data/ @@ -178,4 +181,5 @@ output/ temp/ # cache/ -lettucedetect/cache/ \ No newline at end of file +lettucedetect/cache/ +testing/ \ No newline at end of file diff --git a/docs/EUROBERT.md b/docs/EUROBERT.md index 0380d8c..3249382 100644 --- a/docs/EUROBERT.md +++ b/docs/EUROBERT.md @@ -1,7 +1,7 @@ # πŸ₯¬ LettuceDetect Goes Multilingual: Fine-tuning EuroBERT on Synthetic RAGTruth Translations

- LettuceDetect Multilingual Task Force + LettuceDetect Multilingual Task Force
Expanding hallucination detection across languages for RAG pipelines.

diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md deleted file mode 100644 index c6b8cdb..0000000 --- a/docs/EVALUATION.md +++ /dev/null @@ -1,13 +0,0 @@ -# Evaluation - -## Use LLM baselines - -```bash -python scripts/evaluate_llm.py --model "gpt-4o-mini" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level" -``` - -## Use HallucinationDetector - -```bash -python scripts/evaluate.py --model_path "output/hallucination_detection_de_210m" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level" -``` diff --git a/docs/api/datasets.md b/docs/api/datasets.md new file mode 100644 index 0000000..6aae94d --- /dev/null +++ b/docs/api/datasets.md @@ -0,0 +1,13 @@ +# Datasets + +## HallucinationSample + +::: lettucedetect.datasets.hallucination_dataset.HallucinationSample + +## HallucinationData + +::: lettucedetect.datasets.hallucination_dataset.HallucinationData + +## HallucinationDataset + +::: lettucedetect.datasets.hallucination_dataset.HallucinationDataset diff --git a/docs/api/detectors.md b/docs/api/detectors.md new file mode 100644 index 0000000..da56499 --- /dev/null +++ b/docs/api/detectors.md @@ -0,0 +1,13 @@ +# Detectors + +## Factory + +::: lettucedetect.detectors.factory.make_detector + +## Base Detector + +::: lettucedetect.detectors.base.BaseDetector + +## Transformer Detector + +::: lettucedetect.detectors.transformer.TransformerDetector diff --git a/docs/api/inference.md b/docs/api/inference.md new file mode 100644 index 0000000..14a827e --- /dev/null +++ b/docs/api/inference.md @@ -0,0 +1,5 @@ +# Inference + +The main entry point for hallucination detection. + +::: lettucedetect.models.inference.HallucinationDetector diff --git a/docs/api/training.md b/docs/api/training.md new file mode 100644 index 0000000..91a935c --- /dev/null +++ b/docs/api/training.md @@ -0,0 +1,9 @@ +# Training + +## Trainer + +::: lettucedetect.models.trainer.Trainer + +## Evaluator + +::: lettucedetect.models.evaluator diff --git a/docs/benchmarks.md b/docs/benchmarks.md new file mode 100644 index 0000000..87e4e97 --- /dev/null +++ b/docs/benchmarks.md @@ -0,0 +1,43 @@ +# Benchmarks + +## RAGTruth (English) + +Evaluated on the [RAGTruth](https://aclanthology.org/2024.acl-long.585/) test set. This benchmark measures how well models detect hallucinations in LLM-generated text across QA, summarization, and data-to-text tasks. + +### Example-Level Detection + +Binary classification: does the answer contain any hallucination? + +| Model | Type | Overall F1 | +|-------|------|-----------| +| GPT-4 | LLM (zero-shot) | 63.4% | +| Luna | Encoder | 65.4% | +| **lettucedetect-base-v1** | **Encoder (149M)** | **76.8%** | +| Llama-2-13B (fine-tuned) | LLM | 78.7% | +| **lettucedetect-large-v1** | **Encoder (395M)** | **79.2%** | +| RAG-HAT (Llama-3-8B) | LLM | 83.9% | + +### Span-Level Detection + +LettuceDetect achieves state-of-the-art span-level results among models that report this metric, outperforming fine-tuned Llama-2-13B. Span-level evaluation measures how precisely the model can locate the exact hallucinated text within an answer. + +## What These Numbers Mean + +- **Example F1** β€” Can the model tell if an answer has *any* hallucination? Higher is better. +- **Span F1** β€” Can the model point to *exactly which parts* are hallucinated? This is the harder task and where LettuceDetect excels relative to its size. + +LettuceDetect models are 50-500x smaller than LLM-based detectors while achieving competitive or better accuracy. + +## Citation + +```bibtex +@misc{Kovacs:2025, + title={LettuceDetect: A Hallucination Detection Framework for RAG Applications}, + author={ÁdΓ‘m KovΓ‘cs and GΓ‘bor Recski}, + year={2025}, + eprint={2502.17125}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2502.17125}, +} +``` diff --git a/docs/code-hallucination/architecture-research.md b/docs/code-hallucination/architecture-research.md new file mode 100644 index 0000000..232114a --- /dev/null +++ b/docs/code-hallucination/architecture-research.md @@ -0,0 +1,154 @@ +# Architecture Research: Detection Models for Code Hallucination + +Research notes on model architectures for training on the code hallucination dataset. We compare four approaches ranging from fast encoder-based classifiers to generative span detectors. + +## Approach A: Token Classification (Encoder) + +**Architecture:** ModernBERT/EuroBERT + linear classification head + +The current LettuceDetect approach. Each answer token gets a binary label (0=supported, 1=hallucinated). Consecutive hallucinated tokens are merged into spans at inference. + +``` +Input: [CLS] context [SEP] question [SEP] answer [SEP] +Output: [-100, -100, ..., 0, 0, 1, 1, 1, 0, 0, ...] + ^^^^^^^^^ hallucinated span +``` + +| Property | Value | +|----------|-------| +| **Models** | ModernBERT-base (149M), ModernBERT-large (395M), EuroBERT (210M-2.1B) | +| **Context** | 8K tokens | +| **Inference** | Single forward pass, 30-60 samples/sec on A100 | +| **Training** | Standard token classification, CrossEntropyLoss | +| **Validated by** | LettuceDetect (79.2% F1), HaluGate (vLLM), PsiloQA (EMNLP 2025) | + +**Strengths:** Fast, simple, production-ready. Handles long contiguous spans well. +**Weaknesses:** No code-specific pretraining. Cannot explain *why* something is hallucinated. + +--- + +## Approach B: Token Classification (Decoder LLM) + +**Architecture:** Qwen3.5-2B + bidirectional attention (LLM2Vec) + linear head + +Use a decoder LLM pretrained on massive code corpora, convert to bidirectional encoder via [LLM2Vec](https://arxiv.org/abs/2404.05961), then add a token classification head. + +``` +Step 1: Load Qwen3.5-2B base (2B params, code-heavy pretraining) +Step 2: Enable bidirectional attention (remove causal mask) +Step 3: Short MNTP adaptation (masked next token prediction with LoRA) +Step 4: Add linear head (hidden_dim=2048 β†’ 2 classes) +Step 5: Fine-tune on code hallucination dataset with LoRA +``` + +| Property | Value | +|----------|-------| +| **Model** | Qwen3.5-2B (2B params) | +| **Context** | 262K native (practically limited by GPU memory) | +| **Inference** | Single forward pass, ~5-15 samples/sec | +| **VRAM** | ~5-8GB in bf16 | +| **Reference** | [Looking Right is Sometimes Right (ACL 2024)](https://arxiv.org/abs/2401.14556) β€” 0.947 F1 on NER with mask removal | + +**Strengths:** Deep code understanding from pretraining. Bidirectional attention after conversion. +**Weaknesses:** 5x larger than ModernBERT. Requires LLM2Vec conversion step. Novel (unvalidated for hallucination detection). + +**Key insight:** The [ACL 2024 paper](https://arxiv.org/abs/2401.14556) showed decoder LLMs with causal mask removal reach 0.947 F1 on NER, significantly above RoBERTa-large (0.900). The gains come from combining rich pretrained representations with bidirectional context. + +--- + +## Approach C: Chunk Verification (Reranker-style) + +**Architecture:** Qwen3.5-2B or Qwen3-0.6B, reranker-style yes/no scoring + +Inspired by [Qwen3-Reranker](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B). Split the answer into chunks (lines, statements), then ask the model for each chunk: "Is this code correct given the context?" + +``` +Input: "Given this source code, is this line correct? yes/no" +Output: P(yes) = 0.12 β†’ hallucinated + P(yes) = 0.95 β†’ supported +``` + +No architectural modifications. Uses the LLM's native next-token prediction to classify. + +| Property | Value | +|----------|-------| +| **Models** | Qwen3-0.6B (tiny, fast) or Qwen3.5-2B | +| **Inference** | N forward passes per sample (one per chunk) | +| **Training** | Standard SFT with yes/no labels | +| **Reference** | [MiniCheck (EMNLP 2024)](https://arxiv.org/abs/2404.10774) β€” GPT-4-level at 400x lower cost | + +**Strengths:** No architecture changes. Uses LLM code reasoning directly. Can work with tiny models. +**Weaknesses:** Slowest inference (N passes per sample). Chunk boundary sensitivity. No sub-chunk granularity. + +--- + +## Approach D: Generative Span Detection + +**Architecture:** Qwen3.5-2B, standard SFT, generates JSON with hallucinated spans + +The model directly outputs which spans are hallucinated and why. This is the reverse of the hallucination injection process. + +``` +Input: "Given the source code and answer, identify hallucinated spans." +Output: { + "hallucinated_spans": [ + {"text": "response.json_decode()", "explanation": "method is json(), not json_decode()"} + ] +} +``` + +| Property | Value | +|----------|-------| +| **Models** | Qwen3.5-2B or larger | +| **Inference** | Single generation (autoregressive, slower than forward pass) | +| **Training** | Standard SFT with LoRA | +| **SOTA** | [RL4HS (Oct 2025)](https://arxiv.org/abs/2510.02173) β€” 58.3 F1 on RAGTruth, beats GPT-5 (42.2) and o3 (51.2) | + +**Strengths:** + +- No architecture changes β€” pure text generation +- Free explanations alongside span detection +- Naturally handles variable span counts +- Can leverage the LLM's code knowledge ("this API doesn't exist") +- Training data format already matches (reverse of injection pipeline) +- Current SOTA approach (RL4HS) + +**Weaknesses:** Autoregressive generation is slower. Risk of hallucinating in the detector itself. String matching needed to map spans back to character offsets. + +**RL enhancement:** [RL4HS](https://arxiv.org/abs/2510.02173) shows that adding reinforcement learning (GRPO with span-level rewards) on top of SFT dramatically improves performance. SFT alone is a strong baseline; RL pushes it to SOTA. + +--- + +## Comparison + +| | A. Encoder token | B. LLM token | C. Chunk verifier | D. Generative span | +|---|---|---|---|---| +| **Base model** | ModernBERT-large | Qwen3.5-2B | Qwen3-0.6B | Qwen3.5-2B | +| **Parameters** | 395M | 2B | 0.6B | 2B | +| **Architecture mods** | None | Mask removal | None | None | +| **Inference speed** | Fastest | Medium | Slowest | Medium-slow | +| **Explainable** | No | No | No | Yes | +| **Code understanding** | Limited | Deep | Deep | Deep | +| **Training complexity** | Simple | LLM2Vec + LoRA | Simple SFT | Simple SFT | +| **SOTA reference** | LettuceDetect, HaluGate | ACL 2024 paper | MiniCheck | RL4HS | + +## Recommended Experiments + +1. **A vs D** β€” Token classification (ModernBERT) vs generative span detection (Qwen3.5-2B). The core comparison: fast encoder vs reasoning LLM, both trained on the same dataset. + +2. **A vs B** β€” Does code pretraining help token classification? Same task, different backbone. + +3. **D with RL** β€” If SFT results are promising, add GRPO with span-overlap rewards (following RL4HS). + +## Key References + +- [LettuceDetect (arXiv:2502.17125)](https://arxiv.org/abs/2502.17125) β€” Encoder token classification baseline +- [HaluGate (vLLM, Dec 2025)](https://blog.vllm.ai/2025/12/14/halugate.html) β€” Production ModernBERT + NLI pipeline +- [RL4HS (arXiv:2510.02173)](https://arxiv.org/abs/2510.02173) β€” SOTA generative span detection with RL +- [FAVA (COLM 2024)](https://arxiv.org/abs/2401.06855) β€” Generative hallucination editing +- [PsiloQA (EMNLP 2025)](https://arxiv.org/abs/2510.04849) β€” Multilingual encoder-based span detection +- [Looking Right is Sometimes Right (ACL 2024)](https://arxiv.org/abs/2401.14556) β€” Decoder LLMs for token classification +- [LLM2Vec (2024)](https://arxiv.org/abs/2404.05961) β€” Converting decoders to bidirectional encoders +- [MiniCheck (EMNLP 2024)](https://arxiv.org/abs/2404.10774) β€” Sentence-level fact checking +- [Qwen3-Reranker](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B) β€” LLM-based yes/no classification +- [CodeMirage (2024)](https://arxiv.org/abs/2408.08333) β€” Code hallucination taxonomy (snippet-level only) diff --git a/docs/code-hallucination/configuration.md b/docs/code-hallucination/configuration.md new file mode 100644 index 0000000..e77e073 --- /dev/null +++ b/docs/code-hallucination/configuration.md @@ -0,0 +1,68 @@ +# Configuration + +All pipeline configuration is centralized in `scripts/code_hallucination/config.py`. + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `OPENAI_API_KEY` | (none) | API key for the LLM provider | +| `API_BASE_URL` | `https://api.groq.com/openai/v1` | OpenAI-compatible API endpoint | +| `MODEL` | `moonshotai/kimi-k2-instruct-0905` | Model name | +| `BATCH_SIZE` | `1` | Concurrent requests. Set >1 for local vLLM to saturate GPU | +| `CONTEXT7_API_KEY` | (none) | API key for Context7 documentation service | + +These can also be overridden via CLI flags (`--api-key`, `--base-url`, `--model`). + +## Dataset Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `HALLUCINATION_RATIO` | `0.4` | Fraction of instances that get hallucination injection | +| `DOCS_RATIO` | `0.5` | Fraction of instances that get Context7 documentation | +| `MAX_FILE_CHARS` | `12000` | Maximum characters per source file | +| `MAX_CONTEXT7_CHARS` | `4000` | Maximum characters per library doc | +| `LLM_TEMPERATURE` | `0.7` | Temperature for query rewriting | +| `HALLUCINATION_TEMPERATURE` | `0.8` | Temperature for hallucination injection (higher for variety) | +| `MAX_RETRIES` | `3` | API retry attempts | +| `RETRY_DELAY` | `2.0` | Base delay between retries (seconds) | + +## Answer Format Weights + +| Format | Weight | Description | +|--------|--------|-------------| +| `complete_function` | 0.4 | Full patched function body via AST | +| `edit_style` | 0.3 | "In file X, replace Y with Z" | +| `fragment` | 0.3 | Added/changed lines from diff | + +## Hallucination Types + +Assigned round-robin across injected instances: + +- **structural** β€” Non-existent APIs, wrong methods, invented parameters +- **behavioral** β€” Wrong values, logic errors, swapped conditions +- **semantic** β€” Code that looks correct but does something subtly different + +## File Paths + +All data is stored under `data/code_hallucination/`: + +| Path | Description | +|------|-------------| +| `swebench_instances.json` | Phase 1: loaded instances | +| `repos/` | Phase 2: bare git clones | +| `source_cache/` | Phase 2: per-instance source data | +| `queries.jsonl` | Phase 3: rewritten queries | +| `documentation.jsonl` | Phase 4: library docs | +| `formats.jsonl` | Phase 5: assigned formats | +| `hallucinated_samples.jsonl` | Phase 6: injected hallucinations | +| `code_hallucination_data.json` | Phase 7: final dataset | +| `code_hallucination_metadata.json` | Phase 7: metadata | +| `validation_report.txt` | Phase 9: quality report | + +## Data Sources + +| Source | Dataset ID | +|--------|-----------| +| SWE-bench (full) | `princeton-nlp/SWE-bench` | +| SWE-bench Lite | `princeton-nlp/SWE-bench_Lite` | diff --git a/docs/code-hallucination/index.md b/docs/code-hallucination/index.md new file mode 100644 index 0000000..4a6f01f --- /dev/null +++ b/docs/code-hallucination/index.md @@ -0,0 +1,294 @@ +# Code Hallucination Dataset Pipeline + +A modular 9-phase pipeline for generating span-level code hallucination detection datasets from [SWE-bench](https://www.swebench.com/). Produces training data where each sample contains source code context, a code answer, and character-level annotations marking hallucinated spans. + +## Why This Dataset? + +Existing hallucination detection datasets (RAGTruth, RAGBench) focus on **text** β€” question answering, summarization, data-to-text. There is no established **span-level code hallucination dataset**. CodeMirage classifies entire snippets but doesn't localize where the hallucination is. + +This pipeline generates samples where an LLM coding assistant answers a developer's question about a real codebase, and we know exactly which character spans in the answer are hallucinated β€” enabling training of both token-level classifiers (ModernBERT) and generative span detectors (decoder LLMs). + +## Dataset Overview + +| Property | Value | +|----------|-------| +| **Source** | SWE-bench (all splits) | +| **Total instances** | ~21,500 (19k train + 225 dev + 2.3k test) | +| **Repos** | 53 unique repos, zero overlap between splits | +| **Clean/hallucinated ratio** | ~60% clean / ~40% hallucinated | +| **Hallucination types** | Structural, behavioral, semantic | +| **Answer formats** | Complete function, edit-style, code fragment | +| **Annotation granularity** | Character-level spans | + +## Quick Start + +### Test with a few examples + +```bash +# Using Groq (fast, free tier available) +OPENAI_API_KEY=your_groq_key \ + python -m scripts.code_hallucination.pipeline --test 5 + +# Using any OpenAI-compatible API +OPENAI_API_KEY=your_key \ +API_BASE_URL=https://api.example.com/v1 \ +MODEL=your-model-name \ + python -m scripts.code_hallucination.pipeline --test 10 +``` + +### Run the full pipeline + +```bash +# Run all 9 phases +python -m scripts.code_hallucination.pipeline --all + +# Run specific phases +python -m scripts.code_hallucination.pipeline --phase 1 2 3 + +# Override LLM settings via CLI +python -m scripts.code_hallucination.pipeline --all \ + --api-key YOUR_KEY \ + --base-url https://api.groq.com/openai/v1 \ + --model moonshotai/kimi-k2-instruct-0905 +``` + +### Run with local vLLM (recommended for bulk generation) + +```bash +# Terminal 1: Start vLLM +vllm serve Qwen/Qwen3.5-2B --port 8000 + +# Terminal 2: Run pipeline with batch processing +BATCH_SIZE=16 \ +API_BASE_URL=http://localhost:8000/v1 \ +OPENAI_API_KEY=dummy \ +MODEL=Qwen/Qwen3.5-2B \ + python -m scripts.code_hallucination.pipeline --all +``` + +`BATCH_SIZE>1` enables async concurrent requests β€” no rate limiting, full GPU saturation. + +### CLI Options + +| Flag | Description | +|------|-------------| +| `--test N` | Test mode: run pipeline on N random test instances using GitHub API (no repo cloning) | +| `--all` | Run all 9 phases | +| `--phase 1 2 3` | Run specific phases (1-9) | +| `--api-key` | LLM API key (or set `OPENAI_API_KEY` env var) | +| `--base-url` | LLM API base URL (or set `API_BASE_URL` env var) | +| `--model` | LLM model name (or set `MODEL` env var) | + +| Environment Variable | Description | +|---------------------|-------------| +| `BATCH_SIZE` | Number of concurrent requests (default: 1). Set >1 for local vLLM | +| `OPENAI_API_KEY` | API key for the LLM provider | +| `API_BASE_URL` | OpenAI-compatible API endpoint | +| `MODEL` | Model name | +| `CONTEXT7_API_KEY` | API key for Context7 documentation service | + +## Pipeline Architecture + +```mermaid +graph TD + A[Phase 1: Load SWE-bench] --> B[Phase 2: Fetch Sources] + B --> C[Phase 3: Rewrite Queries] + B --> D[Phase 4: Fetch Docs] + B --> E[Phase 5: Assign Formats] + A --> F[Phase 8: Select Targets] + C --> G[Phase 6: Inject Hallucinations] + E --> G + F --> G + D --> H[Phase 7: Assemble Samples] + G --> H + E --> H + C --> H + H --> I[Phase 9: Validate] +``` + +Phases 3 and 4 can run in parallel. Phase 8 (target selection) runs before Phase 6 (injection). + +## Output Format + +Each sample follows the `HallucinationSample` format used by LettuceDetect: + +```json +{ + "prompt": "File: django/http/response.py\n```python\n...\n```\n\nUser request: How do I fix the Content-Type header issue?", + "answer": "def fix_content_type(self):\n self.headers['Content-Type'] = response.get_type()\n ...", + "labels": [ + {"start": 55, "end": 75, "label": "structural"} + ], + "split": "test", + "task_type": "code_generation", + "dataset": "swebench_code", + "language": "en" +} +``` + +- **`prompt`**: Source code files + documentation + user query +- **`answer`**: Code in one of three formats (complete function, edit-style, fragment) +- **`labels`**: Character-level span annotations (empty for clean samples) +- **`split`**: train/dev/test (inherited from SWE-bench, zero repo overlap) + +## Supported LLM Providers + +The pipeline works with any OpenAI-compatible API. Tested with: + +| Provider | Model | Notes | +|----------|-------|-------| +| [Groq](https://groq.com) | `moonshotai/kimi-k2-instruct-0905` | Fast, free tier | +| [Groq](https://groq.com) | `llama-3.3-70b-versatile` | Good quality | +| [Novita AI](https://novita.ai) | `qwen/qwen3.5-27b` | Good for bulk generation | +| Local (vLLM/Ollama) | Any model | Free, best for large runs | + +## Directory Structure + +``` +scripts/code_hallucination/ +β”œβ”€β”€ __init__.py # Package declaration +β”œβ”€β”€ pipeline.py # CLI orchestrator +β”œβ”€β”€ config.py # Constants, paths, API settings +β”œβ”€β”€ swebench_loader.py # Phase 1: Load SWE-bench instances +β”œβ”€β”€ source_fetcher.py # Phase 2: Clone repos, fetch source files +β”œβ”€β”€ query_rewriter.py # Phase 3: LLM query rewriting +β”œβ”€β”€ context7_docs.py # Phase 4: Fetch library documentation +β”œβ”€β”€ format_builder.py # Phase 5: Assign answer formats +β”œβ”€β”€ hallucination_injector.py # Phase 6: LLM hallucination injection +β”œβ”€β”€ sample_assembler.py # Phase 7: Assemble final samples +β”œβ”€β”€ splitter.py # Phase 8: Select hallucination targets +└── validator.py # Phase 9: Quality validation +``` + +Output data: + +``` +data/code_hallucination/ +β”œβ”€β”€ swebench_instances.json # Phase 1 output +β”œβ”€β”€ repos/ # Phase 2: cloned repos (bare) +β”œβ”€β”€ source_cache/ # Phase 2: per-instance source data +β”‚ └── {instance_id}.json +β”œβ”€β”€ queries.jsonl # Phase 3 output +β”œβ”€β”€ documentation.jsonl # Phase 4 output +β”œβ”€β”€ formats.jsonl # Phase 5 output +β”œβ”€β”€ hallucinated_samples.jsonl # Phase 6 output +β”œβ”€β”€ code_hallucination_data.json # Phase 7: final dataset +β”œβ”€β”€ code_hallucination_metadata.json # Phase 7: metadata +└── validation_report.txt # Phase 9 output +``` + +## Design Decisions + +### One sample per instance +Each SWE-bench instance produces exactly one sample β€” either clean (gold patch answer) or hallucinated (LLM-injected). No instance appears in both classes. This avoids the artificial pairing problem where models learn to distinguish the specific instance rather than the hallucination. + +### JSON-based span annotations +Hallucination spans are extracted from the LLM's structured JSON response, not from difflib character-level diffs. The LLM returns `{"hallucinated_code": "...", "changes": [{"original": "...", "hallucinated": "..."}]}` and spans are found by string matching. This produces clean, meaningful spans (avg 70 chars) instead of noisy character-level artifacts (1-3 char noise from difflib). + +### 50/50 documentation split +Half of instances include Context7 library documentation, half don't. This teaches models to handle both documented and undocumented scenarios. + +### Zero repo overlap between splits +SWE-bench's train/dev/test splits naturally have zero repository overlap across 53 unique repos. This means test performance measures generalization to completely unseen codebases. + +## Training on This Dataset + +Once you've generated the dataset, there are two training approaches. See [Architecture Research](architecture-research.md) for the full rationale behind each. + +### Approach A: Token Classification (ModernBERT) + +The standard LettuceDetect approach β€” a lightweight encoder that labels each token as supported or hallucinated. + +```bash +# Code data only +python scripts/train_code_hallucination.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/code_hallucination_detector \ + --batch-size 4 \ + --epochs 6 \ + --learning-rate 1e-5 + +# Code + RAGTruth combined (better generalization across text and code) +python scripts/train_code_hallucination.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --ragtruth-path data/ragtruth/ragtruth_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/code_hallucination_detector \ + --batch-size 4 \ + --epochs 6 +``` + +The training script uses SWE-bench splits directly β€” train for training, dev for validation, test held out. Zero repository overlap between splits. + +### Approach D: Generative Span Detection (Qwen SFT) + +Fine-tune a decoder LLM to read context + answer and generate a JSON list of hallucinated spans with explanations. This is the reverse of the injection pipeline. + +```bash +# Requires: pip install peft +python scripts/train_generative_detector.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name Qwen/Qwen3.5-2B \ + --output-dir output/generative_detector \ + --batch-size 2 \ + --epochs 3 \ + --lora-r 16 +``` + +The model learns to output: + +```json +{"hallucinated_spans": [{"text": "response.json_decode()", "explanation": "method is json(), not json_decode()"}]} +``` + +For clean samples it outputs `{"hallucinated_spans": []}`. + +Training uses [LoRA](https://arxiv.org/abs/2106.09685) for memory efficiency (~5-8GB VRAM). Only the model's response tokens contribute to the loss β€” context and prompt tokens are masked. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--model-name` | `Qwen/Qwen3.5-2B` | Any HuggingFace causal LM | +| `--lora-r` | 16 | LoRA rank (higher = more capacity, more memory) | +| `--lora-alpha` | 32 | LoRA scaling factor | +| `--batch-size` | 2 | Training batch size | +| `--epochs` | 3 | Number of training epochs | +| `--learning-rate` | 2e-4 | Learning rate | +| `--gradient-accumulation-steps` | 4 | Accumulate gradients over N steps (simulates larger batch) | + +### Evaluate + +```bash +python scripts/evaluate_code_hallucination.py \ + --model_path output/code_hallucination_detector \ + --data_path data/code_hallucination/code_hallucination_data.json \ + --evaluation_type example_level +``` + +## End-to-End Example + +```bash +# 1. Generate dataset (local vLLM for speed) +vllm serve Qwen/Qwen3.5-2B --port 8000 # in another terminal + +BATCH_SIZE=16 \ +API_BASE_URL=http://localhost:8000/v1 \ +OPENAI_API_KEY=dummy \ +MODEL=Qwen/Qwen3.5-2B \ + python -m scripts.code_hallucination.pipeline --all + +# 2. Train +python scripts/train_code_hallucination.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/code_hallucination_detector \ + --batch-size 4 --epochs 6 + +# 3. Evaluate +python scripts/evaluate_code_hallucination.py \ + --model_path output/code_hallucination_detector \ + --data_path data/code_hallucination/code_hallucination_data.json \ + --evaluation_type example_level +``` + +The pipeline is **fully resumable** β€” every slow phase saves results incrementally to JSONL. If it crashes, re-run the same command and it picks up where it left off. diff --git a/docs/code-hallucination/phases.md b/docs/code-hallucination/phases.md new file mode 100644 index 0000000..6e93da0 --- /dev/null +++ b/docs/code-hallucination/phases.md @@ -0,0 +1,272 @@ +# Pipeline Phases + +Detailed documentation for each of the 9 pipeline phases. + +## Phase 1: Load SWE-bench + +**Module:** `swebench_loader.py` + +Loads all SWE-bench splits from HuggingFace and tags each instance with split and Lite membership. + +| Split | Instances | Repos | Source | +|-------|-----------|-------|--------| +| Train | 19,008 | 35 | `princeton-nlp/SWE-bench` train | +| Dev | 225 | 6 | `princeton-nlp/SWE-bench` dev | +| Test | 2,294 | 12 | `princeton-nlp/SWE-bench` test | +| Lite | 300 | 12 | `princeton-nlp/SWE-bench_Lite` (subset of test) | + +**Key function:** `load_all_splits() -> list[dict]` + +Each instance includes: `instance_id`, `repo`, `base_commit`, `patch`, `problem_statement`, `split`, `is_lite`. + +**Output:** `data/code_hallucination/swebench_instances.json` + +--- + +## Phase 2: Fetch Sources + +**Module:** `source_fetcher.py` + +Clones repositories and extracts source code at the base commit for each instance. Builds three answer format variants. + +### Strategy + +- **Default:** Clone repos as bare git repos to `data/code_hallucination/repos/`. Use `git show {commit}:{path}` for instant file access. +- **Test mode:** Use GitHub raw API (`raw.githubusercontent.com`) β€” slower but no cloning needed. +- **Fallback:** If cloning fails, automatically falls back to GitHub API. + +### What it extracts per instance + +| Field | Description | +|-------|-------------| +| `changed_files` | File paths modified by the gold patch | +| `source_files` | Original source code at base commit | +| `patch_code` | Added/changed lines from the diff (fragment format) | +| `edit_style` | "In file X, replace Y with Z" format | +| `modified_functions` | AST-extracted functions that changed (complete function format) | + +### Key functions + +- `extract_changed_files(patch)` β€” Parse unified diff for file paths (anchored regex, not `lstrip("b/")`) +- `clone_repo(repo)` β€” `git clone --bare` with 30min timeout +- `fetch_file_at_commit(repo_dir, commit, filepath)` β€” `git show` for file contents +- `apply_patch_and_get_file(repo_dir, commit, patch, filepath)` β€” Apply patch in temp worktree +- `extract_modified_functions(original, patched)` β€” AST-based function diff + +**Output:** `data/code_hallucination/source_cache/{instance_id}.json` + +--- + +## Phase 3: Rewrite Queries + +**Module:** `query_rewriter.py` + +Transforms raw GitHub issue `problem_statement` fields into natural developer queries using an LLM. + +### Example + +**Before (raw issue):** +> BUG: DataFrame.groupby with as_index=False gives wrong result when grouping by single column with duplicate name. Steps to reproduce: ... + +**After (rewritten):** +> I'm getting wrong results when using DataFrame.groupby with as_index=False on a column that has a duplicate name. How do I fix this? + +### Prompt strategy + +The LLM is instructed to: + +- Write conversational, natural language +- Extract the core technical ask +- Remove GitHub formatting, reproduction steps, tracebacks +- Keep to 1-3 sentences + +### Resumability + +Writes results to JSONL incrementally. On restart, skips already-processed `instance_id`s. + +**Output:** `data/code_hallucination/queries.jsonl` + +--- + +## Phase 4: Fetch Documentation + +**Module:** `context7_docs.py` + +Fetches library documentation via the [Context7](https://context7.com) API for **50% of instances** (configurable via `DOCS_RATIO`). + +### Library detection + +Detects libraries from: +1. Import statements in the patch (`import django`, `from sklearn import ...`) +2. File paths (`django/http/response.py` β†’ django) + +Maps to Context7 library names via a predefined dictionary. + +### 50/50 split rationale + +Half of samples include documentation context, half don't. This creates training variety β€” models learn to detect hallucinations both with and without documentation support. + +Instances not selected for docs still get an entry written with empty docs (by design, not failure). + +**Output:** `data/code_hallucination/documentation.jsonl` + +--- + +## Phase 5: Assign Answer Formats + +**Module:** `format_builder.py` + +Each instance gets exactly one answer format, chosen by weighted random selection from available options. + +### Format types + +**Complete function** (weight: 0.4) +```python +def validate_response(self, response): + if response.status_code != 200: + raise ValidationError(f"Unexpected status: {response.status_code}") + return response.json() +``` +Extracted via Python AST from the patched source. Only available when changes are inside a function (~60% of patches). + +**Edit-style** (weight: 0.3) +``` +In file django/http/response.py, replace: + def set_cookie(self, key, value=""): + self.cookies[key] = value +with: + def set_cookie(self, key, value="", max_age=None): + self.cookies[key] = value + if max_age is not None: + self.cookies[key]["max-age"] = max_age +``` +Available for all patches where changed regions can be extracted. + +**Fragment** (weight: 0.3) +```python +if max_age is not None: + self.cookies[key]["max-age"] = max_age + self.cookies[key]["expires"] = http_date(time.time() + max_age) +``` +Added/changed lines from the diff with surrounding context. + +**Output:** `data/code_hallucination/formats.jsonl` + +--- + +## Phase 6: Inject Hallucinations + +**Module:** `hallucination_injector.py` + +Uses an LLM to inject realistic hallucinations into selected instances (determined by Phase 8). Returns structured JSON with span annotations. + +### Hallucination types (round-robin) + +| Type | Description | Example | +|------|-------------|---------| +| **Structural** | Non-existent APIs, wrong methods, invented parameters | `response.json_decode()` instead of `response.json()` | +| **Behavioral** | Wrong values, logic errors, off-by-one, swapped conditions | `if status >= 200` instead of `if status == 200` | +| **Semantic** | Code that looks right but does something subtly different | Sorting ascending instead of descending | + +### JSON-based span extraction + +The LLM returns structured output: + +```json +{ + "hallucinated_code": "def fix(self):\n self.data = response.json_decode()\n ...", + "changes": [ + { + "original": "response.json()", + "hallucinated": "response.json_decode()", + "explanation": "json_decode() is not a valid method on Response objects" + } + ] +} +``` + +Spans are found by string-matching each `change["hallucinated"]` in `hallucinated_code`. This produces clean, meaningful spans (minimum 3 chars) with zero noise. + +### Quality metrics (from 100-sample test runs) + +| Metric | Value | +|--------|-------| +| Noise-only samples | 0% | +| Min span length | 10 chars | +| Avg span length | 70 chars | +| Avg spans per sample | 1.2 | + +**Output:** `data/code_hallucination/hallucinated_samples.jsonl` + +--- + +## Phase 7: Assemble Samples + +**Module:** `sample_assembler.py` + +Combines all intermediate data into the final `HallucinationSample` format. + +### Prompt construction + +``` +File: path/to/file.py +```python + +``` + +Documentation for django: + + +User request: +``` + +### Sample types + +**Clean samples** (~60%): Gold patch answer, empty labels, from instances NOT selected for injection. + +**Hallucinated samples** (~40%): LLM-modified answer with character-level span annotations. + +### Outputs + +- `data/code_hallucination/code_hallucination_data.json` β€” List of samples +- `data/code_hallucination/code_hallucination_metadata.json` β€” Metadata (instance_id, repo, format_type, hallucination_type, injector_model, is_hallucinated) + +--- + +## Phase 8: Select Hallucination Targets + +**Module:** `splitter.py` + +Selects which instances receive hallucination injection. Applies the hallucination ratio (default 40%) **uniformly within each split** to maintain consistent class distribution. + +``` +Train: ~7,600 hallucinated + ~11,400 clean = ~19,000 +Dev: ~90 hallucinated + ~135 clean = ~225 +Test: ~920 hallucinated + ~1,374 clean = ~2,294 +``` + +!!! note + Phase 8 runs **before** Phase 6 in the pipeline (target selection must happen before injection). + +**Output:** Set of `instance_id`s (used in-memory by Phase 6 and Phase 7) + +--- + +## Phase 9: Validate + +**Module:** `validator.py` + +Runs automated quality checks and generates a report. + +### Checks performed + +| Check | Description | +|-------|-------------| +| **Span validity** | No negative offsets, empty spans, or out-of-bounds | +| **Span coverage** | Distribution of hallucinated text ratio; flags <2% or >80% | +| **Distributions** | Format type, hallucination type, injector model, repo, split | +| **Near-duplicates** | Jaccard similarity >0.95 on sampled answer pairs | +| **AST parseability** | For complete_function format, checks if answer parses as valid Python | +| **Length statistics** | Prompt and answer character length ranges | + +**Output:** `data/code_hallucination/validation_report.txt` diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md new file mode 100644 index 0000000..273b669 --- /dev/null +++ b/docs/getting-started/installation.md @@ -0,0 +1,42 @@ +# Installation + +## From PyPI + +```bash +pip install lettucedetect +``` + +## From Source (development) + +```bash +git clone https://github.com/KRLabsOrg/LettuceDetect.git +cd LettuceDetect +pip install -e . +``` + +## Optional Dependencies + +```bash +# Web API server +pip install -e ".[api]" + +# Development tools (testing, linting) +pip install -e ".[dev]" + +# Documentation site +pip install -e ".[docs]" +``` + +## Requirements + +- Python >= 3.10 +- PyTorch >= 2.6.0 +- Transformers >= 4.48.3 + +## Environment Variables + +For LLM-based detection or data generation: + +```bash +export OPENAI_API_KEY=your_api_key +``` diff --git a/docs/getting-started/models.md b/docs/getting-started/models.md new file mode 100644 index 0000000..b79af3d --- /dev/null +++ b/docs/getting-started/models.md @@ -0,0 +1,31 @@ +# Models + +## English Models + +| Model | Base | Max Tokens | Example F1 | Span F1 | +|-------|------|-----------|-----------|---------| +| [lettucedetect-base-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-modernbert-en-v1) | ModernBERT-base | 4K | 76.8% | SOTA | +| [lettucedetect-large-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-large-modernbert-en-v1) | ModernBERT-large | 4K | 79.2% | SOTA | + +## Multilingual Models + +| Model | Base | Languages | Max Tokens | +|-------|------|-----------|-----------| +| [lettucedetect-base-eurobert-multilingual-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1) | EuroBERT-210M | en, de, fr, es, it, pl, cn | 8K | + +## TinyLettuce (Distilled) + +Smaller models for resource-constrained environments. See [TinyLettuce docs](../TINYLETTUCE.md). + +## Using a Model + +```python +from lettucedetect.models.inference import HallucinationDetector + +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-large-modernbert-en-v1" +) +``` + +Models are downloaded automatically from HuggingFace Hub on first use. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md new file mode 100644 index 0000000..906c1d3 --- /dev/null +++ b/docs/getting-started/quickstart.md @@ -0,0 +1,65 @@ +# Quick Start + +## Detect Hallucinations + +```python +from lettucedetect.models.inference import HallucinationDetector + +# Load a pre-trained model +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1" +) + +# Provide context, question, and answer +contexts = [ + "France is a country in Europe. The capital of France is Paris. " + "The population of France is 67 million." +] +question = "What is the capital of France? What is the population?" +answer = "The capital of France is Paris. The population of France is 69 million." + +# Get span-level predictions +predictions = detector.predict( + context=contexts, + question=question, + answer=answer, + output_format="spans" +) +print(predictions) +# [{'start': 31, 'end': 71, 'confidence': 0.99, +# 'text': ' The population of France is 69 million.'}] +``` + +## Available Models + +| Model | Language | Context | Size | +|-------|----------|---------|------| +| `KRLabsOrg/lettucedetect-base-modernbert-en-v1` | English | 4K | 149M | +| `KRLabsOrg/lettucedetect-large-modernbert-en-v1` | English | 4K | 395M | +| `KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1` | 7 languages | 8K | 210M | + +See [Models](models.md) for the full list. + +## Detection Methods + +```python +# Transformer-based (recommended for production) +detector = HallucinationDetector(method="transformer", model_path="...") + +# LLM-based (uses OpenAI API) +detector = HallucinationDetector(method="llm", model_path="gpt-4o-mini") + +# RAG Fact Checker (triplet-based) +detector = HallucinationDetector(method="rag_fact_checker", model_path="gpt-4o-mini") +``` + +## Output Formats + +```python +# Span-level: exact character ranges of hallucinated text +predictions = detector.predict(..., output_format="spans") + +# Sentence-level: which sentences contain hallucinations +predictions = detector.predict(..., output_format="sentences") +``` diff --git a/docs/guide/api.md b/docs/guide/api.md new file mode 100644 index 0000000..45f1f67 --- /dev/null +++ b/docs/guide/api.md @@ -0,0 +1,37 @@ +# Web API + +LettuceDetect includes a FastAPI server for HTTP-based hallucination detection. + +## Start the Server + +```bash +# Development +python scripts/start_api.py dev + +# Production +python scripts/start_api.py prod + +# Custom model +python scripts/start_api.py dev --model KRLabsOrg/lettucedetect-large-modernbert-en-v1 +``` + +## Python Client + +```python +from lettucedetect_api.client import LettuceClient + +client = LettuceClient("http://127.0.0.1:8000") + +contexts = ["The capital of France is Paris."] +question = "What is the capital of France?" +answer = "The capital of France is Berlin." + +response = client.detect_spans(contexts, question, answer) +print(response) +``` + +## Installation + +```bash +pip install lettucedetect[api] +``` diff --git a/docs/guide/detection-methods.md b/docs/guide/detection-methods.md new file mode 100644 index 0000000..6f75e3c --- /dev/null +++ b/docs/guide/detection-methods.md @@ -0,0 +1,42 @@ +# Detection Methods + +LettuceDetect supports three ways to detect hallucinations, each with different trade-offs. + +## Transformer (Recommended) + +Fine-tuned encoder models that classify each token in the answer as supported or hallucinated. Best balance of speed and accuracy. + +```python +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-large-modernbert-en-v1" +) +``` + +**How it works:** The model reads the context, question, and answer together. It labels each answer token, then merges consecutive hallucinated tokens into character spans. A single forward pass β€” fast enough for production use (30-60 samples/sec on GPU). + +**When to use:** Production systems, latency-sensitive applications, or when you need precise span locations. + +## LLM-based + +Uses OpenAI-compatible APIs (GPT-4, Claude, etc.) for hallucination detection. No fine-tuning needed. + +```python +detector = HallucinationDetector(method="llm", model_path="gpt-4o-mini") +``` + +**How it works:** Sends context + question + answer to the LLM with a prompt requesting hallucination spans in a structured format. + +**When to use:** Quick prototyping, or when you want the LLM to explain *why* something is hallucinated. + +## RAG Fact Checker + +Triplet-based fact checking that breaks the answer into structured claims and verifies each one. + +```python +detector = HallucinationDetector(method="rag_fact_checker", model_path="gpt-4o-mini") +``` + +**How it works:** Extracts (subject, predicate, object) claims from the answer, then checks each claim against the context. + +**When to use:** When you want claim-level granularity and structured verification results. diff --git a/docs/guide/evaluation.md b/docs/guide/evaluation.md new file mode 100644 index 0000000..e59354b --- /dev/null +++ b/docs/guide/evaluation.md @@ -0,0 +1,47 @@ +# Evaluation + +LettuceDetect supports three levels of evaluation, from coarse to fine-grained. + +## Example-Level + +The simplest check: does the answer contain **any** hallucination? This is a binary yes/no classification per answer. + +```bash +python scripts/evaluate.py \ + --model_path output/hallucination_detector \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type example_level +``` + +## Span-Level (Character IoU) + +Measures how well predicted hallucination spans overlap with the gold annotations at the character level. This is the most demanding metric β€” the model must identify not just *that* something is wrong, but *exactly where*. + +```bash +python scripts/evaluate.py \ + --model_path output/hallucination_detector \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type span_level +``` + +## Token-Level + +Per-token precision, recall, and F1 for the hallucinated class. This is what the model is directly optimized for during training. + +## Metrics Explained + +- **Precision**: Of everything the model flagged as hallucinated, how much was actually hallucinated? +- **Recall**: Of all the actual hallucinations, how many did the model catch? +- **F1**: The balance between precision and recall (harmonic mean). This is the primary metric. +- **AUROC**: Area under the ROC curve β€” measures how well the model separates hallucinated from supported tokens across all confidence thresholds. + +## LLM Baselines + +You can also evaluate LLM-based detectors for comparison: + +```bash +python scripts/evaluate_llm.py \ + --model "gpt-4o-mini" \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type example_level +``` diff --git a/docs/guide/integrations.md b/docs/guide/integrations.md new file mode 100644 index 0000000..953bcd7 --- /dev/null +++ b/docs/guide/integrations.md @@ -0,0 +1,59 @@ +# Integrations + +LettuceDetect can be used with popular LLM frameworks. Integration examples are available in the repository's `integrations/` directory (see the [GitHub repo](https://github.com/KRLabsOrg/LettuceDetect)). + +## LangChain + +Use LettuceDetect as a callback, chain component, or tool within LangChain pipelines. + +```python +from lettucedetect.models.inference import HallucinationDetector + +# Create detector +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1" +) + +# Use in your LangChain pipeline +def check_hallucination(context, question, answer): + spans = detector.predict( + context=context, question=question, + answer=answer, output_format="spans" + ) + return len(spans) == 0 # True if no hallucinations detected +``` + +## Pydantic AI + +Use LettuceDetect within Pydantic AI agents for structured hallucination checking. + +## Haystack + +Add LettuceDetect as a pipeline component in Haystack for post-generation verification. + +## General Pattern + +Any framework that gives you access to the retrieved context and generated answer can use LettuceDetect: + +```python +from lettucedetect.models.inference import HallucinationDetector + +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1" +) + +# After your RAG pipeline generates an answer: +spans = detector.predict( + context=retrieved_documents, + question=user_query, + answer=generated_answer, + output_format="spans" +) + +if spans: + print(f"Found {len(spans)} hallucinated spans") + for span in spans: + print(f" '{span['text']}' (confidence: {span['confidence']:.2f})") +``` diff --git a/docs/guide/multilingual.md b/docs/guide/multilingual.md new file mode 100644 index 0000000..0f20a83 --- /dev/null +++ b/docs/guide/multilingual.md @@ -0,0 +1,29 @@ +# Multilingual Support + +LettuceDetect supports hallucination detection in multiple languages via EuroBERT-based models. + +## Supported Languages + +| Language | Code | Model | +|----------|------|-------| +| English | en | ModernBERT + EuroBERT | +| German | de | EuroBERT | +| French | fr | EuroBERT | +| Spanish | es | EuroBERT | +| Italian | it | EuroBERT | +| Polish | pl | EuroBERT | +| Chinese | cn | EuroBERT | +| Hungarian | hu | EuroBERT | + +## Usage + +```python +from lettucedetect.models.inference import HallucinationDetector + +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1" +) +``` + +The multilingual model handles all supported languages with a single checkpoint. No language detection or switching needed. diff --git a/docs/guide/training.md b/docs/guide/training.md new file mode 100644 index 0000000..59c454f --- /dev/null +++ b/docs/guide/training.md @@ -0,0 +1,109 @@ +# Training Your Own Model + +LettuceDetect models are standard HuggingFace token classifiers. You can fine-tune them on your own data or on the provided RAGTruth dataset. + +## How It Works + +The model reads a concatenated input β€” context, question, and answer β€” and classifies each answer token as either **supported** (0) or **hallucinated** (1). Consecutive hallucinated tokens are merged into spans at inference time. + +``` +Input: [CLS] context [SEP] question [SEP] answer [SEP] +Labels: [-100, -100, ..., -100, ..., 0, 0, 1, 1, 1, 0, ...] + ^^^^^^^^^^^^^^^^^^^^^^^^ + only answer tokens are labeled +``` + +- `-100` = ignored by the loss function (context + question tokens) +- `0` = supported (answer token backed by context) +- `1` = hallucinated (answer token not supported by context) + +The best checkpoint is saved based on validation F1 for the hallucinated class. + +## Train on RAGTruth + +[RAGTruth](https://aclanthology.org/2024.acl-long.585/) is a text hallucination dataset covering QA, summarization, and data-to-text tasks. + +```bash +# Preprocess the raw data first +python lettucedetect/preprocess/preprocess_ragtruth.py \ + --input_dir data/ragtruth --output_dir data/ragtruth + +# Train +python scripts/train.py \ + --ragtruth-path data/ragtruth/ragtruth_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/hallucination_detector \ + --batch-size 4 \ + --epochs 6 \ + --learning-rate 1e-5 +``` + +## Train on Your Own Data + +Your data must follow the `HallucinationSample` format β€” a JSON list where each item has: + +```json +{ + "prompt": "Your context text here...", + "answer": "The LLM-generated answer to check...", + "labels": [{"start": 45, "end": 72, "label": "hallucination"}], + "split": "train", + "task_type": "qa", + "dataset": "ragtruth", + "language": "en" +} +``` + +- `labels` contains character-level spans within the `answer` marking hallucinated regions. Empty list `[]` for clean samples. +- `split` should be `"train"`, `"dev"`, or `"test"`. + +Then train the same way: + +```bash +python scripts/train.py \ + --ragtruth-path path/to/your_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/my_detector \ + --batch-size 4 \ + --epochs 6 +``` + +## Configuration + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--model-name` | `answerdotai/ModernBERT-base` | HuggingFace model to fine-tune | +| `--batch-size` | 4 | Training batch size | +| `--epochs` | 6 | Number of training epochs | +| `--learning-rate` | 1e-5 | Learning rate | +| `--max-length` | 4096 | Maximum input length in tokens | +| `--seed` | 42 | Random seed for reproducibility | + +## Recommended Base Models + +| Model | Parameters | Max Tokens | Best For | +|-------|-----------|------------|----------| +| `answerdotai/ModernBERT-base` | 149M | 8K | Fast training, good baseline | +| `answerdotai/ModernBERT-large` | 395M | 8K | Best English performance | +| `EuroBERT/EuroBERT-210m` | 210M | 8K | Multilingual (7 languages) | +| `EuroBERT/EuroBERT-610m` | 610M | 8K | Best multilingual performance | + +## Evaluate + +```bash +# Example-level: does the answer contain any hallucination? +python scripts/evaluate.py \ + --model_path output/hallucination_detector \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type example_level + +# Span-level: how well do predicted spans match gold spans? +python scripts/evaluate.py \ + --model_path output/hallucination_detector \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type span_level +``` + +## Code Hallucination + +For training on code hallucination data (from SWE-bench), see the [Code Hallucination Dataset](../code-hallucination/index.md) section which covers generating the dataset and training both encoder and generative models. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..619df26 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,60 @@ +# LettuceDetect + +

+ LettuceDetect Logo +

+ +**A lightweight hallucination detection framework for RAG applications.** + +LettuceDetect is an encoder-based model built on [ModernBERT](https://github.com/AnswerDotAI/ModernBERT) that detects unsupported spans in LLM-generated answers by comparing them against provided context. It provides token-level precision for identifying exactly which parts of an answer are hallucinated. + +## Highlights + +- **Token-level precision** β€” identifies exact hallucinated spans, not just "this answer has a problem" +- **Fast inference** β€” 30-60 samples/sec on A100, suitable for production +- **Long context** β€” supports up to 4K tokens (ModernBERT) or 8K tokens (EuroBERT) +- **Multilingual** β€” English, German, French, Spanish, Italian, Polish, Chinese, Hungarian +- **Open source** β€” MIT license, models on HuggingFace + +## Quick Example + +```python +from lettucedetect.models.inference import HallucinationDetector + +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1" +) + +contexts = ["The capital of France is Paris. The population is 67 million."] +question = "What is the capital and population of France?" +answer = "The capital of France is Paris. The population is 69 million." + +predictions = detector.predict( + context=contexts, question=question, answer=answer, output_format="spans" +) +# [{'start': 31, 'end': 71, 'confidence': 0.99, 'text': ' The population of France is 69 million.'}] +``` + +## Performance + +| Model | Example F1 | vs GPT-4 | vs Luna | Parameters | +|-------|-----------|----------|---------|-----------| +| lettucedetect-base-v1 | 76.8% | +13.4% | +11.4% | 149M | +| lettucedetect-large-v1 | **79.2%** | +15.8% | +13.8% | 395M | + +Evaluated on [RAGTruth](https://aclanthology.org/2024.acl-long.585/) test set. Surpasses GPT-4, Luna, and fine-tuned Llama-2-13B. + +## What's New + +- **[Code Hallucination Dataset](code-hallucination/index.md)** β€” A pipeline for generating span-level code hallucination data from SWE-bench (~18k samples across 53 repos) +- **Multilingual models** β€” EuroBERT-based models for 8 languages +- **Web API** β€” FastAPI server with async client support + +## Links + +- [GitHub](https://github.com/KRLabsOrg/LettuceDetect) +- [PyPI](https://pypi.org/project/lettucedetect/) +- [arXiv Paper](https://arxiv.org/abs/2502.17125) +- [HuggingFace Models](https://huggingface.co/KRLabsOrg) +- [Streamlit Demo](https://huggingface.co/spaces/KRLabsOrg/LettuceDetect) diff --git a/lettucedetect/__init__.py b/lettucedetect/__init__.py index 928b8f6..62a7324 100644 --- a/lettucedetect/__init__.py +++ b/lettucedetect/__init__.py @@ -15,7 +15,7 @@ # Direct RAGFactChecker access for advanced users from lettucedetect.ragfactchecker import RAGFactChecker -__version__ = "0.1.7" +__version__ = "0.1.8" __all__ = [ "HallucinationData", diff --git a/lettucedetect/datasets/analyze_lengths.py b/lettucedetect/datasets/analyze_lengths.py new file mode 100644 index 0000000..e98fd3c --- /dev/null +++ b/lettucedetect/datasets/analyze_lengths.py @@ -0,0 +1,184 @@ +import argparse +import json +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List + +import matplotlib.pyplot as plt +from hallucination_dataset import HallucinationData, HallucinationSample +from tqdm import tqdm +from transformers import AutoTokenizer + + +def analyze_lengths( + samples: List[HallucinationSample], + tokenizer: AutoTokenizer, + max_length: int = 4096, +) -> Dict[str, Any]: + """Analyze the token lengths of samples and identify outliers. + + Args: + samples: List of HallucinationSample objects + tokenizer: Tokenizer to use + max_length: Maximum allowed length + + Returns: + Dictionary containing analysis results + + """ + lengths = [] + outliers = [] + split_stats = defaultdict(list) + + for sample in tqdm(samples, desc="Analyzing lengths"): + # Tokenize combined context and answer + tokens = tokenizer( + sample.prompt, + sample.answer, + truncation=False, # No truncation to get true lengths + add_special_tokens=True, + ) + + length = len(tokens["input_ids"]) + lengths.append(length) + split_stats[sample.split].append(length) + + if length > max_length: + outliers.append( + { + "index": len(lengths) - 1, + "length": length, + "prompt_length": len(tokenizer(sample.prompt)["input_ids"]), + "answer_length": len(tokenizer(sample.answer)["input_ids"]), + "split": sample.split, + } + ) + + # Calculate statistics + stats = { + "total_samples": len(samples), + "max_length": max(lengths), + "min_length": min(lengths), + "mean_length": sum(lengths) / len(lengths), + "num_outliers": len(outliers), + "outliers": outliers, + "split_stats": { + split: { + "count": len(lens), + "max": max(lens), + "min": min(lens), + "mean": sum(lens) / len(lens), + "outliers": len([l for l in lens if l > max_length]), + } + for split, lens in split_stats.items() + }, + } + + return stats + + +def plot_length_distribution(lengths: List[int], output_path: str, max_length: int): + """Plot the distribution of sequence lengths.""" + plt.figure(figsize=(10, 6)) + plt.hist(lengths, bins=50) + plt.axvline(x=max_length, color="r", linestyle="--", label=f"Max Length ({max_length})") + plt.xlabel("Sequence Length") + plt.ylabel("Count") + plt.title("Distribution of Sequence Lengths") + plt.legend() + plt.savefig(output_path) + plt.close() + + +def filter_dataset( + data: HallucinationData, max_length: int, tokenizer: AutoTokenizer, output_path: Path +) -> HallucinationData: + """Filter out samples that exceed max_length.""" + filtered_samples = [] + + for sample in tqdm(data.samples, desc="Filtering samples"): + tokens = tokenizer( + sample.prompt, + sample.answer, + truncation=False, + add_special_tokens=True, + ) + + if len(tokens["input_ids"]) <= max_length: + filtered_samples.append(sample) + + filtered_data = HallucinationData(samples=filtered_samples) + + # Save filtered dataset + with open(output_path, "w") as f: + json.dump(filtered_data.to_json(), f, indent=2) + + return filtered_data + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze and filter dataset based on sequence lengths" + ) + parser.add_argument("--input_file", type=str, required=True, help="Path to input JSON dataset") + parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files") + parser.add_argument( + "--model_name", + type=str, + default="meta-llama/Llama-2-7b-hf", + help="Model name for tokenizer", + ) + parser.add_argument("--max_length", type=int, default=4096, help="Maximum sequence length") + + args = parser.parse_args() + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Load dataset + with open(args.input_file) as f: + data = HallucinationData.from_json(json.load(f)) + + # Analyze lengths + stats = analyze_lengths(data.samples, tokenizer, args.max_length) + + # Save statistics + with open(output_dir / "length_stats.json", "w") as f: + json.dump(stats, f, indent=2) + + # Plot distribution + plot_length_distribution( + [ + len(tokenizer(s.prompt, s.answer, add_special_tokens=True)["input_ids"]) + for s in data.samples + ], + str(output_dir / "length_distribution.png"), + args.max_length, + ) + + # Filter dataset + filtered_data = filter_dataset( + data, args.max_length, tokenizer, output_dir / "filtered_dataset.json" + ) + + print(f"\nAnalysis complete. Results saved to {output_dir}") + print(f"Original dataset size: {len(data.samples)}") + print(f"Filtered dataset size: {len(filtered_data.samples)}") + print(f"Removed {len(data.samples) - len(filtered_data.samples)} samples") + + # Print split-wise statistics + print("\nSplit-wise statistics:") + for split, stats in stats["split_stats"].items(): + print(f"\n{split.upper()}:") + print(f" Total samples: {stats['count']}") + print(f" Samples exceeding max length: {stats['outliers']}") + print(f" Max length: {stats['max']}") + print(f" Mean length: {stats['mean']:.2f}") + + +if __name__ == "__main__": + main() diff --git a/lettucedetect/datasets/hallucination_dataset.py b/lettucedetect/datasets/hallucination_dataset.py index 64e2644..1713193 100644 --- a/lettucedetect/datasets/hallucination_dataset.py +++ b/lettucedetect/datasets/hallucination_dataset.py @@ -8,15 +8,30 @@ @dataclass class HallucinationSample: + """A single hallucination detection sample. + + Attributes: + prompt: Context text (source documents, code files, documentation, user query). + answer: The LLM-generated answer to check for hallucinations. + labels: List of span annotations. Each dict has ``start``, ``end`` (character offsets + within ``answer``), and ``label`` keys. Empty list for clean samples. + split: Dataset split (``train``, ``dev``, or ``test``). + task_type: Task type (e.g. ``summarization``, ``qa``, ``code_generation``). + dataset: Source dataset (``ragtruth``, ``ragbench``, or ``swebench_code``). + language: Language code. + + """ + prompt: str answer: str labels: list[dict] split: Literal["train", "dev", "test"] task_type: str - dataset: Literal["ragtruth", "ragbench"] - language: Literal["en", "de"] + dataset: Literal["ragtruth", "ragbench", "swebench_code"] + language: Literal["en", "de", "fr", "es", "it", "pl", "cn", "hu"] def to_json(self) -> dict: + """Serialize to a JSON-compatible dict.""" return { "prompt": self.prompt, "answer": self.answer, @@ -29,6 +44,7 @@ def to_json(self) -> dict: @classmethod def from_json(cls, json_dict: dict) -> "HallucinationSample": + """Deserialize from a JSON dict.""" return cls( prompt=json_dict["prompt"], answer=json_dict["answer"], @@ -42,13 +58,22 @@ def from_json(cls, json_dict: dict) -> "HallucinationSample": @dataclass class HallucinationData: + """A collection of hallucination detection samples. + + Attributes: + samples: List of :class:`HallucinationSample` instances. + + """ + samples: list[HallucinationSample] def to_json(self) -> list[dict]: + """Serialize all samples to a JSON-compatible list.""" return [sample.to_json() for sample in self.samples] @classmethod def from_json(cls, json_dict: list[dict]) -> "HallucinationData": + """Deserialize from a list of JSON dicts.""" return cls( samples=[HallucinationSample.from_json(sample) for sample in json_dict], ) @@ -74,6 +99,7 @@ def __init__( self.max_length = max_length def __len__(self) -> int: + """Return the number of samples in the dataset.""" return len(self.samples) @classmethod @@ -84,8 +110,10 @@ def prepare_tokenized_input( answer: str, max_length: int = 4096, ) -> tuple[dict[str, torch.Tensor], list[int], torch.Tensor, int]: - """Tokenizes the context and answer together, computes the answer start token index, - and initializes a labels list (using -100 for context tokens and 0 for answer tokens). + """Tokenize context and answer, compute answer start index, and initialize labels. + + Computes the answer start token index and initializes a labels list + (using -100 for context tokens and 0 for answer tokens). :param tokenizer: The tokenizer to use. :param context: The context string. @@ -109,19 +137,14 @@ def prepare_tokenized_input( offsets = encoding.pop("offset_mapping")[0] # shape: (seq_length, 2) - # Simple approach: encode just the context with special tokens - # For most tokenizers, the answer starts right after this - context_only = tokenizer(context, add_special_tokens=True, return_tensors="pt") - # The answer starts after the context sequence (with its special tokens) - answer_start_token = context_only["input_ids"].shape[1] - - # Handle any edge cases where this might land on a special token - if ( - answer_start_token < offsets.size(0) - and offsets[answer_start_token][0] == offsets[answer_start_token][1] - ): - # If we landed on a special token, move forward - answer_start_token += 1 + # Compute answer_start_token from the answer side. This is correct + # even when the context has been truncated by truncation="only_first", + # because the answer is never truncated in that mode. + # Layout: [CLS] context_tokens [SEP] answer_tokens [SEP] + answer_only = tokenizer(answer, add_special_tokens=False, return_tensors="pt") + answer_token_count = answer_only["input_ids"].shape[1] + total_seq_len = encoding["input_ids"].shape[1] + answer_start_token = total_seq_len - answer_token_count - 1 # -1 for trailing [SEP] # Initialize labels: -100 for tokens before the asnwer, 0 for tokens in the answer. labels = [-100] * encoding["input_ids"].shape[1] diff --git a/lettucedetect/detectors/base.py b/lettucedetect/detectors/base.py index f9d20c6..9127751 100644 --- a/lettucedetect/detectors/base.py +++ b/lettucedetect/detectors/base.py @@ -19,21 +19,33 @@ def predict( """Predict hallucination tokens or spans given passages and an answer. :param context: List of passages that were supplied to the LLM / user. - :param answer: Model‑generated answer to inspect. + :param answer: Model-generated answer to inspect. :param question: Original question (``None`` for summarisation). - :param output_format: ``"tokens"`` for token‑level dicts, ``"spans"`` for character spans. + :param output_format: ``"tokens"`` for token-level dicts, ``"spans"`` for character spans. :returns: List of predictions in requested format. """ pass @abstractmethod def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list: - """Predict hallucinations when you already have a *single* full prompt string.""" + """Predict hallucinations from a pre-built prompt string. + + :param prompt: Full prompt (context + question already concatenated). + :param answer: Model-generated answer to inspect. + :param output_format: ``"tokens"`` or ``"spans"``. + :returns: List of predictions in requested format. + """ pass @abstractmethod def predict_prompt_batch( self, prompts: list[str], answers: list[str], output_format: str = "tokens" ) -> list: - """Batch version of `predict_prompt`.""" + """Batch version of :meth:`predict_prompt`. + + :param prompts: List of full prompt strings. + :param answers: List of answers to inspect. + :param output_format: ``"tokens"`` or ``"spans"``. + :returns: List of prediction lists, one per input pair. + """ pass diff --git a/lettucedetect/detectors/cache.py b/lettucedetect/detectors/cache.py index 437b21e..159edf5 100644 --- a/lettucedetect/detectors/cache.py +++ b/lettucedetect/detectors/cache.py @@ -1,4 +1,4 @@ -"""Thread‑safe JSON cache with SHA‑256 keys.""" +"""Thread-safe JSON cache with SHA-256 keys.""" from __future__ import annotations @@ -10,9 +10,13 @@ class CacheManager: - """Disk‑backed cache for expensive LLM calls.""" + """Disk-backed cache for expensive LLM calls.""" - def __init__(self, file_path: str | Path): + def __init__(self, file_path: str | Path) -> None: + """Initialize cache from disk. + + :param file_path: Path to the JSON cache file. + """ self.path = Path(file_path) self.path.parent.mkdir(parents=True, exist_ok=True) self._lock = threading.Lock() @@ -24,11 +28,13 @@ def __init__(self, file_path: str | Path): def _hash(*parts: str) -> str: return hashlib.sha256("||".join(parts).encode()).hexdigest() - def get(self, key: str) -> Any | None: + def get(self, key: str) -> Any | None: # noqa: ANN401 + """Retrieve a cached value by key.""" with self._lock: return self._data.get(key) - def set(self, key: str, value: Any) -> None: + def set(self, key: str, value: Any) -> None: # noqa: ANN401 + """Store a value in the cache and persist to disk.""" with self._lock: self._data[key] = value self.path.write_text(json.dumps(self._data, ensure_ascii=False), encoding="utf-8") diff --git a/lettucedetect/detectors/llm.py b/lettucedetect/detectors/llm.py index 4c4b010..93ea354 100644 --- a/lettucedetect/detectors/llm.py +++ b/lettucedetect/detectors/llm.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging import os import re from concurrent.futures import ThreadPoolExecutor @@ -12,6 +13,8 @@ from lettucedetect.detectors.cache import CacheManager from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils +logger = logging.getLogger(__name__) + # JSON schema for structured response format HALLUCINATION_SCHEMA = { "type": "json_schema", @@ -72,7 +75,7 @@ def __init__( ) path = Path(fewshot_path) if not path.exists(): - print(f"Warning: Few-shot examples file not found at {path}") + logger.warning("Few-shot examples file not found at %s", path) self.fewshot = json.loads(path.read_text(encoding="utf-8")) if path.exists() else [] # Load hallucination detection template @@ -90,9 +93,9 @@ def __init__( / "cache" / f"cache_{model.replace(':', '_')}_{lang}.json" ) - print(f"Using default cache file: {cache_file}") + logger.info("Using default cache file: %s", cache_file) else: - print(f"Using provided cache file: {cache_file}") + logger.info("Using provided cache file: %s", cache_file) self.cache = CacheManager(cache_file) @@ -185,8 +188,8 @@ def _predict(self, prompt: str, answer: str) -> list[dict]: payload = json.loads(cached) return self._to_spans(payload["hallucination_list"], answer) except (json.JSONDecodeError, KeyError) as e: - print(f"Error parsing LLM response: {e}") - print(f"Raw response: {cached}") + logger.error("Error parsing LLM response: %s", e) + logger.debug("Raw response: %s", cached) return [] def predict( @@ -199,7 +202,7 @@ def predict( """Predict hallucination spans from the provided context, answer, and question. :param context: List of passages that were supplied to the LLM / user. - :param answer: Model‑generated answer to inspect. + :param answer: Model-generated answer to inspect. :param question: Original question (``None`` for summarisation). :param output_format: ``"spans"`` for character spans. :returns: List of spans. diff --git a/lettucedetect/detectors/rag_fact_checker.py b/lettucedetect/detectors/rag_fact_checker.py index fb6cb73..3d66bd9 100644 --- a/lettucedetect/detectors/rag_fact_checker.py +++ b/lettucedetect/detectors/rag_fact_checker.py @@ -1,6 +1,8 @@ """Simple RAGFactChecker detector wrapper for lettuceDetect factory pattern.""" -from typing import Any, Dict, List +from __future__ import annotations + +from typing import Any from lettucedetect.detectors.base import BaseDetector @@ -14,9 +16,9 @@ class RAGFactCheckerDetector(BaseDetector): def __init__( self, - openai_api_key: str = None, + openai_api_key: str | None = None, model: str = "gpt-4o", - base_url: str = None, + base_url: str | None = None, temperature: float = 0.0, **kwargs, ): @@ -38,12 +40,12 @@ def __init__( def predict( self, - context: List[str], + context: list[str], answer: str, - question: str = None, + question: str | None = None, output_format: str = "tokens", **kwargs, - ) -> List[Dict[str, Any]] | Dict[str, Any]: + ) -> list[dict[str, Any]] | dict[str, Any]: """Predict hallucinations using RAGFactChecker. :param context: List of context documents @@ -81,13 +83,13 @@ def predict( def predict_prompt( self, prompt: str, answer: str, output_format: str = "tokens" - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Predict using a single prompt string as context.""" return self.predict([prompt], answer, output_format=output_format) def predict_prompt_batch( - self, prompts: List[str], answers: List[str], output_format: str = "tokens" - ) -> List[List[Dict[str, Any]]]: + self, prompts: list[str], answers: list[str], output_format: str = "tokens" + ) -> list[list[dict[str, Any]]]: """Batch prediction using RAGFactChecker's batch processing.""" if len(prompts) != len(answers): raise ValueError("Number of prompts must match number of answers") @@ -108,7 +110,7 @@ def predict_prompt_batch( return converted_results - def _convert_to_tokens(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]: + def _convert_to_tokens(self, answer: str, rag_result: dict[str, Any]) -> list[dict[str, Any]]: """Convert RAGFactChecker result to token format.""" tokens = answer.split() hallucinated_triplets = rag_result.get("hallucinated_triplets", []) @@ -130,7 +132,7 @@ def _convert_to_tokens(self, answer: str, rag_result: Dict[str, Any]) -> List[Di return token_predictions - def _convert_to_spans(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]: + def _convert_to_spans(self, answer: str, rag_result: dict[str, Any]) -> list[dict[str, Any]]: """Convert RAGFactChecker result to span format with improved triplet matching.""" spans = [] hallucinated_triplets = rag_result.get("hallucinated_triplets", []) @@ -195,7 +197,7 @@ def _convert_to_spans(self, answer: str, rag_result: Dict[str, Any]) -> List[Dic # Merge overlapping spans return self._merge_overlapping_spans(spans) - def _merge_overlapping_spans(self, spans: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _merge_overlapping_spans(self, spans: list[dict[str, Any]]) -> list[dict[str, Any]]: """Merge overlapping spans to avoid duplicates.""" if not spans: return spans diff --git a/lettucedetect/detectors/transformer.py b/lettucedetect/detectors/transformer.py index 3ef8c23..a64efa8 100644 --- a/lettucedetect/detectors/transformer.py +++ b/lettucedetect/detectors/transformer.py @@ -1,7 +1,9 @@ -"""Transformer‑based hallucination detector.""" +"""Transformer-based hallucination detector.""" from __future__ import annotations +import logging + import torch from transformers import AutoModelForTokenClassification, AutoTokenizer @@ -9,15 +11,28 @@ from lettucedetect.detectors.base import BaseDetector from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils +logger = logging.getLogger(__name__) + __all__ = ["TransformerDetector"] class TransformerDetector(BaseDetector): - """Detect hallucinations with a fine‑tuned token classifier.""" + """Detect hallucinations with a fine-tuned token classifier. + + When the combined context + answer exceeds ``max_length`` tokens, the + context is automatically split into chunks. Each chunk is scored + independently and the per-token hallucination probability is aggregated + across chunks with ``max()``. + """ def __init__( - self, model_path: str, max_length: int = 4096, device=None, lang: Lang = "en", **tok_kwargs - ): + self, + model_path: str, + max_length: int = 4096, + device: torch.device | str | None = None, + lang: Lang = "en", + **tok_kwargs: object, + ) -> None: """Initialize the transformer detector. :param model_path: Path to the pre-trained model. @@ -36,143 +51,358 @@ def __init__( ) self.model.to(self.device).eval() - def _predict(self, prompt: str, answer: str, output_format: str) -> list: - """Predict hallucination tokens or spans from the provided prompt and answer. + # ------------------------------------------------------------------ + # Chunking helpers + # ------------------------------------------------------------------ + + def _group_passages_into_chunks( + self, + context: list[str], + question: str | None, + answer: str, + ) -> list[list[str]]: + """Group passages into chunks so each chunk, wrapped in the full instruction template, fits in ``max_length``. + + Preserves the instruction template (question, "Bear in mind...", etc.) + in every chunk by working at the passage level instead of slicing raw + tokens. + + :param context: List of passage strings. + :param question: Original question (``None`` for summarisation). + :param answer: The answer string (token budget is reserved). + :returns: List of passage groups. Each group will be formatted into a + complete prompt via ``PromptUtils.format_context``. + """ + answer_tokens = self.tokenizer(answer, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ].shape[1] + # Total prompt-token budget: max_length minus answer tokens minus 3 special tokens + total_budget = self.max_length - answer_tokens - 3 + + # Fast path: check whether all passages fit in one prompt. + full_prompt = PromptUtils.format_context(context, question, self.lang) + full_prompt_tokens = self.tokenizer( + full_prompt, add_special_tokens=False, return_tensors="pt" + )["input_ids"].shape[1] + + if full_prompt_tokens <= total_budget: + return [context] + + # Measure instruction overhead (everything except the passage content). + minimal_prompt = PromptUtils.format_context([""], question, self.lang) + instruction_overhead = self.tokenizer( + minimal_prompt, add_special_tokens=False, return_tensors="pt" + )["input_ids"].shape[1] + + passage_budget = total_budget - instruction_overhead + if passage_budget <= 0: + # Instructions + answer alone exceed max_length; single group, will truncate. + return [context] + + # Tokenize each formatted passage line to get its token count. + p_word = LANG_TO_PASSAGE[self.lang] + passage_token_counts: list[int] = [] + for i, passage in enumerate(context): + line = f"{p_word} {i + 1}: {passage}" + tok_count = self.tokenizer(line, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ].shape[1] + passage_token_counts.append(tok_count) + + # Greedily group passages into buckets that fit within the budget. + groups: list[list[str]] = [] + current_group: list[str] = [] + current_tokens = 0 + + for passage, tok_count in zip(context, passage_token_counts): + # +1 for the newline separator between passages + effective = tok_count + (1 if current_group else 0) + if current_tokens + effective > passage_budget and current_group: + groups.append(current_group) + current_group = [] + current_tokens = 0 + effective = tok_count + current_group.append(passage) + current_tokens += effective + + if current_group: + groups.append(current_group) + + return groups if groups else [context] + + # ------------------------------------------------------------------ + # Single-chunk prediction (the original _predict logic) + # ------------------------------------------------------------------ + + def _predict_single(self, prompt: str, answer: str, output_format: str) -> list: + """Run prediction on a single (prompt, answer) pair that fits in ``max_length``. :param prompt: The prompt string. :param answer: The answer string. - :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + :param output_format: ``"tokens"`` or ``"spans"``. """ - if output_format not in ["tokens", "spans"]: - raise ValueError( - f"TransformerDetector doesn't support '{output_format}' format. " - "Use 'tokens' or 'spans'" - ) - # Use the shared tokenization logic from HallucinationDataset encoding, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input( self.tokenizer, prompt, answer, self.max_length ) - # Create a label tensor: mark tokens before answer as -100 (ignored) and answer tokens as 0. labels = torch.full_like(encoding.input_ids[0], -100, device=self.device) labels[answer_start_token:] = 0 - # Move encoding to the device + encoding = { key: value.to(self.device) for key, value in encoding.items() if key in ["input_ids", "attention_mask", "labels"] } - # Run model inference with torch.no_grad(): outputs = self.model(**encoding) logits = outputs.logits token_preds = torch.argmax(logits, dim=-1)[0] probabilities = torch.softmax(logits, dim=-1)[0] - # Mask out predictions for context tokens. token_preds = torch.where(labels == -100, labels, token_preds) if output_format == "tokens": - # return token probabilities for each token (with the tokens as well, if not -100) - token_probs = [] - input_ids = encoding["input_ids"][0] # Get the input_ids tensor from the encoding dict + token_probs: list[dict] = [] + input_ids = encoding["input_ids"][0] for i, (token, pred, prob) in enumerate(zip(input_ids, token_preds, probabilities)): - if not labels[i].item() == -100: + if labels[i].item() != -100: token_probs.append( { "token": self.tokenizer.decode([token]), "pred": pred.item(), - "prob": prob[1].item(), # Get probability for class 1 (hallucination) + "prob": prob[1].item(), } ) return token_probs - elif output_format == "spans": - # Compute the answer's character offset (the first token of the answer). - if answer_start_token < offsets.size(0): - answer_char_offset = offsets[answer_start_token][0].item() - else: - answer_char_offset = 0 - - spans: list[dict] = [] - current_span: dict | None = None - - # Iterate over tokens in the answer region. - for i in range(answer_start_token, token_preds.size(0)): - # Skip tokens marked as ignored. - if labels[i].item() == -100: - continue - - token_start, token_end = offsets[i].tolist() - # Skip special tokens with zero length. - if token_start == token_end: - continue - - # Adjust offsets relative to the answer text. - rel_start = token_start - answer_char_offset - rel_end = token_end - answer_char_offset - - is_hallucination = ( - token_preds[i].item() == 1 - ) # assuming class 1 indicates hallucination. - confidence = probabilities[i, 1].item() if is_hallucination else 0.0 - - if is_hallucination: - if current_span is None: - current_span = { - "start": rel_start, - "end": rel_end, - "confidence": confidence, - } - else: - # Extend the current span. - current_span["end"] = rel_end - current_span["confidence"] = max(current_span["confidence"], confidence) + + # output_format == "spans" + if answer_start_token < offsets.size(0): + answer_char_offset = offsets[answer_start_token][0].item() + else: + answer_char_offset = 0 + + spans: list[dict] = [] + current_span: dict | None = None + + for i in range(answer_start_token, token_preds.size(0)): + if labels[i].item() == -100: + continue + + token_start, token_end = offsets[i].tolist() + if token_start == token_end: + continue + + rel_start = token_start - answer_char_offset + rel_end = token_end - answer_char_offset + + is_hallucination = token_preds[i].item() == 1 + confidence = probabilities[i, 1].item() if is_hallucination else 0.0 + + if is_hallucination: + if current_span is None: + current_span = { + "start": rel_start, + "end": rel_end, + "confidence": confidence, + } else: - # If we were building a hallucination span, finalize it. - if current_span is not None: - # Extract the hallucinated text from the answer. - span_text = answer[current_span["start"] : current_span["end"]] - current_span["text"] = span_text - spans.append(current_span) - current_span = None - - # Append any span still in progress. - if current_span is not None: - span_text = answer[current_span["start"] : current_span["end"]] - current_span["text"] = span_text - spans.append(current_span) - - return spans + current_span["end"] = rel_end + current_span["confidence"] = max(current_span["confidence"], confidence) + else: + if current_span is not None: + current_span["text"] = answer[current_span["start"] : current_span["end"]] + spans.append(current_span) + current_span = None + + if current_span is not None: + current_span["text"] = answer[current_span["start"] : current_span["end"]] + spans.append(current_span) + + return spans + + # ------------------------------------------------------------------ + # Multi-chunk prediction with max() aggregation + # ------------------------------------------------------------------ + + def _predict_chunked(self, chunk_prompts: list[str], answer: str, output_format: str) -> list: + """Run prediction over multiple context chunks and aggregate with ``max()``. + + For each answer token, the hallucination probability is the maximum + across all chunks. This is conservative: a token is only considered + supported if *every* chunk considers it supported. + + :param chunk_prompts: Context chunk strings. + :param answer: The answer string (same for all chunks). + :param output_format: ``"tokens"`` or ``"spans"``. + """ + all_token_results: list[list[dict]] = [] + for chunk in chunk_prompts: + tokens = self._predict_single(chunk, answer, output_format="tokens") + all_token_results.append(tokens) + + n_tokens = len(all_token_results[0]) + + # Aggregate: max hallucination probability across chunks. + aggregated: list[dict] = [] + for tok_idx in range(n_tokens): + max_prob = max( + chunk_result[tok_idx]["prob"] + for chunk_result in all_token_results + if tok_idx < len(chunk_result) + ) + aggregated.append( + { + "token": all_token_results[0][tok_idx]["token"], + "pred": 1 if max_prob >= 0.5 else 0, + "prob": max_prob, + } + ) + + if output_format == "tokens": + return aggregated + + # For "spans", build character spans from the aggregated token predictions. + # Use offset mapping from the first chunk (answer offsets are identical + # across chunks because BERT tokenizers process segments independently). + return self._build_spans_from_tokens(aggregated, chunk_prompts[0], answer) + + def _build_spans_from_tokens( + self, token_results: list[dict], prompt: str, answer: str + ) -> list[dict]: + """Convert aggregated token predictions into character-level spans. + + Uses the tokenizer offset mapping from encoding *(prompt, answer)* to + get precise character positions within the answer. + + :param token_results: Per-answer-token dicts with ``token``, ``pred``, ``prob``. + :param prompt: A context prompt string (used to get answer offset mapping). + :param answer: The original answer string. + :returns: List of span dicts with ``start``, ``end``, ``confidence``, ``text``. + """ + _, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input( + self.tokenizer, prompt, answer, self.max_length + ) + + if answer_start_token < offsets.size(0): + answer_char_offset = offsets[answer_start_token][0].item() else: - raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.") + answer_char_offset = 0 + + # Collect answer-region offsets (skip special tokens with zero length). + answer_offsets: list[tuple[int, int]] = [] + for i in range(answer_start_token, offsets.size(0)): + s, e = offsets[i].tolist() + if s == e: + continue + answer_offsets.append((s - answer_char_offset, e - answer_char_offset)) + + spans: list[dict] = [] + current_span: dict | None = None + + for tok_idx, tok in enumerate(token_results): + if tok_idx >= len(answer_offsets): + break + rel_start, rel_end = answer_offsets[tok_idx] + is_hall = tok["pred"] == 1 + confidence = tok["prob"] if is_hall else 0.0 - def predict(self, context, answer, question=None, output_format="tokens") -> list: + if is_hall: + if current_span is None: + current_span = { + "start": rel_start, + "end": rel_end, + "confidence": confidence, + } + else: + current_span["end"] = rel_end + current_span["confidence"] = max(current_span["confidence"], confidence) + else: + if current_span is not None: + current_span["text"] = answer[current_span["start"] : current_span["end"]] + spans.append(current_span) + current_span = None + + if current_span is not None: + current_span["text"] = answer[current_span["start"] : current_span["end"]] + spans.append(current_span) + + return spans + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def predict( + self, + context: list[str], + answer: str, + question: str | None = None, + output_format: str = "tokens", + ) -> list: """Predict hallucination tokens or spans from the provided context, answer, and question. :param context: List of passages that were supplied to the LLM / user. - :param answer: Model‑generated answer to inspect. + :param answer: Model-generated answer to inspect. :param question: Original question (``None`` for summarisation). - :param output_format: ``"tokens"`` for token‑level dicts, ``"spans"`` for character spans. + :param output_format: ``"tokens"`` for token-level dicts, ``"spans"`` for character spans. :returns: List of predictions in requested format. """ - formatted_prompt = PromptUtils.format_context(context, question, self.lang) - return self._predict(formatted_prompt, answer, output_format) + if output_format not in ("tokens", "spans"): + raise ValueError( + f"TransformerDetector doesn't support '{output_format}' format. " + "Use 'tokens' or 'spans'" + ) - def predict_prompt(self, prompt, answer, output_format="tokens") -> list: + groups = self._group_passages_into_chunks(context, question, answer) + + if len(groups) == 1: + prompt = PromptUtils.format_context(groups[0], question, self.lang) + return self._predict_single(prompt, answer, output_format) + + chunk_prompts = [PromptUtils.format_context(group, question, self.lang) for group in groups] + return self._predict_chunked(chunk_prompts, answer, output_format) + + def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list: """Predict hallucination tokens or spans from the provided prompt and answer. + Note: unlike :meth:`predict`, this method does **not** chunk the prompt + automatically. If the prompt + answer exceed ``max_length``, the prompt + will be truncated. Use :meth:`predict` with structured passages for + automatic chunking. + :param prompt: The prompt string. :param answer: The answer string. - :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + :param output_format: ``"tokens"`` or ``"spans"``. + :returns: List of predictions in requested format. """ - return self._predict(prompt, answer, output_format) + if output_format not in ("tokens", "spans"): + raise ValueError( + f"TransformerDetector doesn't support '{output_format}' format. " + "Use 'tokens' or 'spans'" + ) + # Warn if the input will be truncated. + total_tokens = self.tokenizer(prompt, answer, add_special_tokens=True, return_tensors="pt")[ + "input_ids" + ].shape[1] + if total_tokens > self.max_length: + logger.warning( + "predict_prompt: input (%d tokens) exceeds max_length (%d). " + "The prompt will be truncated. Use predict() with structured " + "passages for automatic chunking.", + total_tokens, + self.max_length, + ) + return self._predict_single(prompt, answer, output_format) - def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list: + def predict_prompt_batch( + self, prompts: list[str], answers: list[str], output_format: str = "tokens" + ) -> list: """Predict hallucination tokens or spans from the provided prompts and answers. :param prompts: List of prompt strings. :param answers: List of answer strings. - :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + :param output_format: ``"tokens"`` or ``"spans"``. + :returns: List of prediction lists, one per input pair. """ - return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)] + return [self.predict_prompt(p, a, output_format) for p, a in zip(prompts, answers)] diff --git a/lettucedetect/models/generation.py b/lettucedetect/models/generation.py index 939b3f0..f378d12 100644 --- a/lettucedetect/models/generation.py +++ b/lettucedetect/models/generation.py @@ -1,6 +1,8 @@ """Simple hallucination generation using RAGFactChecker.""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +from typing import Any from lettucedetect.ragfactchecker import RAGFactChecker @@ -14,9 +16,9 @@ class HallucinationGenerator: def __init__( self, method: str = "rag_fact_checker", - openai_api_key: str = None, + openai_api_key: str | None = None, model: str = "gpt-4o", - base_url: str = None, + base_url: str | None = None, temperature: float = 0.0, **kwargs, ): @@ -36,12 +38,12 @@ def __init__( def generate( self, - context: List[str], + context: list[str], question: str, - answer: str = None, - error_types: Optional[List[str]] = None, + answer: str | None = None, + error_types: list[str] | None = None, intensity: float = 0.3, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Generate hallucinated content. :param context: List of context documents @@ -65,12 +67,12 @@ def generate( def generate_batch( self, - contexts: List[List[str]], - questions: List[str], - answers: List[str] = None, - error_types: Optional[List[str]] = None, + contexts: list[list[str]], + questions: list[str], + answers: list[str] | None = None, + error_types: list[str] | None = None, intensity: float = 0.3, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Generate hallucinated content for multiple inputs. :param contexts: List of context lists @@ -96,12 +98,12 @@ def generate_batch( async def generate_batch_async( self, - contexts: List[List[str]], - questions: List[str], - answers: List[str] = None, - error_types: Optional[List[str]] = None, + contexts: list[list[str]], + questions: list[str], + answers: list[str] | None = None, + error_types: list[str] | None = None, intensity: float = 0.3, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Generate hallucinated content for multiple inputs. :param contexts: List of context lists diff --git a/lettucedetect/models/inference.py b/lettucedetect/models/inference.py index 035ed3c..28de429 100644 --- a/lettucedetect/models/inference.py +++ b/lettucedetect/models/inference.py @@ -13,10 +13,15 @@ class HallucinationDetector: """Facade class that delegates to a concrete detector chosen by *method*. :param method: ``"transformer"`` (token-classifier) or ``"llm"`` (OpenAI function-calling). - :param kwargs: Passed straight through to the chosen detector’s constructor. + :param kwargs: Passed straight through to the chosen detector's constructor. """ - def __init__(self, method: str = "transformer", **kwargs): + def __init__(self, method: str = "transformer", **kwargs) -> None: + """Initialize the detector. + + :param method: Detection method to use. + :param kwargs: Passed to the detector constructor. + """ self.detector = make_detector(method, **kwargs) def predict( @@ -47,6 +52,7 @@ def predict_prompt_batch( self, prompts: list[str], answers: list[str], output_format: str = "tokens" ) -> list: """Batch version of :py:meth:`predict_prompt`. + Length of *prompts* and *answers* must match. :param prompts: List of prompt strings. diff --git a/lettucedetect/models/trainer.py b/lettucedetect/models/trainer.py index 1ec97d7..c951908 100644 --- a/lettucedetect/models/trainer.py +++ b/lettucedetect/models/trainer.py @@ -12,6 +12,12 @@ class Trainer: + """Token classification trainer with epoch-based training and validation. + + Trains a model using AdamW, evaluates on a test set after each epoch, + and saves the best checkpoint based on hallucinated-class F1. + """ + def __init__( self, model: Module, diff --git a/lettucedetect/prompts/examples_hu.json b/lettucedetect/prompts/examples_hu.json index 509272f..7c15fb4 100644 --- a/lettucedetect/prompts/examples_hu.json +++ b/lettucedetect/prompts/examples_hu.json @@ -34,3 +34,4 @@ } ] + diff --git a/lettucedetect/prompts/summary_prompt_hu.txt b/lettucedetect/prompts/summary_prompt_hu.txt index 2a995dc..4d59f77 100644 --- a/lettucedetect/prompts/summary_prompt_hu.txt +++ b/lettucedetect/prompts/summary_prompt_hu.txt @@ -2,3 +2,4 @@ Foglalja ΓΆssze az alΓ‘bbi szΓΆveget: ${text} output: + diff --git a/lettucedetect/ragfactchecker.py b/lettucedetect/ragfactchecker.py index c62d6f6..3b1d014 100644 --- a/lettucedetect/ragfactchecker.py +++ b/lettucedetect/ragfactchecker.py @@ -1,8 +1,10 @@ """Simple, clean RAGFactChecker wrapper for lettuceDetect.""" +from __future__ import annotations + import logging import os -from typing import Any, Dict, List, Optional +from typing import Any class RAGFactChecker: @@ -17,9 +19,9 @@ class RAGFactChecker: def __init__( self, - openai_api_key: Optional[str] = None, + openai_api_key: str | None = None, model: str = "gpt-4o", - base_url: Optional[str] = None, + base_url: str | None = None, temperature: float = 0.0, ): """Initialize RAGFactChecker. @@ -43,7 +45,7 @@ def __init__( self.logger = logging.getLogger(__name__) self._setup_components() - def _setup_components(self): + def _setup_components(self) -> None: """Initialize RAGFactChecker components.""" try: from rag_fact_checker.data import Config @@ -75,7 +77,7 @@ def _setup_components(self): # ============ TRIPLET OPERATIONS ============ - def generate_triplets(self, text: str) -> List[List[str]]: + def generate_triplets(self, text: str) -> list[list[str]]: """Generate triplets from text. :param text: Input text @@ -88,8 +90,8 @@ def generate_triplets(self, text: str) -> List[List[str]]: return result.triplets def compare_triplets( - self, answer_triplets: List[List[str]], reference_triplets: List[List[str]] - ) -> Dict[str, Any]: + self, answer_triplets: list[list[str]], reference_triplets: list[list[str]] + ) -> dict[str, Any]: """Compare answer triplets against reference triplets. :param answer_triplets: Triplets from answer to check @@ -103,7 +105,7 @@ def compare_triplets( ) return {"fact_check_results": result.fact_check_prediction_binary, "raw_output": result} - def analyze_text_pair(self, answer_text: str, reference_text: str) -> Dict[str, Any]: + def analyze_text_pair(self, answer_text: str, reference_text: str) -> dict[str, Any]: """Generate and compare triplets for two texts. :param answer_text: Text to analyze @@ -125,8 +127,8 @@ def analyze_text_pair(self, answer_text: str, reference_text: str) -> Dict[str, # ============ HALLUCINATION DETECTION ============ def detect_hallucinations( - self, context: List[str], answer: str, question: Optional[str] = None - ) -> Dict[str, Any]: + self, context: list[str], answer: str, question: str | None = None + ) -> dict[str, Any]: """Detect hallucinations in answer given context. :param context: List of context documents @@ -158,8 +160,8 @@ def detect_hallucinations( # ============ HALLUCINATION GENERATION ============ def generate_hallucination_from_context( - self, context: List[str], question: str - ) -> Dict[str, Any]: + self, context: list[str], question: str + ) -> dict[str, Any]: """Generate hallucinated content from context and question. :param context: List of context documents @@ -181,9 +183,9 @@ def generate_hallucination_from_answer( self, correct_answer: str, question: str, - error_types: Optional[List[str]] = None, + error_types: list[str] | None = None, intensity: float = 0.3, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Generate hallucinated version of a correct answer. :param correct_answer: The correct answer to modify @@ -223,11 +225,11 @@ def generate_hallucination_from_answer( async def generate_hallucination_from_answer_batch_async( self, - correct_answers: List[str], - questions: List[str], - error_types: Optional[List[List[str]]] = None, - intensities: Optional[List[float]] = None, - ) -> List[Dict[str, Any]]: + correct_answers: list[str], + questions: list[str], + error_types: list[list[str]] | None = None, + intensities: list[float] | None = None, + ) -> list[dict[str, Any]]: """Generate hallucinated version of multiple correct answers.""" error_type_enums_list = None if error_types: @@ -253,9 +255,9 @@ async def generate_hallucination_from_answer_batch_async( async def generate_hallucination_from_context_batch_async( self, - contexts: List[List[str]], - questions: List[str], - ) -> List[Dict[str, Any]]: + contexts: list[list[str]], + questions: list[str], + ) -> list[dict[str, Any]]: """Generate hallucinated version of multiple correct answers.""" result = await self.reference_generator.generate_hlcntn_data_batch_async( contexts, questions @@ -264,11 +266,11 @@ async def generate_hallucination_from_context_batch_async( def generate_hallucination_from_answer_batch( self, - correct_answers: List[str], - questions: List[str], - error_types: Optional[List[List[str]]] = None, - intensities: Optional[List[float]] = None, - ) -> List[Dict[str, Any]]: + correct_answers: list[str], + questions: list[str], + error_types: list[list[str]] | None = None, + intensities: list[float] | None = None, + ) -> list[dict[str, Any]]: """Generate hallucinated version of multiple correct answers. :param correct_answers: List of correct answers to modify @@ -303,9 +305,9 @@ def generate_hallucination_from_answer_batch( def generate_hallucination_from_context_batch( self, - contexts: List[List[str]], - questions: List[str], - ) -> List[Dict[str, Any]]: + contexts: list[list[str]], + questions: list[str], + ) -> list[dict[str, Any]]: """Generate hallucinated version of multiple correct answers. :param contexts: List of context document lists @@ -317,7 +319,7 @@ def generate_hallucination_from_context_batch( result = self.reference_generator.generate_hlcntn_data_batch(contexts, questions) return result - def generate_triplets_batch(self, texts: List[str]) -> List[List[List[str]]]: + def generate_triplets_batch(self, texts: list[str]) -> list[list[list[str]]]: """Generate triplets for multiple texts. :param texts: List of input texts @@ -341,8 +343,8 @@ def generate_triplets_batch(self, texts: List[str]) -> List[List[List[str]]]: return results def detect_hallucinations_batch( - self, contexts: List[List[str]], answers: List[str], questions: Optional[List[str]] = None - ) -> List[Dict[str, Any]]: + self, contexts: list[list[str]], answers: list[str], questions: list[str] | None = None + ) -> list[dict[str, Any]]: """Detect hallucinations for multiple context-answer pairs. :param contexts: List of context document lists diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..e3bbf28 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,95 @@ +site_name: LettuceDetect +site_description: A lightweight hallucination detection framework for RAG applications +site_url: https://krlabsorg.github.io/LettuceDetect +repo_url: https://github.com/KRLabsOrg/LettuceDetect +repo_name: KRLabsOrg/LettuceDetect + +theme: + name: material + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: green + accent: light green + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: green + accent: light green + toggle: + icon: material/brightness-4 + name: Switch to light mode + features: + - navigation.sections + - navigation.expand + - navigation.top + - content.code.copy + - content.code.annotate + - toc.follow + icon: + repo: fontawesome/brands/github + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.tabbed: + alternate_style: true + - pymdownx.snippets + - attr_list + - md_in_html + - tables + - toc: + permalink: true + +plugins: + - search + - mkdocstrings: + handlers: + python: + options: + show_source: false + show_root_heading: true + show_root_full_path: false + heading_level: 2 + members_order: source + docstring_style: sphinx + merge_init_into_class: true + show_if_no_docstring: false + +nav: + - Home: index.md + - Getting Started: + - Installation: getting-started/installation.md + - Quick Start: getting-started/quickstart.md + - Models: getting-started/models.md + - Guide: + - Detection Methods: guide/detection-methods.md + - Evaluation: guide/evaluation.md + - Training: guide/training.md + - Multilingual: guide/multilingual.md + - Web API: guide/api.md + - Integrations: guide/integrations.md + - Code Hallucination Dataset: + - Overview: code-hallucination/index.md + - Pipeline Phases: code-hallucination/phases.md + - Configuration: code-hallucination/configuration.md + - Architecture Research: code-hallucination/architecture-research.md + - Research: + - TinyLettuce: TINYLETTUCE.md + - Multilingual (EuroBERT): EUROBERT.md + - API Reference: + - Inference: api/inference.md + - Detectors: api/detectors.md + - Datasets: api/datasets.md + - Training: api/training.md + - Benchmarks: benchmarks.md diff --git a/pyproject.toml b/pyproject.toml index 0d62884..0faab52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ [project.urls] Homepage = "https://github.com/krlabsorg/lettucedetect" +Documentation = "https://krlabsorg.github.io/LettuceDetect" [project.optional-dependencies] dev = [ @@ -42,11 +43,15 @@ api = [ "pydantic-settings>=2.8.0", "httpx>=0.28" ] +docs = [ + "mkdocs-material>=9.0", + "mkdocstrings[python]>=0.24", +] [tool.setuptools] packages = ["lettucedetect", "lettucedetect_api"] -[tool.pytest] +[tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*_pytest.py" @@ -77,8 +82,13 @@ ignore = [ "ANN003", # **kwargs annotation "ANN204", # missing return type for __init__ "PTH123", # path.open + "C901", # function complexity ] [tool.ruff.lint.per-file-ignores] +"tests/**" = ["S101", "ANN001", "ANN201", "ANN202", "D103", "F841"] +"lettucedetect/preprocess/**" = ["ANN001", "ANN201"] +"lettucedetect/datasets/analyze_lengths.py" = ["ANN201", "D103", "E741"] +"lettucedetect/models/evaluator.py" = ["ANN001", "ANN201", "D401"] "lettucedetect_api/test_server.py" = ["S101"] "lettucedetect_api/test_client.py" = ["S101"] \ No newline at end of file diff --git a/scripts/add_clean_swebench_samples.py b/scripts/add_clean_swebench_samples.py new file mode 100644 index 0000000..2c5ae32 --- /dev/null +++ b/scripts/add_clean_swebench_samples.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +"""Add clean (non-hallucinated) samples from remaining SWE-bench Lite instances. + +Fetches source files from GitHub and builds LettuceDetect-format samples +with empty labels (= supported/correct code). +""" + +import json +import re +import time + +import requests +from datasets import load_dataset + +INPUT_DATASET = ( + "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect_v2.json" +) +RAW_DATASET = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json" +OUTPUT_PATH = INPUT_DATASET # Overwrite with merged data + +GITHUB_RAW_BASE = "https://raw.githubusercontent.com" +MAX_FILE_CHARS = 12000 +MAX_NEW_SAMPLES = 200 # Cap to avoid too many GitHub requests + + +def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None: + url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}" + try: + r = requests.get(url, timeout=15) + if r.status_code == 200: + return r.text[:MAX_FILE_CHARS] + return None + except Exception as e: + print(f" Error fetching {filepath}: {e}") + return None + + +def extract_changed_files(patch: str) -> list[str]: + files = [] + for line in patch.split("\n"): + if line.startswith("diff --git"): + match = re.search(r"b/(.+)$", line) + if match: + files.append(match.group(1)) + return files + + +def extract_code_from_patch(patch: str) -> str: + added_lines = [] + in_hunk = False + for line in patch.split("\n"): + if line.startswith("@@"): + in_hunk = True + continue + if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"): + continue + if in_hunk: + if line.startswith("+"): + added_lines.append(line[1:]) + elif line.startswith("-"): + continue + else: + if line.startswith(" "): + added_lines.append(line[1:]) + else: + added_lines.append(line) + return "\n".join(added_lines) + + +def build_prompt(source_files: dict[str, str], user_query: str) -> str: + parts = [] + for filepath, content in source_files.items(): + parts.append(f"File: {filepath}\n```python\n{content}\n```") + parts.append(f"User request: {user_query}") + return "\n\n".join(parts) + + +def main(): + # Load existing data + with open(INPUT_DATASET) as f: + existing_data = json.load(f) + print(f"Existing dataset: {len(existing_data)} samples") + + # Get already-used instance IDs + with open(RAW_DATASET) as f: + raw_data = json.load(f) + used_ids = {item["instance_id"] for item in raw_data} + print(f"Already used instance IDs: {len(used_ids)}") + + # Load SWE-bench Lite + print("Loading SWE-bench Lite...") + ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + + # Filter to unused instances + remaining = [s for s in ds if s["instance_id"] not in used_ids] + print(f"Remaining SWE-bench instances: {len(remaining)}") + + # Cap the number of new samples + remaining = remaining[:MAX_NEW_SAMPLES] + print(f"Processing {len(remaining)} instances...") + + new_samples = [] + for i, sample in enumerate(remaining): + instance_id = sample["instance_id"] + repo = sample["repo"] + commit = sample["base_commit"] + patch = sample["patch"] + + print(f"\n[{i + 1}/{len(remaining)}] {instance_id}") + + # Extract changed files + changed_files = extract_changed_files(patch) + if not changed_files: + print(" SKIP: No changed files found") + continue + + # Fetch source files from GitHub + source_files = {} + for filepath in changed_files[:3]: # Limit to 3 files per sample + content = fetch_file_from_github(repo, commit, filepath) + if content: + source_files[filepath] = content + print(f" Fetched {filepath}: {len(content)} chars") + else: + print(f" Failed: {filepath}") + time.sleep(0.3) + + if not source_files: + print(" SKIP: No source files fetched") + continue + + # Extract code from gold patch + code = extract_code_from_patch(patch) + if not code.strip(): + print(" SKIP: Empty code") + continue + + # Build prompt with source files + user query + user_query = sample["problem_statement"][:500] + prompt = build_prompt(source_files, user_query) + + new_samples.append( + { + "prompt": prompt, + "answer": code, + "labels": [], + "split": "train", + "task_type": "code_generation", + "dataset": "code_hallucination_swebench", + "language": "en", + } + ) + + # Progress check + if (i + 1) % 20 == 0: + print(f"\n Progress: {len(new_samples)} clean samples so far\n") + + # Merge with existing data + merged = existing_data + new_samples + print(f"\nNew clean samples: {len(new_samples)}") + print(f"Total merged: {len(merged)}") + + # Save + with open(OUTPUT_PATH, "w") as f: + json.dump(merged, f, indent=2) + + # Stats + n_clean = sum(1 for s in merged if not s["labels"]) + n_hall = sum(1 for s in merged if s["labels"]) + print("\nFinal dataset:") + print(f" Clean: {n_clean} ({n_clean / len(merged) * 100:.1f}%)") + print(f" Hallucinated: {n_hall} ({n_hall / len(merged) * 100:.1f}%)") + print(f" Total: {len(merged)}") + print(f"\nSaved to {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/scripts/code_hallucination/__init__.py b/scripts/code_hallucination/__init__.py new file mode 100644 index 0000000..706e62a --- /dev/null +++ b/scripts/code_hallucination/__init__.py @@ -0,0 +1 @@ +"""Code hallucination dataset generation pipeline from SWE-bench.""" diff --git a/scripts/code_hallucination/config.py b/scripts/code_hallucination/config.py new file mode 100644 index 0000000..85591b9 --- /dev/null +++ b/scripts/code_hallucination/config.py @@ -0,0 +1,56 @@ +"""Configuration for the code hallucination dataset pipeline.""" + +import os +from pathlib import Path + +# === Paths === +PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent +DATA_DIR = PROJECT_ROOT / "data" / "code_hallucination" +REPOS_DIR = DATA_DIR / "repos" +SOURCE_CACHE_DIR = DATA_DIR / "source_cache" + +# Intermediate outputs +INSTANCES_PATH = DATA_DIR / "swebench_instances.json" +QUERIES_PATH = DATA_DIR / "queries.jsonl" +DOCS_PATH = DATA_DIR / "documentation.jsonl" +FORMATS_PATH = DATA_DIR / "formats.jsonl" +HALLUCINATED_PATH = DATA_DIR / "hallucinated_samples.jsonl" + +# Final outputs +DATASET_PATH = DATA_DIR / "code_hallucination_data.json" +METADATA_PATH = DATA_DIR / "code_hallucination_metadata.json" +VALIDATION_REPORT_PATH = DATA_DIR / "validation_report.txt" + +# === LLM API Config === +# Override via env vars or CLI args +API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1") +API_KEY = os.environ.get("OPENAI_API_KEY", "") +MODEL = os.environ.get("MODEL", "moonshotai/kimi-k2-instruct-0905") + +# Context7 +CONTEXT7_BASE = "https://context7.com/api/v2" +CONTEXT7_API_KEY = os.environ.get("CONTEXT7_API_KEY", "") +DOCS_RATIO = 0.5 # Only fetch docs for 50% of instances + +# === Dataset Config === +HALLUCINATION_RATIO = 0.4 # 40% hallucinated, 60% clean +MAX_FILE_CHARS = 12000 # Cap individual source file size +MAX_CONTEXT7_CHARS = 4000 # Documentation fetch limit + +# === LLM Config === +RETRY_DELAY = 2 +MAX_RETRIES = 3 +LLM_TEMPERATURE = 0.7 +HALLUCINATION_TEMPERATURE = 0.8 +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) # >1 for local vLLM + +# Hallucination types (round-robin assignment) +HALLUCINATION_TYPES = ["structural", "behavioral", "semantic"] + +# Answer format types +FORMAT_TYPES = ["complete_function", "edit_style", "fragment"] +FORMAT_WEIGHTS = [0.4, 0.3, 0.3] # Target distribution + +# SWE-bench datasets +SWEBENCH_FULL = "princeton-nlp/SWE-bench" +SWEBENCH_LITE = "princeton-nlp/SWE-bench_Lite" diff --git a/scripts/code_hallucination/context7_docs.py b/scripts/code_hallucination/context7_docs.py new file mode 100644 index 0000000..4c1df87 --- /dev/null +++ b/scripts/code_hallucination/context7_docs.py @@ -0,0 +1,225 @@ +"""Phase 4: Fetch library documentation from Context7 API. + +Only fetches docs for ~50% of instances (configurable via DOCS_RATIO). +The other 50% get empty docs, creating training variety β€” models learn +to handle both with-docs and without-docs scenarios. +""" + +import json +import random +import re + +import requests + +from .config import CONTEXT7_API_KEY, CONTEXT7_BASE, DOCS_PATH, DOCS_RATIO, MAX_CONTEXT7_CHARS + +# Map repo paths / import names to likely library names for Context7 +PATH_TO_LIB = { + "django": "django", + "astropy": "astropy", + "sympy": "sympy", + "sklearn": "scikit-learn", + "matplotlib": "matplotlib", + "requests": "requests", + "flask": "flask", + "pytest": "pytest", + "sphinx": "sphinx", + "xarray": "xarray", + "seaborn": "seaborn", + "pylint": "pylint", + "pandas": "pandas", + "numpy": "numpy", + "scipy": "scipy", + "transformers": "transformers", + "jax": "jax", + "torch": "pytorch", + "tensorflow": "tensorflow", + "sqlalchemy": "sqlalchemy", + "celery": "celery", + "pydantic": "pydantic", + "fastapi": "fastapi", + "httpx": "httpx", +} + + +def extract_imports_from_patch(patch: str) -> list[str]: + """Extract Python import statements from added lines in a patch.""" + imports = set() + for line in patch.split("\n"): + if line.startswith("+") and not line.startswith("+++"): + clean = line[1:].strip() + if clean.startswith("import ") or clean.startswith("from "): + match = re.match(r"(?:from|import)\s+([\w.]+)", clean) + if match: + module = match.group(1).split(".")[0] + if module and not module.startswith("_"): + imports.add(module) + return list(imports) + + +def extract_libraries_from_files(changed_files: list[str]) -> list[str]: + """Infer libraries from file paths.""" + libs = set() + for f in changed_files: + for key, lib in PATH_TO_LIB.items(): + if key in f: + libs.add(lib) + return list(libs) + + +def fetch_context7_docs( + library_name: str, query: str, max_chars: int = MAX_CONTEXT7_CHARS +) -> str | None: + """Fetch documentation from Context7 for a library + query.""" + try: + headers = {} + if CONTEXT7_API_KEY: + headers["Authorization"] = f"Bearer {CONTEXT7_API_KEY}" + + r = requests.get( + f"{CONTEXT7_BASE}/libs/search", + params={"query": query, "libraryName": library_name}, + headers=headers, + timeout=10, + ) + if r.status_code != 200: + return None + results = r.json().get("results", []) + if not results: + return None + + lib_id = results[0]["id"] + + r2 = requests.get( + f"{CONTEXT7_BASE}/context", + params={"libraryId": lib_id, "query": query, "type": "txt"}, + headers=headers, + timeout=10, + ) + if r2.status_code != 200: + return None + + doc_text = r2.text[:max_chars] + return doc_text if doc_text.strip() else None + except Exception as e: + print(f" Context7 error for {library_name}: {e}") + return None + + +def get_documentation_for_instance( + changed_files: list[str], patch: str, problem_statement: str +) -> dict[str, str]: + """Fetch documentation for libraries referenced in an instance.""" + imported_libs = extract_imports_from_patch(patch) + path_libs = extract_libraries_from_files(changed_files) + all_libs = list(set(imported_libs + path_libs)) + + short_query = problem_statement[:200].replace("\n", " ").strip() + + docs = {} + for lib in all_libs[:3]: + doc = fetch_context7_docs(lib, short_query) + if doc: + docs[lib] = doc + + return docs + + +def select_docs_instances( + instances: list[dict], ratio: float = DOCS_RATIO, seed: int = 42 +) -> set[str]: + """Select which instances should get documentation fetched. + + Returns set of instance_ids that should have docs. + """ + rng = random.Random(seed) + ids = [inst["instance_id"] for inst in instances] + n_with_docs = int(len(ids) * ratio) + rng.shuffle(ids) + return set(ids[:n_with_docs]) + + +def load_existing_docs(path=DOCS_PATH) -> dict[str, dict]: + """Load already-fetched docs for resumability.""" + existing = {} + if path.exists(): + with open(path) as f: + for line in f: + try: + entry = json.loads(line) + existing[entry["instance_id"]] = entry["docs"] + except (json.JSONDecodeError, KeyError): + continue + return existing + + +def run(instances: list[dict]): + """Run Phase 4: Fetch documentation for selected instances (~50%).""" + print("=" * 60) + print("Phase 4: Context7 Documentation") + print("=" * 60) + + DOCS_PATH.parent.mkdir(parents=True, exist_ok=True) + + # Select which instances get docs + docs_ids = select_docs_instances(instances) + print( + f"Selected {len(docs_ids)}/{len(instances)} instances for documentation ({DOCS_RATIO:.0%})" + ) + + existing = load_existing_docs() + print(f"Already fetched: {len(existing)} instances") + + to_process = [inst for inst in instances if inst["instance_id"] not in existing] + print(f"Remaining: {len(to_process)} instances to process") + + processed = 0 + with_docs = 0 + skipped_by_ratio = 0 + + with open(DOCS_PATH, "a") as f: + for i, inst in enumerate(to_process): + instance_id = inst["instance_id"] + + # Skip docs for instances not selected (write empty docs) + if instance_id not in docs_ids: + entry = {"instance_id": instance_id, "docs": {}} + f.write(json.dumps(entry) + "\n") + f.flush() + processed += 1 + skipped_by_ratio += 1 + continue + + changed_files = inst.get("changed_files", []) + if not changed_files: + from .source_fetcher import extract_changed_files + + changed_files = extract_changed_files(inst["patch"]) + + docs = get_documentation_for_instance( + changed_files, inst["patch"], inst["problem_statement"] + ) + + entry = {"instance_id": instance_id, "docs": docs} + f.write(json.dumps(entry) + "\n") + f.flush() + + processed += 1 + if docs: + with_docs += 1 + + if processed % 100 == 0: + print( + f" Progress: {processed}/{len(to_process)} ({with_docs} with docs, {skipped_by_ratio} skipped)" + ) + + print( + f"\nDone: {processed} processed, {with_docs} with docs, {skipped_by_ratio} skipped (no-docs by design)" + ) + + +if __name__ == "__main__": + from .swebench_loader import load_instances + + instances = load_instances() + run(instances) diff --git a/scripts/code_hallucination/format_builder.py b/scripts/code_hallucination/format_builder.py new file mode 100644 index 0000000..78531ab --- /dev/null +++ b/scripts/code_hallucination/format_builder.py @@ -0,0 +1,115 @@ +"""Phase 5: Assign answer format to each instance.""" + +import json +import random + +from .config import FORMAT_TYPES, FORMAT_WEIGHTS, FORMATS_PATH, SOURCE_CACHE_DIR + + +def assign_format(source_data: dict) -> tuple[str, str]: + """Assign a format type and build the answer for an instance. + + Returns (format_type, answer_text). + Falls back if preferred format isn't available. + """ + has_functions = bool(source_data.get("modified_functions")) + has_edit = bool(source_data.get("edit_style")) + has_fragment = bool(source_data.get("patch_code", "").strip()) + + # Build available formats + available = [] + if has_functions: + available.append("complete_function") + if has_edit: + available.append("edit_style") + if has_fragment: + available.append("fragment") + + if not available: + return None, None + + # Weighted random choice from available formats + weights = [] + for fmt in available: + idx = FORMAT_TYPES.index(fmt) + weights.append(FORMAT_WEIGHTS[idx]) + + # Normalize weights + total = sum(weights) + weights = [w / total for w in weights] + + chosen = random.choices(available, weights=weights, k=1)[0] + + # Build answer text + if chosen == "complete_function": + funcs = source_data["modified_functions"] + # Take the first (or longest) modified function + func = max(funcs, key=lambda f: len(f.get("patched", ""))) + answer = func["patched"] + elif chosen == "edit_style": + answer = source_data["edit_style"] + else: # fragment + answer = source_data["patch_code"] + + return chosen, answer + + +def run(instances: list[dict], source_cache_dir=SOURCE_CACHE_DIR): + """Run Phase 5: Assign formats and build answers. + + Returns list of dicts with instance_id, format_type, answer. + """ + print("=" * 60) + print("Phase 5: Answer Format Building") + print("=" * 60) + + FORMATS_PATH.parent.mkdir(parents=True, exist_ok=True) + + results = [] + format_counts = {fmt: 0 for fmt in FORMAT_TYPES} + skipped = 0 + + for inst in instances: + instance_id = inst["instance_id"] + + # Load source data from cache + cache_path = source_cache_dir / f"{instance_id}.json" + if not cache_path.exists(): + skipped += 1 + continue + + with open(cache_path) as f: + source_data = json.load(f) + + fmt, answer = assign_format(source_data) + if fmt is None: + skipped += 1 + continue + + results.append( + { + "instance_id": instance_id, + "format_type": fmt, + "answer": answer, + } + ) + format_counts[fmt] += 1 + + # Save + with open(FORMATS_PATH, "w") as f: + for entry in results: + f.write(json.dumps(entry) + "\n") + + print(f"\nAssigned formats for {len(results)} instances (skipped {skipped})") + for fmt, count in format_counts.items(): + pct = count * 100 // max(len(results), 1) + print(f" {fmt}: {count} ({pct}%)") + + return results + + +if __name__ == "__main__": + from .swebench_loader import load_instances + + instances = load_instances() + run(instances) diff --git a/scripts/code_hallucination/hallucination_injector.py b/scripts/code_hallucination/hallucination_injector.py new file mode 100644 index 0000000..19dfeed --- /dev/null +++ b/scripts/code_hallucination/hallucination_injector.py @@ -0,0 +1,424 @@ +"""Phase 6: Inject hallucinations using LLM with JSON span annotations. + +Supports both sequential (remote API) and async batch (local vLLM) modes. +Set BATCH_SIZE>1 env var for parallel requests to local vLLM. +""" + +import asyncio +import json +import re +import textwrap +import time + +from openai import AsyncOpenAI, OpenAI + +from .config import ( + API_BASE_URL, + API_KEY, + BATCH_SIZE, + HALLUCINATED_PATH, + HALLUCINATION_TEMPERATURE, + HALLUCINATION_TYPES, + MAX_RETRIES, + MODEL, + RETRY_DELAY, +) + +INJECTION_SYSTEM_PROMPT = textwrap.dedent("""\ + You are a code hallucination injector for building a hallucination detection dataset. + + Given correct code and context, create a hallucinated version with a specific type of error. + + Hallucination types: + - STRUCTURAL: Change a function call, import, or parameter to something that + doesn't exist or is wrong. Code should still parse but reference non-existent + APIs, wrong methods, or invented parameters. + - BEHAVIORAL: Use correct APIs but with wrong values or logic. Wrong defaults, + off-by-one errors, swapped conditions, wrong argument values. + - SEMANTIC: Code that looks like it addresses the user's request but does + something subtly different or opposite. The code parses, uses real APIs, + but fails to do what was asked. + + Rules: + - Make changes PLAUSIBLE - something an LLM would realistically generate + - Changes must be SUBTLE, not obviously broken + - The hallucinated code must still be syntactically valid + - Make 1-3 changes, not more + + Respond in this exact JSON format (no markdown, no code blocks): + { + "hallucinated_code": "the full modified code with hallucinations injected", + "changes": [ + { + "original": "exact original code that was changed", + "hallucinated": "what you changed it to", + "explanation": "why this is a hallucination" + } + ] + } + + IMPORTANT: + - "original" must be an exact substring of the correct code + - "hallucinated" must be an exact substring of your hallucinated_code + - Return ONLY valid JSON, nothing else +""") + + +def inject_hallucination( + client: OpenAI, + model: str, + clean_answer: str, + hall_type: str, + user_query: str = "", + context: str = "", +) -> dict | None: + """Inject a hallucination and get back structured JSON with spans. + + Returns dict with 'hallucinated_code' and 'changes', or None if failed. + """ + user_msg = f"""Hallucination type to inject: {hall_type.upper()} + +User's original request: {user_query[:500]} + +Context (source code): +{context[:2000]} + +Correct code to modify: +{clean_answer} + +Generate a hallucinated version with {hall_type} error(s). Return JSON only.""" + + for attempt in range(MAX_RETRIES): + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": INJECTION_SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ], + temperature=HALLUCINATION_TEMPERATURE, + max_tokens=4000, + ) + raw = response.choices[0].message.content.strip() + + # Parse JSON from response + json_match = re.search(r"\{[\s\S]*\}", raw) + if not json_match: + if attempt < MAX_RETRIES - 1: + continue + return None + + result = json.loads(json_match.group()) + + if "hallucinated_code" not in result or "changes" not in result: + if attempt < MAX_RETRIES - 1: + continue + return None + + # Verify the hallucinated code is actually different + if result["hallucinated_code"].strip() == clean_answer.strip(): + if attempt < MAX_RETRIES - 1: + continue + return None + + return result + + except (json.JSONDecodeError, Exception) as e: + if attempt < MAX_RETRIES - 1: + wait = RETRY_DELAY * (attempt + 1) + print(f" Injection error (attempt {attempt + 1}): {e}. Retrying in {wait}s...") + time.sleep(wait) + else: + return None + + +def compute_span_offsets(hallucinated_code: str, hallucinated_span: str) -> list[dict]: + """Find character offsets of a hallucinated span within the answer code.""" + spans = [] + idx = hallucinated_code.find(hallucinated_span) + if idx != -1: + spans.append({"start": idx, "end": idx + len(hallucinated_span)}) + return spans + + +def build_labels_from_changes( + hallucinated_code: str, changes: list[dict], hall_type: str +) -> list[dict]: + """Build span labels by finding each hallucinated string in the code. + + Only includes spans where the hallucinated text is actually found in the answer. + """ + labels = [] + for change in changes: + h_span = change.get("hallucinated", "") + if not h_span or len(h_span) < 3: + continue + if h_span not in hallucinated_code: + continue + + offsets = compute_span_offsets(hallucinated_code, h_span) + for offset in offsets[:1]: # First occurrence only + labels.append( + { + "start": offset["start"], + "end": offset["end"], + "label": hall_type, + } + ) + + return labels + + +def load_existing_hallucinations(path=HALLUCINATED_PATH) -> dict[str, dict]: + """Load already-processed hallucinations for resumability.""" + existing = {} + if path.exists(): + with open(path) as f: + for line in f: + try: + entry = json.loads(line) + existing[entry["instance_id"]] = entry + except (json.JSONDecodeError, KeyError): + continue + return existing + + +async def _inject_one_async( + aclient: AsyncOpenAI, + model: str, + clean_answer: str, + hall_type: str, + user_query: str, + context: str, +) -> dict | None: + """Async version of inject_hallucination for batch processing.""" + user_msg = f"""Hallucination type to inject: {hall_type.upper()} + +User's original request: {user_query[:500]} + +Context (source code): +{context[:2000]} + +Correct code to modify: +{clean_answer} + +Generate a hallucinated version with {hall_type} error(s). Return JSON only.""" + + for attempt in range(MAX_RETRIES): + try: + response = await aclient.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": INJECTION_SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ], + temperature=HALLUCINATION_TEMPERATURE, + max_tokens=4000, + ) + raw = response.choices[0].message.content.strip() + json_match = re.search(r"\{[\s\S]*\}", raw) + if not json_match: + continue + result = json.loads(json_match.group()) + if "hallucinated_code" not in result or "changes" not in result: + continue + if result["hallucinated_code"].strip() == clean_answer.strip(): + continue + return result + except Exception: + if attempt < MAX_RETRIES - 1: + await asyncio.sleep(RETRY_DELAY * (attempt + 1)) + else: + return None + return None + + +def _process_result(result, instance_id, hall_type, fmt_data, model): + """Process a single injection result into a JSONL entry.""" + if result is None: + return None + hallucinated_code = result["hallucinated_code"] + changes = result.get("changes", []) + labels = build_labels_from_changes(hallucinated_code, changes, hall_type) + if not labels: + return None + return { + "instance_id": instance_id, + "hallucinated_answer": hallucinated_code, + "labels": labels, + "hallucination_type": hall_type, + "injector_model": model, + "format_type": fmt_data.get("format_type", "fragment"), + } + + +def run( + instances_to_inject: list[dict], + formats: dict[str, dict], + queries: dict[str, str], + api_key: str = API_KEY, + base_url: str = API_BASE_URL, + model: str = MODEL, +): + """Run Phase 6: Inject hallucinations into selected instances. + + Uses async batch processing when BATCH_SIZE > 1 (for local vLLM). + Falls back to sequential processing for remote APIs (BATCH_SIZE=1). + """ + print("=" * 60) + print("Phase 6: Hallucination Injection") + print("=" * 60) + + HALLUCINATED_PATH.parent.mkdir(parents=True, exist_ok=True) + + print(f"Using {base_url} with model {model}") + print(f"Batch size: {BATCH_SIZE}") + + existing = load_existing_hallucinations() + print(f"Already processed: {len(existing)}") + + to_process = [ + inst + for inst in instances_to_inject + if inst["instance_id"] not in existing and inst["instance_id"] in formats + ] + print(f"Remaining: {len(to_process)} instances to inject") + + if BATCH_SIZE > 1: + results = _run_batched(to_process, formats, queries, api_key, base_url, model) + else: + results = _run_sequential(to_process, formats, queries, api_key, base_url, model) + + # Stats + type_counts = {} + for r in results: + t = r["hallucination_type"] + type_counts[t] = type_counts.get(t, 0) + 1 + print("By type:", type_counts) + + if results: + avg_spans = sum(len(r["labels"]) for r in results) / len(results) + span_sizes = [lab["end"] - lab["start"] for r in results for lab in r["labels"]] + print(f"Avg spans per sample: {avg_spans:.1f}") + print( + f"Span sizes: min={min(span_sizes)}, max={max(span_sizes)}, avg={sum(span_sizes) // len(span_sizes)}" + ) + + return results + + +def _run_sequential(to_process, formats, queries, api_key, base_url, model): + """Sequential processing for remote APIs (rate-limited).""" + client = OpenAI(api_key=api_key, base_url=base_url) + processed = 0 + failed = 0 + no_spans = 0 + results = [] + + with open(HALLUCINATED_PATH, "a") as f: + for i, inst in enumerate(to_process): + instance_id = inst["instance_id"] + fmt_data = formats.get(instance_id, {}) + clean_answer = fmt_data.get("answer", "") + if not clean_answer: + failed += 1 + continue + + hall_type = HALLUCINATION_TYPES[i % len(HALLUCINATION_TYPES)] + query = queries.get(instance_id, "") + context = inst.get("problem_statement", "")[:2000] + + result = inject_hallucination(client, model, clean_answer, hall_type, query, context) + entry = _process_result(result, instance_id, hall_type, fmt_data, model) + + if entry is None: + if result is not None: + no_spans += 1 + failed += 1 + continue + + f.write(json.dumps(entry) + "\n") + f.flush() + results.append(entry) + processed += 1 + + if processed % 50 == 0: + print(f" Progress: {processed}/{len(to_process)} (failed: {failed})") + + print(f"\nDone: {processed} injected, {failed} failed ({no_spans} had no matchable spans)") + return results + + +def _run_batched(to_process, formats, queries, api_key, base_url, model): + """Async batch processing for local vLLM (no rate limiting needed).""" + aclient = AsyncOpenAI(api_key=api_key, base_url=base_url) + processed = 0 + failed = 0 + no_spans = 0 + results = [] + + async def process_batches(): + nonlocal processed, failed, no_spans + + with open(HALLUCINATED_PATH, "a") as f: + for batch_start in range(0, len(to_process), BATCH_SIZE): + batch = to_process[batch_start : batch_start + BATCH_SIZE] + + # Build async tasks for the batch + tasks = [] + batch_meta = [] + for i, inst in enumerate(batch): + global_idx = batch_start + i + instance_id = inst["instance_id"] + fmt_data = formats.get(instance_id, {}) + clean_answer = fmt_data.get("answer", "") + if not clean_answer: + failed += 1 + continue + + hall_type = HALLUCINATION_TYPES[global_idx % len(HALLUCINATION_TYPES)] + query = queries.get(instance_id, "") + context = inst.get("problem_statement", "")[:2000] + + tasks.append( + _inject_one_async(aclient, model, clean_answer, hall_type, query, context) + ) + batch_meta.append((instance_id, hall_type, fmt_data)) + + if not tasks: + continue + + # Run batch concurrently + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results and write immediately + for result, (instance_id, hall_type, fmt_data) in zip(batch_results, batch_meta): + if isinstance(result, Exception): + failed += 1 + continue + + entry = _process_result(result, instance_id, hall_type, fmt_data, model) + if entry is None: + if result is not None: + no_spans += 1 + failed += 1 + continue + + f.write(json.dumps(entry) + "\n") + f.flush() + results.append(entry) + processed += 1 + + if processed % 50 == 0 or batch_start + BATCH_SIZE >= len(to_process): + total = processed + failed + print( + f" Progress: {total}/{len(to_process)} ({processed} ok, {failed} failed)" + ) + + asyncio.run(process_batches()) + print(f"\nDone: {processed} injected, {failed} failed ({no_spans} had no matchable spans)") + return results + + +if __name__ == "__main__": + print("Run via pipeline.py") diff --git a/scripts/code_hallucination/pipeline.py b/scripts/code_hallucination/pipeline.py new file mode 100644 index 0000000..d4e79a9 --- /dev/null +++ b/scripts/code_hallucination/pipeline.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +"""Orchestrator for the code hallucination dataset pipeline. + +Usage: + # Run all phases + python -m scripts.code_hallucination.pipeline --all + + # Run specific phase + python -m scripts.code_hallucination.pipeline --phase 1 + + # Test with a few examples + python -m scripts.code_hallucination.pipeline --test 5 + + # Override LLM settings via env vars: + OPENAI_API_KEY=xxx API_BASE_URL=https://api.groq.com/openai/v1 MODEL=moonshotai/kimi-k2-instruct-0905 \ + python -m scripts.code_hallucination.pipeline --test 10 +""" + +import argparse +import json +import random + +from .config import ( + API_BASE_URL, + API_KEY, + DATA_DIR, + DOCS_PATH, + FORMATS_PATH, + HALLUCINATED_PATH, + MODEL, + QUERIES_PATH, +) + + +def load_jsonl_dict(path, key="instance_id", value_key=None) -> dict: + """Load a JSONL file into a dict keyed by instance_id.""" + result = {} + if not path.exists(): + return result + with open(path) as f: + for line in f: + try: + entry = json.loads(line) + if value_key: + result[entry[key]] = entry[value_key] + else: + result[entry[key]] = entry + except (json.JSONDecodeError, KeyError): + continue + return result + + +def run_test(n: int = 5, api_key: str = API_KEY, base_url: str = API_BASE_URL, model: str = MODEL): + """Run a quick test with n instances from the test split.""" + print("=" * 60) + print(f"TEST MODE: Running pipeline with {n} instances") + print(f"LLM: {model} @ {base_url}") + print("=" * 60) + + from .swebench_loader import load_all_splits + + # Load and filter to test split + all_instances = load_all_splits() + test_instances = [i for i in all_instances if i["split"] == "test"] + + # Take n random instances + rng = random.Random(42) + selected = rng.sample(test_instances, min(n, len(test_instances))) + print(f"Selected {len(selected)} test instances") + + # Save temporary instances + DATA_DIR.mkdir(parents=True, exist_ok=True) + test_path = DATA_DIR / "test_instances.json" + with open(test_path, "w") as f: + json.dump(selected, f, indent=2) + + # Phase 2: Fetch sources (use GitHub API for test mode β€” no cloning needed) + from .source_fetcher import run as run_fetch + + sources = run_fetch(selected, use_github_api=True) + + if not sources: + print("No sources fetched, aborting test") + return + + # Phase 3: Rewrite queries + from .query_rewriter import run as run_queries + + run_queries(selected, api_key=api_key, base_url=base_url, model=model) + + # Phase 4: Context7 docs + from .context7_docs import run as run_docs + + run_docs(selected) + + # Phase 5: Assign formats + from .format_builder import run as run_formats + + run_formats(selected) + + # Phase 8: Select targets (before phase 6) + from .splitter import select_hallucination_targets + + targets = select_hallucination_targets(selected) + + # Phase 6: Inject hallucinations + from .hallucination_injector import run as run_inject + + formats = load_jsonl_dict(FORMATS_PATH) + queries = load_jsonl_dict(QUERIES_PATH, value_key="query") + to_inject = [i for i in selected if i["instance_id"] in targets] + run_inject(to_inject, formats, queries, api_key=api_key, base_url=base_url, model=model) + + # Phase 7: Assemble + from .sample_assembler import run as run_assemble + + docs = load_jsonl_dict(DOCS_PATH, value_key="docs") + hallucinations = load_jsonl_dict(HALLUCINATED_PATH) + samples, metadata = run_assemble(selected, queries, docs, formats, hallucinations, targets) + + # Phase 9: Validate + from .validator import run as run_validate + + run_validate(samples, metadata) + + print("\n" + "=" * 60) + print("TEST COMPLETE") + print("=" * 60) + print(f"Generated {len(samples)} samples from {n} test instances") + + # Show a sample + if samples: + print("\n--- Sample Example ---") + s = samples[0] + print(f"Prompt length: {len(s['prompt'])} chars") + print(f"Answer length: {len(s['answer'])} chars") + print(f"Labels: {len(s['labels'])}") + print(f"Split: {s['split']}") + print(f"Answer preview: {s['answer'][:200]}...") + + +def main(): + parser = argparse.ArgumentParser(description="Code hallucination dataset pipeline") + parser.add_argument( + "--phase", nargs="+", type=int, choices=range(1, 10), help="Run specific phase(s)" + ) + parser.add_argument("--all", action="store_true", help="Run all phases") + parser.add_argument("--test", type=int, metavar="N", help="Test with N instances") + parser.add_argument("--api-key", type=str, default=API_KEY, help="LLM API key") + parser.add_argument("--base-url", type=str, default=API_BASE_URL, help="LLM API base URL") + parser.add_argument("--model", type=str, default=MODEL, help="LLM model name") + args = parser.parse_args() + + if args.test: + run_test(args.test, api_key=args.api_key, base_url=args.base_url, model=args.model) + return + + if not args.phase and not args.all: + parser.print_help() + return + + phases = list(range(1, 10)) if args.all else args.phase + + for phase in sorted(phases): + print(f"\n{'#' * 60}") + print(f"# Running Phase {phase}") + print(f"{'#' * 60}\n") + + if phase == 1: + from .swebench_loader import run + + run() + elif phase == 2: + from .source_fetcher import run + from .swebench_loader import load_instances + + run(load_instances()) + elif phase == 3: + from .query_rewriter import run + from .swebench_loader import load_instances + + run(load_instances(), api_key=args.api_key, base_url=args.base_url, model=args.model) + elif phase == 4: + from .context7_docs import run + from .swebench_loader import load_instances + + run(load_instances()) + elif phase == 5: + from .format_builder import run + from .swebench_loader import load_instances + + run(load_instances()) + elif phase == 6: + from .hallucination_injector import run + from .splitter import select_hallucination_targets + from .swebench_loader import load_instances + + instances = load_instances() + formats = load_jsonl_dict(FORMATS_PATH) + queries = load_jsonl_dict(QUERIES_PATH, value_key="query") + targets = select_hallucination_targets(instances) + to_inject = [i for i in instances if i["instance_id"] in targets] + run( + to_inject, + formats, + queries, + api_key=args.api_key, + base_url=args.base_url, + model=args.model, + ) + elif phase == 7: + from .sample_assembler import run + from .splitter import select_hallucination_targets + from .swebench_loader import load_instances + + instances = load_instances() + queries = load_jsonl_dict(QUERIES_PATH, value_key="query") + docs = load_jsonl_dict(DOCS_PATH, value_key="docs") + formats = load_jsonl_dict(FORMATS_PATH) + hallucinations = load_jsonl_dict(HALLUCINATED_PATH) + targets = select_hallucination_targets(instances) + run(instances, queries, docs, formats, hallucinations, targets) + elif phase == 8: + from .splitter import run + from .swebench_loader import load_instances + + run(load_instances()) + elif phase == 9: + from .validator import run + + run() + + print("\nPipeline complete!") + + +if __name__ == "__main__": + main() diff --git a/scripts/code_hallucination/query_rewriter.py b/scripts/code_hallucination/query_rewriter.py new file mode 100644 index 0000000..966c70f --- /dev/null +++ b/scripts/code_hallucination/query_rewriter.py @@ -0,0 +1,141 @@ +"""Phase 3: Rewrite problem statements into natural user queries via LLM.""" + +import json +import textwrap +import time + +from openai import OpenAI + +from .config import ( + API_BASE_URL, + API_KEY, + LLM_TEMPERATURE, + MAX_RETRIES, + MODEL, + QUERIES_PATH, + RETRY_DELAY, +) + +REWRITE_SYSTEM_PROMPT = textwrap.dedent("""\ + You transform GitHub issue descriptions into realistic user queries + that a developer would type into an AI coding assistant (like Claude Code or Cursor). + + Rules: + - Make it conversational and natural + - Keep the core technical ask but remove GitHub formatting + - Remove reproduction steps, stack traces, verbose details + - Keep it to 1-3 sentences + - Don't mention "issue" or "bug report" + - Sound like someone asking for help, not filing a report +""") + + +def get_client(api_key: str = API_KEY, base_url: str = API_BASE_URL) -> OpenAI: + return OpenAI(api_key=api_key, base_url=base_url) + + +def llm_call( + client: OpenAI, + model: str, + system: str, + user: str, + temperature: float = LLM_TEMPERATURE, + max_tokens: int = 300, +) -> str: + """Make an LLM call with retries.""" + for attempt in range(MAX_RETRIES): + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=temperature, + max_tokens=max_tokens, + ) + return response.choices[0].message.content.strip() + except Exception as e: + if attempt < MAX_RETRIES - 1: + wait = RETRY_DELAY * (attempt + 1) + print(f" LLM error (attempt {attempt + 1}): {e}. Retrying in {wait}s...") + time.sleep(wait) + else: + raise + + +def rewrite_query(client: OpenAI, model: str, problem_statement: str, repo: str) -> str: + """Rewrite a problem statement into a natural user query.""" + user_msg = f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}" + return llm_call(client, model, REWRITE_SYSTEM_PROMPT, user_msg) + + +def load_existing_queries(path=QUERIES_PATH) -> dict[str, str]: + """Load already-processed queries for resumability.""" + existing = {} + if path.exists(): + with open(path) as f: + for line in f: + try: + entry = json.loads(line) + existing[entry["instance_id"]] = entry["query"] + except (json.JSONDecodeError, KeyError): + continue + return existing + + +def run( + instances: list[dict], + api_key: str = API_KEY, + base_url: str = API_BASE_URL, + model: str = MODEL, +): + """Run Phase 3: Rewrite all queries.""" + print("=" * 60) + print("Phase 3: Query Rewriting") + print("=" * 60) + + QUERIES_PATH.parent.mkdir(parents=True, exist_ok=True) + + client = get_client(api_key, base_url) + print(f"Using {base_url} with model {model}") + + # Load existing for resumability + existing = load_existing_queries() + print(f"Already processed: {len(existing)} queries") + + to_process = [inst for inst in instances if inst["instance_id"] not in existing] + print(f"Remaining: {len(to_process)} queries to process") + + processed = 0 + failed = 0 + + with open(QUERIES_PATH, "a") as f: + for i, inst in enumerate(to_process): + instance_id = inst["instance_id"] + repo = inst["repo"] + problem = inst["problem_statement"] + + try: + query = rewrite_query(client, model, problem, repo) + entry = {"instance_id": instance_id, "query": query} + f.write(json.dumps(entry) + "\n") + f.flush() + processed += 1 + + if processed % 50 == 0: + print(f" Progress: {processed}/{len(to_process)} (failed: {failed})") + except Exception as e: + print(f" ERROR {instance_id}: {e}") + failed += 1 + + print(f"\nDone: {processed} new queries, {failed} failed") + total = len(existing) + processed + print(f"Total queries: {total}") + + +if __name__ == "__main__": + from .swebench_loader import load_instances + + instances = load_instances() + run(instances) diff --git a/scripts/code_hallucination/sample_assembler.py b/scripts/code_hallucination/sample_assembler.py new file mode 100644 index 0000000..bdcbd61 --- /dev/null +++ b/scripts/code_hallucination/sample_assembler.py @@ -0,0 +1,182 @@ +"""Phase 7: Assemble final HallucinationSample format.""" + +import json + +from .config import DATASET_PATH, METADATA_PATH, SOURCE_CACHE_DIR + + +def build_prompt( + source_files: dict[str, str], + documentation: dict[str, str], + user_query: str, +) -> str: + """Build the prompt (context) for a sample. + + Format: source files + documentation + user query. + """ + parts = [] + + for filepath, content in source_files.items(): + parts.append(f"File: {filepath}\n```python\n{content}\n```") + + for lib, doc in documentation.items(): + parts.append(f"Documentation for {lib}:\n{doc}") + + parts.append(f"User request: {user_query}") + + return "\n\n".join(parts) + + +def assemble_samples( + instances: list[dict], + source_cache: dict[str, dict], + queries: dict[str, str], + docs: dict[str, dict], + formats: dict[str, dict], + hallucinations: dict[str, dict], + hallucination_instance_ids: set[str], +) -> tuple[list[dict], list[dict]]: + """Assemble all samples into HallucinationSample format. + + Returns (samples, metadata) where each instance maps to exactly 1 sample. + """ + samples = [] + metadata = [] + + for inst in instances: + instance_id = inst["instance_id"] + split = inst["split"] + repo = inst["repo"] + + # Skip if no source data or format + if instance_id not in source_cache or instance_id not in formats: + continue + + source_data = source_cache[instance_id] + fmt_data = formats[instance_id] + query = queries.get(instance_id, inst.get("problem_statement", "")[:500]) + doc = docs.get(instance_id, {}) + + # Build prompt from source files + docs + query + source_files = source_data.get("source_files", {}) + prompt = build_prompt(source_files, doc, query) + + if instance_id in hallucination_instance_ids and instance_id in hallucinations: + # Hallucinated sample + hall_data = hallucinations[instance_id] + sample = { + "prompt": prompt, + "answer": hall_data["hallucinated_answer"], + "labels": hall_data["labels"], + "split": split, + "task_type": "code_generation", + "dataset": "swebench_code", + "language": "en", + } + meta = { + "instance_id": instance_id, + "repo": repo, + "split": split, + "is_lite": inst.get("is_lite", False), + "format_type": hall_data.get("format_type", fmt_data.get("format_type")), + "hallucination_type": hall_data.get("hallucination_type"), + "injector_model": hall_data.get("injector_model"), + "is_hallucinated": True, + } + else: + # Clean sample + answer = fmt_data.get("answer", "") + if not answer.strip(): + continue + + sample = { + "prompt": prompt, + "answer": answer, + "labels": [], + "split": split, + "task_type": "code_generation", + "dataset": "swebench_code", + "language": "en", + } + meta = { + "instance_id": instance_id, + "repo": repo, + "split": split, + "is_lite": inst.get("is_lite", False), + "format_type": fmt_data.get("format_type"), + "hallucination_type": None, + "injector_model": None, + "is_hallucinated": False, + } + + samples.append(sample) + metadata.append(meta) + + return samples, metadata + + +def run( + instances: list[dict], + queries: dict[str, str], + docs: dict[str, dict], + formats: dict[str, dict], + hallucinations: dict[str, dict], + hallucination_instance_ids: set[str], +): + """Run Phase 7: Assemble all samples.""" + print("=" * 60) + print("Phase 7: Sample Assembly") + print("=" * 60) + + # Load source cache + source_cache = {} + for inst in instances: + cache_path = SOURCE_CACHE_DIR / f"{inst['instance_id']}.json" + if cache_path.exists(): + with open(cache_path) as f: + source_cache[inst["instance_id"]] = json.load(f) + + print(f"Source cache: {len(source_cache)} instances") + print(f"Queries: {len(queries)}") + print(f"Docs: {len(docs)}") + print(f"Formats: {len(formats)}") + print(f"Hallucinations: {len(hallucinations)}") + print(f"Hallucination targets: {len(hallucination_instance_ids)}") + + samples, metadata = assemble_samples( + instances, + source_cache, + queries, + docs, + formats, + hallucinations, + hallucination_instance_ids, + ) + + # Save + DATASET_PATH.parent.mkdir(parents=True, exist_ok=True) + + with open(DATASET_PATH, "w") as f: + json.dump(samples, f, indent=2) + + with open(METADATA_PATH, "w") as f: + json.dump(metadata, f, indent=2) + + # Stats + n_clean = sum(1 for s in samples if not s["labels"]) + n_hall = sum(1 for s in samples if s["labels"]) + print(f"\nTotal samples: {len(samples)}") + print(f" Clean: {n_clean} ({n_clean * 100 // max(len(samples), 1)}%)") + print(f" Hallucinated: {n_hall} ({n_hall * 100 // max(len(samples), 1)}%)") + + split_counts = {} + for s in samples: + split_counts[s["split"]] = split_counts.get(s["split"], 0) + 1 + for split, count in sorted(split_counts.items()): + print(f" {split}: {count}") + + return samples, metadata + + +if __name__ == "__main__": + print("Run via pipeline.py") diff --git a/scripts/code_hallucination/source_fetcher.py b/scripts/code_hallucination/source_fetcher.py new file mode 100644 index 0000000..fa0c3ac --- /dev/null +++ b/scripts/code_hallucination/source_fetcher.py @@ -0,0 +1,460 @@ +"""Phase 2: Clone repos and fetch source files via git show.""" + +import ast +import json +import re +import subprocess +import tempfile +from pathlib import Path + +import requests + +from .config import MAX_FILE_CHARS, REPOS_DIR, SOURCE_CACHE_DIR + +GITHUB_RAW_BASE = "https://raw.githubusercontent.com" + + +def extract_changed_files(patch: str) -> list[str]: + """Extract file paths from a unified diff using anchored regex. + + Fixed: Uses re.match on 'diff --git' lines instead of re.search(r"b/(.+)$") + which was ambiguous on paths containing 'b/'. + """ + files = [] + for line in patch.split("\n"): + if line.startswith("diff --git"): + match = re.match(r"diff --git a/(.+?) b/(.+)$", line) + if match: + files.append(match.group(2)) + return files + + +def clone_repo(repo: str, repos_dir: Path = REPOS_DIR) -> Path | None: + """Clone a repo if not already present. Returns repo directory path.""" + repos_dir.mkdir(parents=True, exist_ok=True) + repo_dir = repos_dir / repo.replace("/", "__") + + if repo_dir.exists(): + return repo_dir + + print(f" Cloning {repo}...") + try: + result = subprocess.run( + ["git", "clone", "--bare", f"https://github.com/{repo}.git", str(repo_dir)], + capture_output=True, + text=True, + timeout=1800, # 30 min for large repos + ) + if result.returncode != 0: + print(f" ERROR cloning {repo}: {result.stderr[:200]}") + return None + except subprocess.TimeoutExpired: + print(f" TIMEOUT cloning {repo} (>30 min)") + return None + + return repo_dir + + +def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None: + """Fallback: fetch a file from GitHub raw API (for when repo isn't cloned).""" + url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}" + try: + r = requests.get(url, timeout=15) + if r.status_code == 200: + return r.text[:MAX_FILE_CHARS] + return None + except Exception: + return None + + +def fetch_file_at_commit(repo_dir: Path, commit: str, filepath: str) -> str | None: + """Fetch a file's contents at a specific commit using git show.""" + try: + result = subprocess.run( + ["git", "show", f"{commit}:{filepath}"], + cwd=str(repo_dir), + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + return result.stdout[:MAX_FILE_CHARS] + return None + except (subprocess.TimeoutExpired, Exception) as e: + print(f" Error fetching {filepath}@{commit[:8]}: {e}") + return None + + +def apply_patch_and_get_file(repo_dir: Path, commit: str, patch: str, filepath: str) -> str | None: + """Apply a patch to get the post-fix version of a file. + + Uses a temporary worktree to apply the patch cleanly. + """ + try: + with tempfile.TemporaryDirectory() as tmpdir: + # Create a temporary worktree at the base commit + result = subprocess.run( + ["git", "worktree", "add", tmpdir, commit], + cwd=str(repo_dir), + capture_output=True, + text=True, + timeout=60, + ) + if result.returncode != 0: + return None + + # Apply the patch + result = subprocess.run( + ["git", "apply", "--allow-empty", "-"], + input=patch, + cwd=tmpdir, + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + return None + + # Read the patched file + patched_path = Path(tmpdir) / filepath + if patched_path.exists(): + content = patched_path.read_text()[:MAX_FILE_CHARS] + return content + + # Clean up worktree + subprocess.run( + ["git", "worktree", "remove", "--force", tmpdir], + cwd=str(repo_dir), + capture_output=True, + timeout=30, + ) + return None + except Exception as e: + print(f" Error applying patch for {filepath}: {e}") + return None + + +def extract_modified_functions(original_source: str, patched_source: str) -> list[dict]: + """Extract functions that were modified between original and patched source. + + Returns list of dicts with 'name', 'original', 'patched' function bodies. + """ + + def get_functions(source: str) -> dict[str, str]: + """Parse source and extract function name -> source mapping.""" + try: + tree = ast.parse(source) + except SyntaxError: + return {} + + funcs = {} + lines = source.split("\n") + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + start = node.lineno - 1 + end = node.end_lineno + func_source = "\n".join(lines[start:end]) + funcs[node.name] = func_source + return funcs + + orig_funcs = get_functions(original_source) + patched_funcs = get_functions(patched_source) + + modified = [] + for name, patched_body in patched_funcs.items(): + orig_body = orig_funcs.get(name) + if orig_body and orig_body != patched_body: + modified.append( + { + "name": name, + "original": orig_body, + "patched": patched_body, + } + ) + elif not orig_body: + # New function added by patch + modified.append( + { + "name": name, + "original": None, + "patched": patched_body, + } + ) + + return modified + + +def extract_code_from_patch(patch: str) -> str: + """Extract added/changed lines from a unified diff as code fragment. + + Fixed: Uses removeprefix("b/") instead of lstrip("b/"). + """ + added_lines = [] + in_hunk = False + + for line in patch.split("\n"): + if line.startswith("@@"): + in_hunk = True + continue + if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"): + continue + if in_hunk: + if line.startswith("+"): + added_lines.append(line[1:]) + elif line.startswith("-"): + continue + else: + if line.startswith(" "): + added_lines.append(line[1:]) + else: + added_lines.append(line) + + return "\n".join(added_lines) + + +def build_edit_style_answer(patch: str, changed_files: list[str]) -> str | None: + """Build an edit-style answer from a patch. + + Format: + In file path/to/file.py, replace: + ```python + old code + ``` + with: + ```python + new code + ``` + """ + edits = [] + + current_file = None + old_lines = [] + new_lines = [] + in_hunk = False + + for line in patch.split("\n"): + if line.startswith("diff --git"): + # Flush previous hunk + if current_file and (old_lines or new_lines): + edits.append( + { + "file": current_file, + "old": "\n".join(old_lines), + "new": "\n".join(new_lines), + } + ) + old_lines = [] + new_lines = [] + match = re.match(r"diff --git a/(.+?) b/(.+)$", line) + if match: + current_file = match.group(2) + in_hunk = False + continue + + if line.startswith("@@"): + # Flush previous hunk in same file + if old_lines or new_lines: + edits.append( + { + "file": current_file, + "old": "\n".join(old_lines), + "new": "\n".join(new_lines), + } + ) + old_lines = [] + new_lines = [] + in_hunk = True + continue + + if line.startswith("---") or line.startswith("+++"): + continue + + if in_hunk: + if line.startswith("-"): + old_lines.append(line[1:]) + elif line.startswith("+"): + new_lines.append(line[1:]) + # Context lines are skipped in edit-style + + # Flush last hunk + if current_file and (old_lines or new_lines): + edits.append( + { + "file": current_file, + "old": "\n".join(old_lines), + "new": "\n".join(new_lines), + } + ) + + if not edits: + return None + + parts = [] + for edit in edits: + if edit["old"] and edit["new"]: + parts.append( + f"In file {edit['file']}, replace:\n```python\n{edit['old']}\n```\nwith:\n```python\n{edit['new']}\n```" + ) + elif edit["new"]: + parts.append(f"In file {edit['file']}, add:\n```python\n{edit['new']}\n```") + + return "\n\n".join(parts) if parts else None + + +def fetch_source_for_instance( + instance: dict, repos_dir: Path = REPOS_DIR, use_github_api: bool = False +) -> dict | None: + """Fetch all source files for a single SWE-bench instance. + + Args: + instance: SWE-bench instance dict + repos_dir: Directory for cloned repos + use_github_api: If True, use GitHub raw API instead of git clone + + Returns dict with: + instance_id, changed_files, source_files, patch_code, edit_style, + modified_functions + Or None if fetching failed. + + """ + repo = instance["repo"] + commit = instance["base_commit"] + patch = instance["patch"] + + # Extract changed files from patch + changed_files = extract_changed_files(patch) + if not changed_files: + return None + + # Fetch source files + source_files = {} + repo_dir = None + + if use_github_api: + # Use GitHub raw API (slower but no cloning needed) + import time + + for filepath in changed_files: + content = fetch_file_from_github(repo, commit, filepath) + if content: + source_files[filepath] = content + time.sleep(0.3) + else: + # Clone repo and use git show (fast for bulk) + repo_dir = clone_repo(repo, repos_dir) + if repo_dir is None: + # Fallback to GitHub API + import time + + print(f" Falling back to GitHub API for {repo}") + for filepath in changed_files: + content = fetch_file_from_github(repo, commit, filepath) + if content: + source_files[filepath] = content + time.sleep(0.3) + else: + for filepath in changed_files: + content = fetch_file_at_commit(repo_dir, commit, filepath) + if content: + source_files[filepath] = content + + if not source_files: + return None + + # Build different answer formats + # Fragment format + patch_code = extract_code_from_patch(patch) + + # Edit-style format + edit_style = build_edit_style_answer(patch, changed_files) + + # Complete function format (needs patched source via git apply) + modified_functions = [] + if repo_dir is not None: + for filepath in changed_files: + if filepath in source_files: + patched_source = apply_patch_and_get_file(repo_dir, commit, patch, filepath) + if patched_source: + funcs = extract_modified_functions(source_files[filepath], patched_source) + for func in funcs: + func["file"] = filepath + modified_functions.extend(funcs) + + return { + "instance_id": instance["instance_id"], + "changed_files": changed_files, + "source_files": source_files, + "patch_code": patch_code, + "edit_style": edit_style, + "modified_functions": modified_functions, + } + + +def run(instances: list[dict], use_github_api: bool = False): + """Run Phase 2: Fetch source files for all instances. + + Args: + instances: List of SWE-bench instances + use_github_api: If True, use GitHub raw API instead of cloning repos + + """ + print("=" * 60) + print("Phase 2: Source File Fetching") + print("=" * 60) + + SOURCE_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + if not use_github_api: + REPOS_DIR.mkdir(parents=True, exist_ok=True) + # Group by repo for efficient cloning + repos = set(inst["repo"] for inst in instances) + print(f"Need to clone {len(repos)} repos") + for repo in sorted(repos): + clone_repo(repo) + else: + print("Using GitHub raw API (no cloning)") + + # Fetch sources per instance + results = [] + failed = 0 + + for i, instance in enumerate(instances): + if (i + 1) % 100 == 0: + print(f" Progress: {i + 1}/{len(instances)} ({len(results)} success, {failed} failed)") + + # Skip if already cached + cache_path = SOURCE_CACHE_DIR / f"{instance['instance_id']}.json" + if cache_path.exists(): + with open(cache_path) as f: + results.append(json.load(f)) + continue + + result = fetch_source_for_instance(instance, use_github_api=use_github_api) + if result: + results.append(result) + # Cache result + cache_path = SOURCE_CACHE_DIR / f"{instance['instance_id']}.json" + with open(cache_path, "w") as f: + json.dump(result, f) + else: + failed += 1 + + print(f"\nDone: {len(results)} success, {failed} failed out of {len(instances)}") + + # Report format availability + n_fragment = sum(1 for r in results if r.get("patch_code", "").strip()) + n_edit = sum(1 for r in results if r.get("edit_style")) + n_function = sum(1 for r in results if r.get("modified_functions")) + print("Format availability:") + print(f" Fragment: {n_fragment}") + print(f" Edit-style: {n_edit}") + print(f" Complete function: {n_function}") + + return results + + +if __name__ == "__main__": + from .swebench_loader import load_instances + + instances = load_instances() + run(instances) diff --git a/scripts/code_hallucination/splitter.py b/scripts/code_hallucination/splitter.py new file mode 100644 index 0000000..a98c70a --- /dev/null +++ b/scripts/code_hallucination/splitter.py @@ -0,0 +1,64 @@ +"""Phase 8: Train/dev/test split by SWE-bench splits. + +Since we use SWE-bench splits directly, this phase mostly validates +the splits and selects which instances get hallucination injection. +""" + +import random + +from .config import HALLUCINATION_RATIO + + +def select_hallucination_targets( + instances: list[dict], + ratio: float = HALLUCINATION_RATIO, + seed: int = 42, +) -> set[str]: + """Select which instances get hallucination injection. + + Applies the ratio uniformly within each split to maintain + consistent class distribution across train/dev/test. + + Returns set of instance_ids that should be hallucinated. + """ + rng = random.Random(seed) + targets = set() + + # Group by split + by_split = {} + for inst in instances: + split = inst["split"] + if split not in by_split: + by_split[split] = [] + by_split[split].append(inst) + + for split, split_instances in by_split.items(): + n_hall = int(len(split_instances) * ratio) + rng.shuffle(split_instances) + for inst in split_instances[:n_hall]: + targets.add(inst["instance_id"]) + + n_clean = len(split_instances) - n_hall + print(f" {split}: {n_hall} hallucinated + {n_clean} clean = {len(split_instances)}") + + return targets + + +def run(instances: list[dict]) -> set[str]: + """Run Phase 8: Select hallucination targets.""" + print("=" * 60) + print("Phase 8: Split & Target Selection") + print("=" * 60) + + targets = select_hallucination_targets(instances) + print(f"\nTotal hallucination targets: {len(targets)} out of {len(instances)}") + print(f"Ratio: {len(targets) / max(len(instances), 1):.1%}") + + return targets + + +if __name__ == "__main__": + from .swebench_loader import load_instances + + instances = load_instances() + run(instances) diff --git a/scripts/code_hallucination/swebench_loader.py b/scripts/code_hallucination/swebench_loader.py new file mode 100644 index 0000000..523fd19 --- /dev/null +++ b/scripts/code_hallucination/swebench_loader.py @@ -0,0 +1,93 @@ +"""Phase 1: Load SWE-bench instances from all splits.""" + +import json + +from datasets import load_dataset + +from .config import DATA_DIR, INSTANCES_PATH, SWEBENCH_FULL, SWEBENCH_LITE + + +def load_all_splits() -> list[dict]: + """Load all SWE-bench splits and tag each instance. + + Returns list of dicts with fields: + instance_id, repo, base_commit, patch, test_patch, + problem_statement, hints_text, created_at, version, + FAIL_TO_PASS, PASS_TO_PASS, environment_setup_commit, + split, is_lite + """ + # Load Lite instance IDs for tagging + print("Loading SWE-bench Lite...") + lite_ds = load_dataset(SWEBENCH_LITE, split="test") + lite_ids = {row["instance_id"] for row in lite_ds} + print(f" Lite: {len(lite_ids)} instances") + + all_instances = [] + + for split_name in ["train", "dev", "test"]: + print(f"Loading SWE-bench {split_name} split...") + ds = load_dataset(SWEBENCH_FULL, split=split_name) + print(f" {split_name}: {len(ds)} instances") + + for row in ds: + instance = dict(row) + instance["split"] = split_name + instance["is_lite"] = instance["instance_id"] in lite_ids + all_instances.append(instance) + + # Report stats + repos_by_split = {} + for inst in all_instances: + split = inst["split"] + repo = inst["repo"] + if split not in repos_by_split: + repos_by_split[split] = set() + repos_by_split[split].add(repo) + + print(f"\nTotal: {len(all_instances)} instances") + for split, repos in repos_by_split.items(): + count = sum(1 for i in all_instances if i["split"] == split) + print(f" {split}: {count} instances across {len(repos)} repos") + + n_lite = sum(1 for i in all_instances if i["is_lite"]) + print(f" Lite-tagged: {n_lite}") + + # Verify zero repo overlap + all_splits = list(repos_by_split.keys()) + for i, s1 in enumerate(all_splits): + for s2 in all_splits[i + 1 :]: + overlap = repos_by_split[s1] & repos_by_split[s2] + if overlap: + print(f" WARNING: Repo overlap between {s1} and {s2}: {overlap}") + else: + print(f" {s1} ∩ {s2}: 0 repo overlap βœ“") + + return all_instances + + +def save_instances(instances: list[dict], path=INSTANCES_PATH): + """Save instances to JSON.""" + DATA_DIR.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(instances, f, indent=2) + print(f"Saved {len(instances)} instances to {path}") + + +def load_instances(path=INSTANCES_PATH) -> list[dict]: + """Load previously saved instances.""" + with open(path) as f: + return json.load(f) + + +def run(): + """Run Phase 1: Load and save all SWE-bench instances.""" + print("=" * 60) + print("Phase 1: Load SWE-bench") + print("=" * 60) + instances = load_all_splits() + save_instances(instances) + return instances + + +if __name__ == "__main__": + run() diff --git a/scripts/code_hallucination/validator.py b/scripts/code_hallucination/validator.py new file mode 100644 index 0000000..50d5251 --- /dev/null +++ b/scripts/code_hallucination/validator.py @@ -0,0 +1,200 @@ +"""Phase 9: Quality checks and validation report.""" + +import ast +import json +from collections import Counter + +from .config import DATASET_PATH, METADATA_PATH, VALIDATION_REPORT_PATH + + +def validate_spans(samples: list[dict]) -> list[str]: + """Check span boundary validity.""" + issues = [] + for i, sample in enumerate(samples): + answer_len = len(sample["answer"]) + for label in sample.get("labels", []): + start = label.get("start", 0) + end = label.get("end", 0) + if start < 0 or end < 0: + issues.append(f"Sample {i}: negative span offset ({start}, {end})") + if end <= start: + issues.append(f"Sample {i}: empty/inverted span ({start}, {end})") + if end > answer_len: + issues.append(f"Sample {i}: span exceeds answer length ({end} > {answer_len})") + return issues + + +def check_span_coverage(samples: list[dict]) -> dict: + """Report span coverage distribution for hallucinated samples.""" + coverages = [] + for sample in samples: + if not sample.get("labels"): + continue + answer_len = len(sample["answer"]) + if answer_len == 0: + continue + total_span = sum(label["end"] - label["start"] for label in sample["labels"]) + coverage = total_span / answer_len + coverages.append(coverage) + + if not coverages: + return {"count": 0} + + return { + "count": len(coverages), + "min": min(coverages), + "max": max(coverages), + "mean": sum(coverages) / len(coverages), + "low_coverage": sum(1 for c in coverages if c < 0.02), + "high_coverage": sum(1 for c in coverages if c > 0.80), + } + + +def check_distributions(metadata: list[dict]) -> dict: + """Report format/type/LLM/repo distributions.""" + format_counts = Counter(m.get("format_type") for m in metadata if m.get("format_type")) + type_counts = Counter( + m.get("hallucination_type") for m in metadata if m.get("hallucination_type") + ) + model_counts = Counter(m.get("injector_model") for m in metadata if m.get("injector_model")) + repo_counts = Counter(m.get("repo") for m in metadata) + split_counts = Counter(m.get("split") for m in metadata) + + return { + "format": dict(format_counts), + "hallucination_type": dict(type_counts), + "injector_model": dict(model_counts), + "repos": len(repo_counts), + "top_repos": dict(repo_counts.most_common(10)), + "split": dict(split_counts), + } + + +def check_near_duplicates(samples: list[dict], threshold: float = 0.95) -> int: + """Simple near-duplicate check via answer Jaccard similarity (sampled).""" + import random + + rng = random.Random(42) + + n = min(500, len(samples)) + sample_indices = rng.sample(range(len(samples)), n) + + duplicates = 0 + for i in range(len(sample_indices)): + for j in range(i + 1, min(i + 5, len(sample_indices))): + a = set(samples[sample_indices[i]]["answer"].split()) + b = set(samples[sample_indices[j]]["answer"].split()) + if not a or not b: + continue + jaccard = len(a & b) / len(a | b) + if jaccard > threshold: + duplicates += 1 + + return duplicates + + +def check_ast_parseability(samples: list[dict], metadata: list[dict]) -> dict: + """Check AST parseability for complete_function format samples.""" + total = 0 + parseable = 0 + + for sample, meta in zip(samples, metadata): + if meta.get("format_type") != "complete_function": + continue + total += 1 + try: + ast.parse(sample["answer"]) + parseable += 1 + except SyntaxError: + pass + + return { + "total": total, + "parseable": parseable, + "rate": parseable / max(total, 1), + } + + +def run(samples: list[dict] = None, metadata: list[dict] = None): + """Run Phase 9: Validation.""" + print("=" * 60) + print("Phase 9: Validation") + print("=" * 60) + + if samples is None: + with open(DATASET_PATH) as f: + samples = json.load(f) + if metadata is None: + with open(METADATA_PATH) as f: + metadata = json.load(f) + + report_lines = [] + + def report(text): + print(text) + report_lines.append(text) + + report(f"Total samples: {len(samples)}") + n_clean = sum(1 for s in samples if not s["labels"]) + n_hall = sum(1 for s in samples if s["labels"]) + report(f"Clean: {n_clean}, Hallucinated: {n_hall}") + report("") + + # 1. Span validity + report("=== Span Validity ===") + span_issues = validate_spans(samples) + report(f"Issues found: {len(span_issues)}") + for issue in span_issues[:10]: + report(f" {issue}") + report("") + + # 2. Span coverage + report("=== Span Coverage ===") + coverage = check_span_coverage(samples) + for k, v in coverage.items(): + report(f" {k}: {v}") + report("") + + # 3. Distributions + report("=== Distributions ===") + dists = check_distributions(metadata) + for k, v in dists.items(): + report(f" {k}: {v}") + report("") + + # 4. Near duplicates + report("=== Near Duplicates ===") + n_dup = check_near_duplicates(samples) + report(f"Near duplicates (sampled): {n_dup}") + report("") + + # 5. AST parseability + report("=== AST Parseability ===") + ast_check = check_ast_parseability(samples, metadata) + for k, v in ast_check.items(): + report(f" {k}: {v}") + report("") + + # 6. Length stats + report("=== Length Statistics ===") + prompt_lens = [len(s["prompt"]) for s in samples] + answer_lens = [len(s["answer"]) for s in samples] + if prompt_lens: + report( + f" Prompt chars - min: {min(prompt_lens):,}, max: {max(prompt_lens):,}, avg: {sum(prompt_lens) // len(prompt_lens):,}" + ) + report( + f" Answer chars - min: {min(answer_lens):,}, max: {max(answer_lens):,}, avg: {sum(answer_lens) // len(answer_lens):,}" + ) + + # Save report + VALIDATION_REPORT_PATH.parent.mkdir(parents=True, exist_ok=True) + with open(VALIDATION_REPORT_PATH, "w") as f: + f.write("\n".join(report_lines)) + + print(f"\nReport saved to {VALIDATION_REPORT_PATH}") + return report_lines + + +if __name__ == "__main__": + run() diff --git a/scripts/enrich_code_hallucination_dataset.py b/scripts/enrich_code_hallucination_dataset.py new file mode 100644 index 0000000..0ceb5a2 --- /dev/null +++ b/scripts/enrich_code_hallucination_dataset.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +"""Enrich code hallucination dataset with actual source file contents from GitHub. + +Takes the raw dataset and: +1. Fetches actual source files from GitHub at the base commit +2. Builds proper context (source files + docs + query) +3. Converts diff-format answers to actual code +4. Outputs in exact LettuceDetect HallucinationSample format +""" + +import json +import os +import time + +import requests + +INPUT_PATH = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json" +OUTPUT_PATH = ( + "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect_v2.json" +) +GITHUB_RAW_BASE = "https://raw.githubusercontent.com" +MAX_FILE_CHARS = 12000 # Cap individual file size to avoid blowing up context + + +def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None: + """Fetch a file's contents from GitHub at a specific commit.""" + url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}" + try: + r = requests.get(url, timeout=15) + if r.status_code == 200: + return r.text[:MAX_FILE_CHARS] + return None + except Exception as e: + print(f" Error fetching {filepath}: {e}") + return None + + +def extract_code_from_patch(patch: str) -> str: + """Extract just the added lines from a unified diff as the 'answer' code. + + Returns the new code (added lines) that represents what was generated. + """ + added_lines = [] + in_hunk = False + + for line in patch.split("\n"): + if line.startswith("@@"): + in_hunk = True + continue + if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"): + continue + if in_hunk: + if line.startswith("+"): + added_lines.append(line[1:]) # Remove the '+' prefix + elif line.startswith("-"): + continue # Skip removed lines + else: + # Context line - include to maintain readability + if line.startswith(" "): + added_lines.append(line[1:]) + else: + added_lines.append(line) + + return "\n".join(added_lines) + + +def build_prompt( + source_files: dict[str, str], + documentation: dict[str, str], + user_query: str, +) -> str: + """Build the prompt (context) in the same style as RAGTruth. + + Format: all context concatenated, similar to how RAGTruth provides + the source prompt that was given to the LLM. + """ + parts = [] + + # Source code files (the main context) + for filepath, content in source_files.items(): + parts.append(f"File: {filepath}\n```python\n{content}\n```") + + # Documentation + for lib, doc in documentation.items(): + parts.append(f"Documentation for {lib}:\n{doc}") + + # User query (acts as the "question" / instruction) + parts.append(f"User request: {user_query}") + + return "\n\n".join(parts) + + +def compute_span_offsets_in_code(code: str, hallucinated_span: str) -> list[dict]: + """Find character offsets of a hallucinated span within the answer code.""" + spans = [] + start = 0 + while True: + idx = code.find(hallucinated_span, start) + if idx == -1: + break + spans.append({"start": idx, "end": idx + len(hallucinated_span)}) + start = idx + 1 + break # Take first occurrence only + return spans + + +def process_sample(sample: dict, idx: int, total: int) -> list[dict]: + """Process one raw sample into LettuceDetect format samples.""" + instance_id = sample["instance_id"] + repo = sample["repo"] + commit = sample["base_commit"] + changed_files = sample["changed_files"] + + print(f"[{idx + 1}/{total}] {instance_id}") + + # Step 1: Fetch actual source files from GitHub + source_files = {} + for filepath in changed_files: + # Some paths in SWE-bench have the 'a/' or weird format from diff parsing + # Clean up: "models/fields/__init__.py b/django/db/models/fields/__init__.py" + # Take the last valid-looking path + clean_paths = [ + p.strip() for p in filepath.split(" ") if "/" in p and not p.startswith("a/") + ] + if not clean_paths: + clean_paths = [filepath] + + for path in clean_paths: + path = path.lstrip("b/") + content = fetch_file_from_github(repo, commit, path) + if content: + source_files[path] = content + print(f" Fetched {path}: {len(content)} chars") + else: + print(f" Failed to fetch {path}") + time.sleep(0.3) # Rate limit + + if not source_files: + print(" SKIP: No source files fetched") + return [] + + # Step 2: Build the prompt (context) + documentation = sample.get("documentation", {}) + user_query = sample.get("user_query", "") + prompt = build_prompt(source_files, documentation, user_query) + + # Step 3: Create correct (negative) sample + gold_code = extract_code_from_patch(sample["gold_patch"]) + samples_out = [] + + if gold_code.strip(): + samples_out.append( + { + "prompt": prompt, + "answer": gold_code, + "labels": [], + "split": "train", + "task_type": "code_generation", + "dataset": "code_hallucination_swebench", + "language": "en", + } + ) + + # Step 4: Create hallucinated (positive) samples + for hall in sample.get("hallucinations", []): + if not isinstance(hall, dict) or "hallucinated_patch" not in hall: + continue + + hall_code = extract_code_from_patch(hall["hallucinated_patch"]) + if not hall_code.strip(): + continue + + # Compute span labels in the answer code + labels = [] + for change in hall.get("changes", []): + h_span = change.get("hallucinated", "") + if h_span and h_span in hall_code: + offsets = compute_span_offsets_in_code(hall_code, h_span) + for offset in offsets: + labels.append( + { + "start": offset["start"], + "end": offset["end"], + "label": hall.get("type", "hallucinated"), + } + ) + + if labels: + samples_out.append( + { + "prompt": prompt, + "answer": hall_code, + "labels": labels, + "split": "train", + "task_type": "code_generation", + "dataset": "code_hallucination_swebench", + "language": "en", + } + ) + + print( + f" Generated {len(samples_out)} samples (1 correct + {len(samples_out) - 1} hallucinated)" + ) + return samples_out + + +def main(): + print("=" * 60) + print("Enriching Code Hallucination Dataset") + print("Fetching source files + converting to LettuceDetect format") + print("=" * 60) + + with open(INPUT_PATH) as f: + raw_data = json.load(f) + + print(f"Loaded {len(raw_data)} raw samples\n") + + all_samples = [] + for i, sample in enumerate(raw_data): + new_samples = process_sample(sample, i, len(raw_data)) + all_samples.extend(new_samples) + + # Save + os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) + with open(OUTPUT_PATH, "w") as f: + json.dump(all_samples, f, indent=2) + + # Stats + print("\n" + "=" * 60) + print("FINAL DATASET") + print("=" * 60) + print(f"Total samples: {len(all_samples)}") + n_correct = sum(1 for s in all_samples if not s["labels"]) + n_hall = sum(1 for s in all_samples if s["labels"]) + print(f"Correct (negative): {n_correct}") + print(f"Hallucinated (positive): {n_hall}") + + prompt_lens = [len(s["prompt"]) for s in all_samples] + answer_lens = [len(s["answer"]) for s in all_samples] + total_lens = [p + a for p, a in zip(prompt_lens, answer_lens)] + + print( + f"\nPrompt chars - min: {min(prompt_lens):,}, max: {max(prompt_lens):,}, avg: {sum(prompt_lens) // len(prompt_lens):,}" + ) + print( + f"Answer chars - min: {min(answer_lens):,}, max: {max(answer_lens):,}, avg: {sum(answer_lens) // len(answer_lens):,}" + ) + print( + f"Total chars - min: {min(total_lens):,}, max: {max(total_lens):,}, avg: {sum(total_lens) // len(total_lens):,}" + ) + + est_tokens = [t // 4 for t in total_lens] + print( + f"\nEst. tokens - min: {min(est_tokens):,}, max: {max(est_tokens):,}, avg: {sum(est_tokens) // len(est_tokens):,}" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/evaluate_code_hallucination.py b/scripts/evaluate_code_hallucination.py new file mode 100644 index 0000000..87a68fd --- /dev/null +++ b/scripts/evaluate_code_hallucination.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +"""Evaluate LLM baseline on the code hallucination dataset. + +Uses LettuceDetect's existing LLMDetector and evaluator infrastructure. +Supports Groq API with any OpenAI-compatible model. + +Usage: + # With Groq + Kimi + OPENAI_API_KEY=gsk_... OPENAI_API_BASE=https://api.groq.com/openai/v1 \ + python scripts/evaluate_code_hallucination.py \ + --model moonshotai/kimi-k2-instruct-0905 \ + --data_path data/code_hallucination_lettucedetect_v2.json \ + --evaluation_type example_level +""" + +import argparse +import json +import os +import re +import time +from pathlib import Path +from string import Template + +from openai import OpenAI +from sklearn.metrics import auc, classification_report, precision_recall_fscore_support, roc_curve +from tqdm import tqdm + +from lettucedetect.datasets.hallucination_dataset import HallucinationSample + +# Simpler prompt for code hallucination detection (no structured output dependency) +CODE_HALLUCINATION_PROMPT = Template(""" +You are an expert code reviewer who must identify hallucinated code spans. + +A "code hallucination" is any part of the generated code that: +(a) Uses APIs, methods, functions, classes, or parameters that do NOT exist in the codebase context provided +(b) Contradicts the documentation or codebase shown in the source +(c) Does something semantically different from what the user requested + +## Instructions +1. Read the generated code inside .... +2. Compare it against the source context in .... +3. Identify any code spans that are hallucinated. +4. Return a JSON object with this exact format (no markdown, no code blocks): + {"hallucination_list": ["exact_span_1", "exact_span_2", ...]} + If no hallucinations found, return: {"hallucination_list": []} + +IMPORTANT: Each item in hallucination_list must be an EXACT substring from the answer. + + + +${context} + + + +${answer} +""") + + +def predict_with_llm( + client: OpenAI, + model: str, + prompt: str, + answer: str, + temperature: float = 0.0, +) -> list[dict]: + """Predict hallucination spans using an LLM.""" + llm_prompt = CODE_HALLUCINATION_PROMPT.substitute(context=prompt, answer=answer) + + try: + resp = client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are an expert in detecting hallucinations in LLM-generated code.", + }, + {"role": "user", "content": llm_prompt}, + ], + temperature=temperature, + max_tokens=1000, + ) + raw = resp.choices[0].message.content.strip() + + # Parse JSON from response + json_match = re.search(r"\{[\s\S]*\}", raw) + if json_match: + payload = json.loads(json_match.group()) + spans = [] + for sub in payload.get("hallucination_list", []): + if not sub: + continue + match = re.search(re.escape(sub), answer) + if match: + spans.append({"start": match.start(), "end": match.end(), "text": sub}) + return spans + return [] + except Exception as e: + print(f" Error: {e}") + return [] + + +def evaluate_example_level( + samples: list[HallucinationSample], + client: OpenAI, + model: str, + temperature: float = 0.0, +) -> dict: + """Evaluate at example level β€” binary: does this sample contain hallucinations?""" + example_preds = [] + example_labels = [] + + for sample in tqdm(samples, desc="Evaluating (example level)"): + predicted_spans = predict_with_llm(client, model, sample.prompt, sample.answer, temperature) + true_label = 1 if sample.labels else 0 + pred_label = 1 if predicted_spans else 0 + example_labels.append(true_label) + example_preds.append(pred_label) + time.sleep(0.5) # Rate limit + + precision, recall, f1, _ = precision_recall_fscore_support( + example_labels, example_preds, labels=[0, 1], average=None, zero_division=0 + ) + + fpr, tpr, _ = roc_curve(example_labels, example_preds) + auroc = auc(fpr, tpr) + + results = { + "supported": { + "precision": float(precision[0]), + "recall": float(recall[0]), + "f1": float(f1[0]), + }, + "hallucinated": { + "precision": float(precision[1]), + "recall": float(recall[1]), + "f1": float(f1[1]), + }, + "auroc": auroc, + } + + report = classification_report( + example_labels, + example_preds, + target_names=["Supported", "Hallucinated"], + digits=4, + zero_division=0, + ) + print("\nExample-Level Classification Report:") + print(report) + print(f"AUROC: {auroc:.4f}") + + return results + + +def evaluate_char_level( + samples: list[HallucinationSample], + client: OpenAI, + model: str, + temperature: float = 0.0, +) -> dict: + """Evaluate at character level β€” overlap between predicted and gold spans.""" + total_overlap = 0 + total_predicted = 0 + total_gold = 0 + + for sample in tqdm(samples, desc="Evaluating (char level)"): + predicted_spans = predict_with_llm(client, model, sample.prompt, sample.answer, temperature) + gold_spans = sample.labels + + total_predicted += sum(p["end"] - p["start"] for p in predicted_spans) + total_gold += sum(g["end"] - g["start"] for g in gold_spans) + + for pred in predicted_spans: + for gold in gold_spans: + overlap_start = max(pred["start"], gold["start"]) + overlap_end = min(pred["end"], gold["end"]) + if overlap_end > overlap_start: + total_overlap += overlap_end - overlap_start + + time.sleep(0.5) # Rate limit + + precision = total_overlap / total_predicted if total_predicted > 0 else 0 + recall = total_overlap / total_gold if total_gold > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + + print("\nCharacter-Level Results:") + print(f" Precision: {precision:.4f}") + print(f" Recall: {recall:.4f}") + print(f" F1: {f1:.4f}") + + return {"precision": precision, "recall": recall, "f1": f1} + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate LLM baseline on code hallucination dataset" + ) + parser.add_argument("--model", type=str, default="moonshotai/kimi-k2-instruct-0905") + parser.add_argument("--data_path", type=str, required=True) + parser.add_argument( + "--evaluation_type", + type=str, + default="example_level", + choices=["example_level", "char_level", "both"], + ) + parser.add_argument( + "--max_samples", + type=int, + default=None, + help="Limit number of test samples (for quick testing)", + ) + parser.add_argument( + "--test_ratio", type=float, default=0.3, help="Fraction of data to use as test set" + ) + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args() + + # Load data + data_path = Path(args.data_path) + raw_data = json.loads(data_path.read_text()) + + # Convert to HallucinationSample objects + samples = [] + for item in raw_data: + # Handle the dataset literal type - use "ragtruth" as fallback + dataset = item.get("dataset", "ragtruth") + if dataset not in ("ragtruth", "ragbench"): + dataset = "ragtruth" # Fallback for compatibility + language = item.get("language", "en") + if language not in ("en", "de"): + language = "en" + + samples.append( + HallucinationSample( + prompt=item["prompt"], + answer=item["answer"], + labels=item["labels"], + split=item.get("split", "test"), + task_type=item.get("task_type", "code_generation"), + dataset=dataset, + language=language, + ) + ) + + # Split into test set + import random + + random.seed(args.seed) + random.shuffle(samples) + + test_size = int(len(samples) * args.test_ratio) + test_samples = samples[:test_size] + + if args.max_samples: + test_samples = test_samples[: args.max_samples] + + n_positive = sum(1 for s in test_samples if s.labels) + n_negative = sum(1 for s in test_samples if not s.labels) + + print(f"Dataset: {data_path}") + print(f"Total samples: {len(samples)}") + print(f"Test samples: {len(test_samples)} (positive: {n_positive}, negative: {n_negative})") + print(f"Model: {args.model}") + print(f"API base: {os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1')}") + + # Setup client + client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + ) + + # Run evaluation + results = {} + + if args.evaluation_type in ("example_level", "both"): + print("\n" + "=" * 60) + print("EXAMPLE-LEVEL EVALUATION") + print("=" * 60) + results["example_level"] = evaluate_example_level(test_samples, client, args.model) + + if args.evaluation_type in ("char_level", "both"): + print("\n" + "=" * 60) + print("CHARACTER-LEVEL EVALUATION") + print("=" * 60) + results["char_level"] = evaluate_char_level(test_samples, client, args.model) + + # Save results + output_path = data_path.parent / f"eval_results_{args.model.replace('/', '_')}.json" + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/filter_data.py b/scripts/filter_data.py new file mode 100644 index 0000000..94bedd8 --- /dev/null +++ b/scripts/filter_data.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +from pathlib import Path + +from lettucedetect.datasets.hallucination_dataset import HallucinationData + +# Set up logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("filter_data") + + +def filter_dataset( + input_file: Path, output_file: Path, split: str = "test", task_type: str = "Data2txt" +) -> None: + """Filter dataset to include only samples with specific split and task type. + + :param input_file: Path to the input JSON file + :param output_file: Path to save the filtered JSON file + :param split: The split to filter for (default: 'test') + :param task_type: The task type to filter for (default: 'Data2txt') + """ + if not input_file.exists(): + raise FileNotFoundError(f"Input file not found: {input_file}") + + # Create output directory if it doesn't exist + output_file.parent.mkdir(parents=True, exist_ok=True) + + # Load data + try: + with open(input_file, "r") as f: + data = json.load(f) + + dataset = HallucinationData.from_json(data) + logger.info(f"Loaded {len(dataset.samples)} samples from {input_file}") + except Exception as e: + logger.error(f"Error loading input data: {e}") + raise + + # Filter samples + filtered_samples = [ + sample + for sample in dataset.samples + if sample.split == split and sample.task_type == task_type + ] + + # Create filtered dataset + filtered_dataset = HallucinationData(samples=filtered_samples) + + # Save filtered data + with open(output_file, "w") as f: + json.dump(filtered_dataset.to_json(), f, indent=2) + + logger.info(f"Filtered dataset from {len(dataset.samples)} to {len(filtered_samples)} samples") + logger.info(f"Saved filtered dataset to {output_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Filter hallucination dataset to specific split and task type" + ) + parser.add_argument("--input", type=str, required=True, help="Path to input JSON file") + parser.add_argument("--output", type=str, required=True, help="Path to save filtered JSON file") + parser.add_argument( + "--split", type=str, default="test", help="Split to filter for (default: 'test')" + ) + parser.add_argument( + "--task-type", + type=str, + default="Data2txt", + help="Task type to filter for (default: 'Data2txt')", + ) + + args = parser.parse_args() + + filter_dataset(Path(args.input), Path(args.output), args.split, args.task_type) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_code_hallucination_dataset.py b/scripts/generate_code_hallucination_dataset.py new file mode 100644 index 0000000..46b7634 --- /dev/null +++ b/scripts/generate_code_hallucination_dataset.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +"""Generate a code hallucination detection dataset from SWE-bench + Context7. + +Pipeline: +1. Load SWE-bench Lite samples +2. Extract context files and imports from patches +3. Fetch relevant documentation via Context7 +4. Transform issues into user-style queries (via LLM) +5. Inject hallucinations (structural + behavioral + semantic) +6. Convert to LettuceDetect training format +""" + +import json +import os +import re +import textwrap +import time +from typing import Any + +import requests +from openai import OpenAI + +# === Config === +GROQ_BASE_URL = "https://api.groq.com/openai/v1" +MODEL = "moonshotai/kimi-k2-instruct-0905" +CONTEXT7_BASE = "https://context7.com/api/v2" +OUTPUT_PATH = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json" +LETTUCEDETECT_OUTPUT_PATH = ( + "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect.json" +) +NUM_SAMPLES = 50 +RETRY_DELAY = 2 +MAX_RETRIES = 3 + + +def get_client() -> OpenAI: + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("Set OPENAI_API_KEY to your Groq API key") + return OpenAI(api_key=api_key, base_url=GROQ_BASE_URL) + + +# === Patch Parsing === + + +def extract_changed_files(patch: str) -> list[str]: + """Extract file paths changed in a unified diff.""" + files = [] + for line in patch.split("\n"): + if line.startswith("diff --git"): + match = re.search(r"b/(.+)$", line) + if match: + files.append(match.group(1)) + return files + + +def extract_imports_from_patch(patch: str) -> list[str]: + """Extract Python import statements from added lines in a patch.""" + imports = set() + for line in patch.split("\n"): + if line.startswith("+") and not line.startswith("+++"): + clean = line[1:].strip() + # Match import statements + if clean.startswith("import ") or clean.startswith("from "): + # Extract the top-level module + match = re.match(r"(?:from|import)\s+([\w.]+)", clean) + if match: + module = match.group(1).split(".")[0] + # Filter out local/relative imports and stdlib + if module and not module.startswith("_"): + imports.add(module) + return list(imports) + + +def extract_libraries_from_files(changed_files: list[str]) -> list[str]: + """Infer which external libraries might be relevant from file paths.""" + # Map repo paths to likely library names + path_to_lib = { + "django": "django", + "astropy": "astropy", + "sympy": "sympy", + "sklearn": "scikit-learn", + "matplotlib": "matplotlib", + "requests": "requests", + "flask": "flask", + "pytest": "pytest", + "sphinx": "sphinx", + "xarray": "xarray", + "seaborn": "seaborn", + "pylint": "pylint", + } + libs = set() + for f in changed_files: + for key, lib in path_to_lib.items(): + if key in f: + libs.add(lib) + return list(libs) + + +# === Context7 Documentation === + + +def fetch_context7_docs(library_name: str, query: str, max_chars: int = 2000) -> str | None: + """Fetch documentation from Context7 for a library + query.""" + try: + # Step 1: Resolve library ID + r = requests.get( + f"{CONTEXT7_BASE}/libs/search", + params={"query": query, "libraryName": library_name}, + timeout=10, + ) + if r.status_code != 200: + return None + results = r.json().get("results", []) + if not results: + return None + + lib_id = results[0]["id"] + + # Step 2: Get relevant docs + r2 = requests.get( + f"{CONTEXT7_BASE}/context", + params={"libraryId": lib_id, "query": query, "type": "txt"}, + timeout=10, + ) + if r2.status_code != 200: + return None + + doc_text = r2.text[:max_chars] + return doc_text if doc_text.strip() else None + + except Exception as e: + print(f" Context7 error for {library_name}: {e}") + return None + + +def get_documentation_context( + changed_files: list[str], patch: str, problem_statement: str +) -> dict[str, str]: + """Fetch documentation for libraries referenced in the patch.""" + docs = {} + + # Get libraries from imports in the patch + imported_libs = extract_imports_from_patch(patch) + # Get libraries from file paths + path_libs = extract_libraries_from_files(changed_files) + + all_libs = list(set(imported_libs + path_libs)) + + # Extract a short query from the problem statement + short_query = problem_statement[:200].replace("\n", " ").strip() + + for lib in all_libs[:3]: # Limit to 3 libraries to avoid rate limits + doc = fetch_context7_docs(lib, short_query) + if doc: + docs[lib] = doc + + return docs + + +# === LLM Calls === + + +def llm_call( + client: OpenAI, system: str, user: str, temperature: float = 0.7, max_tokens: int = 500 +) -> str: + """Make an LLM call with retries.""" + for attempt in range(MAX_RETRIES): + try: + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=temperature, + max_tokens=max_tokens, + ) + return response.choices[0].message.content.strip() + except Exception as e: + if attempt < MAX_RETRIES - 1: + wait = RETRY_DELAY * (attempt + 1) + print(f" LLM error (attempt {attempt + 1}): {e}. Retrying in {wait}s...") + time.sleep(wait) + else: + raise + + +def transform_to_user_query(client: OpenAI, problem_statement: str, repo: str) -> str: + """Transform a GitHub issue into a realistic user-to-agent query.""" + system = textwrap.dedent("""\ + You transform GitHub issue descriptions into realistic user queries + that a developer would type into an AI coding assistant (like Claude Code or Cursor). + + Rules: + - Make it conversational and natural + - Keep the core technical ask but remove GitHub formatting + - Remove reproduction steps, stack traces, verbose details + - Keep it to 1-3 sentences + - Don't mention "issue" or "bug report" + - Sound like someone asking for help, not filing a report + """) + user = f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}" + return llm_call(client, system, user, temperature=0.7, max_tokens=300) + + +def inject_hallucinations( + client: OpenAI, + gold_patch: str, + user_query: str, + problem_statement: str, + repo: str, + documentation: dict[str, str], +) -> dict[str, Any]: + """Inject hallucinations into a gold patch.""" + doc_context = "" + if documentation: + doc_context = "\n\nRelevant library documentation:\n" + for lib, doc in documentation.items(): + doc_context += f"\n--- {lib} docs ---\n{doc[:800]}\n" + + system = textwrap.dedent("""\ + You are a code hallucination injector for building a hallucination detection dataset. + + Given a correct code patch, user query, and documentation context, create THREE + hallucinated versions of the patch: + + 1. STRUCTURAL: Change a function call, import, or parameter to something that + doesn't exist or is wrong. Code should still parse but reference non-existent + APIs, wrong methods, or invented parameters. If documentation is provided, + you can hallucinate by using API calls that contradict the docs. + + 2. BEHAVIORAL: Use correct APIs but with wrong values or logic. Wrong defaults, + off-by-one errors, swapped conditions, wrong argument values. + + 3. SEMANTIC: Code that looks like it addresses the user's request but does + something subtly different or opposite. This should be the most subtle - + the code parses, uses real APIs, but fails to do what was asked. + Examples: implementing global when per-item was asked, any() vs all(), + catching exceptions instead of fixing root cause, inverted conditions. + + Respond in this exact JSON format (no markdown, no code blocks): + { + "hallucinations": [ + { + "type": "structural", + "hallucinated_patch": "the full modified patch text", + "changes": [ + { + "original": "exact original code span", + "hallucinated": "what you changed it to", + "explanation": "why this is a hallucination" + } + ] + }, + { + "type": "behavioral", + "hallucinated_patch": "...", + "changes": [...] + }, + { + "type": "semantic", + "hallucinated_patch": "...", + "changes": [...] + } + ] + } + + IMPORTANT: + - Hallucinations must be PLAUSIBLE - something an LLM would realistically generate + - Each change must be subtle, not obviously broken + - Return ONLY valid JSON + """) + + user = f"""Repository: {repo} + +User's query: {user_query} + +Original issue context: {problem_statement[:2000]}{doc_context} + +Correct gold patch: +{gold_patch} + +Generate three hallucinated versions of this patch.""" + + raw = llm_call(client, system, user, temperature=0.8, max_tokens=4000) + + # Parse JSON from response + json_match = re.search(r"\{[\s\S]*\}", raw) + if json_match: + try: + return json.loads(json_match.group()) + except json.JSONDecodeError: + return {"raw_response": raw, "parse_error": True} + return {"raw_response": raw, "parse_error": True} + + +# === LettuceDetect Format Conversion === + + +def compute_span_offsets(code: str, hallucinated_span: str) -> list[dict]: + """Find the character offsets of a hallucinated span within code.""" + spans = [] + start = 0 + while True: + idx = code.find(hallucinated_span, start) + if idx == -1: + break + spans.append({"start": idx, "end": idx + len(hallucinated_span)}) + start = idx + 1 + return spans + + +def convert_to_lettucedetect_format(samples: list[dict]) -> list[dict]: + """Convert enriched SWE-bench samples to LettuceDetect training format. + + For each hallucinated variant, create a training sample with: + - prompt: context (codebase files + docs + user query) + - answer: the hallucinated code + - labels: span annotations of hallucinated parts + """ + ld_samples = [] + + for sample in samples: + # Build context prompt from available information + context_parts = [] + + # Add user query as the "question" + user_query = sample.get("user_query", "") + + # Add documentation context + docs = sample.get("documentation", {}) + if docs: + for lib, doc in docs.items(): + context_parts.append(f"Documentation for {lib}:\n{doc}") + + # Add the gold patch as reference context (what the correct code looks like) + context_parts.append(f"Changed files: {', '.join(sample.get('changed_files', []))}") + + # Add original problem statement as additional context + problem = sample.get("original_problem_statement", "") + if problem: + context_parts.append(f"Issue description:\n{problem[:1500]}") + + context_text = "\n\n".join(context_parts) + + # Create a "correct" sample (gold patch, no hallucination labels) + gold_patch = sample.get("gold_patch", "") + ld_samples.append( + { + "prompt": context_text, + "answer": gold_patch, + "labels": [], + "split": "train", + "task_type": "code_generation", + "dataset": "code_hallucination_swebench", + "language": "python", + "instance_id": sample["instance_id"], + "repo": sample["repo"], + "user_query": user_query, + } + ) + + # Create hallucinated samples + for hall in sample.get("hallucinations", []): + if isinstance(hall, dict) and "hallucinated_patch" in hall: + h_patch = hall["hallucinated_patch"] + labels = [] + + for change in hall.get("changes", []): + h_span = change.get("hallucinated", "") + if h_span and h_span in h_patch: + offsets = compute_span_offsets(h_patch, h_span) + for offset in offsets[:1]: # Take first occurrence + labels.append( + { + "start": offset["start"], + "end": offset["end"], + "label": "hallucinated", + "hallucination_type": hall.get("type", "unknown"), + "explanation": change.get("explanation", ""), + "original_span": change.get("original", ""), + } + ) + + if labels: # Only add if we found valid span offsets + ld_samples.append( + { + "prompt": context_text, + "answer": h_patch, + "labels": labels, + "split": "train", + "task_type": "code_generation", + "dataset": "code_hallucination_swebench", + "language": "python", + "instance_id": sample["instance_id"], + "repo": sample["repo"], + "user_query": user_query, + "hallucination_type": hall.get("type", "unknown"), + } + ) + + return ld_samples + + +# === Main Pipeline === + + +def process_sample(client: OpenAI, sample: dict, idx: int, total: int) -> dict | None: + """Process a single SWE-bench sample.""" + instance_id = sample["instance_id"] + repo = sample["repo"] + problem_statement = sample["problem_statement"] + gold_patch = sample["patch"] + + print(f"\n[{idx + 1}/{total}] {instance_id} ({repo})") + + # Step 1: Extract file and import info + changed_files = extract_changed_files(gold_patch) + print(f" Files: {changed_files}") + + # Step 2: Fetch documentation via Context7 + print(" Fetching documentation from Context7...") + documentation = get_documentation_context(changed_files, gold_patch, problem_statement) + if documentation: + print(f" Got docs for: {list(documentation.keys())}") + else: + print(" No external docs found (repo-internal change)") + + # Step 3: Transform issue to user query + print(" Generating user query...") + try: + user_query = transform_to_user_query(client, problem_statement, repo) + print(f" Query: {user_query[:120]}...") + except Exception as e: + print(f" ERROR generating query: {e}") + return None + + # Small delay to avoid rate limits + time.sleep(1) + + # Step 4: Inject hallucinations + print(" Injecting hallucinations...") + try: + hall_result = inject_hallucinations( + client, gold_patch, user_query, problem_statement, repo, documentation + ) + except Exception as e: + print(f" ERROR injecting hallucinations: {e}") + return None + + if hall_result.get("parse_error"): + print(" WARNING: Failed to parse hallucination JSON") + return None + + hallucinations = hall_result.get("hallucinations", []) + print(f" Generated {len(hallucinations)} hallucination variants") + for h in hallucinations: + h_type = h.get("type", "?") + n_changes = len(h.get("changes", [])) + print(f" - {h_type}: {n_changes} changes") + + # Small delay between samples + time.sleep(1) + + return { + "instance_id": instance_id, + "repo": repo, + "base_commit": sample["base_commit"], + "original_problem_statement": problem_statement, + "user_query": user_query, + "changed_files": changed_files, + "gold_patch": gold_patch, + "test_patch": sample["test_patch"], + "fail_to_pass": sample["FAIL_TO_PASS"], + "documentation": documentation, + "hallucinations": hallucinations, + } + + +def select_diverse_samples(dataset, n: int) -> list[int]: + """Select diverse samples across different repos.""" + # Group by repo + repo_indices: dict[str, list[int]] = {} + for i in range(len(dataset)): + repo = dataset[i]["repo"] + if repo not in repo_indices: + repo_indices[repo] = [] + repo_indices[repo].append(i) + + print(f"Repos available: {list(repo_indices.keys())}") + print(f"Samples per repo: {', '.join(f'{k}: {len(v)}' for k, v in repo_indices.items())}") + + # Round-robin across repos + selected = [] + repo_iters = {repo: iter(indices) for repo, indices in repo_indices.items()} + repos = list(repo_indices.keys()) + repo_idx = 0 + + while len(selected) < n: + repo = repos[repo_idx % len(repos)] + try: + idx = next(repo_iters[repo]) + selected.append(idx) + except StopIteration: + repos.remove(repo) + if not repos: + break + repo_idx += 1 + + return selected[:n] + + +def main(): + from datasets import load_dataset + + client = get_client() + + print("=" * 60) + print("Code Hallucination Dataset Generator") + print("SWE-bench + Context7 + LLM Injection") + print("=" * 60) + + # Load SWE-bench Lite + print("\nLoading SWE-bench Lite...") + ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + print(f"Loaded {len(ds)} samples") + + # Select diverse samples + indices = select_diverse_samples(ds, NUM_SAMPLES) + print(f"\nSelected {len(indices)} diverse samples") + + # Process each sample + results = [] + failed = 0 + + for i, idx in enumerate(indices): + sample = ds[idx] + result = process_sample(client, sample, i, len(indices)) + if result: + results.append(result) + else: + failed += 1 + + # Progress + if (i + 1) % 10 == 0: + print( + f"\n === Progress: {i + 1}/{len(indices)}, success: {len(results)}, failed: {failed} ===\n" + ) + + # Save raw results + os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) + with open(OUTPUT_PATH, "w") as f: + json.dump(results, f, indent=2) + print(f"\nSaved {len(results)} raw samples to {OUTPUT_PATH}") + + # Convert to LettuceDetect format + print("\nConverting to LettuceDetect training format...") + ld_samples = convert_to_lettucedetect_format(results) + + with open(LETTUCEDETECT_OUTPUT_PATH, "w") as f: + json.dump(ld_samples, f, indent=2) + print(f"Saved {len(ld_samples)} LettuceDetect samples to {LETTUCEDETECT_OUTPUT_PATH}") + + # Statistics + print("\n" + "=" * 60) + print("DATASET STATISTICS") + print("=" * 60) + print(f"SWE-bench samples processed: {len(results)}") + print(f"Failed samples: {failed}") + + repos = set(r["repo"] for r in results) + print(f"Repos covered: {len(repos)}") + for repo in sorted(repos): + count = sum(1 for r in results if r["repo"] == repo) + print(f" {repo}: {count}") + + n_with_docs = sum(1 for r in results if r.get("documentation")) + print(f"\nSamples with Context7 docs: {n_with_docs}/{len(results)}") + + hall_counts = {"structural": 0, "behavioral": 0, "semantic": 0} + for r in results: + for h in r.get("hallucinations", []): + t = h.get("type", "unknown") + if t in hall_counts: + hall_counts[t] += 1 + print(f"Hallucination variants: {hall_counts}") + + print(f"\nLettuceDetect training samples: {len(ld_samples)}") + n_positive = sum(1 for s in ld_samples if s.get("labels")) + n_negative = sum(1 for s in ld_samples if not s.get("labels")) + print(f" Hallucinated (positive): {n_positive}") + print(f" Correct (negative): {n_negative}") + + +if __name__ == "__main__": + main() diff --git a/scripts/prototype_code_hallucination.py b/scripts/prototype_code_hallucination.py new file mode 100644 index 0000000..292bbc4 --- /dev/null +++ b/scripts/prototype_code_hallucination.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +"""Prototype: Transform SWE-bench into a code hallucination detection dataset. + +Pipeline: +1. Load SWE-bench sample (issue, repo, patch) +2. Extract context files from the patch (what files were relevant) +3. Transform issue into user-style query (via LLM) +4. Inject hallucinations into the gold patch (structural + semantic) +5. Output annotated samples +""" + +import json +import os +import re +import textwrap +from typing import Any + +from openai import OpenAI + +# Groq API (OpenAI-compatible) +GROQ_BASE_URL = "https://api.groq.com/openai/v1" +MODEL = "moonshotai/kimi-k2-instruct-0905" + + +def get_client() -> OpenAI: + """Get Groq client.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("Set OPENAI_API_KEY to your Groq API key") + return OpenAI(api_key=api_key, base_url=GROQ_BASE_URL) + + +def extract_changed_files(patch: str) -> list[str]: + """Extract file paths that were changed from a unified diff patch.""" + files = [] + for line in patch.split("\n"): + if line.startswith("diff --git"): + # Extract b/ path (the destination file) + match = re.search(r"b/(.+)$", line) + if match: + files.append(match.group(1)) + return files + + +def extract_patch_changes(patch: str) -> list[dict]: + """Parse a unified diff into structured changes per file.""" + file_diffs = [] + current_file = None + current_hunks = [] + + for line in patch.split("\n"): + if line.startswith("diff --git"): + if current_file: + file_diffs.append({"file": current_file, "diff": "\n".join(current_hunks)}) + match = re.search(r"b/(.+)$", line) + current_file = match.group(1) if match else "unknown" + current_hunks = [line] + elif current_file: + current_hunks.append(line) + + if current_file: + file_diffs.append({"file": current_file, "diff": "\n".join(current_hunks)}) + + return file_diffs + + +def transform_to_user_query(client: OpenAI, problem_statement: str, repo: str) -> str: + """Transform a GitHub issue into a realistic user-to-agent query.""" + response = client.chat.completions.create( + model=MODEL, + messages=[ + { + "role": "system", + "content": textwrap.dedent("""\ + You transform GitHub issue descriptions into realistic user queries + that a developer would type into an AI coding assistant (like Claude Code + or Cursor). + + Rules: + - Make it conversational and natural, like someone typing into a chat + - Keep the core technical ask but remove GitHub-specific formatting + - Remove reproduction steps, stack traces, and verbose details + - Keep it to 1-3 sentences + - Don't mention "issue" or "bug report" + - Make it sound like someone asking for help, not filing a report + + Examples: + Issue: "BUG: DataFrame.merge raises TypeError when merging on columns with different dtypes" + Query: "pd.merge is crashing with a TypeError when I try to merge two dataframes where the join columns have different dtypes. Can you fix the merge to handle dtype coercion?" + + Issue: "Enable quiet mode/no-verbose in CLI for programmatic use" + Query: "Can you add a --quiet flag to the CLI? I need to use it in scripts and the verbose output is getting in the way." + """), + }, + { + "role": "user", + "content": f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}", + }, + ], + temperature=0.7, + max_tokens=300, + ) + return response.choices[0].message.content.strip() + + +def inject_hallucinations( + client: OpenAI, + gold_patch: str, + user_query: str, + problem_statement: str, + repo: str, +) -> dict[str, Any]: + """Inject hallucinations into a gold patch, returning annotated hallucinated code.""" + response = client.chat.completions.create( + model=MODEL, + messages=[ + { + "role": "system", + "content": textwrap.dedent("""\ + You are a code hallucination injector for building a hallucination detection dataset. + + Given a correct code patch and the user's query, create THREE hallucinated versions: + + 1. STRUCTURAL hallucination: Change a function call, import, or parameter to + something that doesn't exist or is wrong for this codebase. The code should + still parse but reference non-existent APIs, wrong method names, or invented + parameters. + + 2. BEHAVIORAL hallucination: Use correct APIs but with wrong values or logic + that would produce incorrect behavior. For example, wrong default values, + off-by-one errors, or swapped conditions. + + 3. SEMANTIC hallucination: Code that looks like it addresses the user's request + but actually does something subtly different or opposite. This is the hardest + type - the code should parse, use real APIs, but fail to do what was asked. + Examples: implementing global when per-item was asked, catching exceptions + instead of fixing the root cause, or implementing the inverse logic. + + For EACH hallucinated version, respond in this exact JSON format: + { + "hallucinations": [ + { + "type": "structural", + "hallucinated_patch": "the full modified patch", + "changes": [ + { + "original": "the original correct code span", + "hallucinated": "what you changed it to", + "explanation": "why this is a hallucination" + } + ] + }, + { + "type": "behavioral", + "hallucinated_patch": "...", + "changes": [...] + }, + { + "type": "semantic", + "hallucinated_patch": "...", + "changes": [...] + } + ] + } + + IMPORTANT: + - The hallucinated patches must be plausible β€” something an LLM would realistically generate + - Each change must be subtle, not obviously broken + - Return ONLY valid JSON, no markdown code blocks + """), + }, + { + "role": "user", + "content": f"""Repository: {repo} + +User's query: {user_query} + +Original issue context: {problem_statement[:2000]} + +Correct gold patch: +{gold_patch} + +Generate three hallucinated versions of this patch.""", + }, + ], + temperature=0.8, + max_tokens=4000, + ) + + raw = response.choices[0].message.content.strip() + # Try to extract JSON from the response + # Sometimes LLMs wrap it in ```json blocks + json_match = re.search(r"\{[\s\S]*\}", raw) + if json_match: + try: + return json.loads(json_match.group()) + except json.JSONDecodeError: + return {"raw_response": raw, "parse_error": True} + return {"raw_response": raw, "parse_error": True} + + +def process_sample(client: OpenAI, sample: dict) -> dict: + """Process a single SWE-bench sample into a hallucination detection sample.""" + instance_id = sample["instance_id"] + repo = sample["repo"] + problem_statement = sample["problem_statement"] + gold_patch = sample["patch"] + + print(f"\n{'=' * 60}") + print(f"Processing: {instance_id}") + print(f"Repo: {repo}") + print(f"{'=' * 60}") + + # Step 1: Extract changed files from the patch + changed_files = extract_changed_files(gold_patch) + print(f"\nChanged files: {changed_files}") + + # Step 2: Transform issue into user query + print("\nTransforming issue into user query...") + user_query = transform_to_user_query(client, problem_statement, repo) + print(f"User query: {user_query}") + + # Step 3: Inject hallucinations + print("\nInjecting hallucinations...") + hallucination_result = inject_hallucinations( + client, gold_patch, user_query, problem_statement, repo + ) + + if hallucination_result.get("parse_error"): + print("WARNING: Failed to parse hallucination response") + print( + f"Raw response (first 500 chars): {hallucination_result.get('raw_response', '')[:500]}" + ) + + # Build the output sample + output = { + "instance_id": instance_id, + "repo": repo, + "base_commit": sample["base_commit"], + "original_problem_statement": problem_statement, + "user_query": user_query, + "changed_files": changed_files, + "gold_patch": gold_patch, + "test_patch": sample["test_patch"], + "fail_to_pass": sample["FAIL_TO_PASS"], + "hallucinations": hallucination_result.get("hallucinations", []), + } + + # Print summary + if not hallucination_result.get("parse_error"): + for h in hallucination_result.get("hallucinations", []): + print(f"\n [{h.get('type', 'unknown').upper()}]") + for c in h.get("changes", []): + print(f" Original: {c.get('original', 'N/A')[:80]}") + print(f" Hallucinated: {c.get('hallucinated', 'N/A')[:80]}") + print(f" Explanation: {c.get('explanation', 'N/A')[:100]}") + + return output + + +def main(): + """Run prototype on a few SWE-bench Lite samples.""" + from datasets import load_dataset + + client = get_client() + + # Load SWE-bench Lite (300 curated samples) + print("Loading SWE-bench Lite...") + ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + + # Pick a few diverse samples to test + # Choose samples from different repos with different patch sizes + test_indices = [0, 5, 10] # Start with 3 samples + results = [] + + for idx in test_indices: + sample = ds[idx] + try: + result = process_sample(client, sample) + results.append(result) + except Exception as e: + print(f"\nERROR processing {sample['instance_id']}: {e}") + continue + + # Save results + output_path = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_prototype.json" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + + print(f"\n\nSaved {len(results)} samples to {output_path}") + + # Print summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for r in results: + print(f"\n{r['instance_id']} ({r['repo']})") + print(f" Query: {r['user_query'][:100]}...") + print(f" Files: {r['changed_files']}") + n_hall = len(r.get("hallucinations", [])) + print(f" Hallucinations generated: {n_hall}") + for h in r.get("hallucinations", []): + print(f" - {h.get('type', 'unknown')}: {len(h.get('changes', []))} changes") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_code_hallucination.py b/scripts/train_code_hallucination.py new file mode 100644 index 0000000..09930ed --- /dev/null +++ b/scripts/train_code_hallucination.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Train a hallucination detector on the code hallucination dataset. + +Can optionally combine with RAGTruth data for mixed training. + +Usage: + # Train on code hallucination data only + python scripts/train_code_hallucination.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/code_hallucination_detector + + # Train on both code + RAGTruth data + python scripts/train_code_hallucination.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --ragtruth-path data/ragtruth/ragtruth_data.json \ + --model-name answerdotai/ModernBERT-base \ + --output-dir output/code_hallucination_detector +""" + +import argparse +import json +import random +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForTokenClassification, + AutoTokenizer, + DataCollatorForTokenClassification, +) + +from lettucedetect.datasets.hallucination_dataset import ( + HallucinationData, + HallucinationDataset, + HallucinationSample, +) +from lettucedetect.models.trainer import Trainer + + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_code_data(path: str) -> list[HallucinationSample]: + """Load code hallucination dataset. + + Uses the SWE-bench split field directly (train/dev/test). + """ + data = json.loads(Path(path).read_text()) + samples = [] + for item in data: + samples.append( + HallucinationSample( + prompt=item["prompt"], + answer=item["answer"], + labels=item["labels"], + split=item.get("split", "train"), + task_type=item.get("task_type", "code_generation"), + dataset="swebench_code", + language="en", + ) + ) + return samples + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train hallucination detector on code data") + parser.add_argument( + "--code-data-path", type=str, required=True, help="Path to code hallucination dataset JSON" + ) + parser.add_argument( + "--ragtruth-path", + type=str, + default=None, + help="Optional path to RAGTruth data for mixed training", + ) + parser.add_argument("--model-name", type=str, default="answerdotai/ModernBERT-base") + parser.add_argument("--output-dir", type=str, default="output/code_hallucination_detector") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--epochs", type=int, default=6) + parser.add_argument("--learning-rate", type=float, default=1e-5) + parser.add_argument("--max-length", type=int, default=4096) + parser.add_argument("--seed", type=int, default=42) + return parser.parse_args() + + +def main(): + args = parse_args() + set_seed(args.seed) + + # Load code hallucination data + print(f"Loading code hallucination data from {args.code_data_path}") + code_samples = load_code_data(args.code_data_path) + + n_clean = sum(1 for s in code_samples if not s.labels) + n_hall = sum(1 for s in code_samples if s.labels) + print(f" Total: {len(code_samples)} (clean: {n_clean}, hallucinated: {n_hall})") + + # Use SWE-bench splits directly (zero repo overlap by design) + # Train on train, validate on dev, hold out test for final evaluation + train_samples = [s for s in code_samples if s.split == "train"] + dev_samples = [s for s in code_samples if s.split == "dev"] + test_samples = [s for s in code_samples if s.split == "test"] + + print(f" Train (SWE-bench train): {len(train_samples)}") + print(f" Dev (SWE-bench dev): {len(dev_samples)}") + print(f" Test (SWE-bench test, held out): {len(test_samples)}") + + # Optionally add RAGTruth data + if args.ragtruth_path: + print(f"\nLoading RAGTruth data from {args.ragtruth_path}") + ragtruth_data = HallucinationData.from_json( + json.loads(Path(args.ragtruth_path).read_text()) + ) + ragtruth_train = [s for s in ragtruth_data.samples if s.split == "train"] + ragtruth_dev = [s for s in ragtruth_data.samples if s.split in ("dev", "test")] + + print(f" RAGTruth train: {len(ragtruth_train)}, dev: {len(ragtruth_dev)}") + train_samples.extend(ragtruth_train) + dev_samples.extend(ragtruth_dev) + + print("\nFinal splits:") + print(f" Train: {len(train_samples)}") + print(f" Dev: {len(dev_samples)}") + + # Setup tokenizer and model + print(f"\nLoading model: {args.model_name}") + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) + data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100) + + train_dataset = HallucinationDataset(train_samples, tokenizer, max_length=args.max_length) + dev_dataset = HallucinationDataset(dev_samples, tokenizer, max_length=args.max_length) + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=data_collator, + ) + dev_loader = DataLoader( + dev_dataset, + batch_size=args.batch_size, + shuffle=False, + collate_fn=data_collator, + ) + + model = AutoModelForTokenClassification.from_pretrained( + args.model_name, + num_labels=2, + trust_remote_code=True, + ) + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + train_loader=train_loader, + test_loader=dev_loader, + epochs=args.epochs, + learning_rate=args.learning_rate, + save_path=args.output_dir, + ) + + print(f"\nStarting training for {args.epochs} epochs...") + trainer.train() + print(f"\nModel saved to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_generative_detector.py b/scripts/train_generative_detector.py new file mode 100644 index 0000000..3b6ea4e --- /dev/null +++ b/scripts/train_generative_detector.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +"""Train a generative hallucination span detector via SFT. + +Fine-tunes a decoder LLM (e.g. Qwen3.5-2B) to detect hallucinated spans +in code by generating structured JSON output. + +This is Approach D: the model reads context + answer and generates: + {"hallucinated_spans": [{"text": "...", "explanation": "..."}]} +or for clean samples: + {"hallucinated_spans": []} + +Usage: + # Train with LoRA on a single GPU + python scripts/train_generative_detector.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name Qwen/Qwen3.5-2B \ + --output-dir output/generative_detector + + # With custom settings + python scripts/train_generative_detector.py \ + --code-data-path data/code_hallucination/code_hallucination_data.json \ + --model-name Qwen/Qwen3.5-2B \ + --output-dir output/generative_detector \ + --lora-r 16 \ + --batch-size 2 \ + --epochs 3 \ + --max-length 4096 +""" + +import argparse +import json +import random +from pathlib import Path + +import torch +from peft import LoraConfig, TaskType, get_peft_model +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup + +SYSTEM_PROMPT = ( + "You are a code hallucination detector. Given source code context and a code answer, " + "identify any hallucinated spans β€” code that is factually wrong, uses non-existent APIs, " + "has incorrect logic, or doesn't match the source context.\n\n" + "Respond with JSON only:\n" + '{"hallucinated_spans": [{"text": "exact hallucinated text", "explanation": "why it is wrong"}]}\n\n' + "If the answer is correct, respond with:\n" + '{"hallucinated_spans": []}' +) + + +def build_training_pairs(data_path: str) -> list[dict]: + """Build (input, output) pairs from the code hallucination dataset. + + For hallucinated samples: output lists the hallucinated spans with explanations. + For clean samples: output is {"hallucinated_spans": []}. + """ + data = json.loads(Path(data_path).read_text()) + pairs = [] + + for item in data: + prompt = item["prompt"] + answer = item["answer"] + labels = item.get("labels", []) + split = item.get("split", "train") + + user_msg = f"Context:\n{prompt}\n\nAnswer to check:\n{answer}" + + if labels: + spans = [] + for label in labels: + start = label["start"] + end = label["end"] + text = answer[start:end] + if text.strip(): + spans.append( + { + "text": text, + "explanation": f"{label.get('label', 'hallucination')} error in code", + } + ) + assistant_msg = json.dumps({"hallucinated_spans": spans}) + else: + assistant_msg = json.dumps({"hallucinated_spans": []}) + + pairs.append( + { + "system": SYSTEM_PROMPT, + "user": user_msg, + "assistant": assistant_msg, + "split": split, + } + ) + + return pairs + + +class SFTDataset(Dataset): + """Dataset for supervised fine-tuning with chat-formatted inputs.""" + + def __init__(self, pairs: list[dict], tokenizer, max_length: int = 4096): + self.tokenizer = tokenizer + self.max_length = max_length + self.examples = [] + + for pair in pairs: + messages = [ + {"role": "system", "content": pair["system"]}, + {"role": "user", "content": pair["user"]}, + {"role": "assistant", "content": pair["assistant"]}, + ] + + # Tokenize full conversation + full_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + full_ids = tokenizer( + full_text, truncation=True, max_length=max_length, return_tensors="pt" + ) + + # Tokenize without assistant response to find where labels start + messages_no_assistant = [ + {"role": "system", "content": pair["system"]}, + {"role": "user", "content": pair["user"]}, + ] + prompt_text = tokenizer.apply_chat_template( + messages_no_assistant, tokenize=False, add_generation_prompt=True + ) + prompt_ids = tokenizer( + prompt_text, truncation=True, max_length=max_length, return_tensors="pt" + ) + prompt_len = prompt_ids["input_ids"].shape[1] + + input_ids = full_ids["input_ids"].squeeze(0) + attention_mask = full_ids["attention_mask"].squeeze(0) + + # Labels: -100 for prompt tokens (masked), actual ids for assistant tokens + labels = input_ids.clone() + labels[:prompt_len] = -100 + + self.examples.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + ) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + return self.examples[idx] + + +def collate_fn(batch): + """Pad batch to same length.""" + max_len = max(ex["input_ids"].shape[0] for ex in batch) + + input_ids = [] + attention_mask = [] + labels = [] + + for ex in batch: + pad_len = max_len - ex["input_ids"].shape[0] + input_ids.append(torch.cat([ex["input_ids"], torch.zeros(pad_len, dtype=torch.long)])) + attention_mask.append( + torch.cat([ex["attention_mask"], torch.zeros(pad_len, dtype=torch.long)]) + ) + labels.append(torch.cat([ex["labels"], torch.full((pad_len,), -100, dtype=torch.long)])) + + return { + "input_ids": torch.stack(input_ids), + "attention_mask": torch.stack(attention_mask), + "labels": torch.stack(labels), + } + + +def evaluate(model, dataloader, device): + """Compute average loss on validation set.""" + model.eval() + total_loss = 0 + total_steps = 0 + + with torch.no_grad(): + for batch in dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + total_loss += outputs.loss.item() + total_steps += 1 + + return total_loss / max(total_steps, 1) + + +def main(): + parser = argparse.ArgumentParser(description="Train generative hallucination span detector") + parser.add_argument("--code-data-path", type=str, required=True) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3.5-2B") + parser.add_argument("--output-dir", type=str, default="output/generative_detector") + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--learning-rate", type=float, default=2e-4) + parser.add_argument("--max-length", type=int, default=4096) + parser.add_argument("--lora-r", type=int, default=16) + parser.add_argument("--lora-alpha", type=int, default=32) + parser.add_argument("--lora-dropout", type=float, default=0.05) + parser.add_argument("--warmup-ratio", type=float, default=0.05) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--gradient-accumulation-steps", type=int, default=4) + args = parser.parse_args() + + random.seed(args.seed) + torch.manual_seed(args.seed) + + # Build training pairs + print(f"Loading data from {args.code_data_path}") + pairs = build_training_pairs(args.code_data_path) + + train_pairs = [p for p in pairs if p["split"] == "train"] + dev_pairs = [p for p in pairs if p["split"] == "dev"] + test_pairs = [p for p in pairs if p["split"] == "test"] + + print(f"Train: {len(train_pairs)}, Dev: {len(dev_pairs)}, Test (held out): {len(test_pairs)}") + + random.shuffle(train_pairs) + + # Load tokenizer and model + print(f"Loading model: {args.model_name}") + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_name, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + + # Apply LoRA + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # Build datasets + print("Tokenizing datasets...") + train_dataset = SFTDataset(train_pairs, tokenizer, max_length=args.max_length) + dev_dataset = SFTDataset(dev_pairs, tokenizer, max_length=args.max_length) + + train_loader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn + ) + dev_loader = DataLoader( + dev_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn + ) + + print(f"Train examples: {len(train_dataset)}, Dev examples: {len(dev_dataset)}") + + # Optimizer and scheduler + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.01) + total_steps = (len(train_loader) // args.gradient_accumulation_steps) * args.epochs + warmup_steps = int(total_steps * args.warmup_ratio) + scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps) + + # Training loop + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + best_dev_loss = float("inf") + + print(f"\nTraining for {args.epochs} epochs ({total_steps} steps)...") + + for epoch in range(args.epochs): + model.train() + epoch_loss = 0 + optimizer.zero_grad() + + for step, batch in enumerate(train_loader): + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss / args.gradient_accumulation_steps + loss.backward() + epoch_loss += outputs.loss.item() + + if (step + 1) % args.gradient_accumulation_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + if (step + 1) % 100 == 0: + avg_loss = epoch_loss / (step + 1) + print( + f" Epoch {epoch + 1} step {step + 1}/{len(train_loader)} loss={avg_loss:.4f}" + ) + + # Evaluate + train_avg = epoch_loss / len(train_loader) + dev_loss = evaluate(model, dev_loader, device) + print(f"Epoch {epoch + 1}: train_loss={train_avg:.4f}, dev_loss={dev_loss:.4f}") + + # Save best + if dev_loss < best_dev_loss: + best_dev_loss = dev_loss + model.save_pretrained(output_dir / "best") + tokenizer.save_pretrained(output_dir / "best") + print(f" Saved best model (dev_loss={dev_loss:.4f})") + + # Save final + model.save_pretrained(output_dir / "final") + tokenizer.save_pretrained(output_dir / "final") + print(f"\nTraining complete. Best dev loss: {best_dev_loss:.4f}") + print(f"Models saved to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index a0cecaf..4b5c624 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,43 +1 @@ """Shared fixtures for pytest tests.""" - -from unittest.mock import MagicMock - -import pytest -import torch - - -@pytest.fixture -def mock_tokenizer(): - """Create a mock tokenizer for testing.""" - tokenizer = MagicMock() - tokenizer.encode.return_value = [101, 102, 103, 104, 105] - - # Mock tokenizer call to return encoding - tokenizer.return_value = { - "input_ids": torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108]]), - "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]), - "offset_mapping": torch.tensor( - [ - [0, 0], # [CLS] - [0, 4], # "This" - [5, 7], # "is" - [8, 9], # "a" - [10, 16], # "prompt" - [0, 0], # [SEP] - [0, 4], # "This" - [5, 12], # "answer" - ] - ), - } - - return tokenizer - - -@pytest.fixture -def mock_model(): - """Create a mock model for testing.""" - model = MagicMock() - mock_output = MagicMock() - mock_output.logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]]) - model.return_value = mock_output - return model diff --git a/tests/run_pytest.py b/tests/run_pytest.py old mode 100644 new mode 100755 diff --git a/tests/test_inference_pytest.py b/tests/test_inference_pytest.py index 546852b..e6f7c34 100644 --- a/tests/test_inference_pytest.py +++ b/tests/test_inference_pytest.py @@ -5,6 +5,7 @@ import pytest import torch +from lettucedetect.datasets.hallucination_dataset import HallucinationDataset from lettucedetect.detectors.prompt_utils import PromptUtils from lettucedetect.detectors.transformer import TransformerDetector from lettucedetect.models.inference import HallucinationDetector @@ -129,19 +130,15 @@ def test_init(self): def test_predict(self): """Test predict method.""" - - # Create a proper mock encoding with input_ids as a tensor attribute - class MockEncoding: - def __init__(self): - self.input_ids = torch.tensor([[101, 102, 103]]) - - mock_encoding = MockEncoding() - mock_labels = torch.tensor([0, 0, 0]) - mock_offsets = torch.tensor([[0, 0], [0, 1], [1, 2]]) - mock_answer_start = 1 - - # Patch the _predict method to avoid the actual implementation - with patch.object(TransformerDetector, "_predict", return_value=[]): + # Patch internals to avoid actual model inference + with ( + patch.object( + TransformerDetector, + "_group_passages_into_chunks", + return_value=[["This is a test context."]], + ), + patch.object(TransformerDetector, "_predict_single", return_value=[]), + ): detector = TransformerDetector(model_path="dummy_path") context = ["This is a test context."] answer = "This is a test answer." @@ -175,3 +172,191 @@ def test_form_prompt_without_question(self): # Check that the prompt contains the text to summarize assert "This is a text to summarize." in prompt assert "Summarize" in prompt + + +class TestChunking: + """Tests for automatic context chunking when input exceeds max_length.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Set up a TransformerDetector with mocked model/tokenizer.""" + with ( + patch( + "lettucedetect.detectors.transformer.AutoTokenizer.from_pretrained", + ) as tok_cls, + patch( + "lettucedetect.detectors.transformer.AutoModelForTokenClassification.from_pretrained", + ) as model_cls, + ): + # Set up a realistic tokenizer mock + self.mock_tokenizer = MagicMock() + self.mock_model = MagicMock() + tok_cls.return_value = self.mock_tokenizer + model_cls.return_value = self.mock_model + + self.detector = TransformerDetector(model_path="dummy", max_length=32) + yield + + def test_aggregation_uses_max(self): + """Max aggregation: highest hallucination prob across chunks wins.""" + # Simulate two chunks with different probabilities for the same tokens + chunk1_tokens = [ + {"token": "The", "pred": 0, "prob": 0.1}, + {"token": "answer", "pred": 0, "prob": 0.3}, + {"token": "is", "pred": 1, "prob": 0.8}, + ] + chunk2_tokens = [ + {"token": "The", "pred": 0, "prob": 0.2}, + {"token": "answer", "pred": 1, "prob": 0.6}, + {"token": "is", "pred": 0, "prob": 0.4}, + ] + + with patch.object(self.detector, "_predict_single") as mock_single: + mock_single.side_effect = [chunk1_tokens, chunk2_tokens] + + # Mock _build_spans_from_tokens for the spans path + with patch.object(self.detector, "_build_spans_from_tokens", return_value=[]): + result = self.detector._predict_chunked( + ["chunk1", "chunk2"], "The answer is", output_format="tokens" + ) + + assert len(result) == 3 + # Token 0: max(0.1, 0.2) = 0.2 β†’ pred 0 + assert result[0]["prob"] == 0.2 + assert result[0]["pred"] == 0 + # Token 1: max(0.3, 0.6) = 0.6 β†’ pred 1 (β‰₯ 0.5) + assert result[1]["prob"] == 0.6 + assert result[1]["pred"] == 1 + # Token 2: max(0.8, 0.4) = 0.8 β†’ pred 1 + assert result[2]["prob"] == 0.8 + assert result[2]["pred"] == 1 + + def test_group_passages_single_group(self): + """When all passages fit, _group_passages_into_chunks returns one group.""" + + def tokenizer_side_effect(text, **kwargs): + # Approximate: 1 token per word + words = len(text.split()) if text else 1 + return {"input_ids": torch.zeros(1, words, dtype=torch.long)} + + self.mock_tokenizer.side_effect = tokenizer_side_effect + + with patch( + "lettucedetect.detectors.transformer.PromptUtils.format_context", + return_value="short prompt with passages", + ): + groups = self.detector._group_passages_into_chunks( + ["passage one", "passage two"], "What?", "short answer" + ) + assert len(groups) == 1 + assert groups[0] == ["passage one", "passage two"] + + def test_group_passages_multiple_groups(self): + """When passages exceed budget, they should be split into multiple groups.""" + # max_length=32, answer=5 tokens β†’ total_budget = 32-5-3 = 24 + # instruction_overhead = 10 tokens β†’ passage_budget = 24-10 = 14 + # Each passage = 10 tokens β†’ 2 passages won't fit (10+1+10=21 > 14), so 1 per group + + def tokenizer_side_effect(text, **kwargs): + if text == "short answer": + return {"input_ids": torch.zeros(1, 5, dtype=torch.long)} + if text.startswith("passage"): + return {"input_ids": torch.zeros(1, 10, dtype=torch.long)} + # "full_prompt" β†’ 50 tokens (exceeds budget, triggers chunking) + if "full_prompt" in text: + return {"input_ids": torch.zeros(1, 50, dtype=torch.long)} + # "minimal" β†’ 10 tokens (instruction overhead) + return {"input_ids": torch.zeros(1, 10, dtype=torch.long)} + + self.mock_tokenizer.side_effect = tokenizer_side_effect + + format_calls = [0] + + def format_side_effect(ctx, q, lang): + format_calls[0] += 1 + if format_calls[0] == 1: + return "full_prompt_long" # First call: all passages β†’ triggers chunking + return "minimal" # Second call: [""] β†’ measures instruction overhead + + with patch( + "lettucedetect.detectors.transformer.PromptUtils.format_context", + side_effect=format_side_effect, + ): + groups = self.detector._group_passages_into_chunks( + ["passage A", "passage B", "passage C"], "What?", "short answer" + ) + + # With passage_budget = 14 and each passage = 10 tokens, + # each group holds 1 passage (10 < 14, but 10+1+10 = 21 > 14) + assert len(groups) == 3 + assert groups[0] == ["passage A"] + assert groups[1] == ["passage B"] + assert groups[2] == ["passage C"] + + def test_predict_uses_passage_chunking(self): + """predict() should use _group_passages_into_chunks for chunking.""" + with ( + patch.object( + self.detector, + "_group_passages_into_chunks", + return_value=[["p1"], ["p2"]], + ) as mock_group, + patch.object( + self.detector, + "_predict_chunked", + return_value=[{"token": "x", "pred": 0, "prob": 0.1}], + ) as mock_chunked, + patch( + "lettucedetect.detectors.transformer.PromptUtils.format_context", + side_effect=lambda ctx, q, lang: f"prompt_{ctx[0]}", + ), + ): + self.detector.predict(["p1", "p2"], "answer", "What?", "tokens") + mock_group.assert_called_once() + mock_chunked.assert_called_once_with(["prompt_p1", "prompt_p2"], "answer", "tokens") + + +class TestAnswerStartToken: + """Tests for the answer_start_token fix in prepare_tokenized_input.""" + + def test_answer_start_token_basic(self): + """answer_start_token should point to the first answer token.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + context = "The capital of France is Paris." + answer = "Paris is the capital." + + _, _, offsets, answer_start = HallucinationDataset.prepare_tokenized_input( + tokenizer, context, answer, max_length=512 + ) + + # answer_start should be within bounds + assert answer_start > 0 + assert answer_start < offsets.size(0) + # The offset at answer_start should be non-zero (actual token, not special) + assert offsets[answer_start][1].item() > offsets[answer_start][0].item() + + def test_answer_start_token_with_truncation(self): + """answer_start_token should be correct even when context is truncated.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + # Create a very long context that will be truncated at max_length=32 + context = "word " * 200 # ~200 tokens + answer = "short answer" + + encoding, _, offsets, answer_start = HallucinationDataset.prepare_tokenized_input( + tokenizer, context, answer, max_length=32 + ) + + total_len = encoding["input_ids"].shape[1] + assert total_len == 32 # Should be truncated to max_length + + # answer_start should be within the actual sequence + assert answer_start > 0 + assert answer_start < total_len + + # The answer tokens should be at the end (before trailing [SEP]) + # Verify by checking that the text at answer_start offset is non-empty + assert offsets[answer_start][1].item() > offsets[answer_start][0].item()