diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 0000000..bb84bc8
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,43 @@
+name: Deploy Documentation
+
+on:
+ push:
+ branches: [main]
+ paths:
+ - 'docs/**'
+ - 'mkdocs.yml'
+ workflow_dispatch:
+
+permissions:
+ contents: read
+ pages: write
+ id-token: write
+
+concurrency:
+ group: pages
+ cancel-in-progress: false
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ - run: pip install mkdocs-material "mkdocstrings[python]>=0.24"
+ - run: pip install -e .
+ - run: mkdocs build
+ - uses: actions/upload-pages-artifact@v3
+ with:
+ path: site/
+
+ deploy:
+ needs: build
+ runs-on: ubuntu-latest
+ environment:
+ name: github-pages
+ url: ${{ steps.deployment.outputs.page_url }}
+ steps:
+ - id: deployment
+ uses: actions/deploy-pages@v4
diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml
index a1f1c70..9b8dbff 100644
--- a/.github/workflows/python-tests.yml
+++ b/.github/workflows/python-tests.yml
@@ -1,8 +1,10 @@
name: Python Tests
on:
+ push:
+ branches: [main]
pull_request:
- branches: [ main ]
+ branches: [main]
jobs:
test:
@@ -12,24 +14,24 @@ jobs:
python-version: ["3.12"]
steps:
- - uses: actions/checkout@v3
-
+ - uses: actions/checkout@v4
+
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
-
+
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
-
+
- name: Lint with ruff
run: |
- ruff format --check --diff .
- ruff check --select I .
-
+ ruff format --check --diff lettucedetect/ tests/
+ ruff check lettucedetect/ tests/ --extend-exclude lettucedetect/integrations/
+
- name: Test with pytest
run: |
- pytest tests/test_inference_pytest.py -v
\ No newline at end of file
+ pytest tests/test_inference_pytest.py -v
diff --git a/.gitignore b/.gitignore
index b81c6d0..3c2f5ce 100644
--- a/.gitignore
+++ b/.gitignore
@@ -170,6 +170,9 @@ cython_debug/
# PyPI configuration file
.pypirc
+# macOS
+.DS_Store
+
# data/
data/
@@ -178,4 +181,5 @@ output/
temp/
# cache/
-lettucedetect/cache/
\ No newline at end of file
+lettucedetect/cache/
+testing/
\ No newline at end of file
diff --git a/docs/EUROBERT.md b/docs/EUROBERT.md
index 0380d8c..3249382 100644
--- a/docs/EUROBERT.md
+++ b/docs/EUROBERT.md
@@ -1,7 +1,7 @@
# π₯¬ LettuceDetect Goes Multilingual: Fine-tuning EuroBERT on Synthetic RAGTruth Translations
-
+
Expanding hallucination detection across languages for RAG pipelines.
diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md
deleted file mode 100644
index c6b8cdb..0000000
--- a/docs/EVALUATION.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# Evaluation
-
-## Use LLM baselines
-
-```bash
-python scripts/evaluate_llm.py --model "gpt-4o-mini" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level"
-```
-
-## Use HallucinationDetector
-
-```bash
-python scripts/evaluate.py --model_path "output/hallucination_detection_de_210m" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level"
-```
diff --git a/docs/api/datasets.md b/docs/api/datasets.md
new file mode 100644
index 0000000..6aae94d
--- /dev/null
+++ b/docs/api/datasets.md
@@ -0,0 +1,13 @@
+# Datasets
+
+## HallucinationSample
+
+::: lettucedetect.datasets.hallucination_dataset.HallucinationSample
+
+## HallucinationData
+
+::: lettucedetect.datasets.hallucination_dataset.HallucinationData
+
+## HallucinationDataset
+
+::: lettucedetect.datasets.hallucination_dataset.HallucinationDataset
diff --git a/docs/api/detectors.md b/docs/api/detectors.md
new file mode 100644
index 0000000..da56499
--- /dev/null
+++ b/docs/api/detectors.md
@@ -0,0 +1,13 @@
+# Detectors
+
+## Factory
+
+::: lettucedetect.detectors.factory.make_detector
+
+## Base Detector
+
+::: lettucedetect.detectors.base.BaseDetector
+
+## Transformer Detector
+
+::: lettucedetect.detectors.transformer.TransformerDetector
diff --git a/docs/api/inference.md b/docs/api/inference.md
new file mode 100644
index 0000000..14a827e
--- /dev/null
+++ b/docs/api/inference.md
@@ -0,0 +1,5 @@
+# Inference
+
+The main entry point for hallucination detection.
+
+::: lettucedetect.models.inference.HallucinationDetector
diff --git a/docs/api/training.md b/docs/api/training.md
new file mode 100644
index 0000000..91a935c
--- /dev/null
+++ b/docs/api/training.md
@@ -0,0 +1,9 @@
+# Training
+
+## Trainer
+
+::: lettucedetect.models.trainer.Trainer
+
+## Evaluator
+
+::: lettucedetect.models.evaluator
diff --git a/docs/benchmarks.md b/docs/benchmarks.md
new file mode 100644
index 0000000..87e4e97
--- /dev/null
+++ b/docs/benchmarks.md
@@ -0,0 +1,43 @@
+# Benchmarks
+
+## RAGTruth (English)
+
+Evaluated on the [RAGTruth](https://aclanthology.org/2024.acl-long.585/) test set. This benchmark measures how well models detect hallucinations in LLM-generated text across QA, summarization, and data-to-text tasks.
+
+### Example-Level Detection
+
+Binary classification: does the answer contain any hallucination?
+
+| Model | Type | Overall F1 |
+|-------|------|-----------|
+| GPT-4 | LLM (zero-shot) | 63.4% |
+| Luna | Encoder | 65.4% |
+| **lettucedetect-base-v1** | **Encoder (149M)** | **76.8%** |
+| Llama-2-13B (fine-tuned) | LLM | 78.7% |
+| **lettucedetect-large-v1** | **Encoder (395M)** | **79.2%** |
+| RAG-HAT (Llama-3-8B) | LLM | 83.9% |
+
+### Span-Level Detection
+
+LettuceDetect achieves state-of-the-art span-level results among models that report this metric, outperforming fine-tuned Llama-2-13B. Span-level evaluation measures how precisely the model can locate the exact hallucinated text within an answer.
+
+## What These Numbers Mean
+
+- **Example F1** β Can the model tell if an answer has *any* hallucination? Higher is better.
+- **Span F1** β Can the model point to *exactly which parts* are hallucinated? This is the harder task and where LettuceDetect excels relative to its size.
+
+LettuceDetect models are 50-500x smaller than LLM-based detectors while achieving competitive or better accuracy.
+
+## Citation
+
+```bibtex
+@misc{Kovacs:2025,
+ title={LettuceDetect: A Hallucination Detection Framework for RAG Applications},
+ author={ΓdΓ‘m KovΓ‘cs and GΓ‘bor Recski},
+ year={2025},
+ eprint={2502.17125},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ url={https://arxiv.org/abs/2502.17125},
+}
+```
diff --git a/docs/code-hallucination/architecture-research.md b/docs/code-hallucination/architecture-research.md
new file mode 100644
index 0000000..232114a
--- /dev/null
+++ b/docs/code-hallucination/architecture-research.md
@@ -0,0 +1,154 @@
+# Architecture Research: Detection Models for Code Hallucination
+
+Research notes on model architectures for training on the code hallucination dataset. We compare four approaches ranging from fast encoder-based classifiers to generative span detectors.
+
+## Approach A: Token Classification (Encoder)
+
+**Architecture:** ModernBERT/EuroBERT + linear classification head
+
+The current LettuceDetect approach. Each answer token gets a binary label (0=supported, 1=hallucinated). Consecutive hallucinated tokens are merged into spans at inference.
+
+```
+Input: [CLS] context [SEP] question [SEP] answer [SEP]
+Output: [-100, -100, ..., 0, 0, 1, 1, 1, 0, 0, ...]
+ ^^^^^^^^^ hallucinated span
+```
+
+| Property | Value |
+|----------|-------|
+| **Models** | ModernBERT-base (149M), ModernBERT-large (395M), EuroBERT (210M-2.1B) |
+| **Context** | 8K tokens |
+| **Inference** | Single forward pass, 30-60 samples/sec on A100 |
+| **Training** | Standard token classification, CrossEntropyLoss |
+| **Validated by** | LettuceDetect (79.2% F1), HaluGate (vLLM), PsiloQA (EMNLP 2025) |
+
+**Strengths:** Fast, simple, production-ready. Handles long contiguous spans well.
+**Weaknesses:** No code-specific pretraining. Cannot explain *why* something is hallucinated.
+
+---
+
+## Approach B: Token Classification (Decoder LLM)
+
+**Architecture:** Qwen3.5-2B + bidirectional attention (LLM2Vec) + linear head
+
+Use a decoder LLM pretrained on massive code corpora, convert to bidirectional encoder via [LLM2Vec](https://arxiv.org/abs/2404.05961), then add a token classification head.
+
+```
+Step 1: Load Qwen3.5-2B base (2B params, code-heavy pretraining)
+Step 2: Enable bidirectional attention (remove causal mask)
+Step 3: Short MNTP adaptation (masked next token prediction with LoRA)
+Step 4: Add linear head (hidden_dim=2048 β 2 classes)
+Step 5: Fine-tune on code hallucination dataset with LoRA
+```
+
+| Property | Value |
+|----------|-------|
+| **Model** | Qwen3.5-2B (2B params) |
+| **Context** | 262K native (practically limited by GPU memory) |
+| **Inference** | Single forward pass, ~5-15 samples/sec |
+| **VRAM** | ~5-8GB in bf16 |
+| **Reference** | [Looking Right is Sometimes Right (ACL 2024)](https://arxiv.org/abs/2401.14556) β 0.947 F1 on NER with mask removal |
+
+**Strengths:** Deep code understanding from pretraining. Bidirectional attention after conversion.
+**Weaknesses:** 5x larger than ModernBERT. Requires LLM2Vec conversion step. Novel (unvalidated for hallucination detection).
+
+**Key insight:** The [ACL 2024 paper](https://arxiv.org/abs/2401.14556) showed decoder LLMs with causal mask removal reach 0.947 F1 on NER, significantly above RoBERTa-large (0.900). The gains come from combining rich pretrained representations with bidirectional context.
+
+---
+
+## Approach C: Chunk Verification (Reranker-style)
+
+**Architecture:** Qwen3.5-2B or Qwen3-0.6B, reranker-style yes/no scoring
+
+Inspired by [Qwen3-Reranker](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B). Split the answer into chunks (lines, statements), then ask the model for each chunk: "Is this code correct given the context?"
+
+```
+Input: "Given this source code, is this line correct? yes/no"
+Output: P(yes) = 0.12 β hallucinated
+ P(yes) = 0.95 β supported
+```
+
+No architectural modifications. Uses the LLM's native next-token prediction to classify.
+
+| Property | Value |
+|----------|-------|
+| **Models** | Qwen3-0.6B (tiny, fast) or Qwen3.5-2B |
+| **Inference** | N forward passes per sample (one per chunk) |
+| **Training** | Standard SFT with yes/no labels |
+| **Reference** | [MiniCheck (EMNLP 2024)](https://arxiv.org/abs/2404.10774) β GPT-4-level at 400x lower cost |
+
+**Strengths:** No architecture changes. Uses LLM code reasoning directly. Can work with tiny models.
+**Weaknesses:** Slowest inference (N passes per sample). Chunk boundary sensitivity. No sub-chunk granularity.
+
+---
+
+## Approach D: Generative Span Detection
+
+**Architecture:** Qwen3.5-2B, standard SFT, generates JSON with hallucinated spans
+
+The model directly outputs which spans are hallucinated and why. This is the reverse of the hallucination injection process.
+
+```
+Input: "Given the source code and answer, identify hallucinated spans."
+Output: {
+ "hallucinated_spans": [
+ {"text": "response.json_decode()", "explanation": "method is json(), not json_decode()"}
+ ]
+}
+```
+
+| Property | Value |
+|----------|-------|
+| **Models** | Qwen3.5-2B or larger |
+| **Inference** | Single generation (autoregressive, slower than forward pass) |
+| **Training** | Standard SFT with LoRA |
+| **SOTA** | [RL4HS (Oct 2025)](https://arxiv.org/abs/2510.02173) β 58.3 F1 on RAGTruth, beats GPT-5 (42.2) and o3 (51.2) |
+
+**Strengths:**
+
+- No architecture changes β pure text generation
+- Free explanations alongside span detection
+- Naturally handles variable span counts
+- Can leverage the LLM's code knowledge ("this API doesn't exist")
+- Training data format already matches (reverse of injection pipeline)
+- Current SOTA approach (RL4HS)
+
+**Weaknesses:** Autoregressive generation is slower. Risk of hallucinating in the detector itself. String matching needed to map spans back to character offsets.
+
+**RL enhancement:** [RL4HS](https://arxiv.org/abs/2510.02173) shows that adding reinforcement learning (GRPO with span-level rewards) on top of SFT dramatically improves performance. SFT alone is a strong baseline; RL pushes it to SOTA.
+
+---
+
+## Comparison
+
+| | A. Encoder token | B. LLM token | C. Chunk verifier | D. Generative span |
+|---|---|---|---|---|
+| **Base model** | ModernBERT-large | Qwen3.5-2B | Qwen3-0.6B | Qwen3.5-2B |
+| **Parameters** | 395M | 2B | 0.6B | 2B |
+| **Architecture mods** | None | Mask removal | None | None |
+| **Inference speed** | Fastest | Medium | Slowest | Medium-slow |
+| **Explainable** | No | No | No | Yes |
+| **Code understanding** | Limited | Deep | Deep | Deep |
+| **Training complexity** | Simple | LLM2Vec + LoRA | Simple SFT | Simple SFT |
+| **SOTA reference** | LettuceDetect, HaluGate | ACL 2024 paper | MiniCheck | RL4HS |
+
+## Recommended Experiments
+
+1. **A vs D** β Token classification (ModernBERT) vs generative span detection (Qwen3.5-2B). The core comparison: fast encoder vs reasoning LLM, both trained on the same dataset.
+
+2. **A vs B** β Does code pretraining help token classification? Same task, different backbone.
+
+3. **D with RL** β If SFT results are promising, add GRPO with span-overlap rewards (following RL4HS).
+
+## Key References
+
+- [LettuceDetect (arXiv:2502.17125)](https://arxiv.org/abs/2502.17125) β Encoder token classification baseline
+- [HaluGate (vLLM, Dec 2025)](https://blog.vllm.ai/2025/12/14/halugate.html) β Production ModernBERT + NLI pipeline
+- [RL4HS (arXiv:2510.02173)](https://arxiv.org/abs/2510.02173) β SOTA generative span detection with RL
+- [FAVA (COLM 2024)](https://arxiv.org/abs/2401.06855) β Generative hallucination editing
+- [PsiloQA (EMNLP 2025)](https://arxiv.org/abs/2510.04849) β Multilingual encoder-based span detection
+- [Looking Right is Sometimes Right (ACL 2024)](https://arxiv.org/abs/2401.14556) β Decoder LLMs for token classification
+- [LLM2Vec (2024)](https://arxiv.org/abs/2404.05961) β Converting decoders to bidirectional encoders
+- [MiniCheck (EMNLP 2024)](https://arxiv.org/abs/2404.10774) β Sentence-level fact checking
+- [Qwen3-Reranker](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B) β LLM-based yes/no classification
+- [CodeMirage (2024)](https://arxiv.org/abs/2408.08333) β Code hallucination taxonomy (snippet-level only)
diff --git a/docs/code-hallucination/configuration.md b/docs/code-hallucination/configuration.md
new file mode 100644
index 0000000..e77e073
--- /dev/null
+++ b/docs/code-hallucination/configuration.md
@@ -0,0 +1,68 @@
+# Configuration
+
+All pipeline configuration is centralized in `scripts/code_hallucination/config.py`.
+
+## Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `OPENAI_API_KEY` | (none) | API key for the LLM provider |
+| `API_BASE_URL` | `https://api.groq.com/openai/v1` | OpenAI-compatible API endpoint |
+| `MODEL` | `moonshotai/kimi-k2-instruct-0905` | Model name |
+| `BATCH_SIZE` | `1` | Concurrent requests. Set >1 for local vLLM to saturate GPU |
+| `CONTEXT7_API_KEY` | (none) | API key for Context7 documentation service |
+
+These can also be overridden via CLI flags (`--api-key`, `--base-url`, `--model`).
+
+## Dataset Parameters
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `HALLUCINATION_RATIO` | `0.4` | Fraction of instances that get hallucination injection |
+| `DOCS_RATIO` | `0.5` | Fraction of instances that get Context7 documentation |
+| `MAX_FILE_CHARS` | `12000` | Maximum characters per source file |
+| `MAX_CONTEXT7_CHARS` | `4000` | Maximum characters per library doc |
+| `LLM_TEMPERATURE` | `0.7` | Temperature for query rewriting |
+| `HALLUCINATION_TEMPERATURE` | `0.8` | Temperature for hallucination injection (higher for variety) |
+| `MAX_RETRIES` | `3` | API retry attempts |
+| `RETRY_DELAY` | `2.0` | Base delay between retries (seconds) |
+
+## Answer Format Weights
+
+| Format | Weight | Description |
+|--------|--------|-------------|
+| `complete_function` | 0.4 | Full patched function body via AST |
+| `edit_style` | 0.3 | "In file X, replace Y with Z" |
+| `fragment` | 0.3 | Added/changed lines from diff |
+
+## Hallucination Types
+
+Assigned round-robin across injected instances:
+
+- **structural** β Non-existent APIs, wrong methods, invented parameters
+- **behavioral** β Wrong values, logic errors, swapped conditions
+- **semantic** β Code that looks correct but does something subtly different
+
+## File Paths
+
+All data is stored under `data/code_hallucination/`:
+
+| Path | Description |
+|------|-------------|
+| `swebench_instances.json` | Phase 1: loaded instances |
+| `repos/` | Phase 2: bare git clones |
+| `source_cache/` | Phase 2: per-instance source data |
+| `queries.jsonl` | Phase 3: rewritten queries |
+| `documentation.jsonl` | Phase 4: library docs |
+| `formats.jsonl` | Phase 5: assigned formats |
+| `hallucinated_samples.jsonl` | Phase 6: injected hallucinations |
+| `code_hallucination_data.json` | Phase 7: final dataset |
+| `code_hallucination_metadata.json` | Phase 7: metadata |
+| `validation_report.txt` | Phase 9: quality report |
+
+## Data Sources
+
+| Source | Dataset ID |
+|--------|-----------|
+| SWE-bench (full) | `princeton-nlp/SWE-bench` |
+| SWE-bench Lite | `princeton-nlp/SWE-bench_Lite` |
diff --git a/docs/code-hallucination/index.md b/docs/code-hallucination/index.md
new file mode 100644
index 0000000..4a6f01f
--- /dev/null
+++ b/docs/code-hallucination/index.md
@@ -0,0 +1,294 @@
+# Code Hallucination Dataset Pipeline
+
+A modular 9-phase pipeline for generating span-level code hallucination detection datasets from [SWE-bench](https://www.swebench.com/). Produces training data where each sample contains source code context, a code answer, and character-level annotations marking hallucinated spans.
+
+## Why This Dataset?
+
+Existing hallucination detection datasets (RAGTruth, RAGBench) focus on **text** β question answering, summarization, data-to-text. There is no established **span-level code hallucination dataset**. CodeMirage classifies entire snippets but doesn't localize where the hallucination is.
+
+This pipeline generates samples where an LLM coding assistant answers a developer's question about a real codebase, and we know exactly which character spans in the answer are hallucinated β enabling training of both token-level classifiers (ModernBERT) and generative span detectors (decoder LLMs).
+
+## Dataset Overview
+
+| Property | Value |
+|----------|-------|
+| **Source** | SWE-bench (all splits) |
+| **Total instances** | ~21,500 (19k train + 225 dev + 2.3k test) |
+| **Repos** | 53 unique repos, zero overlap between splits |
+| **Clean/hallucinated ratio** | ~60% clean / ~40% hallucinated |
+| **Hallucination types** | Structural, behavioral, semantic |
+| **Answer formats** | Complete function, edit-style, code fragment |
+| **Annotation granularity** | Character-level spans |
+
+## Quick Start
+
+### Test with a few examples
+
+```bash
+# Using Groq (fast, free tier available)
+OPENAI_API_KEY=your_groq_key \
+ python -m scripts.code_hallucination.pipeline --test 5
+
+# Using any OpenAI-compatible API
+OPENAI_API_KEY=your_key \
+API_BASE_URL=https://api.example.com/v1 \
+MODEL=your-model-name \
+ python -m scripts.code_hallucination.pipeline --test 10
+```
+
+### Run the full pipeline
+
+```bash
+# Run all 9 phases
+python -m scripts.code_hallucination.pipeline --all
+
+# Run specific phases
+python -m scripts.code_hallucination.pipeline --phase 1 2 3
+
+# Override LLM settings via CLI
+python -m scripts.code_hallucination.pipeline --all \
+ --api-key YOUR_KEY \
+ --base-url https://api.groq.com/openai/v1 \
+ --model moonshotai/kimi-k2-instruct-0905
+```
+
+### Run with local vLLM (recommended for bulk generation)
+
+```bash
+# Terminal 1: Start vLLM
+vllm serve Qwen/Qwen3.5-2B --port 8000
+
+# Terminal 2: Run pipeline with batch processing
+BATCH_SIZE=16 \
+API_BASE_URL=http://localhost:8000/v1 \
+OPENAI_API_KEY=dummy \
+MODEL=Qwen/Qwen3.5-2B \
+ python -m scripts.code_hallucination.pipeline --all
+```
+
+`BATCH_SIZE>1` enables async concurrent requests β no rate limiting, full GPU saturation.
+
+### CLI Options
+
+| Flag | Description |
+|------|-------------|
+| `--test N` | Test mode: run pipeline on N random test instances using GitHub API (no repo cloning) |
+| `--all` | Run all 9 phases |
+| `--phase 1 2 3` | Run specific phases (1-9) |
+| `--api-key` | LLM API key (or set `OPENAI_API_KEY` env var) |
+| `--base-url` | LLM API base URL (or set `API_BASE_URL` env var) |
+| `--model` | LLM model name (or set `MODEL` env var) |
+
+| Environment Variable | Description |
+|---------------------|-------------|
+| `BATCH_SIZE` | Number of concurrent requests (default: 1). Set >1 for local vLLM |
+| `OPENAI_API_KEY` | API key for the LLM provider |
+| `API_BASE_URL` | OpenAI-compatible API endpoint |
+| `MODEL` | Model name |
+| `CONTEXT7_API_KEY` | API key for Context7 documentation service |
+
+## Pipeline Architecture
+
+```mermaid
+graph TD
+ A[Phase 1: Load SWE-bench] --> B[Phase 2: Fetch Sources]
+ B --> C[Phase 3: Rewrite Queries]
+ B --> D[Phase 4: Fetch Docs]
+ B --> E[Phase 5: Assign Formats]
+ A --> F[Phase 8: Select Targets]
+ C --> G[Phase 6: Inject Hallucinations]
+ E --> G
+ F --> G
+ D --> H[Phase 7: Assemble Samples]
+ G --> H
+ E --> H
+ C --> H
+ H --> I[Phase 9: Validate]
+```
+
+Phases 3 and 4 can run in parallel. Phase 8 (target selection) runs before Phase 6 (injection).
+
+## Output Format
+
+Each sample follows the `HallucinationSample` format used by LettuceDetect:
+
+```json
+{
+ "prompt": "File: django/http/response.py\n```python\n...\n```\n\nUser request: How do I fix the Content-Type header issue?",
+ "answer": "def fix_content_type(self):\n self.headers['Content-Type'] = response.get_type()\n ...",
+ "labels": [
+ {"start": 55, "end": 75, "label": "structural"}
+ ],
+ "split": "test",
+ "task_type": "code_generation",
+ "dataset": "swebench_code",
+ "language": "en"
+}
+```
+
+- **`prompt`**: Source code files + documentation + user query
+- **`answer`**: Code in one of three formats (complete function, edit-style, fragment)
+- **`labels`**: Character-level span annotations (empty for clean samples)
+- **`split`**: train/dev/test (inherited from SWE-bench, zero repo overlap)
+
+## Supported LLM Providers
+
+The pipeline works with any OpenAI-compatible API. Tested with:
+
+| Provider | Model | Notes |
+|----------|-------|-------|
+| [Groq](https://groq.com) | `moonshotai/kimi-k2-instruct-0905` | Fast, free tier |
+| [Groq](https://groq.com) | `llama-3.3-70b-versatile` | Good quality |
+| [Novita AI](https://novita.ai) | `qwen/qwen3.5-27b` | Good for bulk generation |
+| Local (vLLM/Ollama) | Any model | Free, best for large runs |
+
+## Directory Structure
+
+```
+scripts/code_hallucination/
+βββ __init__.py # Package declaration
+βββ pipeline.py # CLI orchestrator
+βββ config.py # Constants, paths, API settings
+βββ swebench_loader.py # Phase 1: Load SWE-bench instances
+βββ source_fetcher.py # Phase 2: Clone repos, fetch source files
+βββ query_rewriter.py # Phase 3: LLM query rewriting
+βββ context7_docs.py # Phase 4: Fetch library documentation
+βββ format_builder.py # Phase 5: Assign answer formats
+βββ hallucination_injector.py # Phase 6: LLM hallucination injection
+βββ sample_assembler.py # Phase 7: Assemble final samples
+βββ splitter.py # Phase 8: Select hallucination targets
+βββ validator.py # Phase 9: Quality validation
+```
+
+Output data:
+
+```
+data/code_hallucination/
+βββ swebench_instances.json # Phase 1 output
+βββ repos/ # Phase 2: cloned repos (bare)
+βββ source_cache/ # Phase 2: per-instance source data
+β βββ {instance_id}.json
+βββ queries.jsonl # Phase 3 output
+βββ documentation.jsonl # Phase 4 output
+βββ formats.jsonl # Phase 5 output
+βββ hallucinated_samples.jsonl # Phase 6 output
+βββ code_hallucination_data.json # Phase 7: final dataset
+βββ code_hallucination_metadata.json # Phase 7: metadata
+βββ validation_report.txt # Phase 9 output
+```
+
+## Design Decisions
+
+### One sample per instance
+Each SWE-bench instance produces exactly one sample β either clean (gold patch answer) or hallucinated (LLM-injected). No instance appears in both classes. This avoids the artificial pairing problem where models learn to distinguish the specific instance rather than the hallucination.
+
+### JSON-based span annotations
+Hallucination spans are extracted from the LLM's structured JSON response, not from difflib character-level diffs. The LLM returns `{"hallucinated_code": "...", "changes": [{"original": "...", "hallucinated": "..."}]}` and spans are found by string matching. This produces clean, meaningful spans (avg 70 chars) instead of noisy character-level artifacts (1-3 char noise from difflib).
+
+### 50/50 documentation split
+Half of instances include Context7 library documentation, half don't. This teaches models to handle both documented and undocumented scenarios.
+
+### Zero repo overlap between splits
+SWE-bench's train/dev/test splits naturally have zero repository overlap across 53 unique repos. This means test performance measures generalization to completely unseen codebases.
+
+## Training on This Dataset
+
+Once you've generated the dataset, there are two training approaches. See [Architecture Research](architecture-research.md) for the full rationale behind each.
+
+### Approach A: Token Classification (ModernBERT)
+
+The standard LettuceDetect approach β a lightweight encoder that labels each token as supported or hallucinated.
+
+```bash
+# Code data only
+python scripts/train_code_hallucination.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/code_hallucination_detector \
+ --batch-size 4 \
+ --epochs 6 \
+ --learning-rate 1e-5
+
+# Code + RAGTruth combined (better generalization across text and code)
+python scripts/train_code_hallucination.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --ragtruth-path data/ragtruth/ragtruth_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/code_hallucination_detector \
+ --batch-size 4 \
+ --epochs 6
+```
+
+The training script uses SWE-bench splits directly β train for training, dev for validation, test held out. Zero repository overlap between splits.
+
+### Approach D: Generative Span Detection (Qwen SFT)
+
+Fine-tune a decoder LLM to read context + answer and generate a JSON list of hallucinated spans with explanations. This is the reverse of the injection pipeline.
+
+```bash
+# Requires: pip install peft
+python scripts/train_generative_detector.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name Qwen/Qwen3.5-2B \
+ --output-dir output/generative_detector \
+ --batch-size 2 \
+ --epochs 3 \
+ --lora-r 16
+```
+
+The model learns to output:
+
+```json
+{"hallucinated_spans": [{"text": "response.json_decode()", "explanation": "method is json(), not json_decode()"}]}
+```
+
+For clean samples it outputs `{"hallucinated_spans": []}`.
+
+Training uses [LoRA](https://arxiv.org/abs/2106.09685) for memory efficiency (~5-8GB VRAM). Only the model's response tokens contribute to the loss β context and prompt tokens are masked.
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `--model-name` | `Qwen/Qwen3.5-2B` | Any HuggingFace causal LM |
+| `--lora-r` | 16 | LoRA rank (higher = more capacity, more memory) |
+| `--lora-alpha` | 32 | LoRA scaling factor |
+| `--batch-size` | 2 | Training batch size |
+| `--epochs` | 3 | Number of training epochs |
+| `--learning-rate` | 2e-4 | Learning rate |
+| `--gradient-accumulation-steps` | 4 | Accumulate gradients over N steps (simulates larger batch) |
+
+### Evaluate
+
+```bash
+python scripts/evaluate_code_hallucination.py \
+ --model_path output/code_hallucination_detector \
+ --data_path data/code_hallucination/code_hallucination_data.json \
+ --evaluation_type example_level
+```
+
+## End-to-End Example
+
+```bash
+# 1. Generate dataset (local vLLM for speed)
+vllm serve Qwen/Qwen3.5-2B --port 8000 # in another terminal
+
+BATCH_SIZE=16 \
+API_BASE_URL=http://localhost:8000/v1 \
+OPENAI_API_KEY=dummy \
+MODEL=Qwen/Qwen3.5-2B \
+ python -m scripts.code_hallucination.pipeline --all
+
+# 2. Train
+python scripts/train_code_hallucination.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/code_hallucination_detector \
+ --batch-size 4 --epochs 6
+
+# 3. Evaluate
+python scripts/evaluate_code_hallucination.py \
+ --model_path output/code_hallucination_detector \
+ --data_path data/code_hallucination/code_hallucination_data.json \
+ --evaluation_type example_level
+```
+
+The pipeline is **fully resumable** β every slow phase saves results incrementally to JSONL. If it crashes, re-run the same command and it picks up where it left off.
diff --git a/docs/code-hallucination/phases.md b/docs/code-hallucination/phases.md
new file mode 100644
index 0000000..6e93da0
--- /dev/null
+++ b/docs/code-hallucination/phases.md
@@ -0,0 +1,272 @@
+# Pipeline Phases
+
+Detailed documentation for each of the 9 pipeline phases.
+
+## Phase 1: Load SWE-bench
+
+**Module:** `swebench_loader.py`
+
+Loads all SWE-bench splits from HuggingFace and tags each instance with split and Lite membership.
+
+| Split | Instances | Repos | Source |
+|-------|-----------|-------|--------|
+| Train | 19,008 | 35 | `princeton-nlp/SWE-bench` train |
+| Dev | 225 | 6 | `princeton-nlp/SWE-bench` dev |
+| Test | 2,294 | 12 | `princeton-nlp/SWE-bench` test |
+| Lite | 300 | 12 | `princeton-nlp/SWE-bench_Lite` (subset of test) |
+
+**Key function:** `load_all_splits() -> list[dict]`
+
+Each instance includes: `instance_id`, `repo`, `base_commit`, `patch`, `problem_statement`, `split`, `is_lite`.
+
+**Output:** `data/code_hallucination/swebench_instances.json`
+
+---
+
+## Phase 2: Fetch Sources
+
+**Module:** `source_fetcher.py`
+
+Clones repositories and extracts source code at the base commit for each instance. Builds three answer format variants.
+
+### Strategy
+
+- **Default:** Clone repos as bare git repos to `data/code_hallucination/repos/`. Use `git show {commit}:{path}` for instant file access.
+- **Test mode:** Use GitHub raw API (`raw.githubusercontent.com`) β slower but no cloning needed.
+- **Fallback:** If cloning fails, automatically falls back to GitHub API.
+
+### What it extracts per instance
+
+| Field | Description |
+|-------|-------------|
+| `changed_files` | File paths modified by the gold patch |
+| `source_files` | Original source code at base commit |
+| `patch_code` | Added/changed lines from the diff (fragment format) |
+| `edit_style` | "In file X, replace Y with Z" format |
+| `modified_functions` | AST-extracted functions that changed (complete function format) |
+
+### Key functions
+
+- `extract_changed_files(patch)` β Parse unified diff for file paths (anchored regex, not `lstrip("b/")`)
+- `clone_repo(repo)` β `git clone --bare` with 30min timeout
+- `fetch_file_at_commit(repo_dir, commit, filepath)` β `git show` for file contents
+- `apply_patch_and_get_file(repo_dir, commit, patch, filepath)` β Apply patch in temp worktree
+- `extract_modified_functions(original, patched)` β AST-based function diff
+
+**Output:** `data/code_hallucination/source_cache/{instance_id}.json`
+
+---
+
+## Phase 3: Rewrite Queries
+
+**Module:** `query_rewriter.py`
+
+Transforms raw GitHub issue `problem_statement` fields into natural developer queries using an LLM.
+
+### Example
+
+**Before (raw issue):**
+> BUG: DataFrame.groupby with as_index=False gives wrong result when grouping by single column with duplicate name. Steps to reproduce: ...
+
+**After (rewritten):**
+> I'm getting wrong results when using DataFrame.groupby with as_index=False on a column that has a duplicate name. How do I fix this?
+
+### Prompt strategy
+
+The LLM is instructed to:
+
+- Write conversational, natural language
+- Extract the core technical ask
+- Remove GitHub formatting, reproduction steps, tracebacks
+- Keep to 1-3 sentences
+
+### Resumability
+
+Writes results to JSONL incrementally. On restart, skips already-processed `instance_id`s.
+
+**Output:** `data/code_hallucination/queries.jsonl`
+
+---
+
+## Phase 4: Fetch Documentation
+
+**Module:** `context7_docs.py`
+
+Fetches library documentation via the [Context7](https://context7.com) API for **50% of instances** (configurable via `DOCS_RATIO`).
+
+### Library detection
+
+Detects libraries from:
+1. Import statements in the patch (`import django`, `from sklearn import ...`)
+2. File paths (`django/http/response.py` β django)
+
+Maps to Context7 library names via a predefined dictionary.
+
+### 50/50 split rationale
+
+Half of samples include documentation context, half don't. This creates training variety β models learn to detect hallucinations both with and without documentation support.
+
+Instances not selected for docs still get an entry written with empty docs (by design, not failure).
+
+**Output:** `data/code_hallucination/documentation.jsonl`
+
+---
+
+## Phase 5: Assign Answer Formats
+
+**Module:** `format_builder.py`
+
+Each instance gets exactly one answer format, chosen by weighted random selection from available options.
+
+### Format types
+
+**Complete function** (weight: 0.4)
+```python
+def validate_response(self, response):
+ if response.status_code != 200:
+ raise ValidationError(f"Unexpected status: {response.status_code}")
+ return response.json()
+```
+Extracted via Python AST from the patched source. Only available when changes are inside a function (~60% of patches).
+
+**Edit-style** (weight: 0.3)
+```
+In file django/http/response.py, replace:
+ def set_cookie(self, key, value=""):
+ self.cookies[key] = value
+with:
+ def set_cookie(self, key, value="", max_age=None):
+ self.cookies[key] = value
+ if max_age is not None:
+ self.cookies[key]["max-age"] = max_age
+```
+Available for all patches where changed regions can be extracted.
+
+**Fragment** (weight: 0.3)
+```python
+if max_age is not None:
+ self.cookies[key]["max-age"] = max_age
+ self.cookies[key]["expires"] = http_date(time.time() + max_age)
+```
+Added/changed lines from the diff with surrounding context.
+
+**Output:** `data/code_hallucination/formats.jsonl`
+
+---
+
+## Phase 6: Inject Hallucinations
+
+**Module:** `hallucination_injector.py`
+
+Uses an LLM to inject realistic hallucinations into selected instances (determined by Phase 8). Returns structured JSON with span annotations.
+
+### Hallucination types (round-robin)
+
+| Type | Description | Example |
+|------|-------------|---------|
+| **Structural** | Non-existent APIs, wrong methods, invented parameters | `response.json_decode()` instead of `response.json()` |
+| **Behavioral** | Wrong values, logic errors, off-by-one, swapped conditions | `if status >= 200` instead of `if status == 200` |
+| **Semantic** | Code that looks right but does something subtly different | Sorting ascending instead of descending |
+
+### JSON-based span extraction
+
+The LLM returns structured output:
+
+```json
+{
+ "hallucinated_code": "def fix(self):\n self.data = response.json_decode()\n ...",
+ "changes": [
+ {
+ "original": "response.json()",
+ "hallucinated": "response.json_decode()",
+ "explanation": "json_decode() is not a valid method on Response objects"
+ }
+ ]
+}
+```
+
+Spans are found by string-matching each `change["hallucinated"]` in `hallucinated_code`. This produces clean, meaningful spans (minimum 3 chars) with zero noise.
+
+### Quality metrics (from 100-sample test runs)
+
+| Metric | Value |
+|--------|-------|
+| Noise-only samples | 0% |
+| Min span length | 10 chars |
+| Avg span length | 70 chars |
+| Avg spans per sample | 1.2 |
+
+**Output:** `data/code_hallucination/hallucinated_samples.jsonl`
+
+---
+
+## Phase 7: Assemble Samples
+
+**Module:** `sample_assembler.py`
+
+Combines all intermediate data into the final `HallucinationSample` format.
+
+### Prompt construction
+
+```
+File: path/to/file.py
+```python
+
+```
+
+Documentation for django:
+
+
+User request:
+```
+
+### Sample types
+
+**Clean samples** (~60%): Gold patch answer, empty labels, from instances NOT selected for injection.
+
+**Hallucinated samples** (~40%): LLM-modified answer with character-level span annotations.
+
+### Outputs
+
+- `data/code_hallucination/code_hallucination_data.json` β List of samples
+- `data/code_hallucination/code_hallucination_metadata.json` β Metadata (instance_id, repo, format_type, hallucination_type, injector_model, is_hallucinated)
+
+---
+
+## Phase 8: Select Hallucination Targets
+
+**Module:** `splitter.py`
+
+Selects which instances receive hallucination injection. Applies the hallucination ratio (default 40%) **uniformly within each split** to maintain consistent class distribution.
+
+```
+Train: ~7,600 hallucinated + ~11,400 clean = ~19,000
+Dev: ~90 hallucinated + ~135 clean = ~225
+Test: ~920 hallucinated + ~1,374 clean = ~2,294
+```
+
+!!! note
+ Phase 8 runs **before** Phase 6 in the pipeline (target selection must happen before injection).
+
+**Output:** Set of `instance_id`s (used in-memory by Phase 6 and Phase 7)
+
+---
+
+## Phase 9: Validate
+
+**Module:** `validator.py`
+
+Runs automated quality checks and generates a report.
+
+### Checks performed
+
+| Check | Description |
+|-------|-------------|
+| **Span validity** | No negative offsets, empty spans, or out-of-bounds |
+| **Span coverage** | Distribution of hallucinated text ratio; flags <2% or >80% |
+| **Distributions** | Format type, hallucination type, injector model, repo, split |
+| **Near-duplicates** | Jaccard similarity >0.95 on sampled answer pairs |
+| **AST parseability** | For complete_function format, checks if answer parses as valid Python |
+| **Length statistics** | Prompt and answer character length ranges |
+
+**Output:** `data/code_hallucination/validation_report.txt`
diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md
new file mode 100644
index 0000000..273b669
--- /dev/null
+++ b/docs/getting-started/installation.md
@@ -0,0 +1,42 @@
+# Installation
+
+## From PyPI
+
+```bash
+pip install lettucedetect
+```
+
+## From Source (development)
+
+```bash
+git clone https://github.com/KRLabsOrg/LettuceDetect.git
+cd LettuceDetect
+pip install -e .
+```
+
+## Optional Dependencies
+
+```bash
+# Web API server
+pip install -e ".[api]"
+
+# Development tools (testing, linting)
+pip install -e ".[dev]"
+
+# Documentation site
+pip install -e ".[docs]"
+```
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch >= 2.6.0
+- Transformers >= 4.48.3
+
+## Environment Variables
+
+For LLM-based detection or data generation:
+
+```bash
+export OPENAI_API_KEY=your_api_key
+```
diff --git a/docs/getting-started/models.md b/docs/getting-started/models.md
new file mode 100644
index 0000000..b79af3d
--- /dev/null
+++ b/docs/getting-started/models.md
@@ -0,0 +1,31 @@
+# Models
+
+## English Models
+
+| Model | Base | Max Tokens | Example F1 | Span F1 |
+|-------|------|-----------|-----------|---------|
+| [lettucedetect-base-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-modernbert-en-v1) | ModernBERT-base | 4K | 76.8% | SOTA |
+| [lettucedetect-large-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-large-modernbert-en-v1) | ModernBERT-large | 4K | 79.2% | SOTA |
+
+## Multilingual Models
+
+| Model | Base | Languages | Max Tokens |
+|-------|------|-----------|-----------|
+| [lettucedetect-base-eurobert-multilingual-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1) | EuroBERT-210M | en, de, fr, es, it, pl, cn | 8K |
+
+## TinyLettuce (Distilled)
+
+Smaller models for resource-constrained environments. See [TinyLettuce docs](../TINYLETTUCE.md).
+
+## Using a Model
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-large-modernbert-en-v1"
+)
+```
+
+Models are downloaded automatically from HuggingFace Hub on first use.
diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md
new file mode 100644
index 0000000..906c1d3
--- /dev/null
+++ b/docs/getting-started/quickstart.md
@@ -0,0 +1,65 @@
+# Quick Start
+
+## Detect Hallucinations
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+# Load a pre-trained model
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1"
+)
+
+# Provide context, question, and answer
+contexts = [
+ "France is a country in Europe. The capital of France is Paris. "
+ "The population of France is 67 million."
+]
+question = "What is the capital of France? What is the population?"
+answer = "The capital of France is Paris. The population of France is 69 million."
+
+# Get span-level predictions
+predictions = detector.predict(
+ context=contexts,
+ question=question,
+ answer=answer,
+ output_format="spans"
+)
+print(predictions)
+# [{'start': 31, 'end': 71, 'confidence': 0.99,
+# 'text': ' The population of France is 69 million.'}]
+```
+
+## Available Models
+
+| Model | Language | Context | Size |
+|-------|----------|---------|------|
+| `KRLabsOrg/lettucedetect-base-modernbert-en-v1` | English | 4K | 149M |
+| `KRLabsOrg/lettucedetect-large-modernbert-en-v1` | English | 4K | 395M |
+| `KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1` | 7 languages | 8K | 210M |
+
+See [Models](models.md) for the full list.
+
+## Detection Methods
+
+```python
+# Transformer-based (recommended for production)
+detector = HallucinationDetector(method="transformer", model_path="...")
+
+# LLM-based (uses OpenAI API)
+detector = HallucinationDetector(method="llm", model_path="gpt-4o-mini")
+
+# RAG Fact Checker (triplet-based)
+detector = HallucinationDetector(method="rag_fact_checker", model_path="gpt-4o-mini")
+```
+
+## Output Formats
+
+```python
+# Span-level: exact character ranges of hallucinated text
+predictions = detector.predict(..., output_format="spans")
+
+# Sentence-level: which sentences contain hallucinations
+predictions = detector.predict(..., output_format="sentences")
+```
diff --git a/docs/guide/api.md b/docs/guide/api.md
new file mode 100644
index 0000000..45f1f67
--- /dev/null
+++ b/docs/guide/api.md
@@ -0,0 +1,37 @@
+# Web API
+
+LettuceDetect includes a FastAPI server for HTTP-based hallucination detection.
+
+## Start the Server
+
+```bash
+# Development
+python scripts/start_api.py dev
+
+# Production
+python scripts/start_api.py prod
+
+# Custom model
+python scripts/start_api.py dev --model KRLabsOrg/lettucedetect-large-modernbert-en-v1
+```
+
+## Python Client
+
+```python
+from lettucedetect_api.client import LettuceClient
+
+client = LettuceClient("http://127.0.0.1:8000")
+
+contexts = ["The capital of France is Paris."]
+question = "What is the capital of France?"
+answer = "The capital of France is Berlin."
+
+response = client.detect_spans(contexts, question, answer)
+print(response)
+```
+
+## Installation
+
+```bash
+pip install lettucedetect[api]
+```
diff --git a/docs/guide/detection-methods.md b/docs/guide/detection-methods.md
new file mode 100644
index 0000000..6f75e3c
--- /dev/null
+++ b/docs/guide/detection-methods.md
@@ -0,0 +1,42 @@
+# Detection Methods
+
+LettuceDetect supports three ways to detect hallucinations, each with different trade-offs.
+
+## Transformer (Recommended)
+
+Fine-tuned encoder models that classify each token in the answer as supported or hallucinated. Best balance of speed and accuracy.
+
+```python
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-large-modernbert-en-v1"
+)
+```
+
+**How it works:** The model reads the context, question, and answer together. It labels each answer token, then merges consecutive hallucinated tokens into character spans. A single forward pass β fast enough for production use (30-60 samples/sec on GPU).
+
+**When to use:** Production systems, latency-sensitive applications, or when you need precise span locations.
+
+## LLM-based
+
+Uses OpenAI-compatible APIs (GPT-4, Claude, etc.) for hallucination detection. No fine-tuning needed.
+
+```python
+detector = HallucinationDetector(method="llm", model_path="gpt-4o-mini")
+```
+
+**How it works:** Sends context + question + answer to the LLM with a prompt requesting hallucination spans in a structured format.
+
+**When to use:** Quick prototyping, or when you want the LLM to explain *why* something is hallucinated.
+
+## RAG Fact Checker
+
+Triplet-based fact checking that breaks the answer into structured claims and verifies each one.
+
+```python
+detector = HallucinationDetector(method="rag_fact_checker", model_path="gpt-4o-mini")
+```
+
+**How it works:** Extracts (subject, predicate, object) claims from the answer, then checks each claim against the context.
+
+**When to use:** When you want claim-level granularity and structured verification results.
diff --git a/docs/guide/evaluation.md b/docs/guide/evaluation.md
new file mode 100644
index 0000000..e59354b
--- /dev/null
+++ b/docs/guide/evaluation.md
@@ -0,0 +1,47 @@
+# Evaluation
+
+LettuceDetect supports three levels of evaluation, from coarse to fine-grained.
+
+## Example-Level
+
+The simplest check: does the answer contain **any** hallucination? This is a binary yes/no classification per answer.
+
+```bash
+python scripts/evaluate.py \
+ --model_path output/hallucination_detector \
+ --data_path data/ragtruth/ragtruth_data.json \
+ --evaluation_type example_level
+```
+
+## Span-Level (Character IoU)
+
+Measures how well predicted hallucination spans overlap with the gold annotations at the character level. This is the most demanding metric β the model must identify not just *that* something is wrong, but *exactly where*.
+
+```bash
+python scripts/evaluate.py \
+ --model_path output/hallucination_detector \
+ --data_path data/ragtruth/ragtruth_data.json \
+ --evaluation_type span_level
+```
+
+## Token-Level
+
+Per-token precision, recall, and F1 for the hallucinated class. This is what the model is directly optimized for during training.
+
+## Metrics Explained
+
+- **Precision**: Of everything the model flagged as hallucinated, how much was actually hallucinated?
+- **Recall**: Of all the actual hallucinations, how many did the model catch?
+- **F1**: The balance between precision and recall (harmonic mean). This is the primary metric.
+- **AUROC**: Area under the ROC curve β measures how well the model separates hallucinated from supported tokens across all confidence thresholds.
+
+## LLM Baselines
+
+You can also evaluate LLM-based detectors for comparison:
+
+```bash
+python scripts/evaluate_llm.py \
+ --model "gpt-4o-mini" \
+ --data_path data/ragtruth/ragtruth_data.json \
+ --evaluation_type example_level
+```
diff --git a/docs/guide/integrations.md b/docs/guide/integrations.md
new file mode 100644
index 0000000..953bcd7
--- /dev/null
+++ b/docs/guide/integrations.md
@@ -0,0 +1,59 @@
+# Integrations
+
+LettuceDetect can be used with popular LLM frameworks. Integration examples are available in the repository's `integrations/` directory (see the [GitHub repo](https://github.com/KRLabsOrg/LettuceDetect)).
+
+## LangChain
+
+Use LettuceDetect as a callback, chain component, or tool within LangChain pipelines.
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+# Create detector
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1"
+)
+
+# Use in your LangChain pipeline
+def check_hallucination(context, question, answer):
+ spans = detector.predict(
+ context=context, question=question,
+ answer=answer, output_format="spans"
+ )
+ return len(spans) == 0 # True if no hallucinations detected
+```
+
+## Pydantic AI
+
+Use LettuceDetect within Pydantic AI agents for structured hallucination checking.
+
+## Haystack
+
+Add LettuceDetect as a pipeline component in Haystack for post-generation verification.
+
+## General Pattern
+
+Any framework that gives you access to the retrieved context and generated answer can use LettuceDetect:
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1"
+)
+
+# After your RAG pipeline generates an answer:
+spans = detector.predict(
+ context=retrieved_documents,
+ question=user_query,
+ answer=generated_answer,
+ output_format="spans"
+)
+
+if spans:
+ print(f"Found {len(spans)} hallucinated spans")
+ for span in spans:
+ print(f" '{span['text']}' (confidence: {span['confidence']:.2f})")
+```
diff --git a/docs/guide/multilingual.md b/docs/guide/multilingual.md
new file mode 100644
index 0000000..0f20a83
--- /dev/null
+++ b/docs/guide/multilingual.md
@@ -0,0 +1,29 @@
+# Multilingual Support
+
+LettuceDetect supports hallucination detection in multiple languages via EuroBERT-based models.
+
+## Supported Languages
+
+| Language | Code | Model |
+|----------|------|-------|
+| English | en | ModernBERT + EuroBERT |
+| German | de | EuroBERT |
+| French | fr | EuroBERT |
+| Spanish | es | EuroBERT |
+| Italian | it | EuroBERT |
+| Polish | pl | EuroBERT |
+| Chinese | cn | EuroBERT |
+| Hungarian | hu | EuroBERT |
+
+## Usage
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-base-eurobert-multilingual-v1"
+)
+```
+
+The multilingual model handles all supported languages with a single checkpoint. No language detection or switching needed.
diff --git a/docs/guide/training.md b/docs/guide/training.md
new file mode 100644
index 0000000..59c454f
--- /dev/null
+++ b/docs/guide/training.md
@@ -0,0 +1,109 @@
+# Training Your Own Model
+
+LettuceDetect models are standard HuggingFace token classifiers. You can fine-tune them on your own data or on the provided RAGTruth dataset.
+
+## How It Works
+
+The model reads a concatenated input β context, question, and answer β and classifies each answer token as either **supported** (0) or **hallucinated** (1). Consecutive hallucinated tokens are merged into spans at inference time.
+
+```
+Input: [CLS] context [SEP] question [SEP] answer [SEP]
+Labels: [-100, -100, ..., -100, ..., 0, 0, 1, 1, 1, 0, ...]
+ ^^^^^^^^^^^^^^^^^^^^^^^^
+ only answer tokens are labeled
+```
+
+- `-100` = ignored by the loss function (context + question tokens)
+- `0` = supported (answer token backed by context)
+- `1` = hallucinated (answer token not supported by context)
+
+The best checkpoint is saved based on validation F1 for the hallucinated class.
+
+## Train on RAGTruth
+
+[RAGTruth](https://aclanthology.org/2024.acl-long.585/) is a text hallucination dataset covering QA, summarization, and data-to-text tasks.
+
+```bash
+# Preprocess the raw data first
+python lettucedetect/preprocess/preprocess_ragtruth.py \
+ --input_dir data/ragtruth --output_dir data/ragtruth
+
+# Train
+python scripts/train.py \
+ --ragtruth-path data/ragtruth/ragtruth_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/hallucination_detector \
+ --batch-size 4 \
+ --epochs 6 \
+ --learning-rate 1e-5
+```
+
+## Train on Your Own Data
+
+Your data must follow the `HallucinationSample` format β a JSON list where each item has:
+
+```json
+{
+ "prompt": "Your context text here...",
+ "answer": "The LLM-generated answer to check...",
+ "labels": [{"start": 45, "end": 72, "label": "hallucination"}],
+ "split": "train",
+ "task_type": "qa",
+ "dataset": "ragtruth",
+ "language": "en"
+}
+```
+
+- `labels` contains character-level spans within the `answer` marking hallucinated regions. Empty list `[]` for clean samples.
+- `split` should be `"train"`, `"dev"`, or `"test"`.
+
+Then train the same way:
+
+```bash
+python scripts/train.py \
+ --ragtruth-path path/to/your_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/my_detector \
+ --batch-size 4 \
+ --epochs 6
+```
+
+## Configuration
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `--model-name` | `answerdotai/ModernBERT-base` | HuggingFace model to fine-tune |
+| `--batch-size` | 4 | Training batch size |
+| `--epochs` | 6 | Number of training epochs |
+| `--learning-rate` | 1e-5 | Learning rate |
+| `--max-length` | 4096 | Maximum input length in tokens |
+| `--seed` | 42 | Random seed for reproducibility |
+
+## Recommended Base Models
+
+| Model | Parameters | Max Tokens | Best For |
+|-------|-----------|------------|----------|
+| `answerdotai/ModernBERT-base` | 149M | 8K | Fast training, good baseline |
+| `answerdotai/ModernBERT-large` | 395M | 8K | Best English performance |
+| `EuroBERT/EuroBERT-210m` | 210M | 8K | Multilingual (7 languages) |
+| `EuroBERT/EuroBERT-610m` | 610M | 8K | Best multilingual performance |
+
+## Evaluate
+
+```bash
+# Example-level: does the answer contain any hallucination?
+python scripts/evaluate.py \
+ --model_path output/hallucination_detector \
+ --data_path data/ragtruth/ragtruth_data.json \
+ --evaluation_type example_level
+
+# Span-level: how well do predicted spans match gold spans?
+python scripts/evaluate.py \
+ --model_path output/hallucination_detector \
+ --data_path data/ragtruth/ragtruth_data.json \
+ --evaluation_type span_level
+```
+
+## Code Hallucination
+
+For training on code hallucination data (from SWE-bench), see the [Code Hallucination Dataset](../code-hallucination/index.md) section which covers generating the dataset and training both encoder and generative models.
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..619df26
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,60 @@
+# LettuceDetect
+
+
+
+
+
+**A lightweight hallucination detection framework for RAG applications.**
+
+LettuceDetect is an encoder-based model built on [ModernBERT](https://github.com/AnswerDotAI/ModernBERT) that detects unsupported spans in LLM-generated answers by comparing them against provided context. It provides token-level precision for identifying exactly which parts of an answer are hallucinated.
+
+## Highlights
+
+- **Token-level precision** β identifies exact hallucinated spans, not just "this answer has a problem"
+- **Fast inference** β 30-60 samples/sec on A100, suitable for production
+- **Long context** β supports up to 4K tokens (ModernBERT) or 8K tokens (EuroBERT)
+- **Multilingual** β English, German, French, Spanish, Italian, Polish, Chinese, Hungarian
+- **Open source** β MIT license, models on HuggingFace
+
+## Quick Example
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedetect-base-modernbert-en-v1"
+)
+
+contexts = ["The capital of France is Paris. The population is 67 million."]
+question = "What is the capital and population of France?"
+answer = "The capital of France is Paris. The population is 69 million."
+
+predictions = detector.predict(
+ context=contexts, question=question, answer=answer, output_format="spans"
+)
+# [{'start': 31, 'end': 71, 'confidence': 0.99, 'text': ' The population of France is 69 million.'}]
+```
+
+## Performance
+
+| Model | Example F1 | vs GPT-4 | vs Luna | Parameters |
+|-------|-----------|----------|---------|-----------|
+| lettucedetect-base-v1 | 76.8% | +13.4% | +11.4% | 149M |
+| lettucedetect-large-v1 | **79.2%** | +15.8% | +13.8% | 395M |
+
+Evaluated on [RAGTruth](https://aclanthology.org/2024.acl-long.585/) test set. Surpasses GPT-4, Luna, and fine-tuned Llama-2-13B.
+
+## What's New
+
+- **[Code Hallucination Dataset](code-hallucination/index.md)** β A pipeline for generating span-level code hallucination data from SWE-bench (~18k samples across 53 repos)
+- **Multilingual models** β EuroBERT-based models for 8 languages
+- **Web API** β FastAPI server with async client support
+
+## Links
+
+- [GitHub](https://github.com/KRLabsOrg/LettuceDetect)
+- [PyPI](https://pypi.org/project/lettucedetect/)
+- [arXiv Paper](https://arxiv.org/abs/2502.17125)
+- [HuggingFace Models](https://huggingface.co/KRLabsOrg)
+- [Streamlit Demo](https://huggingface.co/spaces/KRLabsOrg/LettuceDetect)
diff --git a/lettucedetect/__init__.py b/lettucedetect/__init__.py
index 928b8f6..62a7324 100644
--- a/lettucedetect/__init__.py
+++ b/lettucedetect/__init__.py
@@ -15,7 +15,7 @@
# Direct RAGFactChecker access for advanced users
from lettucedetect.ragfactchecker import RAGFactChecker
-__version__ = "0.1.7"
+__version__ = "0.1.8"
__all__ = [
"HallucinationData",
diff --git a/lettucedetect/datasets/analyze_lengths.py b/lettucedetect/datasets/analyze_lengths.py
new file mode 100644
index 0000000..e98fd3c
--- /dev/null
+++ b/lettucedetect/datasets/analyze_lengths.py
@@ -0,0 +1,184 @@
+import argparse
+import json
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List
+
+import matplotlib.pyplot as plt
+from hallucination_dataset import HallucinationData, HallucinationSample
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+
+def analyze_lengths(
+ samples: List[HallucinationSample],
+ tokenizer: AutoTokenizer,
+ max_length: int = 4096,
+) -> Dict[str, Any]:
+ """Analyze the token lengths of samples and identify outliers.
+
+ Args:
+ samples: List of HallucinationSample objects
+ tokenizer: Tokenizer to use
+ max_length: Maximum allowed length
+
+ Returns:
+ Dictionary containing analysis results
+
+ """
+ lengths = []
+ outliers = []
+ split_stats = defaultdict(list)
+
+ for sample in tqdm(samples, desc="Analyzing lengths"):
+ # Tokenize combined context and answer
+ tokens = tokenizer(
+ sample.prompt,
+ sample.answer,
+ truncation=False, # No truncation to get true lengths
+ add_special_tokens=True,
+ )
+
+ length = len(tokens["input_ids"])
+ lengths.append(length)
+ split_stats[sample.split].append(length)
+
+ if length > max_length:
+ outliers.append(
+ {
+ "index": len(lengths) - 1,
+ "length": length,
+ "prompt_length": len(tokenizer(sample.prompt)["input_ids"]),
+ "answer_length": len(tokenizer(sample.answer)["input_ids"]),
+ "split": sample.split,
+ }
+ )
+
+ # Calculate statistics
+ stats = {
+ "total_samples": len(samples),
+ "max_length": max(lengths),
+ "min_length": min(lengths),
+ "mean_length": sum(lengths) / len(lengths),
+ "num_outliers": len(outliers),
+ "outliers": outliers,
+ "split_stats": {
+ split: {
+ "count": len(lens),
+ "max": max(lens),
+ "min": min(lens),
+ "mean": sum(lens) / len(lens),
+ "outliers": len([l for l in lens if l > max_length]),
+ }
+ for split, lens in split_stats.items()
+ },
+ }
+
+ return stats
+
+
+def plot_length_distribution(lengths: List[int], output_path: str, max_length: int):
+ """Plot the distribution of sequence lengths."""
+ plt.figure(figsize=(10, 6))
+ plt.hist(lengths, bins=50)
+ plt.axvline(x=max_length, color="r", linestyle="--", label=f"Max Length ({max_length})")
+ plt.xlabel("Sequence Length")
+ plt.ylabel("Count")
+ plt.title("Distribution of Sequence Lengths")
+ plt.legend()
+ plt.savefig(output_path)
+ plt.close()
+
+
+def filter_dataset(
+ data: HallucinationData, max_length: int, tokenizer: AutoTokenizer, output_path: Path
+) -> HallucinationData:
+ """Filter out samples that exceed max_length."""
+ filtered_samples = []
+
+ for sample in tqdm(data.samples, desc="Filtering samples"):
+ tokens = tokenizer(
+ sample.prompt,
+ sample.answer,
+ truncation=False,
+ add_special_tokens=True,
+ )
+
+ if len(tokens["input_ids"]) <= max_length:
+ filtered_samples.append(sample)
+
+ filtered_data = HallucinationData(samples=filtered_samples)
+
+ # Save filtered dataset
+ with open(output_path, "w") as f:
+ json.dump(filtered_data.to_json(), f, indent=2)
+
+ return filtered_data
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Analyze and filter dataset based on sequence lengths"
+ )
+ parser.add_argument("--input_file", type=str, required=True, help="Path to input JSON dataset")
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files")
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="meta-llama/Llama-2-7b-hf",
+ help="Model name for tokenizer",
+ )
+ parser.add_argument("--max_length", type=int, default=4096, help="Maximum sequence length")
+
+ args = parser.parse_args()
+
+ # Create output directory
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+
+ # Load dataset
+ with open(args.input_file) as f:
+ data = HallucinationData.from_json(json.load(f))
+
+ # Analyze lengths
+ stats = analyze_lengths(data.samples, tokenizer, args.max_length)
+
+ # Save statistics
+ with open(output_dir / "length_stats.json", "w") as f:
+ json.dump(stats, f, indent=2)
+
+ # Plot distribution
+ plot_length_distribution(
+ [
+ len(tokenizer(s.prompt, s.answer, add_special_tokens=True)["input_ids"])
+ for s in data.samples
+ ],
+ str(output_dir / "length_distribution.png"),
+ args.max_length,
+ )
+
+ # Filter dataset
+ filtered_data = filter_dataset(
+ data, args.max_length, tokenizer, output_dir / "filtered_dataset.json"
+ )
+
+ print(f"\nAnalysis complete. Results saved to {output_dir}")
+ print(f"Original dataset size: {len(data.samples)}")
+ print(f"Filtered dataset size: {len(filtered_data.samples)}")
+ print(f"Removed {len(data.samples) - len(filtered_data.samples)} samples")
+
+ # Print split-wise statistics
+ print("\nSplit-wise statistics:")
+ for split, stats in stats["split_stats"].items():
+ print(f"\n{split.upper()}:")
+ print(f" Total samples: {stats['count']}")
+ print(f" Samples exceeding max length: {stats['outliers']}")
+ print(f" Max length: {stats['max']}")
+ print(f" Mean length: {stats['mean']:.2f}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lettucedetect/datasets/hallucination_dataset.py b/lettucedetect/datasets/hallucination_dataset.py
index 64e2644..1713193 100644
--- a/lettucedetect/datasets/hallucination_dataset.py
+++ b/lettucedetect/datasets/hallucination_dataset.py
@@ -8,15 +8,30 @@
@dataclass
class HallucinationSample:
+ """A single hallucination detection sample.
+
+ Attributes:
+ prompt: Context text (source documents, code files, documentation, user query).
+ answer: The LLM-generated answer to check for hallucinations.
+ labels: List of span annotations. Each dict has ``start``, ``end`` (character offsets
+ within ``answer``), and ``label`` keys. Empty list for clean samples.
+ split: Dataset split (``train``, ``dev``, or ``test``).
+ task_type: Task type (e.g. ``summarization``, ``qa``, ``code_generation``).
+ dataset: Source dataset (``ragtruth``, ``ragbench``, or ``swebench_code``).
+ language: Language code.
+
+ """
+
prompt: str
answer: str
labels: list[dict]
split: Literal["train", "dev", "test"]
task_type: str
- dataset: Literal["ragtruth", "ragbench"]
- language: Literal["en", "de"]
+ dataset: Literal["ragtruth", "ragbench", "swebench_code"]
+ language: Literal["en", "de", "fr", "es", "it", "pl", "cn", "hu"]
def to_json(self) -> dict:
+ """Serialize to a JSON-compatible dict."""
return {
"prompt": self.prompt,
"answer": self.answer,
@@ -29,6 +44,7 @@ def to_json(self) -> dict:
@classmethod
def from_json(cls, json_dict: dict) -> "HallucinationSample":
+ """Deserialize from a JSON dict."""
return cls(
prompt=json_dict["prompt"],
answer=json_dict["answer"],
@@ -42,13 +58,22 @@ def from_json(cls, json_dict: dict) -> "HallucinationSample":
@dataclass
class HallucinationData:
+ """A collection of hallucination detection samples.
+
+ Attributes:
+ samples: List of :class:`HallucinationSample` instances.
+
+ """
+
samples: list[HallucinationSample]
def to_json(self) -> list[dict]:
+ """Serialize all samples to a JSON-compatible list."""
return [sample.to_json() for sample in self.samples]
@classmethod
def from_json(cls, json_dict: list[dict]) -> "HallucinationData":
+ """Deserialize from a list of JSON dicts."""
return cls(
samples=[HallucinationSample.from_json(sample) for sample in json_dict],
)
@@ -74,6 +99,7 @@ def __init__(
self.max_length = max_length
def __len__(self) -> int:
+ """Return the number of samples in the dataset."""
return len(self.samples)
@classmethod
@@ -84,8 +110,10 @@ def prepare_tokenized_input(
answer: str,
max_length: int = 4096,
) -> tuple[dict[str, torch.Tensor], list[int], torch.Tensor, int]:
- """Tokenizes the context and answer together, computes the answer start token index,
- and initializes a labels list (using -100 for context tokens and 0 for answer tokens).
+ """Tokenize context and answer, compute answer start index, and initialize labels.
+
+ Computes the answer start token index and initializes a labels list
+ (using -100 for context tokens and 0 for answer tokens).
:param tokenizer: The tokenizer to use.
:param context: The context string.
@@ -109,19 +137,14 @@ def prepare_tokenized_input(
offsets = encoding.pop("offset_mapping")[0] # shape: (seq_length, 2)
- # Simple approach: encode just the context with special tokens
- # For most tokenizers, the answer starts right after this
- context_only = tokenizer(context, add_special_tokens=True, return_tensors="pt")
- # The answer starts after the context sequence (with its special tokens)
- answer_start_token = context_only["input_ids"].shape[1]
-
- # Handle any edge cases where this might land on a special token
- if (
- answer_start_token < offsets.size(0)
- and offsets[answer_start_token][0] == offsets[answer_start_token][1]
- ):
- # If we landed on a special token, move forward
- answer_start_token += 1
+ # Compute answer_start_token from the answer side. This is correct
+ # even when the context has been truncated by truncation="only_first",
+ # because the answer is never truncated in that mode.
+ # Layout: [CLS] context_tokens [SEP] answer_tokens [SEP]
+ answer_only = tokenizer(answer, add_special_tokens=False, return_tensors="pt")
+ answer_token_count = answer_only["input_ids"].shape[1]
+ total_seq_len = encoding["input_ids"].shape[1]
+ answer_start_token = total_seq_len - answer_token_count - 1 # -1 for trailing [SEP]
# Initialize labels: -100 for tokens before the asnwer, 0 for tokens in the answer.
labels = [-100] * encoding["input_ids"].shape[1]
diff --git a/lettucedetect/detectors/base.py b/lettucedetect/detectors/base.py
index f9d20c6..9127751 100644
--- a/lettucedetect/detectors/base.py
+++ b/lettucedetect/detectors/base.py
@@ -19,21 +19,33 @@ def predict(
"""Predict hallucination tokens or spans given passages and an answer.
:param context: List of passages that were supplied to the LLM / user.
- :param answer: Modelβgenerated answer to inspect.
+ :param answer: Model-generated answer to inspect.
:param question: Original question (``None`` for summarisation).
- :param output_format: ``"tokens"`` for tokenβlevel dicts, ``"spans"`` for character spans.
+ :param output_format: ``"tokens"`` for token-level dicts, ``"spans"`` for character spans.
:returns: List of predictions in requested format.
"""
pass
@abstractmethod
def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
- """Predict hallucinations when you already have a *single* full prompt string."""
+ """Predict hallucinations from a pre-built prompt string.
+
+ :param prompt: Full prompt (context + question already concatenated).
+ :param answer: Model-generated answer to inspect.
+ :param output_format: ``"tokens"`` or ``"spans"``.
+ :returns: List of predictions in requested format.
+ """
pass
@abstractmethod
def predict_prompt_batch(
self, prompts: list[str], answers: list[str], output_format: str = "tokens"
) -> list:
- """Batch version of `predict_prompt`."""
+ """Batch version of :meth:`predict_prompt`.
+
+ :param prompts: List of full prompt strings.
+ :param answers: List of answers to inspect.
+ :param output_format: ``"tokens"`` or ``"spans"``.
+ :returns: List of prediction lists, one per input pair.
+ """
pass
diff --git a/lettucedetect/detectors/cache.py b/lettucedetect/detectors/cache.py
index 437b21e..159edf5 100644
--- a/lettucedetect/detectors/cache.py
+++ b/lettucedetect/detectors/cache.py
@@ -1,4 +1,4 @@
-"""Threadβsafe JSON cache with SHAβ256 keys."""
+"""Thread-safe JSON cache with SHA-256 keys."""
from __future__ import annotations
@@ -10,9 +10,13 @@
class CacheManager:
- """Diskβbacked cache for expensive LLM calls."""
+ """Disk-backed cache for expensive LLM calls."""
- def __init__(self, file_path: str | Path):
+ def __init__(self, file_path: str | Path) -> None:
+ """Initialize cache from disk.
+
+ :param file_path: Path to the JSON cache file.
+ """
self.path = Path(file_path)
self.path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
@@ -24,11 +28,13 @@ def __init__(self, file_path: str | Path):
def _hash(*parts: str) -> str:
return hashlib.sha256("||".join(parts).encode()).hexdigest()
- def get(self, key: str) -> Any | None:
+ def get(self, key: str) -> Any | None: # noqa: ANN401
+ """Retrieve a cached value by key."""
with self._lock:
return self._data.get(key)
- def set(self, key: str, value: Any) -> None:
+ def set(self, key: str, value: Any) -> None: # noqa: ANN401
+ """Store a value in the cache and persist to disk."""
with self._lock:
self._data[key] = value
self.path.write_text(json.dumps(self._data, ensure_ascii=False), encoding="utf-8")
diff --git a/lettucedetect/detectors/llm.py b/lettucedetect/detectors/llm.py
index 4c4b010..93ea354 100644
--- a/lettucedetect/detectors/llm.py
+++ b/lettucedetect/detectors/llm.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
+import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
@@ -12,6 +13,8 @@
from lettucedetect.detectors.cache import CacheManager
from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils
+logger = logging.getLogger(__name__)
+
# JSON schema for structured response format
HALLUCINATION_SCHEMA = {
"type": "json_schema",
@@ -72,7 +75,7 @@ def __init__(
)
path = Path(fewshot_path)
if not path.exists():
- print(f"Warning: Few-shot examples file not found at {path}")
+ logger.warning("Few-shot examples file not found at %s", path)
self.fewshot = json.loads(path.read_text(encoding="utf-8")) if path.exists() else []
# Load hallucination detection template
@@ -90,9 +93,9 @@ def __init__(
/ "cache"
/ f"cache_{model.replace(':', '_')}_{lang}.json"
)
- print(f"Using default cache file: {cache_file}")
+ logger.info("Using default cache file: %s", cache_file)
else:
- print(f"Using provided cache file: {cache_file}")
+ logger.info("Using provided cache file: %s", cache_file)
self.cache = CacheManager(cache_file)
@@ -185,8 +188,8 @@ def _predict(self, prompt: str, answer: str) -> list[dict]:
payload = json.loads(cached)
return self._to_spans(payload["hallucination_list"], answer)
except (json.JSONDecodeError, KeyError) as e:
- print(f"Error parsing LLM response: {e}")
- print(f"Raw response: {cached}")
+ logger.error("Error parsing LLM response: %s", e)
+ logger.debug("Raw response: %s", cached)
return []
def predict(
@@ -199,7 +202,7 @@ def predict(
"""Predict hallucination spans from the provided context, answer, and question.
:param context: List of passages that were supplied to the LLM / user.
- :param answer: Modelβgenerated answer to inspect.
+ :param answer: Model-generated answer to inspect.
:param question: Original question (``None`` for summarisation).
:param output_format: ``"spans"`` for character spans.
:returns: List of spans.
diff --git a/lettucedetect/detectors/rag_fact_checker.py b/lettucedetect/detectors/rag_fact_checker.py
index fb6cb73..3d66bd9 100644
--- a/lettucedetect/detectors/rag_fact_checker.py
+++ b/lettucedetect/detectors/rag_fact_checker.py
@@ -1,6 +1,8 @@
"""Simple RAGFactChecker detector wrapper for lettuceDetect factory pattern."""
-from typing import Any, Dict, List
+from __future__ import annotations
+
+from typing import Any
from lettucedetect.detectors.base import BaseDetector
@@ -14,9 +16,9 @@ class RAGFactCheckerDetector(BaseDetector):
def __init__(
self,
- openai_api_key: str = None,
+ openai_api_key: str | None = None,
model: str = "gpt-4o",
- base_url: str = None,
+ base_url: str | None = None,
temperature: float = 0.0,
**kwargs,
):
@@ -38,12 +40,12 @@ def __init__(
def predict(
self,
- context: List[str],
+ context: list[str],
answer: str,
- question: str = None,
+ question: str | None = None,
output_format: str = "tokens",
**kwargs,
- ) -> List[Dict[str, Any]] | Dict[str, Any]:
+ ) -> list[dict[str, Any]] | dict[str, Any]:
"""Predict hallucinations using RAGFactChecker.
:param context: List of context documents
@@ -81,13 +83,13 @@ def predict(
def predict_prompt(
self, prompt: str, answer: str, output_format: str = "tokens"
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""Predict using a single prompt string as context."""
return self.predict([prompt], answer, output_format=output_format)
def predict_prompt_batch(
- self, prompts: List[str], answers: List[str], output_format: str = "tokens"
- ) -> List[List[Dict[str, Any]]]:
+ self, prompts: list[str], answers: list[str], output_format: str = "tokens"
+ ) -> list[list[dict[str, Any]]]:
"""Batch prediction using RAGFactChecker's batch processing."""
if len(prompts) != len(answers):
raise ValueError("Number of prompts must match number of answers")
@@ -108,7 +110,7 @@ def predict_prompt_batch(
return converted_results
- def _convert_to_tokens(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]:
+ def _convert_to_tokens(self, answer: str, rag_result: dict[str, Any]) -> list[dict[str, Any]]:
"""Convert RAGFactChecker result to token format."""
tokens = answer.split()
hallucinated_triplets = rag_result.get("hallucinated_triplets", [])
@@ -130,7 +132,7 @@ def _convert_to_tokens(self, answer: str, rag_result: Dict[str, Any]) -> List[Di
return token_predictions
- def _convert_to_spans(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]:
+ def _convert_to_spans(self, answer: str, rag_result: dict[str, Any]) -> list[dict[str, Any]]:
"""Convert RAGFactChecker result to span format with improved triplet matching."""
spans = []
hallucinated_triplets = rag_result.get("hallucinated_triplets", [])
@@ -195,7 +197,7 @@ def _convert_to_spans(self, answer: str, rag_result: Dict[str, Any]) -> List[Dic
# Merge overlapping spans
return self._merge_overlapping_spans(spans)
- def _merge_overlapping_spans(self, spans: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ def _merge_overlapping_spans(self, spans: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge overlapping spans to avoid duplicates."""
if not spans:
return spans
diff --git a/lettucedetect/detectors/transformer.py b/lettucedetect/detectors/transformer.py
index 3ef8c23..a64efa8 100644
--- a/lettucedetect/detectors/transformer.py
+++ b/lettucedetect/detectors/transformer.py
@@ -1,7 +1,9 @@
-"""Transformerβbased hallucination detector."""
+"""Transformer-based hallucination detector."""
from __future__ import annotations
+import logging
+
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
@@ -9,15 +11,28 @@
from lettucedetect.detectors.base import BaseDetector
from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils
+logger = logging.getLogger(__name__)
+
__all__ = ["TransformerDetector"]
class TransformerDetector(BaseDetector):
- """Detect hallucinations with a fineβtuned token classifier."""
+ """Detect hallucinations with a fine-tuned token classifier.
+
+ When the combined context + answer exceeds ``max_length`` tokens, the
+ context is automatically split into chunks. Each chunk is scored
+ independently and the per-token hallucination probability is aggregated
+ across chunks with ``max()``.
+ """
def __init__(
- self, model_path: str, max_length: int = 4096, device=None, lang: Lang = "en", **tok_kwargs
- ):
+ self,
+ model_path: str,
+ max_length: int = 4096,
+ device: torch.device | str | None = None,
+ lang: Lang = "en",
+ **tok_kwargs: object,
+ ) -> None:
"""Initialize the transformer detector.
:param model_path: Path to the pre-trained model.
@@ -36,143 +51,358 @@ def __init__(
)
self.model.to(self.device).eval()
- def _predict(self, prompt: str, answer: str, output_format: str) -> list:
- """Predict hallucination tokens or spans from the provided prompt and answer.
+ # ------------------------------------------------------------------
+ # Chunking helpers
+ # ------------------------------------------------------------------
+
+ def _group_passages_into_chunks(
+ self,
+ context: list[str],
+ question: str | None,
+ answer: str,
+ ) -> list[list[str]]:
+ """Group passages into chunks so each chunk, wrapped in the full instruction template, fits in ``max_length``.
+
+ Preserves the instruction template (question, "Bear in mind...", etc.)
+ in every chunk by working at the passage level instead of slicing raw
+ tokens.
+
+ :param context: List of passage strings.
+ :param question: Original question (``None`` for summarisation).
+ :param answer: The answer string (token budget is reserved).
+ :returns: List of passage groups. Each group will be formatted into a
+ complete prompt via ``PromptUtils.format_context``.
+ """
+ answer_tokens = self.tokenizer(answer, add_special_tokens=False, return_tensors="pt")[
+ "input_ids"
+ ].shape[1]
+ # Total prompt-token budget: max_length minus answer tokens minus 3 special tokens
+ total_budget = self.max_length - answer_tokens - 3
+
+ # Fast path: check whether all passages fit in one prompt.
+ full_prompt = PromptUtils.format_context(context, question, self.lang)
+ full_prompt_tokens = self.tokenizer(
+ full_prompt, add_special_tokens=False, return_tensors="pt"
+ )["input_ids"].shape[1]
+
+ if full_prompt_tokens <= total_budget:
+ return [context]
+
+ # Measure instruction overhead (everything except the passage content).
+ minimal_prompt = PromptUtils.format_context([""], question, self.lang)
+ instruction_overhead = self.tokenizer(
+ minimal_prompt, add_special_tokens=False, return_tensors="pt"
+ )["input_ids"].shape[1]
+
+ passage_budget = total_budget - instruction_overhead
+ if passage_budget <= 0:
+ # Instructions + answer alone exceed max_length; single group, will truncate.
+ return [context]
+
+ # Tokenize each formatted passage line to get its token count.
+ p_word = LANG_TO_PASSAGE[self.lang]
+ passage_token_counts: list[int] = []
+ for i, passage in enumerate(context):
+ line = f"{p_word} {i + 1}: {passage}"
+ tok_count = self.tokenizer(line, add_special_tokens=False, return_tensors="pt")[
+ "input_ids"
+ ].shape[1]
+ passage_token_counts.append(tok_count)
+
+ # Greedily group passages into buckets that fit within the budget.
+ groups: list[list[str]] = []
+ current_group: list[str] = []
+ current_tokens = 0
+
+ for passage, tok_count in zip(context, passage_token_counts):
+ # +1 for the newline separator between passages
+ effective = tok_count + (1 if current_group else 0)
+ if current_tokens + effective > passage_budget and current_group:
+ groups.append(current_group)
+ current_group = []
+ current_tokens = 0
+ effective = tok_count
+ current_group.append(passage)
+ current_tokens += effective
+
+ if current_group:
+ groups.append(current_group)
+
+ return groups if groups else [context]
+
+ # ------------------------------------------------------------------
+ # Single-chunk prediction (the original _predict logic)
+ # ------------------------------------------------------------------
+
+ def _predict_single(self, prompt: str, answer: str, output_format: str) -> list:
+ """Run prediction on a single (prompt, answer) pair that fits in ``max_length``.
:param prompt: The prompt string.
:param answer: The answer string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ :param output_format: ``"tokens"`` or ``"spans"``.
"""
- if output_format not in ["tokens", "spans"]:
- raise ValueError(
- f"TransformerDetector doesn't support '{output_format}' format. "
- "Use 'tokens' or 'spans'"
- )
- # Use the shared tokenization logic from HallucinationDataset
encoding, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input(
self.tokenizer, prompt, answer, self.max_length
)
- # Create a label tensor: mark tokens before answer as -100 (ignored) and answer tokens as 0.
labels = torch.full_like(encoding.input_ids[0], -100, device=self.device)
labels[answer_start_token:] = 0
- # Move encoding to the device
+
encoding = {
key: value.to(self.device)
for key, value in encoding.items()
if key in ["input_ids", "attention_mask", "labels"]
}
- # Run model inference
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
token_preds = torch.argmax(logits, dim=-1)[0]
probabilities = torch.softmax(logits, dim=-1)[0]
- # Mask out predictions for context tokens.
token_preds = torch.where(labels == -100, labels, token_preds)
if output_format == "tokens":
- # return token probabilities for each token (with the tokens as well, if not -100)
- token_probs = []
- input_ids = encoding["input_ids"][0] # Get the input_ids tensor from the encoding dict
+ token_probs: list[dict] = []
+ input_ids = encoding["input_ids"][0]
for i, (token, pred, prob) in enumerate(zip(input_ids, token_preds, probabilities)):
- if not labels[i].item() == -100:
+ if labels[i].item() != -100:
token_probs.append(
{
"token": self.tokenizer.decode([token]),
"pred": pred.item(),
- "prob": prob[1].item(), # Get probability for class 1 (hallucination)
+ "prob": prob[1].item(),
}
)
return token_probs
- elif output_format == "spans":
- # Compute the answer's character offset (the first token of the answer).
- if answer_start_token < offsets.size(0):
- answer_char_offset = offsets[answer_start_token][0].item()
- else:
- answer_char_offset = 0
-
- spans: list[dict] = []
- current_span: dict | None = None
-
- # Iterate over tokens in the answer region.
- for i in range(answer_start_token, token_preds.size(0)):
- # Skip tokens marked as ignored.
- if labels[i].item() == -100:
- continue
-
- token_start, token_end = offsets[i].tolist()
- # Skip special tokens with zero length.
- if token_start == token_end:
- continue
-
- # Adjust offsets relative to the answer text.
- rel_start = token_start - answer_char_offset
- rel_end = token_end - answer_char_offset
-
- is_hallucination = (
- token_preds[i].item() == 1
- ) # assuming class 1 indicates hallucination.
- confidence = probabilities[i, 1].item() if is_hallucination else 0.0
-
- if is_hallucination:
- if current_span is None:
- current_span = {
- "start": rel_start,
- "end": rel_end,
- "confidence": confidence,
- }
- else:
- # Extend the current span.
- current_span["end"] = rel_end
- current_span["confidence"] = max(current_span["confidence"], confidence)
+
+ # output_format == "spans"
+ if answer_start_token < offsets.size(0):
+ answer_char_offset = offsets[answer_start_token][0].item()
+ else:
+ answer_char_offset = 0
+
+ spans: list[dict] = []
+ current_span: dict | None = None
+
+ for i in range(answer_start_token, token_preds.size(0)):
+ if labels[i].item() == -100:
+ continue
+
+ token_start, token_end = offsets[i].tolist()
+ if token_start == token_end:
+ continue
+
+ rel_start = token_start - answer_char_offset
+ rel_end = token_end - answer_char_offset
+
+ is_hallucination = token_preds[i].item() == 1
+ confidence = probabilities[i, 1].item() if is_hallucination else 0.0
+
+ if is_hallucination:
+ if current_span is None:
+ current_span = {
+ "start": rel_start,
+ "end": rel_end,
+ "confidence": confidence,
+ }
else:
- # If we were building a hallucination span, finalize it.
- if current_span is not None:
- # Extract the hallucinated text from the answer.
- span_text = answer[current_span["start"] : current_span["end"]]
- current_span["text"] = span_text
- spans.append(current_span)
- current_span = None
-
- # Append any span still in progress.
- if current_span is not None:
- span_text = answer[current_span["start"] : current_span["end"]]
- current_span["text"] = span_text
- spans.append(current_span)
-
- return spans
+ current_span["end"] = rel_end
+ current_span["confidence"] = max(current_span["confidence"], confidence)
+ else:
+ if current_span is not None:
+ current_span["text"] = answer[current_span["start"] : current_span["end"]]
+ spans.append(current_span)
+ current_span = None
+
+ if current_span is not None:
+ current_span["text"] = answer[current_span["start"] : current_span["end"]]
+ spans.append(current_span)
+
+ return spans
+
+ # ------------------------------------------------------------------
+ # Multi-chunk prediction with max() aggregation
+ # ------------------------------------------------------------------
+
+ def _predict_chunked(self, chunk_prompts: list[str], answer: str, output_format: str) -> list:
+ """Run prediction over multiple context chunks and aggregate with ``max()``.
+
+ For each answer token, the hallucination probability is the maximum
+ across all chunks. This is conservative: a token is only considered
+ supported if *every* chunk considers it supported.
+
+ :param chunk_prompts: Context chunk strings.
+ :param answer: The answer string (same for all chunks).
+ :param output_format: ``"tokens"`` or ``"spans"``.
+ """
+ all_token_results: list[list[dict]] = []
+ for chunk in chunk_prompts:
+ tokens = self._predict_single(chunk, answer, output_format="tokens")
+ all_token_results.append(tokens)
+
+ n_tokens = len(all_token_results[0])
+
+ # Aggregate: max hallucination probability across chunks.
+ aggregated: list[dict] = []
+ for tok_idx in range(n_tokens):
+ max_prob = max(
+ chunk_result[tok_idx]["prob"]
+ for chunk_result in all_token_results
+ if tok_idx < len(chunk_result)
+ )
+ aggregated.append(
+ {
+ "token": all_token_results[0][tok_idx]["token"],
+ "pred": 1 if max_prob >= 0.5 else 0,
+ "prob": max_prob,
+ }
+ )
+
+ if output_format == "tokens":
+ return aggregated
+
+ # For "spans", build character spans from the aggregated token predictions.
+ # Use offset mapping from the first chunk (answer offsets are identical
+ # across chunks because BERT tokenizers process segments independently).
+ return self._build_spans_from_tokens(aggregated, chunk_prompts[0], answer)
+
+ def _build_spans_from_tokens(
+ self, token_results: list[dict], prompt: str, answer: str
+ ) -> list[dict]:
+ """Convert aggregated token predictions into character-level spans.
+
+ Uses the tokenizer offset mapping from encoding *(prompt, answer)* to
+ get precise character positions within the answer.
+
+ :param token_results: Per-answer-token dicts with ``token``, ``pred``, ``prob``.
+ :param prompt: A context prompt string (used to get answer offset mapping).
+ :param answer: The original answer string.
+ :returns: List of span dicts with ``start``, ``end``, ``confidence``, ``text``.
+ """
+ _, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input(
+ self.tokenizer, prompt, answer, self.max_length
+ )
+
+ if answer_start_token < offsets.size(0):
+ answer_char_offset = offsets[answer_start_token][0].item()
else:
- raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.")
+ answer_char_offset = 0
+
+ # Collect answer-region offsets (skip special tokens with zero length).
+ answer_offsets: list[tuple[int, int]] = []
+ for i in range(answer_start_token, offsets.size(0)):
+ s, e = offsets[i].tolist()
+ if s == e:
+ continue
+ answer_offsets.append((s - answer_char_offset, e - answer_char_offset))
+
+ spans: list[dict] = []
+ current_span: dict | None = None
+
+ for tok_idx, tok in enumerate(token_results):
+ if tok_idx >= len(answer_offsets):
+ break
+ rel_start, rel_end = answer_offsets[tok_idx]
+ is_hall = tok["pred"] == 1
+ confidence = tok["prob"] if is_hall else 0.0
- def predict(self, context, answer, question=None, output_format="tokens") -> list:
+ if is_hall:
+ if current_span is None:
+ current_span = {
+ "start": rel_start,
+ "end": rel_end,
+ "confidence": confidence,
+ }
+ else:
+ current_span["end"] = rel_end
+ current_span["confidence"] = max(current_span["confidence"], confidence)
+ else:
+ if current_span is not None:
+ current_span["text"] = answer[current_span["start"] : current_span["end"]]
+ spans.append(current_span)
+ current_span = None
+
+ if current_span is not None:
+ current_span["text"] = answer[current_span["start"] : current_span["end"]]
+ spans.append(current_span)
+
+ return spans
+
+ # ------------------------------------------------------------------
+ # Public API
+ # ------------------------------------------------------------------
+
+ def predict(
+ self,
+ context: list[str],
+ answer: str,
+ question: str | None = None,
+ output_format: str = "tokens",
+ ) -> list:
"""Predict hallucination tokens or spans from the provided context, answer, and question.
:param context: List of passages that were supplied to the LLM / user.
- :param answer: Modelβgenerated answer to inspect.
+ :param answer: Model-generated answer to inspect.
:param question: Original question (``None`` for summarisation).
- :param output_format: ``"tokens"`` for tokenβlevel dicts, ``"spans"`` for character spans.
+ :param output_format: ``"tokens"`` for token-level dicts, ``"spans"`` for character spans.
:returns: List of predictions in requested format.
"""
- formatted_prompt = PromptUtils.format_context(context, question, self.lang)
- return self._predict(formatted_prompt, answer, output_format)
+ if output_format not in ("tokens", "spans"):
+ raise ValueError(
+ f"TransformerDetector doesn't support '{output_format}' format. "
+ "Use 'tokens' or 'spans'"
+ )
- def predict_prompt(self, prompt, answer, output_format="tokens") -> list:
+ groups = self._group_passages_into_chunks(context, question, answer)
+
+ if len(groups) == 1:
+ prompt = PromptUtils.format_context(groups[0], question, self.lang)
+ return self._predict_single(prompt, answer, output_format)
+
+ chunk_prompts = [PromptUtils.format_context(group, question, self.lang) for group in groups]
+ return self._predict_chunked(chunk_prompts, answer, output_format)
+
+ def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
"""Predict hallucination tokens or spans from the provided prompt and answer.
+ Note: unlike :meth:`predict`, this method does **not** chunk the prompt
+ automatically. If the prompt + answer exceed ``max_length``, the prompt
+ will be truncated. Use :meth:`predict` with structured passages for
+ automatic chunking.
+
:param prompt: The prompt string.
:param answer: The answer string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ :param output_format: ``"tokens"`` or ``"spans"``.
+ :returns: List of predictions in requested format.
"""
- return self._predict(prompt, answer, output_format)
+ if output_format not in ("tokens", "spans"):
+ raise ValueError(
+ f"TransformerDetector doesn't support '{output_format}' format. "
+ "Use 'tokens' or 'spans'"
+ )
+ # Warn if the input will be truncated.
+ total_tokens = self.tokenizer(prompt, answer, add_special_tokens=True, return_tensors="pt")[
+ "input_ids"
+ ].shape[1]
+ if total_tokens > self.max_length:
+ logger.warning(
+ "predict_prompt: input (%d tokens) exceeds max_length (%d). "
+ "The prompt will be truncated. Use predict() with structured "
+ "passages for automatic chunking.",
+ total_tokens,
+ self.max_length,
+ )
+ return self._predict_single(prompt, answer, output_format)
- def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list:
+ def predict_prompt_batch(
+ self, prompts: list[str], answers: list[str], output_format: str = "tokens"
+ ) -> list:
"""Predict hallucination tokens or spans from the provided prompts and answers.
:param prompts: List of prompt strings.
:param answers: List of answer strings.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ :param output_format: ``"tokens"`` or ``"spans"``.
+ :returns: List of prediction lists, one per input pair.
"""
- return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)]
+ return [self.predict_prompt(p, a, output_format) for p, a in zip(prompts, answers)]
diff --git a/lettucedetect/models/generation.py b/lettucedetect/models/generation.py
index 939b3f0..f378d12 100644
--- a/lettucedetect/models/generation.py
+++ b/lettucedetect/models/generation.py
@@ -1,6 +1,8 @@
"""Simple hallucination generation using RAGFactChecker."""
-from typing import Any, Dict, List, Optional
+from __future__ import annotations
+
+from typing import Any
from lettucedetect.ragfactchecker import RAGFactChecker
@@ -14,9 +16,9 @@ class HallucinationGenerator:
def __init__(
self,
method: str = "rag_fact_checker",
- openai_api_key: str = None,
+ openai_api_key: str | None = None,
model: str = "gpt-4o",
- base_url: str = None,
+ base_url: str | None = None,
temperature: float = 0.0,
**kwargs,
):
@@ -36,12 +38,12 @@ def __init__(
def generate(
self,
- context: List[str],
+ context: list[str],
question: str,
- answer: str = None,
- error_types: Optional[List[str]] = None,
+ answer: str | None = None,
+ error_types: list[str] | None = None,
intensity: float = 0.3,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Generate hallucinated content.
:param context: List of context documents
@@ -65,12 +67,12 @@ def generate(
def generate_batch(
self,
- contexts: List[List[str]],
- questions: List[str],
- answers: List[str] = None,
- error_types: Optional[List[str]] = None,
+ contexts: list[list[str]],
+ questions: list[str],
+ answers: list[str] | None = None,
+ error_types: list[str] | None = None,
intensity: float = 0.3,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated content for multiple inputs.
:param contexts: List of context lists
@@ -96,12 +98,12 @@ def generate_batch(
async def generate_batch_async(
self,
- contexts: List[List[str]],
- questions: List[str],
- answers: List[str] = None,
- error_types: Optional[List[str]] = None,
+ contexts: list[list[str]],
+ questions: list[str],
+ answers: list[str] | None = None,
+ error_types: list[str] | None = None,
intensity: float = 0.3,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated content for multiple inputs.
:param contexts: List of context lists
diff --git a/lettucedetect/models/inference.py b/lettucedetect/models/inference.py
index 035ed3c..28de429 100644
--- a/lettucedetect/models/inference.py
+++ b/lettucedetect/models/inference.py
@@ -13,10 +13,15 @@ class HallucinationDetector:
"""Facade class that delegates to a concrete detector chosen by *method*.
:param method: ``"transformer"`` (token-classifier) or ``"llm"`` (OpenAI function-calling).
- :param kwargs: Passed straight through to the chosen detectorβs constructor.
+ :param kwargs: Passed straight through to the chosen detector's constructor.
"""
- def __init__(self, method: str = "transformer", **kwargs):
+ def __init__(self, method: str = "transformer", **kwargs) -> None:
+ """Initialize the detector.
+
+ :param method: Detection method to use.
+ :param kwargs: Passed to the detector constructor.
+ """
self.detector = make_detector(method, **kwargs)
def predict(
@@ -47,6 +52,7 @@ def predict_prompt_batch(
self, prompts: list[str], answers: list[str], output_format: str = "tokens"
) -> list:
"""Batch version of :py:meth:`predict_prompt`.
+
Length of *prompts* and *answers* must match.
:param prompts: List of prompt strings.
diff --git a/lettucedetect/models/trainer.py b/lettucedetect/models/trainer.py
index 1ec97d7..c951908 100644
--- a/lettucedetect/models/trainer.py
+++ b/lettucedetect/models/trainer.py
@@ -12,6 +12,12 @@
class Trainer:
+ """Token classification trainer with epoch-based training and validation.
+
+ Trains a model using AdamW, evaluates on a test set after each epoch,
+ and saves the best checkpoint based on hallucinated-class F1.
+ """
+
def __init__(
self,
model: Module,
diff --git a/lettucedetect/prompts/examples_hu.json b/lettucedetect/prompts/examples_hu.json
index 509272f..7c15fb4 100644
--- a/lettucedetect/prompts/examples_hu.json
+++ b/lettucedetect/prompts/examples_hu.json
@@ -34,3 +34,4 @@
}
]
+
diff --git a/lettucedetect/prompts/summary_prompt_hu.txt b/lettucedetect/prompts/summary_prompt_hu.txt
index 2a995dc..4d59f77 100644
--- a/lettucedetect/prompts/summary_prompt_hu.txt
+++ b/lettucedetect/prompts/summary_prompt_hu.txt
@@ -2,3 +2,4 @@ Foglalja ΓΆssze az alΓ‘bbi szΓΆveget:
${text}
output:
+
diff --git a/lettucedetect/ragfactchecker.py b/lettucedetect/ragfactchecker.py
index c62d6f6..3b1d014 100644
--- a/lettucedetect/ragfactchecker.py
+++ b/lettucedetect/ragfactchecker.py
@@ -1,8 +1,10 @@
"""Simple, clean RAGFactChecker wrapper for lettuceDetect."""
+from __future__ import annotations
+
import logging
import os
-from typing import Any, Dict, List, Optional
+from typing import Any
class RAGFactChecker:
@@ -17,9 +19,9 @@ class RAGFactChecker:
def __init__(
self,
- openai_api_key: Optional[str] = None,
+ openai_api_key: str | None = None,
model: str = "gpt-4o",
- base_url: Optional[str] = None,
+ base_url: str | None = None,
temperature: float = 0.0,
):
"""Initialize RAGFactChecker.
@@ -43,7 +45,7 @@ def __init__(
self.logger = logging.getLogger(__name__)
self._setup_components()
- def _setup_components(self):
+ def _setup_components(self) -> None:
"""Initialize RAGFactChecker components."""
try:
from rag_fact_checker.data import Config
@@ -75,7 +77,7 @@ def _setup_components(self):
# ============ TRIPLET OPERATIONS ============
- def generate_triplets(self, text: str) -> List[List[str]]:
+ def generate_triplets(self, text: str) -> list[list[str]]:
"""Generate triplets from text.
:param text: Input text
@@ -88,8 +90,8 @@ def generate_triplets(self, text: str) -> List[List[str]]:
return result.triplets
def compare_triplets(
- self, answer_triplets: List[List[str]], reference_triplets: List[List[str]]
- ) -> Dict[str, Any]:
+ self, answer_triplets: list[list[str]], reference_triplets: list[list[str]]
+ ) -> dict[str, Any]:
"""Compare answer triplets against reference triplets.
:param answer_triplets: Triplets from answer to check
@@ -103,7 +105,7 @@ def compare_triplets(
)
return {"fact_check_results": result.fact_check_prediction_binary, "raw_output": result}
- def analyze_text_pair(self, answer_text: str, reference_text: str) -> Dict[str, Any]:
+ def analyze_text_pair(self, answer_text: str, reference_text: str) -> dict[str, Any]:
"""Generate and compare triplets for two texts.
:param answer_text: Text to analyze
@@ -125,8 +127,8 @@ def analyze_text_pair(self, answer_text: str, reference_text: str) -> Dict[str,
# ============ HALLUCINATION DETECTION ============
def detect_hallucinations(
- self, context: List[str], answer: str, question: Optional[str] = None
- ) -> Dict[str, Any]:
+ self, context: list[str], answer: str, question: str | None = None
+ ) -> dict[str, Any]:
"""Detect hallucinations in answer given context.
:param context: List of context documents
@@ -158,8 +160,8 @@ def detect_hallucinations(
# ============ HALLUCINATION GENERATION ============
def generate_hallucination_from_context(
- self, context: List[str], question: str
- ) -> Dict[str, Any]:
+ self, context: list[str], question: str
+ ) -> dict[str, Any]:
"""Generate hallucinated content from context and question.
:param context: List of context documents
@@ -181,9 +183,9 @@ def generate_hallucination_from_answer(
self,
correct_answer: str,
question: str,
- error_types: Optional[List[str]] = None,
+ error_types: list[str] | None = None,
intensity: float = 0.3,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Generate hallucinated version of a correct answer.
:param correct_answer: The correct answer to modify
@@ -223,11 +225,11 @@ def generate_hallucination_from_answer(
async def generate_hallucination_from_answer_batch_async(
self,
- correct_answers: List[str],
- questions: List[str],
- error_types: Optional[List[List[str]]] = None,
- intensities: Optional[List[float]] = None,
- ) -> List[Dict[str, Any]]:
+ correct_answers: list[str],
+ questions: list[str],
+ error_types: list[list[str]] | None = None,
+ intensities: list[float] | None = None,
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated version of multiple correct answers."""
error_type_enums_list = None
if error_types:
@@ -253,9 +255,9 @@ async def generate_hallucination_from_answer_batch_async(
async def generate_hallucination_from_context_batch_async(
self,
- contexts: List[List[str]],
- questions: List[str],
- ) -> List[Dict[str, Any]]:
+ contexts: list[list[str]],
+ questions: list[str],
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated version of multiple correct answers."""
result = await self.reference_generator.generate_hlcntn_data_batch_async(
contexts, questions
@@ -264,11 +266,11 @@ async def generate_hallucination_from_context_batch_async(
def generate_hallucination_from_answer_batch(
self,
- correct_answers: List[str],
- questions: List[str],
- error_types: Optional[List[List[str]]] = None,
- intensities: Optional[List[float]] = None,
- ) -> List[Dict[str, Any]]:
+ correct_answers: list[str],
+ questions: list[str],
+ error_types: list[list[str]] | None = None,
+ intensities: list[float] | None = None,
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated version of multiple correct answers.
:param correct_answers: List of correct answers to modify
@@ -303,9 +305,9 @@ def generate_hallucination_from_answer_batch(
def generate_hallucination_from_context_batch(
self,
- contexts: List[List[str]],
- questions: List[str],
- ) -> List[Dict[str, Any]]:
+ contexts: list[list[str]],
+ questions: list[str],
+ ) -> list[dict[str, Any]]:
"""Generate hallucinated version of multiple correct answers.
:param contexts: List of context document lists
@@ -317,7 +319,7 @@ def generate_hallucination_from_context_batch(
result = self.reference_generator.generate_hlcntn_data_batch(contexts, questions)
return result
- def generate_triplets_batch(self, texts: List[str]) -> List[List[List[str]]]:
+ def generate_triplets_batch(self, texts: list[str]) -> list[list[list[str]]]:
"""Generate triplets for multiple texts.
:param texts: List of input texts
@@ -341,8 +343,8 @@ def generate_triplets_batch(self, texts: List[str]) -> List[List[List[str]]]:
return results
def detect_hallucinations_batch(
- self, contexts: List[List[str]], answers: List[str], questions: Optional[List[str]] = None
- ) -> List[Dict[str, Any]]:
+ self, contexts: list[list[str]], answers: list[str], questions: list[str] | None = None
+ ) -> list[dict[str, Any]]:
"""Detect hallucinations for multiple context-answer pairs.
:param contexts: List of context document lists
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000..e3bbf28
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,95 @@
+site_name: LettuceDetect
+site_description: A lightweight hallucination detection framework for RAG applications
+site_url: https://krlabsorg.github.io/LettuceDetect
+repo_url: https://github.com/KRLabsOrg/LettuceDetect
+repo_name: KRLabsOrg/LettuceDetect
+
+theme:
+ name: material
+ palette:
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ primary: green
+ accent: light green
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ primary: green
+ accent: light green
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
+ features:
+ - navigation.sections
+ - navigation.expand
+ - navigation.top
+ - content.code.copy
+ - content.code.annotate
+ - toc.follow
+ icon:
+ repo: fontawesome/brands/github
+
+markdown_extensions:
+ - admonition
+ - pymdownx.details
+ - pymdownx.superfences:
+ custom_fences:
+ - name: mermaid
+ class: mermaid
+ format: !!python/name:pymdownx.superfences.fence_code_format
+ - pymdownx.highlight:
+ anchor_linenums: true
+ - pymdownx.inlinehilite
+ - pymdownx.tabbed:
+ alternate_style: true
+ - pymdownx.snippets
+ - attr_list
+ - md_in_html
+ - tables
+ - toc:
+ permalink: true
+
+plugins:
+ - search
+ - mkdocstrings:
+ handlers:
+ python:
+ options:
+ show_source: false
+ show_root_heading: true
+ show_root_full_path: false
+ heading_level: 2
+ members_order: source
+ docstring_style: sphinx
+ merge_init_into_class: true
+ show_if_no_docstring: false
+
+nav:
+ - Home: index.md
+ - Getting Started:
+ - Installation: getting-started/installation.md
+ - Quick Start: getting-started/quickstart.md
+ - Models: getting-started/models.md
+ - Guide:
+ - Detection Methods: guide/detection-methods.md
+ - Evaluation: guide/evaluation.md
+ - Training: guide/training.md
+ - Multilingual: guide/multilingual.md
+ - Web API: guide/api.md
+ - Integrations: guide/integrations.md
+ - Code Hallucination Dataset:
+ - Overview: code-hallucination/index.md
+ - Pipeline Phases: code-hallucination/phases.md
+ - Configuration: code-hallucination/configuration.md
+ - Architecture Research: code-hallucination/architecture-research.md
+ - Research:
+ - TinyLettuce: TINYLETTUCE.md
+ - Multilingual (EuroBERT): EUROBERT.md
+ - API Reference:
+ - Inference: api/inference.md
+ - Detectors: api/detectors.md
+ - Datasets: api/datasets.md
+ - Training: api/training.md
+ - Benchmarks: benchmarks.md
diff --git a/pyproject.toml b/pyproject.toml
index 0d62884..0faab52 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ dependencies = [
[project.urls]
Homepage = "https://github.com/krlabsorg/lettucedetect"
+Documentation = "https://krlabsorg.github.io/LettuceDetect"
[project.optional-dependencies]
dev = [
@@ -42,11 +43,15 @@ api = [
"pydantic-settings>=2.8.0",
"httpx>=0.28"
]
+docs = [
+ "mkdocs-material>=9.0",
+ "mkdocstrings[python]>=0.24",
+]
[tool.setuptools]
packages = ["lettucedetect", "lettucedetect_api"]
-[tool.pytest]
+[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*_pytest.py"
@@ -77,8 +82,13 @@ ignore = [
"ANN003", # **kwargs annotation
"ANN204", # missing return type for __init__
"PTH123", # path.open
+ "C901", # function complexity
]
[tool.ruff.lint.per-file-ignores]
+"tests/**" = ["S101", "ANN001", "ANN201", "ANN202", "D103", "F841"]
+"lettucedetect/preprocess/**" = ["ANN001", "ANN201"]
+"lettucedetect/datasets/analyze_lengths.py" = ["ANN201", "D103", "E741"]
+"lettucedetect/models/evaluator.py" = ["ANN001", "ANN201", "D401"]
"lettucedetect_api/test_server.py" = ["S101"]
"lettucedetect_api/test_client.py" = ["S101"]
\ No newline at end of file
diff --git a/scripts/add_clean_swebench_samples.py b/scripts/add_clean_swebench_samples.py
new file mode 100644
index 0000000..2c5ae32
--- /dev/null
+++ b/scripts/add_clean_swebench_samples.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python3
+"""Add clean (non-hallucinated) samples from remaining SWE-bench Lite instances.
+
+Fetches source files from GitHub and builds LettuceDetect-format samples
+with empty labels (= supported/correct code).
+"""
+
+import json
+import re
+import time
+
+import requests
+from datasets import load_dataset
+
+INPUT_DATASET = (
+ "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect_v2.json"
+)
+RAW_DATASET = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json"
+OUTPUT_PATH = INPUT_DATASET # Overwrite with merged data
+
+GITHUB_RAW_BASE = "https://raw.githubusercontent.com"
+MAX_FILE_CHARS = 12000
+MAX_NEW_SAMPLES = 200 # Cap to avoid too many GitHub requests
+
+
+def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None:
+ url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}"
+ try:
+ r = requests.get(url, timeout=15)
+ if r.status_code == 200:
+ return r.text[:MAX_FILE_CHARS]
+ return None
+ except Exception as e:
+ print(f" Error fetching {filepath}: {e}")
+ return None
+
+
+def extract_changed_files(patch: str) -> list[str]:
+ files = []
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ match = re.search(r"b/(.+)$", line)
+ if match:
+ files.append(match.group(1))
+ return files
+
+
+def extract_code_from_patch(patch: str) -> str:
+ added_lines = []
+ in_hunk = False
+ for line in patch.split("\n"):
+ if line.startswith("@@"):
+ in_hunk = True
+ continue
+ if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"):
+ continue
+ if in_hunk:
+ if line.startswith("+"):
+ added_lines.append(line[1:])
+ elif line.startswith("-"):
+ continue
+ else:
+ if line.startswith(" "):
+ added_lines.append(line[1:])
+ else:
+ added_lines.append(line)
+ return "\n".join(added_lines)
+
+
+def build_prompt(source_files: dict[str, str], user_query: str) -> str:
+ parts = []
+ for filepath, content in source_files.items():
+ parts.append(f"File: {filepath}\n```python\n{content}\n```")
+ parts.append(f"User request: {user_query}")
+ return "\n\n".join(parts)
+
+
+def main():
+ # Load existing data
+ with open(INPUT_DATASET) as f:
+ existing_data = json.load(f)
+ print(f"Existing dataset: {len(existing_data)} samples")
+
+ # Get already-used instance IDs
+ with open(RAW_DATASET) as f:
+ raw_data = json.load(f)
+ used_ids = {item["instance_id"] for item in raw_data}
+ print(f"Already used instance IDs: {len(used_ids)}")
+
+ # Load SWE-bench Lite
+ print("Loading SWE-bench Lite...")
+ ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
+
+ # Filter to unused instances
+ remaining = [s for s in ds if s["instance_id"] not in used_ids]
+ print(f"Remaining SWE-bench instances: {len(remaining)}")
+
+ # Cap the number of new samples
+ remaining = remaining[:MAX_NEW_SAMPLES]
+ print(f"Processing {len(remaining)} instances...")
+
+ new_samples = []
+ for i, sample in enumerate(remaining):
+ instance_id = sample["instance_id"]
+ repo = sample["repo"]
+ commit = sample["base_commit"]
+ patch = sample["patch"]
+
+ print(f"\n[{i + 1}/{len(remaining)}] {instance_id}")
+
+ # Extract changed files
+ changed_files = extract_changed_files(patch)
+ if not changed_files:
+ print(" SKIP: No changed files found")
+ continue
+
+ # Fetch source files from GitHub
+ source_files = {}
+ for filepath in changed_files[:3]: # Limit to 3 files per sample
+ content = fetch_file_from_github(repo, commit, filepath)
+ if content:
+ source_files[filepath] = content
+ print(f" Fetched {filepath}: {len(content)} chars")
+ else:
+ print(f" Failed: {filepath}")
+ time.sleep(0.3)
+
+ if not source_files:
+ print(" SKIP: No source files fetched")
+ continue
+
+ # Extract code from gold patch
+ code = extract_code_from_patch(patch)
+ if not code.strip():
+ print(" SKIP: Empty code")
+ continue
+
+ # Build prompt with source files + user query
+ user_query = sample["problem_statement"][:500]
+ prompt = build_prompt(source_files, user_query)
+
+ new_samples.append(
+ {
+ "prompt": prompt,
+ "answer": code,
+ "labels": [],
+ "split": "train",
+ "task_type": "code_generation",
+ "dataset": "code_hallucination_swebench",
+ "language": "en",
+ }
+ )
+
+ # Progress check
+ if (i + 1) % 20 == 0:
+ print(f"\n Progress: {len(new_samples)} clean samples so far\n")
+
+ # Merge with existing data
+ merged = existing_data + new_samples
+ print(f"\nNew clean samples: {len(new_samples)}")
+ print(f"Total merged: {len(merged)}")
+
+ # Save
+ with open(OUTPUT_PATH, "w") as f:
+ json.dump(merged, f, indent=2)
+
+ # Stats
+ n_clean = sum(1 for s in merged if not s["labels"])
+ n_hall = sum(1 for s in merged if s["labels"])
+ print("\nFinal dataset:")
+ print(f" Clean: {n_clean} ({n_clean / len(merged) * 100:.1f}%)")
+ print(f" Hallucinated: {n_hall} ({n_hall / len(merged) * 100:.1f}%)")
+ print(f" Total: {len(merged)}")
+ print(f"\nSaved to {OUTPUT_PATH}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/code_hallucination/__init__.py b/scripts/code_hallucination/__init__.py
new file mode 100644
index 0000000..706e62a
--- /dev/null
+++ b/scripts/code_hallucination/__init__.py
@@ -0,0 +1 @@
+"""Code hallucination dataset generation pipeline from SWE-bench."""
diff --git a/scripts/code_hallucination/config.py b/scripts/code_hallucination/config.py
new file mode 100644
index 0000000..85591b9
--- /dev/null
+++ b/scripts/code_hallucination/config.py
@@ -0,0 +1,56 @@
+"""Configuration for the code hallucination dataset pipeline."""
+
+import os
+from pathlib import Path
+
+# === Paths ===
+PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
+DATA_DIR = PROJECT_ROOT / "data" / "code_hallucination"
+REPOS_DIR = DATA_DIR / "repos"
+SOURCE_CACHE_DIR = DATA_DIR / "source_cache"
+
+# Intermediate outputs
+INSTANCES_PATH = DATA_DIR / "swebench_instances.json"
+QUERIES_PATH = DATA_DIR / "queries.jsonl"
+DOCS_PATH = DATA_DIR / "documentation.jsonl"
+FORMATS_PATH = DATA_DIR / "formats.jsonl"
+HALLUCINATED_PATH = DATA_DIR / "hallucinated_samples.jsonl"
+
+# Final outputs
+DATASET_PATH = DATA_DIR / "code_hallucination_data.json"
+METADATA_PATH = DATA_DIR / "code_hallucination_metadata.json"
+VALIDATION_REPORT_PATH = DATA_DIR / "validation_report.txt"
+
+# === LLM API Config ===
+# Override via env vars or CLI args
+API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
+API_KEY = os.environ.get("OPENAI_API_KEY", "")
+MODEL = os.environ.get("MODEL", "moonshotai/kimi-k2-instruct-0905")
+
+# Context7
+CONTEXT7_BASE = "https://context7.com/api/v2"
+CONTEXT7_API_KEY = os.environ.get("CONTEXT7_API_KEY", "")
+DOCS_RATIO = 0.5 # Only fetch docs for 50% of instances
+
+# === Dataset Config ===
+HALLUCINATION_RATIO = 0.4 # 40% hallucinated, 60% clean
+MAX_FILE_CHARS = 12000 # Cap individual source file size
+MAX_CONTEXT7_CHARS = 4000 # Documentation fetch limit
+
+# === LLM Config ===
+RETRY_DELAY = 2
+MAX_RETRIES = 3
+LLM_TEMPERATURE = 0.7
+HALLUCINATION_TEMPERATURE = 0.8
+BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1")) # >1 for local vLLM
+
+# Hallucination types (round-robin assignment)
+HALLUCINATION_TYPES = ["structural", "behavioral", "semantic"]
+
+# Answer format types
+FORMAT_TYPES = ["complete_function", "edit_style", "fragment"]
+FORMAT_WEIGHTS = [0.4, 0.3, 0.3] # Target distribution
+
+# SWE-bench datasets
+SWEBENCH_FULL = "princeton-nlp/SWE-bench"
+SWEBENCH_LITE = "princeton-nlp/SWE-bench_Lite"
diff --git a/scripts/code_hallucination/context7_docs.py b/scripts/code_hallucination/context7_docs.py
new file mode 100644
index 0000000..4c1df87
--- /dev/null
+++ b/scripts/code_hallucination/context7_docs.py
@@ -0,0 +1,225 @@
+"""Phase 4: Fetch library documentation from Context7 API.
+
+Only fetches docs for ~50% of instances (configurable via DOCS_RATIO).
+The other 50% get empty docs, creating training variety β models learn
+to handle both with-docs and without-docs scenarios.
+"""
+
+import json
+import random
+import re
+
+import requests
+
+from .config import CONTEXT7_API_KEY, CONTEXT7_BASE, DOCS_PATH, DOCS_RATIO, MAX_CONTEXT7_CHARS
+
+# Map repo paths / import names to likely library names for Context7
+PATH_TO_LIB = {
+ "django": "django",
+ "astropy": "astropy",
+ "sympy": "sympy",
+ "sklearn": "scikit-learn",
+ "matplotlib": "matplotlib",
+ "requests": "requests",
+ "flask": "flask",
+ "pytest": "pytest",
+ "sphinx": "sphinx",
+ "xarray": "xarray",
+ "seaborn": "seaborn",
+ "pylint": "pylint",
+ "pandas": "pandas",
+ "numpy": "numpy",
+ "scipy": "scipy",
+ "transformers": "transformers",
+ "jax": "jax",
+ "torch": "pytorch",
+ "tensorflow": "tensorflow",
+ "sqlalchemy": "sqlalchemy",
+ "celery": "celery",
+ "pydantic": "pydantic",
+ "fastapi": "fastapi",
+ "httpx": "httpx",
+}
+
+
+def extract_imports_from_patch(patch: str) -> list[str]:
+ """Extract Python import statements from added lines in a patch."""
+ imports = set()
+ for line in patch.split("\n"):
+ if line.startswith("+") and not line.startswith("+++"):
+ clean = line[1:].strip()
+ if clean.startswith("import ") or clean.startswith("from "):
+ match = re.match(r"(?:from|import)\s+([\w.]+)", clean)
+ if match:
+ module = match.group(1).split(".")[0]
+ if module and not module.startswith("_"):
+ imports.add(module)
+ return list(imports)
+
+
+def extract_libraries_from_files(changed_files: list[str]) -> list[str]:
+ """Infer libraries from file paths."""
+ libs = set()
+ for f in changed_files:
+ for key, lib in PATH_TO_LIB.items():
+ if key in f:
+ libs.add(lib)
+ return list(libs)
+
+
+def fetch_context7_docs(
+ library_name: str, query: str, max_chars: int = MAX_CONTEXT7_CHARS
+) -> str | None:
+ """Fetch documentation from Context7 for a library + query."""
+ try:
+ headers = {}
+ if CONTEXT7_API_KEY:
+ headers["Authorization"] = f"Bearer {CONTEXT7_API_KEY}"
+
+ r = requests.get(
+ f"{CONTEXT7_BASE}/libs/search",
+ params={"query": query, "libraryName": library_name},
+ headers=headers,
+ timeout=10,
+ )
+ if r.status_code != 200:
+ return None
+ results = r.json().get("results", [])
+ if not results:
+ return None
+
+ lib_id = results[0]["id"]
+
+ r2 = requests.get(
+ f"{CONTEXT7_BASE}/context",
+ params={"libraryId": lib_id, "query": query, "type": "txt"},
+ headers=headers,
+ timeout=10,
+ )
+ if r2.status_code != 200:
+ return None
+
+ doc_text = r2.text[:max_chars]
+ return doc_text if doc_text.strip() else None
+ except Exception as e:
+ print(f" Context7 error for {library_name}: {e}")
+ return None
+
+
+def get_documentation_for_instance(
+ changed_files: list[str], patch: str, problem_statement: str
+) -> dict[str, str]:
+ """Fetch documentation for libraries referenced in an instance."""
+ imported_libs = extract_imports_from_patch(patch)
+ path_libs = extract_libraries_from_files(changed_files)
+ all_libs = list(set(imported_libs + path_libs))
+
+ short_query = problem_statement[:200].replace("\n", " ").strip()
+
+ docs = {}
+ for lib in all_libs[:3]:
+ doc = fetch_context7_docs(lib, short_query)
+ if doc:
+ docs[lib] = doc
+
+ return docs
+
+
+def select_docs_instances(
+ instances: list[dict], ratio: float = DOCS_RATIO, seed: int = 42
+) -> set[str]:
+ """Select which instances should get documentation fetched.
+
+ Returns set of instance_ids that should have docs.
+ """
+ rng = random.Random(seed)
+ ids = [inst["instance_id"] for inst in instances]
+ n_with_docs = int(len(ids) * ratio)
+ rng.shuffle(ids)
+ return set(ids[:n_with_docs])
+
+
+def load_existing_docs(path=DOCS_PATH) -> dict[str, dict]:
+ """Load already-fetched docs for resumability."""
+ existing = {}
+ if path.exists():
+ with open(path) as f:
+ for line in f:
+ try:
+ entry = json.loads(line)
+ existing[entry["instance_id"]] = entry["docs"]
+ except (json.JSONDecodeError, KeyError):
+ continue
+ return existing
+
+
+def run(instances: list[dict]):
+ """Run Phase 4: Fetch documentation for selected instances (~50%)."""
+ print("=" * 60)
+ print("Phase 4: Context7 Documentation")
+ print("=" * 60)
+
+ DOCS_PATH.parent.mkdir(parents=True, exist_ok=True)
+
+ # Select which instances get docs
+ docs_ids = select_docs_instances(instances)
+ print(
+ f"Selected {len(docs_ids)}/{len(instances)} instances for documentation ({DOCS_RATIO:.0%})"
+ )
+
+ existing = load_existing_docs()
+ print(f"Already fetched: {len(existing)} instances")
+
+ to_process = [inst for inst in instances if inst["instance_id"] not in existing]
+ print(f"Remaining: {len(to_process)} instances to process")
+
+ processed = 0
+ with_docs = 0
+ skipped_by_ratio = 0
+
+ with open(DOCS_PATH, "a") as f:
+ for i, inst in enumerate(to_process):
+ instance_id = inst["instance_id"]
+
+ # Skip docs for instances not selected (write empty docs)
+ if instance_id not in docs_ids:
+ entry = {"instance_id": instance_id, "docs": {}}
+ f.write(json.dumps(entry) + "\n")
+ f.flush()
+ processed += 1
+ skipped_by_ratio += 1
+ continue
+
+ changed_files = inst.get("changed_files", [])
+ if not changed_files:
+ from .source_fetcher import extract_changed_files
+
+ changed_files = extract_changed_files(inst["patch"])
+
+ docs = get_documentation_for_instance(
+ changed_files, inst["patch"], inst["problem_statement"]
+ )
+
+ entry = {"instance_id": instance_id, "docs": docs}
+ f.write(json.dumps(entry) + "\n")
+ f.flush()
+
+ processed += 1
+ if docs:
+ with_docs += 1
+
+ if processed % 100 == 0:
+ print(
+ f" Progress: {processed}/{len(to_process)} ({with_docs} with docs, {skipped_by_ratio} skipped)"
+ )
+
+ print(
+ f"\nDone: {processed} processed, {with_docs} with docs, {skipped_by_ratio} skipped (no-docs by design)"
+ )
+
+
+if __name__ == "__main__":
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ run(instances)
diff --git a/scripts/code_hallucination/format_builder.py b/scripts/code_hallucination/format_builder.py
new file mode 100644
index 0000000..78531ab
--- /dev/null
+++ b/scripts/code_hallucination/format_builder.py
@@ -0,0 +1,115 @@
+"""Phase 5: Assign answer format to each instance."""
+
+import json
+import random
+
+from .config import FORMAT_TYPES, FORMAT_WEIGHTS, FORMATS_PATH, SOURCE_CACHE_DIR
+
+
+def assign_format(source_data: dict) -> tuple[str, str]:
+ """Assign a format type and build the answer for an instance.
+
+ Returns (format_type, answer_text).
+ Falls back if preferred format isn't available.
+ """
+ has_functions = bool(source_data.get("modified_functions"))
+ has_edit = bool(source_data.get("edit_style"))
+ has_fragment = bool(source_data.get("patch_code", "").strip())
+
+ # Build available formats
+ available = []
+ if has_functions:
+ available.append("complete_function")
+ if has_edit:
+ available.append("edit_style")
+ if has_fragment:
+ available.append("fragment")
+
+ if not available:
+ return None, None
+
+ # Weighted random choice from available formats
+ weights = []
+ for fmt in available:
+ idx = FORMAT_TYPES.index(fmt)
+ weights.append(FORMAT_WEIGHTS[idx])
+
+ # Normalize weights
+ total = sum(weights)
+ weights = [w / total for w in weights]
+
+ chosen = random.choices(available, weights=weights, k=1)[0]
+
+ # Build answer text
+ if chosen == "complete_function":
+ funcs = source_data["modified_functions"]
+ # Take the first (or longest) modified function
+ func = max(funcs, key=lambda f: len(f.get("patched", "")))
+ answer = func["patched"]
+ elif chosen == "edit_style":
+ answer = source_data["edit_style"]
+ else: # fragment
+ answer = source_data["patch_code"]
+
+ return chosen, answer
+
+
+def run(instances: list[dict], source_cache_dir=SOURCE_CACHE_DIR):
+ """Run Phase 5: Assign formats and build answers.
+
+ Returns list of dicts with instance_id, format_type, answer.
+ """
+ print("=" * 60)
+ print("Phase 5: Answer Format Building")
+ print("=" * 60)
+
+ FORMATS_PATH.parent.mkdir(parents=True, exist_ok=True)
+
+ results = []
+ format_counts = {fmt: 0 for fmt in FORMAT_TYPES}
+ skipped = 0
+
+ for inst in instances:
+ instance_id = inst["instance_id"]
+
+ # Load source data from cache
+ cache_path = source_cache_dir / f"{instance_id}.json"
+ if not cache_path.exists():
+ skipped += 1
+ continue
+
+ with open(cache_path) as f:
+ source_data = json.load(f)
+
+ fmt, answer = assign_format(source_data)
+ if fmt is None:
+ skipped += 1
+ continue
+
+ results.append(
+ {
+ "instance_id": instance_id,
+ "format_type": fmt,
+ "answer": answer,
+ }
+ )
+ format_counts[fmt] += 1
+
+ # Save
+ with open(FORMATS_PATH, "w") as f:
+ for entry in results:
+ f.write(json.dumps(entry) + "\n")
+
+ print(f"\nAssigned formats for {len(results)} instances (skipped {skipped})")
+ for fmt, count in format_counts.items():
+ pct = count * 100 // max(len(results), 1)
+ print(f" {fmt}: {count} ({pct}%)")
+
+ return results
+
+
+if __name__ == "__main__":
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ run(instances)
diff --git a/scripts/code_hallucination/hallucination_injector.py b/scripts/code_hallucination/hallucination_injector.py
new file mode 100644
index 0000000..19dfeed
--- /dev/null
+++ b/scripts/code_hallucination/hallucination_injector.py
@@ -0,0 +1,424 @@
+"""Phase 6: Inject hallucinations using LLM with JSON span annotations.
+
+Supports both sequential (remote API) and async batch (local vLLM) modes.
+Set BATCH_SIZE>1 env var for parallel requests to local vLLM.
+"""
+
+import asyncio
+import json
+import re
+import textwrap
+import time
+
+from openai import AsyncOpenAI, OpenAI
+
+from .config import (
+ API_BASE_URL,
+ API_KEY,
+ BATCH_SIZE,
+ HALLUCINATED_PATH,
+ HALLUCINATION_TEMPERATURE,
+ HALLUCINATION_TYPES,
+ MAX_RETRIES,
+ MODEL,
+ RETRY_DELAY,
+)
+
+INJECTION_SYSTEM_PROMPT = textwrap.dedent("""\
+ You are a code hallucination injector for building a hallucination detection dataset.
+
+ Given correct code and context, create a hallucinated version with a specific type of error.
+
+ Hallucination types:
+ - STRUCTURAL: Change a function call, import, or parameter to something that
+ doesn't exist or is wrong. Code should still parse but reference non-existent
+ APIs, wrong methods, or invented parameters.
+ - BEHAVIORAL: Use correct APIs but with wrong values or logic. Wrong defaults,
+ off-by-one errors, swapped conditions, wrong argument values.
+ - SEMANTIC: Code that looks like it addresses the user's request but does
+ something subtly different or opposite. The code parses, uses real APIs,
+ but fails to do what was asked.
+
+ Rules:
+ - Make changes PLAUSIBLE - something an LLM would realistically generate
+ - Changes must be SUBTLE, not obviously broken
+ - The hallucinated code must still be syntactically valid
+ - Make 1-3 changes, not more
+
+ Respond in this exact JSON format (no markdown, no code blocks):
+ {
+ "hallucinated_code": "the full modified code with hallucinations injected",
+ "changes": [
+ {
+ "original": "exact original code that was changed",
+ "hallucinated": "what you changed it to",
+ "explanation": "why this is a hallucination"
+ }
+ ]
+ }
+
+ IMPORTANT:
+ - "original" must be an exact substring of the correct code
+ - "hallucinated" must be an exact substring of your hallucinated_code
+ - Return ONLY valid JSON, nothing else
+""")
+
+
+def inject_hallucination(
+ client: OpenAI,
+ model: str,
+ clean_answer: str,
+ hall_type: str,
+ user_query: str = "",
+ context: str = "",
+) -> dict | None:
+ """Inject a hallucination and get back structured JSON with spans.
+
+ Returns dict with 'hallucinated_code' and 'changes', or None if failed.
+ """
+ user_msg = f"""Hallucination type to inject: {hall_type.upper()}
+
+User's original request: {user_query[:500]}
+
+Context (source code):
+{context[:2000]}
+
+Correct code to modify:
+{clean_answer}
+
+Generate a hallucinated version with {hall_type} error(s). Return JSON only."""
+
+ for attempt in range(MAX_RETRIES):
+ try:
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": INJECTION_SYSTEM_PROMPT},
+ {"role": "user", "content": user_msg},
+ ],
+ temperature=HALLUCINATION_TEMPERATURE,
+ max_tokens=4000,
+ )
+ raw = response.choices[0].message.content.strip()
+
+ # Parse JSON from response
+ json_match = re.search(r"\{[\s\S]*\}", raw)
+ if not json_match:
+ if attempt < MAX_RETRIES - 1:
+ continue
+ return None
+
+ result = json.loads(json_match.group())
+
+ if "hallucinated_code" not in result or "changes" not in result:
+ if attempt < MAX_RETRIES - 1:
+ continue
+ return None
+
+ # Verify the hallucinated code is actually different
+ if result["hallucinated_code"].strip() == clean_answer.strip():
+ if attempt < MAX_RETRIES - 1:
+ continue
+ return None
+
+ return result
+
+ except (json.JSONDecodeError, Exception) as e:
+ if attempt < MAX_RETRIES - 1:
+ wait = RETRY_DELAY * (attempt + 1)
+ print(f" Injection error (attempt {attempt + 1}): {e}. Retrying in {wait}s...")
+ time.sleep(wait)
+ else:
+ return None
+
+
+def compute_span_offsets(hallucinated_code: str, hallucinated_span: str) -> list[dict]:
+ """Find character offsets of a hallucinated span within the answer code."""
+ spans = []
+ idx = hallucinated_code.find(hallucinated_span)
+ if idx != -1:
+ spans.append({"start": idx, "end": idx + len(hallucinated_span)})
+ return spans
+
+
+def build_labels_from_changes(
+ hallucinated_code: str, changes: list[dict], hall_type: str
+) -> list[dict]:
+ """Build span labels by finding each hallucinated string in the code.
+
+ Only includes spans where the hallucinated text is actually found in the answer.
+ """
+ labels = []
+ for change in changes:
+ h_span = change.get("hallucinated", "")
+ if not h_span or len(h_span) < 3:
+ continue
+ if h_span not in hallucinated_code:
+ continue
+
+ offsets = compute_span_offsets(hallucinated_code, h_span)
+ for offset in offsets[:1]: # First occurrence only
+ labels.append(
+ {
+ "start": offset["start"],
+ "end": offset["end"],
+ "label": hall_type,
+ }
+ )
+
+ return labels
+
+
+def load_existing_hallucinations(path=HALLUCINATED_PATH) -> dict[str, dict]:
+ """Load already-processed hallucinations for resumability."""
+ existing = {}
+ if path.exists():
+ with open(path) as f:
+ for line in f:
+ try:
+ entry = json.loads(line)
+ existing[entry["instance_id"]] = entry
+ except (json.JSONDecodeError, KeyError):
+ continue
+ return existing
+
+
+async def _inject_one_async(
+ aclient: AsyncOpenAI,
+ model: str,
+ clean_answer: str,
+ hall_type: str,
+ user_query: str,
+ context: str,
+) -> dict | None:
+ """Async version of inject_hallucination for batch processing."""
+ user_msg = f"""Hallucination type to inject: {hall_type.upper()}
+
+User's original request: {user_query[:500]}
+
+Context (source code):
+{context[:2000]}
+
+Correct code to modify:
+{clean_answer}
+
+Generate a hallucinated version with {hall_type} error(s). Return JSON only."""
+
+ for attempt in range(MAX_RETRIES):
+ try:
+ response = await aclient.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": INJECTION_SYSTEM_PROMPT},
+ {"role": "user", "content": user_msg},
+ ],
+ temperature=HALLUCINATION_TEMPERATURE,
+ max_tokens=4000,
+ )
+ raw = response.choices[0].message.content.strip()
+ json_match = re.search(r"\{[\s\S]*\}", raw)
+ if not json_match:
+ continue
+ result = json.loads(json_match.group())
+ if "hallucinated_code" not in result or "changes" not in result:
+ continue
+ if result["hallucinated_code"].strip() == clean_answer.strip():
+ continue
+ return result
+ except Exception:
+ if attempt < MAX_RETRIES - 1:
+ await asyncio.sleep(RETRY_DELAY * (attempt + 1))
+ else:
+ return None
+ return None
+
+
+def _process_result(result, instance_id, hall_type, fmt_data, model):
+ """Process a single injection result into a JSONL entry."""
+ if result is None:
+ return None
+ hallucinated_code = result["hallucinated_code"]
+ changes = result.get("changes", [])
+ labels = build_labels_from_changes(hallucinated_code, changes, hall_type)
+ if not labels:
+ return None
+ return {
+ "instance_id": instance_id,
+ "hallucinated_answer": hallucinated_code,
+ "labels": labels,
+ "hallucination_type": hall_type,
+ "injector_model": model,
+ "format_type": fmt_data.get("format_type", "fragment"),
+ }
+
+
+def run(
+ instances_to_inject: list[dict],
+ formats: dict[str, dict],
+ queries: dict[str, str],
+ api_key: str = API_KEY,
+ base_url: str = API_BASE_URL,
+ model: str = MODEL,
+):
+ """Run Phase 6: Inject hallucinations into selected instances.
+
+ Uses async batch processing when BATCH_SIZE > 1 (for local vLLM).
+ Falls back to sequential processing for remote APIs (BATCH_SIZE=1).
+ """
+ print("=" * 60)
+ print("Phase 6: Hallucination Injection")
+ print("=" * 60)
+
+ HALLUCINATED_PATH.parent.mkdir(parents=True, exist_ok=True)
+
+ print(f"Using {base_url} with model {model}")
+ print(f"Batch size: {BATCH_SIZE}")
+
+ existing = load_existing_hallucinations()
+ print(f"Already processed: {len(existing)}")
+
+ to_process = [
+ inst
+ for inst in instances_to_inject
+ if inst["instance_id"] not in existing and inst["instance_id"] in formats
+ ]
+ print(f"Remaining: {len(to_process)} instances to inject")
+
+ if BATCH_SIZE > 1:
+ results = _run_batched(to_process, formats, queries, api_key, base_url, model)
+ else:
+ results = _run_sequential(to_process, formats, queries, api_key, base_url, model)
+
+ # Stats
+ type_counts = {}
+ for r in results:
+ t = r["hallucination_type"]
+ type_counts[t] = type_counts.get(t, 0) + 1
+ print("By type:", type_counts)
+
+ if results:
+ avg_spans = sum(len(r["labels"]) for r in results) / len(results)
+ span_sizes = [lab["end"] - lab["start"] for r in results for lab in r["labels"]]
+ print(f"Avg spans per sample: {avg_spans:.1f}")
+ print(
+ f"Span sizes: min={min(span_sizes)}, max={max(span_sizes)}, avg={sum(span_sizes) // len(span_sizes)}"
+ )
+
+ return results
+
+
+def _run_sequential(to_process, formats, queries, api_key, base_url, model):
+ """Sequential processing for remote APIs (rate-limited)."""
+ client = OpenAI(api_key=api_key, base_url=base_url)
+ processed = 0
+ failed = 0
+ no_spans = 0
+ results = []
+
+ with open(HALLUCINATED_PATH, "a") as f:
+ for i, inst in enumerate(to_process):
+ instance_id = inst["instance_id"]
+ fmt_data = formats.get(instance_id, {})
+ clean_answer = fmt_data.get("answer", "")
+ if not clean_answer:
+ failed += 1
+ continue
+
+ hall_type = HALLUCINATION_TYPES[i % len(HALLUCINATION_TYPES)]
+ query = queries.get(instance_id, "")
+ context = inst.get("problem_statement", "")[:2000]
+
+ result = inject_hallucination(client, model, clean_answer, hall_type, query, context)
+ entry = _process_result(result, instance_id, hall_type, fmt_data, model)
+
+ if entry is None:
+ if result is not None:
+ no_spans += 1
+ failed += 1
+ continue
+
+ f.write(json.dumps(entry) + "\n")
+ f.flush()
+ results.append(entry)
+ processed += 1
+
+ if processed % 50 == 0:
+ print(f" Progress: {processed}/{len(to_process)} (failed: {failed})")
+
+ print(f"\nDone: {processed} injected, {failed} failed ({no_spans} had no matchable spans)")
+ return results
+
+
+def _run_batched(to_process, formats, queries, api_key, base_url, model):
+ """Async batch processing for local vLLM (no rate limiting needed)."""
+ aclient = AsyncOpenAI(api_key=api_key, base_url=base_url)
+ processed = 0
+ failed = 0
+ no_spans = 0
+ results = []
+
+ async def process_batches():
+ nonlocal processed, failed, no_spans
+
+ with open(HALLUCINATED_PATH, "a") as f:
+ for batch_start in range(0, len(to_process), BATCH_SIZE):
+ batch = to_process[batch_start : batch_start + BATCH_SIZE]
+
+ # Build async tasks for the batch
+ tasks = []
+ batch_meta = []
+ for i, inst in enumerate(batch):
+ global_idx = batch_start + i
+ instance_id = inst["instance_id"]
+ fmt_data = formats.get(instance_id, {})
+ clean_answer = fmt_data.get("answer", "")
+ if not clean_answer:
+ failed += 1
+ continue
+
+ hall_type = HALLUCINATION_TYPES[global_idx % len(HALLUCINATION_TYPES)]
+ query = queries.get(instance_id, "")
+ context = inst.get("problem_statement", "")[:2000]
+
+ tasks.append(
+ _inject_one_async(aclient, model, clean_answer, hall_type, query, context)
+ )
+ batch_meta.append((instance_id, hall_type, fmt_data))
+
+ if not tasks:
+ continue
+
+ # Run batch concurrently
+ batch_results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ # Process results and write immediately
+ for result, (instance_id, hall_type, fmt_data) in zip(batch_results, batch_meta):
+ if isinstance(result, Exception):
+ failed += 1
+ continue
+
+ entry = _process_result(result, instance_id, hall_type, fmt_data, model)
+ if entry is None:
+ if result is not None:
+ no_spans += 1
+ failed += 1
+ continue
+
+ f.write(json.dumps(entry) + "\n")
+ f.flush()
+ results.append(entry)
+ processed += 1
+
+ if processed % 50 == 0 or batch_start + BATCH_SIZE >= len(to_process):
+ total = processed + failed
+ print(
+ f" Progress: {total}/{len(to_process)} ({processed} ok, {failed} failed)"
+ )
+
+ asyncio.run(process_batches())
+ print(f"\nDone: {processed} injected, {failed} failed ({no_spans} had no matchable spans)")
+ return results
+
+
+if __name__ == "__main__":
+ print("Run via pipeline.py")
diff --git a/scripts/code_hallucination/pipeline.py b/scripts/code_hallucination/pipeline.py
new file mode 100644
index 0000000..d4e79a9
--- /dev/null
+++ b/scripts/code_hallucination/pipeline.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python3
+"""Orchestrator for the code hallucination dataset pipeline.
+
+Usage:
+ # Run all phases
+ python -m scripts.code_hallucination.pipeline --all
+
+ # Run specific phase
+ python -m scripts.code_hallucination.pipeline --phase 1
+
+ # Test with a few examples
+ python -m scripts.code_hallucination.pipeline --test 5
+
+ # Override LLM settings via env vars:
+ OPENAI_API_KEY=xxx API_BASE_URL=https://api.groq.com/openai/v1 MODEL=moonshotai/kimi-k2-instruct-0905 \
+ python -m scripts.code_hallucination.pipeline --test 10
+"""
+
+import argparse
+import json
+import random
+
+from .config import (
+ API_BASE_URL,
+ API_KEY,
+ DATA_DIR,
+ DOCS_PATH,
+ FORMATS_PATH,
+ HALLUCINATED_PATH,
+ MODEL,
+ QUERIES_PATH,
+)
+
+
+def load_jsonl_dict(path, key="instance_id", value_key=None) -> dict:
+ """Load a JSONL file into a dict keyed by instance_id."""
+ result = {}
+ if not path.exists():
+ return result
+ with open(path) as f:
+ for line in f:
+ try:
+ entry = json.loads(line)
+ if value_key:
+ result[entry[key]] = entry[value_key]
+ else:
+ result[entry[key]] = entry
+ except (json.JSONDecodeError, KeyError):
+ continue
+ return result
+
+
+def run_test(n: int = 5, api_key: str = API_KEY, base_url: str = API_BASE_URL, model: str = MODEL):
+ """Run a quick test with n instances from the test split."""
+ print("=" * 60)
+ print(f"TEST MODE: Running pipeline with {n} instances")
+ print(f"LLM: {model} @ {base_url}")
+ print("=" * 60)
+
+ from .swebench_loader import load_all_splits
+
+ # Load and filter to test split
+ all_instances = load_all_splits()
+ test_instances = [i for i in all_instances if i["split"] == "test"]
+
+ # Take n random instances
+ rng = random.Random(42)
+ selected = rng.sample(test_instances, min(n, len(test_instances)))
+ print(f"Selected {len(selected)} test instances")
+
+ # Save temporary instances
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
+ test_path = DATA_DIR / "test_instances.json"
+ with open(test_path, "w") as f:
+ json.dump(selected, f, indent=2)
+
+ # Phase 2: Fetch sources (use GitHub API for test mode β no cloning needed)
+ from .source_fetcher import run as run_fetch
+
+ sources = run_fetch(selected, use_github_api=True)
+
+ if not sources:
+ print("No sources fetched, aborting test")
+ return
+
+ # Phase 3: Rewrite queries
+ from .query_rewriter import run as run_queries
+
+ run_queries(selected, api_key=api_key, base_url=base_url, model=model)
+
+ # Phase 4: Context7 docs
+ from .context7_docs import run as run_docs
+
+ run_docs(selected)
+
+ # Phase 5: Assign formats
+ from .format_builder import run as run_formats
+
+ run_formats(selected)
+
+ # Phase 8: Select targets (before phase 6)
+ from .splitter import select_hallucination_targets
+
+ targets = select_hallucination_targets(selected)
+
+ # Phase 6: Inject hallucinations
+ from .hallucination_injector import run as run_inject
+
+ formats = load_jsonl_dict(FORMATS_PATH)
+ queries = load_jsonl_dict(QUERIES_PATH, value_key="query")
+ to_inject = [i for i in selected if i["instance_id"] in targets]
+ run_inject(to_inject, formats, queries, api_key=api_key, base_url=base_url, model=model)
+
+ # Phase 7: Assemble
+ from .sample_assembler import run as run_assemble
+
+ docs = load_jsonl_dict(DOCS_PATH, value_key="docs")
+ hallucinations = load_jsonl_dict(HALLUCINATED_PATH)
+ samples, metadata = run_assemble(selected, queries, docs, formats, hallucinations, targets)
+
+ # Phase 9: Validate
+ from .validator import run as run_validate
+
+ run_validate(samples, metadata)
+
+ print("\n" + "=" * 60)
+ print("TEST COMPLETE")
+ print("=" * 60)
+ print(f"Generated {len(samples)} samples from {n} test instances")
+
+ # Show a sample
+ if samples:
+ print("\n--- Sample Example ---")
+ s = samples[0]
+ print(f"Prompt length: {len(s['prompt'])} chars")
+ print(f"Answer length: {len(s['answer'])} chars")
+ print(f"Labels: {len(s['labels'])}")
+ print(f"Split: {s['split']}")
+ print(f"Answer preview: {s['answer'][:200]}...")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Code hallucination dataset pipeline")
+ parser.add_argument(
+ "--phase", nargs="+", type=int, choices=range(1, 10), help="Run specific phase(s)"
+ )
+ parser.add_argument("--all", action="store_true", help="Run all phases")
+ parser.add_argument("--test", type=int, metavar="N", help="Test with N instances")
+ parser.add_argument("--api-key", type=str, default=API_KEY, help="LLM API key")
+ parser.add_argument("--base-url", type=str, default=API_BASE_URL, help="LLM API base URL")
+ parser.add_argument("--model", type=str, default=MODEL, help="LLM model name")
+ args = parser.parse_args()
+
+ if args.test:
+ run_test(args.test, api_key=args.api_key, base_url=args.base_url, model=args.model)
+ return
+
+ if not args.phase and not args.all:
+ parser.print_help()
+ return
+
+ phases = list(range(1, 10)) if args.all else args.phase
+
+ for phase in sorted(phases):
+ print(f"\n{'#' * 60}")
+ print(f"# Running Phase {phase}")
+ print(f"{'#' * 60}\n")
+
+ if phase == 1:
+ from .swebench_loader import run
+
+ run()
+ elif phase == 2:
+ from .source_fetcher import run
+ from .swebench_loader import load_instances
+
+ run(load_instances())
+ elif phase == 3:
+ from .query_rewriter import run
+ from .swebench_loader import load_instances
+
+ run(load_instances(), api_key=args.api_key, base_url=args.base_url, model=args.model)
+ elif phase == 4:
+ from .context7_docs import run
+ from .swebench_loader import load_instances
+
+ run(load_instances())
+ elif phase == 5:
+ from .format_builder import run
+ from .swebench_loader import load_instances
+
+ run(load_instances())
+ elif phase == 6:
+ from .hallucination_injector import run
+ from .splitter import select_hallucination_targets
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ formats = load_jsonl_dict(FORMATS_PATH)
+ queries = load_jsonl_dict(QUERIES_PATH, value_key="query")
+ targets = select_hallucination_targets(instances)
+ to_inject = [i for i in instances if i["instance_id"] in targets]
+ run(
+ to_inject,
+ formats,
+ queries,
+ api_key=args.api_key,
+ base_url=args.base_url,
+ model=args.model,
+ )
+ elif phase == 7:
+ from .sample_assembler import run
+ from .splitter import select_hallucination_targets
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ queries = load_jsonl_dict(QUERIES_PATH, value_key="query")
+ docs = load_jsonl_dict(DOCS_PATH, value_key="docs")
+ formats = load_jsonl_dict(FORMATS_PATH)
+ hallucinations = load_jsonl_dict(HALLUCINATED_PATH)
+ targets = select_hallucination_targets(instances)
+ run(instances, queries, docs, formats, hallucinations, targets)
+ elif phase == 8:
+ from .splitter import run
+ from .swebench_loader import load_instances
+
+ run(load_instances())
+ elif phase == 9:
+ from .validator import run
+
+ run()
+
+ print("\nPipeline complete!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/code_hallucination/query_rewriter.py b/scripts/code_hallucination/query_rewriter.py
new file mode 100644
index 0000000..966c70f
--- /dev/null
+++ b/scripts/code_hallucination/query_rewriter.py
@@ -0,0 +1,141 @@
+"""Phase 3: Rewrite problem statements into natural user queries via LLM."""
+
+import json
+import textwrap
+import time
+
+from openai import OpenAI
+
+from .config import (
+ API_BASE_URL,
+ API_KEY,
+ LLM_TEMPERATURE,
+ MAX_RETRIES,
+ MODEL,
+ QUERIES_PATH,
+ RETRY_DELAY,
+)
+
+REWRITE_SYSTEM_PROMPT = textwrap.dedent("""\
+ You transform GitHub issue descriptions into realistic user queries
+ that a developer would type into an AI coding assistant (like Claude Code or Cursor).
+
+ Rules:
+ - Make it conversational and natural
+ - Keep the core technical ask but remove GitHub formatting
+ - Remove reproduction steps, stack traces, verbose details
+ - Keep it to 1-3 sentences
+ - Don't mention "issue" or "bug report"
+ - Sound like someone asking for help, not filing a report
+""")
+
+
+def get_client(api_key: str = API_KEY, base_url: str = API_BASE_URL) -> OpenAI:
+ return OpenAI(api_key=api_key, base_url=base_url)
+
+
+def llm_call(
+ client: OpenAI,
+ model: str,
+ system: str,
+ user: str,
+ temperature: float = LLM_TEMPERATURE,
+ max_tokens: int = 300,
+) -> str:
+ """Make an LLM call with retries."""
+ for attempt in range(MAX_RETRIES):
+ try:
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": system},
+ {"role": "user", "content": user},
+ ],
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ return response.choices[0].message.content.strip()
+ except Exception as e:
+ if attempt < MAX_RETRIES - 1:
+ wait = RETRY_DELAY * (attempt + 1)
+ print(f" LLM error (attempt {attempt + 1}): {e}. Retrying in {wait}s...")
+ time.sleep(wait)
+ else:
+ raise
+
+
+def rewrite_query(client: OpenAI, model: str, problem_statement: str, repo: str) -> str:
+ """Rewrite a problem statement into a natural user query."""
+ user_msg = f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}"
+ return llm_call(client, model, REWRITE_SYSTEM_PROMPT, user_msg)
+
+
+def load_existing_queries(path=QUERIES_PATH) -> dict[str, str]:
+ """Load already-processed queries for resumability."""
+ existing = {}
+ if path.exists():
+ with open(path) as f:
+ for line in f:
+ try:
+ entry = json.loads(line)
+ existing[entry["instance_id"]] = entry["query"]
+ except (json.JSONDecodeError, KeyError):
+ continue
+ return existing
+
+
+def run(
+ instances: list[dict],
+ api_key: str = API_KEY,
+ base_url: str = API_BASE_URL,
+ model: str = MODEL,
+):
+ """Run Phase 3: Rewrite all queries."""
+ print("=" * 60)
+ print("Phase 3: Query Rewriting")
+ print("=" * 60)
+
+ QUERIES_PATH.parent.mkdir(parents=True, exist_ok=True)
+
+ client = get_client(api_key, base_url)
+ print(f"Using {base_url} with model {model}")
+
+ # Load existing for resumability
+ existing = load_existing_queries()
+ print(f"Already processed: {len(existing)} queries")
+
+ to_process = [inst for inst in instances if inst["instance_id"] not in existing]
+ print(f"Remaining: {len(to_process)} queries to process")
+
+ processed = 0
+ failed = 0
+
+ with open(QUERIES_PATH, "a") as f:
+ for i, inst in enumerate(to_process):
+ instance_id = inst["instance_id"]
+ repo = inst["repo"]
+ problem = inst["problem_statement"]
+
+ try:
+ query = rewrite_query(client, model, problem, repo)
+ entry = {"instance_id": instance_id, "query": query}
+ f.write(json.dumps(entry) + "\n")
+ f.flush()
+ processed += 1
+
+ if processed % 50 == 0:
+ print(f" Progress: {processed}/{len(to_process)} (failed: {failed})")
+ except Exception as e:
+ print(f" ERROR {instance_id}: {e}")
+ failed += 1
+
+ print(f"\nDone: {processed} new queries, {failed} failed")
+ total = len(existing) + processed
+ print(f"Total queries: {total}")
+
+
+if __name__ == "__main__":
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ run(instances)
diff --git a/scripts/code_hallucination/sample_assembler.py b/scripts/code_hallucination/sample_assembler.py
new file mode 100644
index 0000000..bdcbd61
--- /dev/null
+++ b/scripts/code_hallucination/sample_assembler.py
@@ -0,0 +1,182 @@
+"""Phase 7: Assemble final HallucinationSample format."""
+
+import json
+
+from .config import DATASET_PATH, METADATA_PATH, SOURCE_CACHE_DIR
+
+
+def build_prompt(
+ source_files: dict[str, str],
+ documentation: dict[str, str],
+ user_query: str,
+) -> str:
+ """Build the prompt (context) for a sample.
+
+ Format: source files + documentation + user query.
+ """
+ parts = []
+
+ for filepath, content in source_files.items():
+ parts.append(f"File: {filepath}\n```python\n{content}\n```")
+
+ for lib, doc in documentation.items():
+ parts.append(f"Documentation for {lib}:\n{doc}")
+
+ parts.append(f"User request: {user_query}")
+
+ return "\n\n".join(parts)
+
+
+def assemble_samples(
+ instances: list[dict],
+ source_cache: dict[str, dict],
+ queries: dict[str, str],
+ docs: dict[str, dict],
+ formats: dict[str, dict],
+ hallucinations: dict[str, dict],
+ hallucination_instance_ids: set[str],
+) -> tuple[list[dict], list[dict]]:
+ """Assemble all samples into HallucinationSample format.
+
+ Returns (samples, metadata) where each instance maps to exactly 1 sample.
+ """
+ samples = []
+ metadata = []
+
+ for inst in instances:
+ instance_id = inst["instance_id"]
+ split = inst["split"]
+ repo = inst["repo"]
+
+ # Skip if no source data or format
+ if instance_id not in source_cache or instance_id not in formats:
+ continue
+
+ source_data = source_cache[instance_id]
+ fmt_data = formats[instance_id]
+ query = queries.get(instance_id, inst.get("problem_statement", "")[:500])
+ doc = docs.get(instance_id, {})
+
+ # Build prompt from source files + docs + query
+ source_files = source_data.get("source_files", {})
+ prompt = build_prompt(source_files, doc, query)
+
+ if instance_id in hallucination_instance_ids and instance_id in hallucinations:
+ # Hallucinated sample
+ hall_data = hallucinations[instance_id]
+ sample = {
+ "prompt": prompt,
+ "answer": hall_data["hallucinated_answer"],
+ "labels": hall_data["labels"],
+ "split": split,
+ "task_type": "code_generation",
+ "dataset": "swebench_code",
+ "language": "en",
+ }
+ meta = {
+ "instance_id": instance_id,
+ "repo": repo,
+ "split": split,
+ "is_lite": inst.get("is_lite", False),
+ "format_type": hall_data.get("format_type", fmt_data.get("format_type")),
+ "hallucination_type": hall_data.get("hallucination_type"),
+ "injector_model": hall_data.get("injector_model"),
+ "is_hallucinated": True,
+ }
+ else:
+ # Clean sample
+ answer = fmt_data.get("answer", "")
+ if not answer.strip():
+ continue
+
+ sample = {
+ "prompt": prompt,
+ "answer": answer,
+ "labels": [],
+ "split": split,
+ "task_type": "code_generation",
+ "dataset": "swebench_code",
+ "language": "en",
+ }
+ meta = {
+ "instance_id": instance_id,
+ "repo": repo,
+ "split": split,
+ "is_lite": inst.get("is_lite", False),
+ "format_type": fmt_data.get("format_type"),
+ "hallucination_type": None,
+ "injector_model": None,
+ "is_hallucinated": False,
+ }
+
+ samples.append(sample)
+ metadata.append(meta)
+
+ return samples, metadata
+
+
+def run(
+ instances: list[dict],
+ queries: dict[str, str],
+ docs: dict[str, dict],
+ formats: dict[str, dict],
+ hallucinations: dict[str, dict],
+ hallucination_instance_ids: set[str],
+):
+ """Run Phase 7: Assemble all samples."""
+ print("=" * 60)
+ print("Phase 7: Sample Assembly")
+ print("=" * 60)
+
+ # Load source cache
+ source_cache = {}
+ for inst in instances:
+ cache_path = SOURCE_CACHE_DIR / f"{inst['instance_id']}.json"
+ if cache_path.exists():
+ with open(cache_path) as f:
+ source_cache[inst["instance_id"]] = json.load(f)
+
+ print(f"Source cache: {len(source_cache)} instances")
+ print(f"Queries: {len(queries)}")
+ print(f"Docs: {len(docs)}")
+ print(f"Formats: {len(formats)}")
+ print(f"Hallucinations: {len(hallucinations)}")
+ print(f"Hallucination targets: {len(hallucination_instance_ids)}")
+
+ samples, metadata = assemble_samples(
+ instances,
+ source_cache,
+ queries,
+ docs,
+ formats,
+ hallucinations,
+ hallucination_instance_ids,
+ )
+
+ # Save
+ DATASET_PATH.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(DATASET_PATH, "w") as f:
+ json.dump(samples, f, indent=2)
+
+ with open(METADATA_PATH, "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ # Stats
+ n_clean = sum(1 for s in samples if not s["labels"])
+ n_hall = sum(1 for s in samples if s["labels"])
+ print(f"\nTotal samples: {len(samples)}")
+ print(f" Clean: {n_clean} ({n_clean * 100 // max(len(samples), 1)}%)")
+ print(f" Hallucinated: {n_hall} ({n_hall * 100 // max(len(samples), 1)}%)")
+
+ split_counts = {}
+ for s in samples:
+ split_counts[s["split"]] = split_counts.get(s["split"], 0) + 1
+ for split, count in sorted(split_counts.items()):
+ print(f" {split}: {count}")
+
+ return samples, metadata
+
+
+if __name__ == "__main__":
+ print("Run via pipeline.py")
diff --git a/scripts/code_hallucination/source_fetcher.py b/scripts/code_hallucination/source_fetcher.py
new file mode 100644
index 0000000..fa0c3ac
--- /dev/null
+++ b/scripts/code_hallucination/source_fetcher.py
@@ -0,0 +1,460 @@
+"""Phase 2: Clone repos and fetch source files via git show."""
+
+import ast
+import json
+import re
+import subprocess
+import tempfile
+from pathlib import Path
+
+import requests
+
+from .config import MAX_FILE_CHARS, REPOS_DIR, SOURCE_CACHE_DIR
+
+GITHUB_RAW_BASE = "https://raw.githubusercontent.com"
+
+
+def extract_changed_files(patch: str) -> list[str]:
+ """Extract file paths from a unified diff using anchored regex.
+
+ Fixed: Uses re.match on 'diff --git' lines instead of re.search(r"b/(.+)$")
+ which was ambiguous on paths containing 'b/'.
+ """
+ files = []
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ match = re.match(r"diff --git a/(.+?) b/(.+)$", line)
+ if match:
+ files.append(match.group(2))
+ return files
+
+
+def clone_repo(repo: str, repos_dir: Path = REPOS_DIR) -> Path | None:
+ """Clone a repo if not already present. Returns repo directory path."""
+ repos_dir.mkdir(parents=True, exist_ok=True)
+ repo_dir = repos_dir / repo.replace("/", "__")
+
+ if repo_dir.exists():
+ return repo_dir
+
+ print(f" Cloning {repo}...")
+ try:
+ result = subprocess.run(
+ ["git", "clone", "--bare", f"https://github.com/{repo}.git", str(repo_dir)],
+ capture_output=True,
+ text=True,
+ timeout=1800, # 30 min for large repos
+ )
+ if result.returncode != 0:
+ print(f" ERROR cloning {repo}: {result.stderr[:200]}")
+ return None
+ except subprocess.TimeoutExpired:
+ print(f" TIMEOUT cloning {repo} (>30 min)")
+ return None
+
+ return repo_dir
+
+
+def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None:
+ """Fallback: fetch a file from GitHub raw API (for when repo isn't cloned)."""
+ url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}"
+ try:
+ r = requests.get(url, timeout=15)
+ if r.status_code == 200:
+ return r.text[:MAX_FILE_CHARS]
+ return None
+ except Exception:
+ return None
+
+
+def fetch_file_at_commit(repo_dir: Path, commit: str, filepath: str) -> str | None:
+ """Fetch a file's contents at a specific commit using git show."""
+ try:
+ result = subprocess.run(
+ ["git", "show", f"{commit}:{filepath}"],
+ cwd=str(repo_dir),
+ capture_output=True,
+ text=True,
+ timeout=30,
+ )
+ if result.returncode == 0:
+ return result.stdout[:MAX_FILE_CHARS]
+ return None
+ except (subprocess.TimeoutExpired, Exception) as e:
+ print(f" Error fetching {filepath}@{commit[:8]}: {e}")
+ return None
+
+
+def apply_patch_and_get_file(repo_dir: Path, commit: str, patch: str, filepath: str) -> str | None:
+ """Apply a patch to get the post-fix version of a file.
+
+ Uses a temporary worktree to apply the patch cleanly.
+ """
+ try:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a temporary worktree at the base commit
+ result = subprocess.run(
+ ["git", "worktree", "add", tmpdir, commit],
+ cwd=str(repo_dir),
+ capture_output=True,
+ text=True,
+ timeout=60,
+ )
+ if result.returncode != 0:
+ return None
+
+ # Apply the patch
+ result = subprocess.run(
+ ["git", "apply", "--allow-empty", "-"],
+ input=patch,
+ cwd=tmpdir,
+ capture_output=True,
+ text=True,
+ timeout=30,
+ )
+ if result.returncode != 0:
+ return None
+
+ # Read the patched file
+ patched_path = Path(tmpdir) / filepath
+ if patched_path.exists():
+ content = patched_path.read_text()[:MAX_FILE_CHARS]
+ return content
+
+ # Clean up worktree
+ subprocess.run(
+ ["git", "worktree", "remove", "--force", tmpdir],
+ cwd=str(repo_dir),
+ capture_output=True,
+ timeout=30,
+ )
+ return None
+ except Exception as e:
+ print(f" Error applying patch for {filepath}: {e}")
+ return None
+
+
+def extract_modified_functions(original_source: str, patched_source: str) -> list[dict]:
+ """Extract functions that were modified between original and patched source.
+
+ Returns list of dicts with 'name', 'original', 'patched' function bodies.
+ """
+
+ def get_functions(source: str) -> dict[str, str]:
+ """Parse source and extract function name -> source mapping."""
+ try:
+ tree = ast.parse(source)
+ except SyntaxError:
+ return {}
+
+ funcs = {}
+ lines = source.split("\n")
+ for node in ast.walk(tree):
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ start = node.lineno - 1
+ end = node.end_lineno
+ func_source = "\n".join(lines[start:end])
+ funcs[node.name] = func_source
+ return funcs
+
+ orig_funcs = get_functions(original_source)
+ patched_funcs = get_functions(patched_source)
+
+ modified = []
+ for name, patched_body in patched_funcs.items():
+ orig_body = orig_funcs.get(name)
+ if orig_body and orig_body != patched_body:
+ modified.append(
+ {
+ "name": name,
+ "original": orig_body,
+ "patched": patched_body,
+ }
+ )
+ elif not orig_body:
+ # New function added by patch
+ modified.append(
+ {
+ "name": name,
+ "original": None,
+ "patched": patched_body,
+ }
+ )
+
+ return modified
+
+
+def extract_code_from_patch(patch: str) -> str:
+ """Extract added/changed lines from a unified diff as code fragment.
+
+ Fixed: Uses removeprefix("b/") instead of lstrip("b/").
+ """
+ added_lines = []
+ in_hunk = False
+
+ for line in patch.split("\n"):
+ if line.startswith("@@"):
+ in_hunk = True
+ continue
+ if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"):
+ continue
+ if in_hunk:
+ if line.startswith("+"):
+ added_lines.append(line[1:])
+ elif line.startswith("-"):
+ continue
+ else:
+ if line.startswith(" "):
+ added_lines.append(line[1:])
+ else:
+ added_lines.append(line)
+
+ return "\n".join(added_lines)
+
+
+def build_edit_style_answer(patch: str, changed_files: list[str]) -> str | None:
+ """Build an edit-style answer from a patch.
+
+ Format:
+ In file path/to/file.py, replace:
+ ```python
+ old code
+ ```
+ with:
+ ```python
+ new code
+ ```
+ """
+ edits = []
+
+ current_file = None
+ old_lines = []
+ new_lines = []
+ in_hunk = False
+
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ # Flush previous hunk
+ if current_file and (old_lines or new_lines):
+ edits.append(
+ {
+ "file": current_file,
+ "old": "\n".join(old_lines),
+ "new": "\n".join(new_lines),
+ }
+ )
+ old_lines = []
+ new_lines = []
+ match = re.match(r"diff --git a/(.+?) b/(.+)$", line)
+ if match:
+ current_file = match.group(2)
+ in_hunk = False
+ continue
+
+ if line.startswith("@@"):
+ # Flush previous hunk in same file
+ if old_lines or new_lines:
+ edits.append(
+ {
+ "file": current_file,
+ "old": "\n".join(old_lines),
+ "new": "\n".join(new_lines),
+ }
+ )
+ old_lines = []
+ new_lines = []
+ in_hunk = True
+ continue
+
+ if line.startswith("---") or line.startswith("+++"):
+ continue
+
+ if in_hunk:
+ if line.startswith("-"):
+ old_lines.append(line[1:])
+ elif line.startswith("+"):
+ new_lines.append(line[1:])
+ # Context lines are skipped in edit-style
+
+ # Flush last hunk
+ if current_file and (old_lines or new_lines):
+ edits.append(
+ {
+ "file": current_file,
+ "old": "\n".join(old_lines),
+ "new": "\n".join(new_lines),
+ }
+ )
+
+ if not edits:
+ return None
+
+ parts = []
+ for edit in edits:
+ if edit["old"] and edit["new"]:
+ parts.append(
+ f"In file {edit['file']}, replace:\n```python\n{edit['old']}\n```\nwith:\n```python\n{edit['new']}\n```"
+ )
+ elif edit["new"]:
+ parts.append(f"In file {edit['file']}, add:\n```python\n{edit['new']}\n```")
+
+ return "\n\n".join(parts) if parts else None
+
+
+def fetch_source_for_instance(
+ instance: dict, repos_dir: Path = REPOS_DIR, use_github_api: bool = False
+) -> dict | None:
+ """Fetch all source files for a single SWE-bench instance.
+
+ Args:
+ instance: SWE-bench instance dict
+ repos_dir: Directory for cloned repos
+ use_github_api: If True, use GitHub raw API instead of git clone
+
+ Returns dict with:
+ instance_id, changed_files, source_files, patch_code, edit_style,
+ modified_functions
+ Or None if fetching failed.
+
+ """
+ repo = instance["repo"]
+ commit = instance["base_commit"]
+ patch = instance["patch"]
+
+ # Extract changed files from patch
+ changed_files = extract_changed_files(patch)
+ if not changed_files:
+ return None
+
+ # Fetch source files
+ source_files = {}
+ repo_dir = None
+
+ if use_github_api:
+ # Use GitHub raw API (slower but no cloning needed)
+ import time
+
+ for filepath in changed_files:
+ content = fetch_file_from_github(repo, commit, filepath)
+ if content:
+ source_files[filepath] = content
+ time.sleep(0.3)
+ else:
+ # Clone repo and use git show (fast for bulk)
+ repo_dir = clone_repo(repo, repos_dir)
+ if repo_dir is None:
+ # Fallback to GitHub API
+ import time
+
+ print(f" Falling back to GitHub API for {repo}")
+ for filepath in changed_files:
+ content = fetch_file_from_github(repo, commit, filepath)
+ if content:
+ source_files[filepath] = content
+ time.sleep(0.3)
+ else:
+ for filepath in changed_files:
+ content = fetch_file_at_commit(repo_dir, commit, filepath)
+ if content:
+ source_files[filepath] = content
+
+ if not source_files:
+ return None
+
+ # Build different answer formats
+ # Fragment format
+ patch_code = extract_code_from_patch(patch)
+
+ # Edit-style format
+ edit_style = build_edit_style_answer(patch, changed_files)
+
+ # Complete function format (needs patched source via git apply)
+ modified_functions = []
+ if repo_dir is not None:
+ for filepath in changed_files:
+ if filepath in source_files:
+ patched_source = apply_patch_and_get_file(repo_dir, commit, patch, filepath)
+ if patched_source:
+ funcs = extract_modified_functions(source_files[filepath], patched_source)
+ for func in funcs:
+ func["file"] = filepath
+ modified_functions.extend(funcs)
+
+ return {
+ "instance_id": instance["instance_id"],
+ "changed_files": changed_files,
+ "source_files": source_files,
+ "patch_code": patch_code,
+ "edit_style": edit_style,
+ "modified_functions": modified_functions,
+ }
+
+
+def run(instances: list[dict], use_github_api: bool = False):
+ """Run Phase 2: Fetch source files for all instances.
+
+ Args:
+ instances: List of SWE-bench instances
+ use_github_api: If True, use GitHub raw API instead of cloning repos
+
+ """
+ print("=" * 60)
+ print("Phase 2: Source File Fetching")
+ print("=" * 60)
+
+ SOURCE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+
+ if not use_github_api:
+ REPOS_DIR.mkdir(parents=True, exist_ok=True)
+ # Group by repo for efficient cloning
+ repos = set(inst["repo"] for inst in instances)
+ print(f"Need to clone {len(repos)} repos")
+ for repo in sorted(repos):
+ clone_repo(repo)
+ else:
+ print("Using GitHub raw API (no cloning)")
+
+ # Fetch sources per instance
+ results = []
+ failed = 0
+
+ for i, instance in enumerate(instances):
+ if (i + 1) % 100 == 0:
+ print(f" Progress: {i + 1}/{len(instances)} ({len(results)} success, {failed} failed)")
+
+ # Skip if already cached
+ cache_path = SOURCE_CACHE_DIR / f"{instance['instance_id']}.json"
+ if cache_path.exists():
+ with open(cache_path) as f:
+ results.append(json.load(f))
+ continue
+
+ result = fetch_source_for_instance(instance, use_github_api=use_github_api)
+ if result:
+ results.append(result)
+ # Cache result
+ cache_path = SOURCE_CACHE_DIR / f"{instance['instance_id']}.json"
+ with open(cache_path, "w") as f:
+ json.dump(result, f)
+ else:
+ failed += 1
+
+ print(f"\nDone: {len(results)} success, {failed} failed out of {len(instances)}")
+
+ # Report format availability
+ n_fragment = sum(1 for r in results if r.get("patch_code", "").strip())
+ n_edit = sum(1 for r in results if r.get("edit_style"))
+ n_function = sum(1 for r in results if r.get("modified_functions"))
+ print("Format availability:")
+ print(f" Fragment: {n_fragment}")
+ print(f" Edit-style: {n_edit}")
+ print(f" Complete function: {n_function}")
+
+ return results
+
+
+if __name__ == "__main__":
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ run(instances)
diff --git a/scripts/code_hallucination/splitter.py b/scripts/code_hallucination/splitter.py
new file mode 100644
index 0000000..a98c70a
--- /dev/null
+++ b/scripts/code_hallucination/splitter.py
@@ -0,0 +1,64 @@
+"""Phase 8: Train/dev/test split by SWE-bench splits.
+
+Since we use SWE-bench splits directly, this phase mostly validates
+the splits and selects which instances get hallucination injection.
+"""
+
+import random
+
+from .config import HALLUCINATION_RATIO
+
+
+def select_hallucination_targets(
+ instances: list[dict],
+ ratio: float = HALLUCINATION_RATIO,
+ seed: int = 42,
+) -> set[str]:
+ """Select which instances get hallucination injection.
+
+ Applies the ratio uniformly within each split to maintain
+ consistent class distribution across train/dev/test.
+
+ Returns set of instance_ids that should be hallucinated.
+ """
+ rng = random.Random(seed)
+ targets = set()
+
+ # Group by split
+ by_split = {}
+ for inst in instances:
+ split = inst["split"]
+ if split not in by_split:
+ by_split[split] = []
+ by_split[split].append(inst)
+
+ for split, split_instances in by_split.items():
+ n_hall = int(len(split_instances) * ratio)
+ rng.shuffle(split_instances)
+ for inst in split_instances[:n_hall]:
+ targets.add(inst["instance_id"])
+
+ n_clean = len(split_instances) - n_hall
+ print(f" {split}: {n_hall} hallucinated + {n_clean} clean = {len(split_instances)}")
+
+ return targets
+
+
+def run(instances: list[dict]) -> set[str]:
+ """Run Phase 8: Select hallucination targets."""
+ print("=" * 60)
+ print("Phase 8: Split & Target Selection")
+ print("=" * 60)
+
+ targets = select_hallucination_targets(instances)
+ print(f"\nTotal hallucination targets: {len(targets)} out of {len(instances)}")
+ print(f"Ratio: {len(targets) / max(len(instances), 1):.1%}")
+
+ return targets
+
+
+if __name__ == "__main__":
+ from .swebench_loader import load_instances
+
+ instances = load_instances()
+ run(instances)
diff --git a/scripts/code_hallucination/swebench_loader.py b/scripts/code_hallucination/swebench_loader.py
new file mode 100644
index 0000000..523fd19
--- /dev/null
+++ b/scripts/code_hallucination/swebench_loader.py
@@ -0,0 +1,93 @@
+"""Phase 1: Load SWE-bench instances from all splits."""
+
+import json
+
+from datasets import load_dataset
+
+from .config import DATA_DIR, INSTANCES_PATH, SWEBENCH_FULL, SWEBENCH_LITE
+
+
+def load_all_splits() -> list[dict]:
+ """Load all SWE-bench splits and tag each instance.
+
+ Returns list of dicts with fields:
+ instance_id, repo, base_commit, patch, test_patch,
+ problem_statement, hints_text, created_at, version,
+ FAIL_TO_PASS, PASS_TO_PASS, environment_setup_commit,
+ split, is_lite
+ """
+ # Load Lite instance IDs for tagging
+ print("Loading SWE-bench Lite...")
+ lite_ds = load_dataset(SWEBENCH_LITE, split="test")
+ lite_ids = {row["instance_id"] for row in lite_ds}
+ print(f" Lite: {len(lite_ids)} instances")
+
+ all_instances = []
+
+ for split_name in ["train", "dev", "test"]:
+ print(f"Loading SWE-bench {split_name} split...")
+ ds = load_dataset(SWEBENCH_FULL, split=split_name)
+ print(f" {split_name}: {len(ds)} instances")
+
+ for row in ds:
+ instance = dict(row)
+ instance["split"] = split_name
+ instance["is_lite"] = instance["instance_id"] in lite_ids
+ all_instances.append(instance)
+
+ # Report stats
+ repos_by_split = {}
+ for inst in all_instances:
+ split = inst["split"]
+ repo = inst["repo"]
+ if split not in repos_by_split:
+ repos_by_split[split] = set()
+ repos_by_split[split].add(repo)
+
+ print(f"\nTotal: {len(all_instances)} instances")
+ for split, repos in repos_by_split.items():
+ count = sum(1 for i in all_instances if i["split"] == split)
+ print(f" {split}: {count} instances across {len(repos)} repos")
+
+ n_lite = sum(1 for i in all_instances if i["is_lite"])
+ print(f" Lite-tagged: {n_lite}")
+
+ # Verify zero repo overlap
+ all_splits = list(repos_by_split.keys())
+ for i, s1 in enumerate(all_splits):
+ for s2 in all_splits[i + 1 :]:
+ overlap = repos_by_split[s1] & repos_by_split[s2]
+ if overlap:
+ print(f" WARNING: Repo overlap between {s1} and {s2}: {overlap}")
+ else:
+ print(f" {s1} β© {s2}: 0 repo overlap β")
+
+ return all_instances
+
+
+def save_instances(instances: list[dict], path=INSTANCES_PATH):
+ """Save instances to JSON."""
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as f:
+ json.dump(instances, f, indent=2)
+ print(f"Saved {len(instances)} instances to {path}")
+
+
+def load_instances(path=INSTANCES_PATH) -> list[dict]:
+ """Load previously saved instances."""
+ with open(path) as f:
+ return json.load(f)
+
+
+def run():
+ """Run Phase 1: Load and save all SWE-bench instances."""
+ print("=" * 60)
+ print("Phase 1: Load SWE-bench")
+ print("=" * 60)
+ instances = load_all_splits()
+ save_instances(instances)
+ return instances
+
+
+if __name__ == "__main__":
+ run()
diff --git a/scripts/code_hallucination/validator.py b/scripts/code_hallucination/validator.py
new file mode 100644
index 0000000..50d5251
--- /dev/null
+++ b/scripts/code_hallucination/validator.py
@@ -0,0 +1,200 @@
+"""Phase 9: Quality checks and validation report."""
+
+import ast
+import json
+from collections import Counter
+
+from .config import DATASET_PATH, METADATA_PATH, VALIDATION_REPORT_PATH
+
+
+def validate_spans(samples: list[dict]) -> list[str]:
+ """Check span boundary validity."""
+ issues = []
+ for i, sample in enumerate(samples):
+ answer_len = len(sample["answer"])
+ for label in sample.get("labels", []):
+ start = label.get("start", 0)
+ end = label.get("end", 0)
+ if start < 0 or end < 0:
+ issues.append(f"Sample {i}: negative span offset ({start}, {end})")
+ if end <= start:
+ issues.append(f"Sample {i}: empty/inverted span ({start}, {end})")
+ if end > answer_len:
+ issues.append(f"Sample {i}: span exceeds answer length ({end} > {answer_len})")
+ return issues
+
+
+def check_span_coverage(samples: list[dict]) -> dict:
+ """Report span coverage distribution for hallucinated samples."""
+ coverages = []
+ for sample in samples:
+ if not sample.get("labels"):
+ continue
+ answer_len = len(sample["answer"])
+ if answer_len == 0:
+ continue
+ total_span = sum(label["end"] - label["start"] for label in sample["labels"])
+ coverage = total_span / answer_len
+ coverages.append(coverage)
+
+ if not coverages:
+ return {"count": 0}
+
+ return {
+ "count": len(coverages),
+ "min": min(coverages),
+ "max": max(coverages),
+ "mean": sum(coverages) / len(coverages),
+ "low_coverage": sum(1 for c in coverages if c < 0.02),
+ "high_coverage": sum(1 for c in coverages if c > 0.80),
+ }
+
+
+def check_distributions(metadata: list[dict]) -> dict:
+ """Report format/type/LLM/repo distributions."""
+ format_counts = Counter(m.get("format_type") for m in metadata if m.get("format_type"))
+ type_counts = Counter(
+ m.get("hallucination_type") for m in metadata if m.get("hallucination_type")
+ )
+ model_counts = Counter(m.get("injector_model") for m in metadata if m.get("injector_model"))
+ repo_counts = Counter(m.get("repo") for m in metadata)
+ split_counts = Counter(m.get("split") for m in metadata)
+
+ return {
+ "format": dict(format_counts),
+ "hallucination_type": dict(type_counts),
+ "injector_model": dict(model_counts),
+ "repos": len(repo_counts),
+ "top_repos": dict(repo_counts.most_common(10)),
+ "split": dict(split_counts),
+ }
+
+
+def check_near_duplicates(samples: list[dict], threshold: float = 0.95) -> int:
+ """Simple near-duplicate check via answer Jaccard similarity (sampled)."""
+ import random
+
+ rng = random.Random(42)
+
+ n = min(500, len(samples))
+ sample_indices = rng.sample(range(len(samples)), n)
+
+ duplicates = 0
+ for i in range(len(sample_indices)):
+ for j in range(i + 1, min(i + 5, len(sample_indices))):
+ a = set(samples[sample_indices[i]]["answer"].split())
+ b = set(samples[sample_indices[j]]["answer"].split())
+ if not a or not b:
+ continue
+ jaccard = len(a & b) / len(a | b)
+ if jaccard > threshold:
+ duplicates += 1
+
+ return duplicates
+
+
+def check_ast_parseability(samples: list[dict], metadata: list[dict]) -> dict:
+ """Check AST parseability for complete_function format samples."""
+ total = 0
+ parseable = 0
+
+ for sample, meta in zip(samples, metadata):
+ if meta.get("format_type") != "complete_function":
+ continue
+ total += 1
+ try:
+ ast.parse(sample["answer"])
+ parseable += 1
+ except SyntaxError:
+ pass
+
+ return {
+ "total": total,
+ "parseable": parseable,
+ "rate": parseable / max(total, 1),
+ }
+
+
+def run(samples: list[dict] = None, metadata: list[dict] = None):
+ """Run Phase 9: Validation."""
+ print("=" * 60)
+ print("Phase 9: Validation")
+ print("=" * 60)
+
+ if samples is None:
+ with open(DATASET_PATH) as f:
+ samples = json.load(f)
+ if metadata is None:
+ with open(METADATA_PATH) as f:
+ metadata = json.load(f)
+
+ report_lines = []
+
+ def report(text):
+ print(text)
+ report_lines.append(text)
+
+ report(f"Total samples: {len(samples)}")
+ n_clean = sum(1 for s in samples if not s["labels"])
+ n_hall = sum(1 for s in samples if s["labels"])
+ report(f"Clean: {n_clean}, Hallucinated: {n_hall}")
+ report("")
+
+ # 1. Span validity
+ report("=== Span Validity ===")
+ span_issues = validate_spans(samples)
+ report(f"Issues found: {len(span_issues)}")
+ for issue in span_issues[:10]:
+ report(f" {issue}")
+ report("")
+
+ # 2. Span coverage
+ report("=== Span Coverage ===")
+ coverage = check_span_coverage(samples)
+ for k, v in coverage.items():
+ report(f" {k}: {v}")
+ report("")
+
+ # 3. Distributions
+ report("=== Distributions ===")
+ dists = check_distributions(metadata)
+ for k, v in dists.items():
+ report(f" {k}: {v}")
+ report("")
+
+ # 4. Near duplicates
+ report("=== Near Duplicates ===")
+ n_dup = check_near_duplicates(samples)
+ report(f"Near duplicates (sampled): {n_dup}")
+ report("")
+
+ # 5. AST parseability
+ report("=== AST Parseability ===")
+ ast_check = check_ast_parseability(samples, metadata)
+ for k, v in ast_check.items():
+ report(f" {k}: {v}")
+ report("")
+
+ # 6. Length stats
+ report("=== Length Statistics ===")
+ prompt_lens = [len(s["prompt"]) for s in samples]
+ answer_lens = [len(s["answer"]) for s in samples]
+ if prompt_lens:
+ report(
+ f" Prompt chars - min: {min(prompt_lens):,}, max: {max(prompt_lens):,}, avg: {sum(prompt_lens) // len(prompt_lens):,}"
+ )
+ report(
+ f" Answer chars - min: {min(answer_lens):,}, max: {max(answer_lens):,}, avg: {sum(answer_lens) // len(answer_lens):,}"
+ )
+
+ # Save report
+ VALIDATION_REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
+ with open(VALIDATION_REPORT_PATH, "w") as f:
+ f.write("\n".join(report_lines))
+
+ print(f"\nReport saved to {VALIDATION_REPORT_PATH}")
+ return report_lines
+
+
+if __name__ == "__main__":
+ run()
diff --git a/scripts/enrich_code_hallucination_dataset.py b/scripts/enrich_code_hallucination_dataset.py
new file mode 100644
index 0000000..0ceb5a2
--- /dev/null
+++ b/scripts/enrich_code_hallucination_dataset.py
@@ -0,0 +1,259 @@
+#!/usr/bin/env python3
+"""Enrich code hallucination dataset with actual source file contents from GitHub.
+
+Takes the raw dataset and:
+1. Fetches actual source files from GitHub at the base commit
+2. Builds proper context (source files + docs + query)
+3. Converts diff-format answers to actual code
+4. Outputs in exact LettuceDetect HallucinationSample format
+"""
+
+import json
+import os
+import time
+
+import requests
+
+INPUT_PATH = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json"
+OUTPUT_PATH = (
+ "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect_v2.json"
+)
+GITHUB_RAW_BASE = "https://raw.githubusercontent.com"
+MAX_FILE_CHARS = 12000 # Cap individual file size to avoid blowing up context
+
+
+def fetch_file_from_github(repo: str, commit: str, filepath: str) -> str | None:
+ """Fetch a file's contents from GitHub at a specific commit."""
+ url = f"{GITHUB_RAW_BASE}/{repo}/{commit}/{filepath}"
+ try:
+ r = requests.get(url, timeout=15)
+ if r.status_code == 200:
+ return r.text[:MAX_FILE_CHARS]
+ return None
+ except Exception as e:
+ print(f" Error fetching {filepath}: {e}")
+ return None
+
+
+def extract_code_from_patch(patch: str) -> str:
+ """Extract just the added lines from a unified diff as the 'answer' code.
+
+ Returns the new code (added lines) that represents what was generated.
+ """
+ added_lines = []
+ in_hunk = False
+
+ for line in patch.split("\n"):
+ if line.startswith("@@"):
+ in_hunk = True
+ continue
+ if line.startswith("diff --git") or line.startswith("---") or line.startswith("+++"):
+ continue
+ if in_hunk:
+ if line.startswith("+"):
+ added_lines.append(line[1:]) # Remove the '+' prefix
+ elif line.startswith("-"):
+ continue # Skip removed lines
+ else:
+ # Context line - include to maintain readability
+ if line.startswith(" "):
+ added_lines.append(line[1:])
+ else:
+ added_lines.append(line)
+
+ return "\n".join(added_lines)
+
+
+def build_prompt(
+ source_files: dict[str, str],
+ documentation: dict[str, str],
+ user_query: str,
+) -> str:
+ """Build the prompt (context) in the same style as RAGTruth.
+
+ Format: all context concatenated, similar to how RAGTruth provides
+ the source prompt that was given to the LLM.
+ """
+ parts = []
+
+ # Source code files (the main context)
+ for filepath, content in source_files.items():
+ parts.append(f"File: {filepath}\n```python\n{content}\n```")
+
+ # Documentation
+ for lib, doc in documentation.items():
+ parts.append(f"Documentation for {lib}:\n{doc}")
+
+ # User query (acts as the "question" / instruction)
+ parts.append(f"User request: {user_query}")
+
+ return "\n\n".join(parts)
+
+
+def compute_span_offsets_in_code(code: str, hallucinated_span: str) -> list[dict]:
+ """Find character offsets of a hallucinated span within the answer code."""
+ spans = []
+ start = 0
+ while True:
+ idx = code.find(hallucinated_span, start)
+ if idx == -1:
+ break
+ spans.append({"start": idx, "end": idx + len(hallucinated_span)})
+ start = idx + 1
+ break # Take first occurrence only
+ return spans
+
+
+def process_sample(sample: dict, idx: int, total: int) -> list[dict]:
+ """Process one raw sample into LettuceDetect format samples."""
+ instance_id = sample["instance_id"]
+ repo = sample["repo"]
+ commit = sample["base_commit"]
+ changed_files = sample["changed_files"]
+
+ print(f"[{idx + 1}/{total}] {instance_id}")
+
+ # Step 1: Fetch actual source files from GitHub
+ source_files = {}
+ for filepath in changed_files:
+ # Some paths in SWE-bench have the 'a/' or weird format from diff parsing
+ # Clean up: "models/fields/__init__.py b/django/db/models/fields/__init__.py"
+ # Take the last valid-looking path
+ clean_paths = [
+ p.strip() for p in filepath.split(" ") if "/" in p and not p.startswith("a/")
+ ]
+ if not clean_paths:
+ clean_paths = [filepath]
+
+ for path in clean_paths:
+ path = path.lstrip("b/")
+ content = fetch_file_from_github(repo, commit, path)
+ if content:
+ source_files[path] = content
+ print(f" Fetched {path}: {len(content)} chars")
+ else:
+ print(f" Failed to fetch {path}")
+ time.sleep(0.3) # Rate limit
+
+ if not source_files:
+ print(" SKIP: No source files fetched")
+ return []
+
+ # Step 2: Build the prompt (context)
+ documentation = sample.get("documentation", {})
+ user_query = sample.get("user_query", "")
+ prompt = build_prompt(source_files, documentation, user_query)
+
+ # Step 3: Create correct (negative) sample
+ gold_code = extract_code_from_patch(sample["gold_patch"])
+ samples_out = []
+
+ if gold_code.strip():
+ samples_out.append(
+ {
+ "prompt": prompt,
+ "answer": gold_code,
+ "labels": [],
+ "split": "train",
+ "task_type": "code_generation",
+ "dataset": "code_hallucination_swebench",
+ "language": "en",
+ }
+ )
+
+ # Step 4: Create hallucinated (positive) samples
+ for hall in sample.get("hallucinations", []):
+ if not isinstance(hall, dict) or "hallucinated_patch" not in hall:
+ continue
+
+ hall_code = extract_code_from_patch(hall["hallucinated_patch"])
+ if not hall_code.strip():
+ continue
+
+ # Compute span labels in the answer code
+ labels = []
+ for change in hall.get("changes", []):
+ h_span = change.get("hallucinated", "")
+ if h_span and h_span in hall_code:
+ offsets = compute_span_offsets_in_code(hall_code, h_span)
+ for offset in offsets:
+ labels.append(
+ {
+ "start": offset["start"],
+ "end": offset["end"],
+ "label": hall.get("type", "hallucinated"),
+ }
+ )
+
+ if labels:
+ samples_out.append(
+ {
+ "prompt": prompt,
+ "answer": hall_code,
+ "labels": labels,
+ "split": "train",
+ "task_type": "code_generation",
+ "dataset": "code_hallucination_swebench",
+ "language": "en",
+ }
+ )
+
+ print(
+ f" Generated {len(samples_out)} samples (1 correct + {len(samples_out) - 1} hallucinated)"
+ )
+ return samples_out
+
+
+def main():
+ print("=" * 60)
+ print("Enriching Code Hallucination Dataset")
+ print("Fetching source files + converting to LettuceDetect format")
+ print("=" * 60)
+
+ with open(INPUT_PATH) as f:
+ raw_data = json.load(f)
+
+ print(f"Loaded {len(raw_data)} raw samples\n")
+
+ all_samples = []
+ for i, sample in enumerate(raw_data):
+ new_samples = process_sample(sample, i, len(raw_data))
+ all_samples.extend(new_samples)
+
+ # Save
+ os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
+ with open(OUTPUT_PATH, "w") as f:
+ json.dump(all_samples, f, indent=2)
+
+ # Stats
+ print("\n" + "=" * 60)
+ print("FINAL DATASET")
+ print("=" * 60)
+ print(f"Total samples: {len(all_samples)}")
+ n_correct = sum(1 for s in all_samples if not s["labels"])
+ n_hall = sum(1 for s in all_samples if s["labels"])
+ print(f"Correct (negative): {n_correct}")
+ print(f"Hallucinated (positive): {n_hall}")
+
+ prompt_lens = [len(s["prompt"]) for s in all_samples]
+ answer_lens = [len(s["answer"]) for s in all_samples]
+ total_lens = [p + a for p, a in zip(prompt_lens, answer_lens)]
+
+ print(
+ f"\nPrompt chars - min: {min(prompt_lens):,}, max: {max(prompt_lens):,}, avg: {sum(prompt_lens) // len(prompt_lens):,}"
+ )
+ print(
+ f"Answer chars - min: {min(answer_lens):,}, max: {max(answer_lens):,}, avg: {sum(answer_lens) // len(answer_lens):,}"
+ )
+ print(
+ f"Total chars - min: {min(total_lens):,}, max: {max(total_lens):,}, avg: {sum(total_lens) // len(total_lens):,}"
+ )
+
+ est_tokens = [t // 4 for t in total_lens]
+ print(
+ f"\nEst. tokens - min: {min(est_tokens):,}, max: {max(est_tokens):,}, avg: {sum(est_tokens) // len(est_tokens):,}"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/evaluate_code_hallucination.py b/scripts/evaluate_code_hallucination.py
new file mode 100644
index 0000000..87a68fd
--- /dev/null
+++ b/scripts/evaluate_code_hallucination.py
@@ -0,0 +1,296 @@
+#!/usr/bin/env python3
+"""Evaluate LLM baseline on the code hallucination dataset.
+
+Uses LettuceDetect's existing LLMDetector and evaluator infrastructure.
+Supports Groq API with any OpenAI-compatible model.
+
+Usage:
+ # With Groq + Kimi
+ OPENAI_API_KEY=gsk_... OPENAI_API_BASE=https://api.groq.com/openai/v1 \
+ python scripts/evaluate_code_hallucination.py \
+ --model moonshotai/kimi-k2-instruct-0905 \
+ --data_path data/code_hallucination_lettucedetect_v2.json \
+ --evaluation_type example_level
+"""
+
+import argparse
+import json
+import os
+import re
+import time
+from pathlib import Path
+from string import Template
+
+from openai import OpenAI
+from sklearn.metrics import auc, classification_report, precision_recall_fscore_support, roc_curve
+from tqdm import tqdm
+
+from lettucedetect.datasets.hallucination_dataset import HallucinationSample
+
+# Simpler prompt for code hallucination detection (no structured output dependency)
+CODE_HALLUCINATION_PROMPT = Template("""
+You are an expert code reviewer who must identify hallucinated code spans.
+
+A "code hallucination" is any part of the generated code that:
+(a) Uses APIs, methods, functions, classes, or parameters that do NOT exist in the codebase context provided
+(b) Contradicts the documentation or codebase shown in the source
+(c) Does something semantically different from what the user requested
+
+## Instructions
+1. Read the generated code inside ....
+2. Compare it against the source context in ....
+3. Identify any code spans that are hallucinated.
+4. Return a JSON object with this exact format (no markdown, no code blocks):
+ {"hallucination_list": ["exact_span_1", "exact_span_2", ...]}
+ If no hallucinations found, return: {"hallucination_list": []}
+
+IMPORTANT: Each item in hallucination_list must be an EXACT substring from the answer.
+
+
+
+${context}
+
+
+
+${answer}
+""")
+
+
+def predict_with_llm(
+ client: OpenAI,
+ model: str,
+ prompt: str,
+ answer: str,
+ temperature: float = 0.0,
+) -> list[dict]:
+ """Predict hallucination spans using an LLM."""
+ llm_prompt = CODE_HALLUCINATION_PROMPT.substitute(context=prompt, answer=answer)
+
+ try:
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert in detecting hallucinations in LLM-generated code.",
+ },
+ {"role": "user", "content": llm_prompt},
+ ],
+ temperature=temperature,
+ max_tokens=1000,
+ )
+ raw = resp.choices[0].message.content.strip()
+
+ # Parse JSON from response
+ json_match = re.search(r"\{[\s\S]*\}", raw)
+ if json_match:
+ payload = json.loads(json_match.group())
+ spans = []
+ for sub in payload.get("hallucination_list", []):
+ if not sub:
+ continue
+ match = re.search(re.escape(sub), answer)
+ if match:
+ spans.append({"start": match.start(), "end": match.end(), "text": sub})
+ return spans
+ return []
+ except Exception as e:
+ print(f" Error: {e}")
+ return []
+
+
+def evaluate_example_level(
+ samples: list[HallucinationSample],
+ client: OpenAI,
+ model: str,
+ temperature: float = 0.0,
+) -> dict:
+ """Evaluate at example level β binary: does this sample contain hallucinations?"""
+ example_preds = []
+ example_labels = []
+
+ for sample in tqdm(samples, desc="Evaluating (example level)"):
+ predicted_spans = predict_with_llm(client, model, sample.prompt, sample.answer, temperature)
+ true_label = 1 if sample.labels else 0
+ pred_label = 1 if predicted_spans else 0
+ example_labels.append(true_label)
+ example_preds.append(pred_label)
+ time.sleep(0.5) # Rate limit
+
+ precision, recall, f1, _ = precision_recall_fscore_support(
+ example_labels, example_preds, labels=[0, 1], average=None, zero_division=0
+ )
+
+ fpr, tpr, _ = roc_curve(example_labels, example_preds)
+ auroc = auc(fpr, tpr)
+
+ results = {
+ "supported": {
+ "precision": float(precision[0]),
+ "recall": float(recall[0]),
+ "f1": float(f1[0]),
+ },
+ "hallucinated": {
+ "precision": float(precision[1]),
+ "recall": float(recall[1]),
+ "f1": float(f1[1]),
+ },
+ "auroc": auroc,
+ }
+
+ report = classification_report(
+ example_labels,
+ example_preds,
+ target_names=["Supported", "Hallucinated"],
+ digits=4,
+ zero_division=0,
+ )
+ print("\nExample-Level Classification Report:")
+ print(report)
+ print(f"AUROC: {auroc:.4f}")
+
+ return results
+
+
+def evaluate_char_level(
+ samples: list[HallucinationSample],
+ client: OpenAI,
+ model: str,
+ temperature: float = 0.0,
+) -> dict:
+ """Evaluate at character level β overlap between predicted and gold spans."""
+ total_overlap = 0
+ total_predicted = 0
+ total_gold = 0
+
+ for sample in tqdm(samples, desc="Evaluating (char level)"):
+ predicted_spans = predict_with_llm(client, model, sample.prompt, sample.answer, temperature)
+ gold_spans = sample.labels
+
+ total_predicted += sum(p["end"] - p["start"] for p in predicted_spans)
+ total_gold += sum(g["end"] - g["start"] for g in gold_spans)
+
+ for pred in predicted_spans:
+ for gold in gold_spans:
+ overlap_start = max(pred["start"], gold["start"])
+ overlap_end = min(pred["end"], gold["end"])
+ if overlap_end > overlap_start:
+ total_overlap += overlap_end - overlap_start
+
+ time.sleep(0.5) # Rate limit
+
+ precision = total_overlap / total_predicted if total_predicted > 0 else 0
+ recall = total_overlap / total_gold if total_gold > 0 else 0
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
+
+ print("\nCharacter-Level Results:")
+ print(f" Precision: {precision:.4f}")
+ print(f" Recall: {recall:.4f}")
+ print(f" F1: {f1:.4f}")
+
+ return {"precision": precision, "recall": recall, "f1": f1}
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Evaluate LLM baseline on code hallucination dataset"
+ )
+ parser.add_argument("--model", type=str, default="moonshotai/kimi-k2-instruct-0905")
+ parser.add_argument("--data_path", type=str, required=True)
+ parser.add_argument(
+ "--evaluation_type",
+ type=str,
+ default="example_level",
+ choices=["example_level", "char_level", "both"],
+ )
+ parser.add_argument(
+ "--max_samples",
+ type=int,
+ default=None,
+ help="Limit number of test samples (for quick testing)",
+ )
+ parser.add_argument(
+ "--test_ratio", type=float, default=0.3, help="Fraction of data to use as test set"
+ )
+ parser.add_argument("--seed", type=int, default=42)
+
+ args = parser.parse_args()
+
+ # Load data
+ data_path = Path(args.data_path)
+ raw_data = json.loads(data_path.read_text())
+
+ # Convert to HallucinationSample objects
+ samples = []
+ for item in raw_data:
+ # Handle the dataset literal type - use "ragtruth" as fallback
+ dataset = item.get("dataset", "ragtruth")
+ if dataset not in ("ragtruth", "ragbench"):
+ dataset = "ragtruth" # Fallback for compatibility
+ language = item.get("language", "en")
+ if language not in ("en", "de"):
+ language = "en"
+
+ samples.append(
+ HallucinationSample(
+ prompt=item["prompt"],
+ answer=item["answer"],
+ labels=item["labels"],
+ split=item.get("split", "test"),
+ task_type=item.get("task_type", "code_generation"),
+ dataset=dataset,
+ language=language,
+ )
+ )
+
+ # Split into test set
+ import random
+
+ random.seed(args.seed)
+ random.shuffle(samples)
+
+ test_size = int(len(samples) * args.test_ratio)
+ test_samples = samples[:test_size]
+
+ if args.max_samples:
+ test_samples = test_samples[: args.max_samples]
+
+ n_positive = sum(1 for s in test_samples if s.labels)
+ n_negative = sum(1 for s in test_samples if not s.labels)
+
+ print(f"Dataset: {data_path}")
+ print(f"Total samples: {len(samples)}")
+ print(f"Test samples: {len(test_samples)} (positive: {n_positive}, negative: {n_negative})")
+ print(f"Model: {args.model}")
+ print(f"API base: {os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1')}")
+
+ # Setup client
+ client = OpenAI(
+ api_key=os.getenv("OPENAI_API_KEY"),
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ )
+
+ # Run evaluation
+ results = {}
+
+ if args.evaluation_type in ("example_level", "both"):
+ print("\n" + "=" * 60)
+ print("EXAMPLE-LEVEL EVALUATION")
+ print("=" * 60)
+ results["example_level"] = evaluate_example_level(test_samples, client, args.model)
+
+ if args.evaluation_type in ("char_level", "both"):
+ print("\n" + "=" * 60)
+ print("CHARACTER-LEVEL EVALUATION")
+ print("=" * 60)
+ results["char_level"] = evaluate_char_level(test_samples, client, args.model)
+
+ # Save results
+ output_path = data_path.parent / f"eval_results_{args.model.replace('/', '_')}.json"
+ with open(output_path, "w") as f:
+ json.dump(results, f, indent=2)
+ print(f"\nResults saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/filter_data.py b/scripts/filter_data.py
new file mode 100644
index 0000000..94bedd8
--- /dev/null
+++ b/scripts/filter_data.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import logging
+from pathlib import Path
+
+from lettucedetect.datasets.hallucination_dataset import HallucinationData
+
+# Set up logging
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger("filter_data")
+
+
+def filter_dataset(
+ input_file: Path, output_file: Path, split: str = "test", task_type: str = "Data2txt"
+) -> None:
+ """Filter dataset to include only samples with specific split and task type.
+
+ :param input_file: Path to the input JSON file
+ :param output_file: Path to save the filtered JSON file
+ :param split: The split to filter for (default: 'test')
+ :param task_type: The task type to filter for (default: 'Data2txt')
+ """
+ if not input_file.exists():
+ raise FileNotFoundError(f"Input file not found: {input_file}")
+
+ # Create output directory if it doesn't exist
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ # Load data
+ try:
+ with open(input_file, "r") as f:
+ data = json.load(f)
+
+ dataset = HallucinationData.from_json(data)
+ logger.info(f"Loaded {len(dataset.samples)} samples from {input_file}")
+ except Exception as e:
+ logger.error(f"Error loading input data: {e}")
+ raise
+
+ # Filter samples
+ filtered_samples = [
+ sample
+ for sample in dataset.samples
+ if sample.split == split and sample.task_type == task_type
+ ]
+
+ # Create filtered dataset
+ filtered_dataset = HallucinationData(samples=filtered_samples)
+
+ # Save filtered data
+ with open(output_file, "w") as f:
+ json.dump(filtered_dataset.to_json(), f, indent=2)
+
+ logger.info(f"Filtered dataset from {len(dataset.samples)} to {len(filtered_samples)} samples")
+ logger.info(f"Saved filtered dataset to {output_file}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Filter hallucination dataset to specific split and task type"
+ )
+ parser.add_argument("--input", type=str, required=True, help="Path to input JSON file")
+ parser.add_argument("--output", type=str, required=True, help="Path to save filtered JSON file")
+ parser.add_argument(
+ "--split", type=str, default="test", help="Split to filter for (default: 'test')"
+ )
+ parser.add_argument(
+ "--task-type",
+ type=str,
+ default="Data2txt",
+ help="Task type to filter for (default: 'Data2txt')",
+ )
+
+ args = parser.parse_args()
+
+ filter_dataset(Path(args.input), Path(args.output), args.split, args.task_type)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/generate_code_hallucination_dataset.py b/scripts/generate_code_hallucination_dataset.py
new file mode 100644
index 0000000..46b7634
--- /dev/null
+++ b/scripts/generate_code_hallucination_dataset.py
@@ -0,0 +1,597 @@
+#!/usr/bin/env python3
+"""Generate a code hallucination detection dataset from SWE-bench + Context7.
+
+Pipeline:
+1. Load SWE-bench Lite samples
+2. Extract context files and imports from patches
+3. Fetch relevant documentation via Context7
+4. Transform issues into user-style queries (via LLM)
+5. Inject hallucinations (structural + behavioral + semantic)
+6. Convert to LettuceDetect training format
+"""
+
+import json
+import os
+import re
+import textwrap
+import time
+from typing import Any
+
+import requests
+from openai import OpenAI
+
+# === Config ===
+GROQ_BASE_URL = "https://api.groq.com/openai/v1"
+MODEL = "moonshotai/kimi-k2-instruct-0905"
+CONTEXT7_BASE = "https://context7.com/api/v2"
+OUTPUT_PATH = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_dataset.json"
+LETTUCEDETECT_OUTPUT_PATH = (
+ "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_lettucedetect.json"
+)
+NUM_SAMPLES = 50
+RETRY_DELAY = 2
+MAX_RETRIES = 3
+
+
+def get_client() -> OpenAI:
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("Set OPENAI_API_KEY to your Groq API key")
+ return OpenAI(api_key=api_key, base_url=GROQ_BASE_URL)
+
+
+# === Patch Parsing ===
+
+
+def extract_changed_files(patch: str) -> list[str]:
+ """Extract file paths changed in a unified diff."""
+ files = []
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ match = re.search(r"b/(.+)$", line)
+ if match:
+ files.append(match.group(1))
+ return files
+
+
+def extract_imports_from_patch(patch: str) -> list[str]:
+ """Extract Python import statements from added lines in a patch."""
+ imports = set()
+ for line in patch.split("\n"):
+ if line.startswith("+") and not line.startswith("+++"):
+ clean = line[1:].strip()
+ # Match import statements
+ if clean.startswith("import ") or clean.startswith("from "):
+ # Extract the top-level module
+ match = re.match(r"(?:from|import)\s+([\w.]+)", clean)
+ if match:
+ module = match.group(1).split(".")[0]
+ # Filter out local/relative imports and stdlib
+ if module and not module.startswith("_"):
+ imports.add(module)
+ return list(imports)
+
+
+def extract_libraries_from_files(changed_files: list[str]) -> list[str]:
+ """Infer which external libraries might be relevant from file paths."""
+ # Map repo paths to likely library names
+ path_to_lib = {
+ "django": "django",
+ "astropy": "astropy",
+ "sympy": "sympy",
+ "sklearn": "scikit-learn",
+ "matplotlib": "matplotlib",
+ "requests": "requests",
+ "flask": "flask",
+ "pytest": "pytest",
+ "sphinx": "sphinx",
+ "xarray": "xarray",
+ "seaborn": "seaborn",
+ "pylint": "pylint",
+ }
+ libs = set()
+ for f in changed_files:
+ for key, lib in path_to_lib.items():
+ if key in f:
+ libs.add(lib)
+ return list(libs)
+
+
+# === Context7 Documentation ===
+
+
+def fetch_context7_docs(library_name: str, query: str, max_chars: int = 2000) -> str | None:
+ """Fetch documentation from Context7 for a library + query."""
+ try:
+ # Step 1: Resolve library ID
+ r = requests.get(
+ f"{CONTEXT7_BASE}/libs/search",
+ params={"query": query, "libraryName": library_name},
+ timeout=10,
+ )
+ if r.status_code != 200:
+ return None
+ results = r.json().get("results", [])
+ if not results:
+ return None
+
+ lib_id = results[0]["id"]
+
+ # Step 2: Get relevant docs
+ r2 = requests.get(
+ f"{CONTEXT7_BASE}/context",
+ params={"libraryId": lib_id, "query": query, "type": "txt"},
+ timeout=10,
+ )
+ if r2.status_code != 200:
+ return None
+
+ doc_text = r2.text[:max_chars]
+ return doc_text if doc_text.strip() else None
+
+ except Exception as e:
+ print(f" Context7 error for {library_name}: {e}")
+ return None
+
+
+def get_documentation_context(
+ changed_files: list[str], patch: str, problem_statement: str
+) -> dict[str, str]:
+ """Fetch documentation for libraries referenced in the patch."""
+ docs = {}
+
+ # Get libraries from imports in the patch
+ imported_libs = extract_imports_from_patch(patch)
+ # Get libraries from file paths
+ path_libs = extract_libraries_from_files(changed_files)
+
+ all_libs = list(set(imported_libs + path_libs))
+
+ # Extract a short query from the problem statement
+ short_query = problem_statement[:200].replace("\n", " ").strip()
+
+ for lib in all_libs[:3]: # Limit to 3 libraries to avoid rate limits
+ doc = fetch_context7_docs(lib, short_query)
+ if doc:
+ docs[lib] = doc
+
+ return docs
+
+
+# === LLM Calls ===
+
+
+def llm_call(
+ client: OpenAI, system: str, user: str, temperature: float = 0.7, max_tokens: int = 500
+) -> str:
+ """Make an LLM call with retries."""
+ for attempt in range(MAX_RETRIES):
+ try:
+ response = client.chat.completions.create(
+ model=MODEL,
+ messages=[
+ {"role": "system", "content": system},
+ {"role": "user", "content": user},
+ ],
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ return response.choices[0].message.content.strip()
+ except Exception as e:
+ if attempt < MAX_RETRIES - 1:
+ wait = RETRY_DELAY * (attempt + 1)
+ print(f" LLM error (attempt {attempt + 1}): {e}. Retrying in {wait}s...")
+ time.sleep(wait)
+ else:
+ raise
+
+
+def transform_to_user_query(client: OpenAI, problem_statement: str, repo: str) -> str:
+ """Transform a GitHub issue into a realistic user-to-agent query."""
+ system = textwrap.dedent("""\
+ You transform GitHub issue descriptions into realistic user queries
+ that a developer would type into an AI coding assistant (like Claude Code or Cursor).
+
+ Rules:
+ - Make it conversational and natural
+ - Keep the core technical ask but remove GitHub formatting
+ - Remove reproduction steps, stack traces, verbose details
+ - Keep it to 1-3 sentences
+ - Don't mention "issue" or "bug report"
+ - Sound like someone asking for help, not filing a report
+ """)
+ user = f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}"
+ return llm_call(client, system, user, temperature=0.7, max_tokens=300)
+
+
+def inject_hallucinations(
+ client: OpenAI,
+ gold_patch: str,
+ user_query: str,
+ problem_statement: str,
+ repo: str,
+ documentation: dict[str, str],
+) -> dict[str, Any]:
+ """Inject hallucinations into a gold patch."""
+ doc_context = ""
+ if documentation:
+ doc_context = "\n\nRelevant library documentation:\n"
+ for lib, doc in documentation.items():
+ doc_context += f"\n--- {lib} docs ---\n{doc[:800]}\n"
+
+ system = textwrap.dedent("""\
+ You are a code hallucination injector for building a hallucination detection dataset.
+
+ Given a correct code patch, user query, and documentation context, create THREE
+ hallucinated versions of the patch:
+
+ 1. STRUCTURAL: Change a function call, import, or parameter to something that
+ doesn't exist or is wrong. Code should still parse but reference non-existent
+ APIs, wrong methods, or invented parameters. If documentation is provided,
+ you can hallucinate by using API calls that contradict the docs.
+
+ 2. BEHAVIORAL: Use correct APIs but with wrong values or logic. Wrong defaults,
+ off-by-one errors, swapped conditions, wrong argument values.
+
+ 3. SEMANTIC: Code that looks like it addresses the user's request but does
+ something subtly different or opposite. This should be the most subtle -
+ the code parses, uses real APIs, but fails to do what was asked.
+ Examples: implementing global when per-item was asked, any() vs all(),
+ catching exceptions instead of fixing root cause, inverted conditions.
+
+ Respond in this exact JSON format (no markdown, no code blocks):
+ {
+ "hallucinations": [
+ {
+ "type": "structural",
+ "hallucinated_patch": "the full modified patch text",
+ "changes": [
+ {
+ "original": "exact original code span",
+ "hallucinated": "what you changed it to",
+ "explanation": "why this is a hallucination"
+ }
+ ]
+ },
+ {
+ "type": "behavioral",
+ "hallucinated_patch": "...",
+ "changes": [...]
+ },
+ {
+ "type": "semantic",
+ "hallucinated_patch": "...",
+ "changes": [...]
+ }
+ ]
+ }
+
+ IMPORTANT:
+ - Hallucinations must be PLAUSIBLE - something an LLM would realistically generate
+ - Each change must be subtle, not obviously broken
+ - Return ONLY valid JSON
+ """)
+
+ user = f"""Repository: {repo}
+
+User's query: {user_query}
+
+Original issue context: {problem_statement[:2000]}{doc_context}
+
+Correct gold patch:
+{gold_patch}
+
+Generate three hallucinated versions of this patch."""
+
+ raw = llm_call(client, system, user, temperature=0.8, max_tokens=4000)
+
+ # Parse JSON from response
+ json_match = re.search(r"\{[\s\S]*\}", raw)
+ if json_match:
+ try:
+ return json.loads(json_match.group())
+ except json.JSONDecodeError:
+ return {"raw_response": raw, "parse_error": True}
+ return {"raw_response": raw, "parse_error": True}
+
+
+# === LettuceDetect Format Conversion ===
+
+
+def compute_span_offsets(code: str, hallucinated_span: str) -> list[dict]:
+ """Find the character offsets of a hallucinated span within code."""
+ spans = []
+ start = 0
+ while True:
+ idx = code.find(hallucinated_span, start)
+ if idx == -1:
+ break
+ spans.append({"start": idx, "end": idx + len(hallucinated_span)})
+ start = idx + 1
+ return spans
+
+
+def convert_to_lettucedetect_format(samples: list[dict]) -> list[dict]:
+ """Convert enriched SWE-bench samples to LettuceDetect training format.
+
+ For each hallucinated variant, create a training sample with:
+ - prompt: context (codebase files + docs + user query)
+ - answer: the hallucinated code
+ - labels: span annotations of hallucinated parts
+ """
+ ld_samples = []
+
+ for sample in samples:
+ # Build context prompt from available information
+ context_parts = []
+
+ # Add user query as the "question"
+ user_query = sample.get("user_query", "")
+
+ # Add documentation context
+ docs = sample.get("documentation", {})
+ if docs:
+ for lib, doc in docs.items():
+ context_parts.append(f"Documentation for {lib}:\n{doc}")
+
+ # Add the gold patch as reference context (what the correct code looks like)
+ context_parts.append(f"Changed files: {', '.join(sample.get('changed_files', []))}")
+
+ # Add original problem statement as additional context
+ problem = sample.get("original_problem_statement", "")
+ if problem:
+ context_parts.append(f"Issue description:\n{problem[:1500]}")
+
+ context_text = "\n\n".join(context_parts)
+
+ # Create a "correct" sample (gold patch, no hallucination labels)
+ gold_patch = sample.get("gold_patch", "")
+ ld_samples.append(
+ {
+ "prompt": context_text,
+ "answer": gold_patch,
+ "labels": [],
+ "split": "train",
+ "task_type": "code_generation",
+ "dataset": "code_hallucination_swebench",
+ "language": "python",
+ "instance_id": sample["instance_id"],
+ "repo": sample["repo"],
+ "user_query": user_query,
+ }
+ )
+
+ # Create hallucinated samples
+ for hall in sample.get("hallucinations", []):
+ if isinstance(hall, dict) and "hallucinated_patch" in hall:
+ h_patch = hall["hallucinated_patch"]
+ labels = []
+
+ for change in hall.get("changes", []):
+ h_span = change.get("hallucinated", "")
+ if h_span and h_span in h_patch:
+ offsets = compute_span_offsets(h_patch, h_span)
+ for offset in offsets[:1]: # Take first occurrence
+ labels.append(
+ {
+ "start": offset["start"],
+ "end": offset["end"],
+ "label": "hallucinated",
+ "hallucination_type": hall.get("type", "unknown"),
+ "explanation": change.get("explanation", ""),
+ "original_span": change.get("original", ""),
+ }
+ )
+
+ if labels: # Only add if we found valid span offsets
+ ld_samples.append(
+ {
+ "prompt": context_text,
+ "answer": h_patch,
+ "labels": labels,
+ "split": "train",
+ "task_type": "code_generation",
+ "dataset": "code_hallucination_swebench",
+ "language": "python",
+ "instance_id": sample["instance_id"],
+ "repo": sample["repo"],
+ "user_query": user_query,
+ "hallucination_type": hall.get("type", "unknown"),
+ }
+ )
+
+ return ld_samples
+
+
+# === Main Pipeline ===
+
+
+def process_sample(client: OpenAI, sample: dict, idx: int, total: int) -> dict | None:
+ """Process a single SWE-bench sample."""
+ instance_id = sample["instance_id"]
+ repo = sample["repo"]
+ problem_statement = sample["problem_statement"]
+ gold_patch = sample["patch"]
+
+ print(f"\n[{idx + 1}/{total}] {instance_id} ({repo})")
+
+ # Step 1: Extract file and import info
+ changed_files = extract_changed_files(gold_patch)
+ print(f" Files: {changed_files}")
+
+ # Step 2: Fetch documentation via Context7
+ print(" Fetching documentation from Context7...")
+ documentation = get_documentation_context(changed_files, gold_patch, problem_statement)
+ if documentation:
+ print(f" Got docs for: {list(documentation.keys())}")
+ else:
+ print(" No external docs found (repo-internal change)")
+
+ # Step 3: Transform issue to user query
+ print(" Generating user query...")
+ try:
+ user_query = transform_to_user_query(client, problem_statement, repo)
+ print(f" Query: {user_query[:120]}...")
+ except Exception as e:
+ print(f" ERROR generating query: {e}")
+ return None
+
+ # Small delay to avoid rate limits
+ time.sleep(1)
+
+ # Step 4: Inject hallucinations
+ print(" Injecting hallucinations...")
+ try:
+ hall_result = inject_hallucinations(
+ client, gold_patch, user_query, problem_statement, repo, documentation
+ )
+ except Exception as e:
+ print(f" ERROR injecting hallucinations: {e}")
+ return None
+
+ if hall_result.get("parse_error"):
+ print(" WARNING: Failed to parse hallucination JSON")
+ return None
+
+ hallucinations = hall_result.get("hallucinations", [])
+ print(f" Generated {len(hallucinations)} hallucination variants")
+ for h in hallucinations:
+ h_type = h.get("type", "?")
+ n_changes = len(h.get("changes", []))
+ print(f" - {h_type}: {n_changes} changes")
+
+ # Small delay between samples
+ time.sleep(1)
+
+ return {
+ "instance_id": instance_id,
+ "repo": repo,
+ "base_commit": sample["base_commit"],
+ "original_problem_statement": problem_statement,
+ "user_query": user_query,
+ "changed_files": changed_files,
+ "gold_patch": gold_patch,
+ "test_patch": sample["test_patch"],
+ "fail_to_pass": sample["FAIL_TO_PASS"],
+ "documentation": documentation,
+ "hallucinations": hallucinations,
+ }
+
+
+def select_diverse_samples(dataset, n: int) -> list[int]:
+ """Select diverse samples across different repos."""
+ # Group by repo
+ repo_indices: dict[str, list[int]] = {}
+ for i in range(len(dataset)):
+ repo = dataset[i]["repo"]
+ if repo not in repo_indices:
+ repo_indices[repo] = []
+ repo_indices[repo].append(i)
+
+ print(f"Repos available: {list(repo_indices.keys())}")
+ print(f"Samples per repo: {', '.join(f'{k}: {len(v)}' for k, v in repo_indices.items())}")
+
+ # Round-robin across repos
+ selected = []
+ repo_iters = {repo: iter(indices) for repo, indices in repo_indices.items()}
+ repos = list(repo_indices.keys())
+ repo_idx = 0
+
+ while len(selected) < n:
+ repo = repos[repo_idx % len(repos)]
+ try:
+ idx = next(repo_iters[repo])
+ selected.append(idx)
+ except StopIteration:
+ repos.remove(repo)
+ if not repos:
+ break
+ repo_idx += 1
+
+ return selected[:n]
+
+
+def main():
+ from datasets import load_dataset
+
+ client = get_client()
+
+ print("=" * 60)
+ print("Code Hallucination Dataset Generator")
+ print("SWE-bench + Context7 + LLM Injection")
+ print("=" * 60)
+
+ # Load SWE-bench Lite
+ print("\nLoading SWE-bench Lite...")
+ ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
+ print(f"Loaded {len(ds)} samples")
+
+ # Select diverse samples
+ indices = select_diverse_samples(ds, NUM_SAMPLES)
+ print(f"\nSelected {len(indices)} diverse samples")
+
+ # Process each sample
+ results = []
+ failed = 0
+
+ for i, idx in enumerate(indices):
+ sample = ds[idx]
+ result = process_sample(client, sample, i, len(indices))
+ if result:
+ results.append(result)
+ else:
+ failed += 1
+
+ # Progress
+ if (i + 1) % 10 == 0:
+ print(
+ f"\n === Progress: {i + 1}/{len(indices)}, success: {len(results)}, failed: {failed} ===\n"
+ )
+
+ # Save raw results
+ os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
+ with open(OUTPUT_PATH, "w") as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved {len(results)} raw samples to {OUTPUT_PATH}")
+
+ # Convert to LettuceDetect format
+ print("\nConverting to LettuceDetect training format...")
+ ld_samples = convert_to_lettucedetect_format(results)
+
+ with open(LETTUCEDETECT_OUTPUT_PATH, "w") as f:
+ json.dump(ld_samples, f, indent=2)
+ print(f"Saved {len(ld_samples)} LettuceDetect samples to {LETTUCEDETECT_OUTPUT_PATH}")
+
+ # Statistics
+ print("\n" + "=" * 60)
+ print("DATASET STATISTICS")
+ print("=" * 60)
+ print(f"SWE-bench samples processed: {len(results)}")
+ print(f"Failed samples: {failed}")
+
+ repos = set(r["repo"] for r in results)
+ print(f"Repos covered: {len(repos)}")
+ for repo in sorted(repos):
+ count = sum(1 for r in results if r["repo"] == repo)
+ print(f" {repo}: {count}")
+
+ n_with_docs = sum(1 for r in results if r.get("documentation"))
+ print(f"\nSamples with Context7 docs: {n_with_docs}/{len(results)}")
+
+ hall_counts = {"structural": 0, "behavioral": 0, "semantic": 0}
+ for r in results:
+ for h in r.get("hallucinations", []):
+ t = h.get("type", "unknown")
+ if t in hall_counts:
+ hall_counts[t] += 1
+ print(f"Hallucination variants: {hall_counts}")
+
+ print(f"\nLettuceDetect training samples: {len(ld_samples)}")
+ n_positive = sum(1 for s in ld_samples if s.get("labels"))
+ n_negative = sum(1 for s in ld_samples if not s.get("labels"))
+ print(f" Hallucinated (positive): {n_positive}")
+ print(f" Correct (negative): {n_negative}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/prototype_code_hallucination.py b/scripts/prototype_code_hallucination.py
new file mode 100644
index 0000000..292bbc4
--- /dev/null
+++ b/scripts/prototype_code_hallucination.py
@@ -0,0 +1,308 @@
+#!/usr/bin/env python3
+"""Prototype: Transform SWE-bench into a code hallucination detection dataset.
+
+Pipeline:
+1. Load SWE-bench sample (issue, repo, patch)
+2. Extract context files from the patch (what files were relevant)
+3. Transform issue into user-style query (via LLM)
+4. Inject hallucinations into the gold patch (structural + semantic)
+5. Output annotated samples
+"""
+
+import json
+import os
+import re
+import textwrap
+from typing import Any
+
+from openai import OpenAI
+
+# Groq API (OpenAI-compatible)
+GROQ_BASE_URL = "https://api.groq.com/openai/v1"
+MODEL = "moonshotai/kimi-k2-instruct-0905"
+
+
+def get_client() -> OpenAI:
+ """Get Groq client."""
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("Set OPENAI_API_KEY to your Groq API key")
+ return OpenAI(api_key=api_key, base_url=GROQ_BASE_URL)
+
+
+def extract_changed_files(patch: str) -> list[str]:
+ """Extract file paths that were changed from a unified diff patch."""
+ files = []
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ # Extract b/ path (the destination file)
+ match = re.search(r"b/(.+)$", line)
+ if match:
+ files.append(match.group(1))
+ return files
+
+
+def extract_patch_changes(patch: str) -> list[dict]:
+ """Parse a unified diff into structured changes per file."""
+ file_diffs = []
+ current_file = None
+ current_hunks = []
+
+ for line in patch.split("\n"):
+ if line.startswith("diff --git"):
+ if current_file:
+ file_diffs.append({"file": current_file, "diff": "\n".join(current_hunks)})
+ match = re.search(r"b/(.+)$", line)
+ current_file = match.group(1) if match else "unknown"
+ current_hunks = [line]
+ elif current_file:
+ current_hunks.append(line)
+
+ if current_file:
+ file_diffs.append({"file": current_file, "diff": "\n".join(current_hunks)})
+
+ return file_diffs
+
+
+def transform_to_user_query(client: OpenAI, problem_statement: str, repo: str) -> str:
+ """Transform a GitHub issue into a realistic user-to-agent query."""
+ response = client.chat.completions.create(
+ model=MODEL,
+ messages=[
+ {
+ "role": "system",
+ "content": textwrap.dedent("""\
+ You transform GitHub issue descriptions into realistic user queries
+ that a developer would type into an AI coding assistant (like Claude Code
+ or Cursor).
+
+ Rules:
+ - Make it conversational and natural, like someone typing into a chat
+ - Keep the core technical ask but remove GitHub-specific formatting
+ - Remove reproduction steps, stack traces, and verbose details
+ - Keep it to 1-3 sentences
+ - Don't mention "issue" or "bug report"
+ - Make it sound like someone asking for help, not filing a report
+
+ Examples:
+ Issue: "BUG: DataFrame.merge raises TypeError when merging on columns with different dtypes"
+ Query: "pd.merge is crashing with a TypeError when I try to merge two dataframes where the join columns have different dtypes. Can you fix the merge to handle dtype coercion?"
+
+ Issue: "Enable quiet mode/no-verbose in CLI for programmatic use"
+ Query: "Can you add a --quiet flag to the CLI? I need to use it in scripts and the verbose output is getting in the way."
+ """),
+ },
+ {
+ "role": "user",
+ "content": f"Repository: {repo}\n\nGitHub Issue:\n{problem_statement[:3000]}",
+ },
+ ],
+ temperature=0.7,
+ max_tokens=300,
+ )
+ return response.choices[0].message.content.strip()
+
+
+def inject_hallucinations(
+ client: OpenAI,
+ gold_patch: str,
+ user_query: str,
+ problem_statement: str,
+ repo: str,
+) -> dict[str, Any]:
+ """Inject hallucinations into a gold patch, returning annotated hallucinated code."""
+ response = client.chat.completions.create(
+ model=MODEL,
+ messages=[
+ {
+ "role": "system",
+ "content": textwrap.dedent("""\
+ You are a code hallucination injector for building a hallucination detection dataset.
+
+ Given a correct code patch and the user's query, create THREE hallucinated versions:
+
+ 1. STRUCTURAL hallucination: Change a function call, import, or parameter to
+ something that doesn't exist or is wrong for this codebase. The code should
+ still parse but reference non-existent APIs, wrong method names, or invented
+ parameters.
+
+ 2. BEHAVIORAL hallucination: Use correct APIs but with wrong values or logic
+ that would produce incorrect behavior. For example, wrong default values,
+ off-by-one errors, or swapped conditions.
+
+ 3. SEMANTIC hallucination: Code that looks like it addresses the user's request
+ but actually does something subtly different or opposite. This is the hardest
+ type - the code should parse, use real APIs, but fail to do what was asked.
+ Examples: implementing global when per-item was asked, catching exceptions
+ instead of fixing the root cause, or implementing the inverse logic.
+
+ For EACH hallucinated version, respond in this exact JSON format:
+ {
+ "hallucinations": [
+ {
+ "type": "structural",
+ "hallucinated_patch": "the full modified patch",
+ "changes": [
+ {
+ "original": "the original correct code span",
+ "hallucinated": "what you changed it to",
+ "explanation": "why this is a hallucination"
+ }
+ ]
+ },
+ {
+ "type": "behavioral",
+ "hallucinated_patch": "...",
+ "changes": [...]
+ },
+ {
+ "type": "semantic",
+ "hallucinated_patch": "...",
+ "changes": [...]
+ }
+ ]
+ }
+
+ IMPORTANT:
+ - The hallucinated patches must be plausible β something an LLM would realistically generate
+ - Each change must be subtle, not obviously broken
+ - Return ONLY valid JSON, no markdown code blocks
+ """),
+ },
+ {
+ "role": "user",
+ "content": f"""Repository: {repo}
+
+User's query: {user_query}
+
+Original issue context: {problem_statement[:2000]}
+
+Correct gold patch:
+{gold_patch}
+
+Generate three hallucinated versions of this patch.""",
+ },
+ ],
+ temperature=0.8,
+ max_tokens=4000,
+ )
+
+ raw = response.choices[0].message.content.strip()
+ # Try to extract JSON from the response
+ # Sometimes LLMs wrap it in ```json blocks
+ json_match = re.search(r"\{[\s\S]*\}", raw)
+ if json_match:
+ try:
+ return json.loads(json_match.group())
+ except json.JSONDecodeError:
+ return {"raw_response": raw, "parse_error": True}
+ return {"raw_response": raw, "parse_error": True}
+
+
+def process_sample(client: OpenAI, sample: dict) -> dict:
+ """Process a single SWE-bench sample into a hallucination detection sample."""
+ instance_id = sample["instance_id"]
+ repo = sample["repo"]
+ problem_statement = sample["problem_statement"]
+ gold_patch = sample["patch"]
+
+ print(f"\n{'=' * 60}")
+ print(f"Processing: {instance_id}")
+ print(f"Repo: {repo}")
+ print(f"{'=' * 60}")
+
+ # Step 1: Extract changed files from the patch
+ changed_files = extract_changed_files(gold_patch)
+ print(f"\nChanged files: {changed_files}")
+
+ # Step 2: Transform issue into user query
+ print("\nTransforming issue into user query...")
+ user_query = transform_to_user_query(client, problem_statement, repo)
+ print(f"User query: {user_query}")
+
+ # Step 3: Inject hallucinations
+ print("\nInjecting hallucinations...")
+ hallucination_result = inject_hallucinations(
+ client, gold_patch, user_query, problem_statement, repo
+ )
+
+ if hallucination_result.get("parse_error"):
+ print("WARNING: Failed to parse hallucination response")
+ print(
+ f"Raw response (first 500 chars): {hallucination_result.get('raw_response', '')[:500]}"
+ )
+
+ # Build the output sample
+ output = {
+ "instance_id": instance_id,
+ "repo": repo,
+ "base_commit": sample["base_commit"],
+ "original_problem_statement": problem_statement,
+ "user_query": user_query,
+ "changed_files": changed_files,
+ "gold_patch": gold_patch,
+ "test_patch": sample["test_patch"],
+ "fail_to_pass": sample["FAIL_TO_PASS"],
+ "hallucinations": hallucination_result.get("hallucinations", []),
+ }
+
+ # Print summary
+ if not hallucination_result.get("parse_error"):
+ for h in hallucination_result.get("hallucinations", []):
+ print(f"\n [{h.get('type', 'unknown').upper()}]")
+ for c in h.get("changes", []):
+ print(f" Original: {c.get('original', 'N/A')[:80]}")
+ print(f" Hallucinated: {c.get('hallucinated', 'N/A')[:80]}")
+ print(f" Explanation: {c.get('explanation', 'N/A')[:100]}")
+
+ return output
+
+
+def main():
+ """Run prototype on a few SWE-bench Lite samples."""
+ from datasets import load_dataset
+
+ client = get_client()
+
+ # Load SWE-bench Lite (300 curated samples)
+ print("Loading SWE-bench Lite...")
+ ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
+
+ # Pick a few diverse samples to test
+ # Choose samples from different repos with different patch sizes
+ test_indices = [0, 5, 10] # Start with 3 samples
+ results = []
+
+ for idx in test_indices:
+ sample = ds[idx]
+ try:
+ result = process_sample(client, sample)
+ results.append(result)
+ except Exception as e:
+ print(f"\nERROR processing {sample['instance_id']}: {e}")
+ continue
+
+ # Save results
+ output_path = "/Users/adamkovacs/projects/LettuceDetect/data/code_hallucination_prototype.json"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ with open(output_path, "w") as f:
+ json.dump(results, f, indent=2)
+
+ print(f"\n\nSaved {len(results)} samples to {output_path}")
+
+ # Print summary
+ print("\n" + "=" * 60)
+ print("SUMMARY")
+ print("=" * 60)
+ for r in results:
+ print(f"\n{r['instance_id']} ({r['repo']})")
+ print(f" Query: {r['user_query'][:100]}...")
+ print(f" Files: {r['changed_files']}")
+ n_hall = len(r.get("hallucinations", []))
+ print(f" Hallucinations generated: {n_hall}")
+ for h in r.get("hallucinations", []):
+ print(f" - {h.get('type', 'unknown')}: {len(h.get('changes', []))} changes")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/train_code_hallucination.py b/scripts/train_code_hallucination.py
new file mode 100644
index 0000000..09930ed
--- /dev/null
+++ b/scripts/train_code_hallucination.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python3
+"""Train a hallucination detector on the code hallucination dataset.
+
+Can optionally combine with RAGTruth data for mixed training.
+
+Usage:
+ # Train on code hallucination data only
+ python scripts/train_code_hallucination.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/code_hallucination_detector
+
+ # Train on both code + RAGTruth data
+ python scripts/train_code_hallucination.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --ragtruth-path data/ragtruth/ragtruth_data.json \
+ --model-name answerdotai/ModernBERT-base \
+ --output-dir output/code_hallucination_detector
+"""
+
+import argparse
+import json
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from transformers import (
+ AutoModelForTokenClassification,
+ AutoTokenizer,
+ DataCollatorForTokenClassification,
+)
+
+from lettucedetect.datasets.hallucination_dataset import (
+ HallucinationData,
+ HallucinationDataset,
+ HallucinationSample,
+)
+from lettucedetect.models.trainer import Trainer
+
+
+def set_seed(seed: int = 42):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def load_code_data(path: str) -> list[HallucinationSample]:
+ """Load code hallucination dataset.
+
+ Uses the SWE-bench split field directly (train/dev/test).
+ """
+ data = json.loads(Path(path).read_text())
+ samples = []
+ for item in data:
+ samples.append(
+ HallucinationSample(
+ prompt=item["prompt"],
+ answer=item["answer"],
+ labels=item["labels"],
+ split=item.get("split", "train"),
+ task_type=item.get("task_type", "code_generation"),
+ dataset="swebench_code",
+ language="en",
+ )
+ )
+ return samples
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train hallucination detector on code data")
+ parser.add_argument(
+ "--code-data-path", type=str, required=True, help="Path to code hallucination dataset JSON"
+ )
+ parser.add_argument(
+ "--ragtruth-path",
+ type=str,
+ default=None,
+ help="Optional path to RAGTruth data for mixed training",
+ )
+ parser.add_argument("--model-name", type=str, default="answerdotai/ModernBERT-base")
+ parser.add_argument("--output-dir", type=str, default="output/code_hallucination_detector")
+ parser.add_argument("--batch-size", type=int, default=4)
+ parser.add_argument("--epochs", type=int, default=6)
+ parser.add_argument("--learning-rate", type=float, default=1e-5)
+ parser.add_argument("--max-length", type=int, default=4096)
+ parser.add_argument("--seed", type=int, default=42)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+
+ # Load code hallucination data
+ print(f"Loading code hallucination data from {args.code_data_path}")
+ code_samples = load_code_data(args.code_data_path)
+
+ n_clean = sum(1 for s in code_samples if not s.labels)
+ n_hall = sum(1 for s in code_samples if s.labels)
+ print(f" Total: {len(code_samples)} (clean: {n_clean}, hallucinated: {n_hall})")
+
+ # Use SWE-bench splits directly (zero repo overlap by design)
+ # Train on train, validate on dev, hold out test for final evaluation
+ train_samples = [s for s in code_samples if s.split == "train"]
+ dev_samples = [s for s in code_samples if s.split == "dev"]
+ test_samples = [s for s in code_samples if s.split == "test"]
+
+ print(f" Train (SWE-bench train): {len(train_samples)}")
+ print(f" Dev (SWE-bench dev): {len(dev_samples)}")
+ print(f" Test (SWE-bench test, held out): {len(test_samples)}")
+
+ # Optionally add RAGTruth data
+ if args.ragtruth_path:
+ print(f"\nLoading RAGTruth data from {args.ragtruth_path}")
+ ragtruth_data = HallucinationData.from_json(
+ json.loads(Path(args.ragtruth_path).read_text())
+ )
+ ragtruth_train = [s for s in ragtruth_data.samples if s.split == "train"]
+ ragtruth_dev = [s for s in ragtruth_data.samples if s.split in ("dev", "test")]
+
+ print(f" RAGTruth train: {len(ragtruth_train)}, dev: {len(ragtruth_dev)}")
+ train_samples.extend(ragtruth_train)
+ dev_samples.extend(ragtruth_dev)
+
+ print("\nFinal splits:")
+ print(f" Train: {len(train_samples)}")
+ print(f" Dev: {len(dev_samples)}")
+
+ # Setup tokenizer and model
+ print(f"\nLoading model: {args.model_name}")
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
+ data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100)
+
+ train_dataset = HallucinationDataset(train_samples, tokenizer, max_length=args.max_length)
+ dev_dataset = HallucinationDataset(dev_samples, tokenizer, max_length=args.max_length)
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ collate_fn=data_collator,
+ )
+ dev_loader = DataLoader(
+ dev_dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ collate_fn=data_collator,
+ )
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ args.model_name,
+ num_labels=2,
+ trust_remote_code=True,
+ )
+
+ trainer = Trainer(
+ model=model,
+ tokenizer=tokenizer,
+ train_loader=train_loader,
+ test_loader=dev_loader,
+ epochs=args.epochs,
+ learning_rate=args.learning_rate,
+ save_path=args.output_dir,
+ )
+
+ print(f"\nStarting training for {args.epochs} epochs...")
+ trainer.train()
+ print(f"\nModel saved to {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/train_generative_detector.py b/scripts/train_generative_detector.py
new file mode 100644
index 0000000..3b6ea4e
--- /dev/null
+++ b/scripts/train_generative_detector.py
@@ -0,0 +1,331 @@
+#!/usr/bin/env python3
+"""Train a generative hallucination span detector via SFT.
+
+Fine-tunes a decoder LLM (e.g. Qwen3.5-2B) to detect hallucinated spans
+in code by generating structured JSON output.
+
+This is Approach D: the model reads context + answer and generates:
+ {"hallucinated_spans": [{"text": "...", "explanation": "..."}]}
+or for clean samples:
+ {"hallucinated_spans": []}
+
+Usage:
+ # Train with LoRA on a single GPU
+ python scripts/train_generative_detector.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name Qwen/Qwen3.5-2B \
+ --output-dir output/generative_detector
+
+ # With custom settings
+ python scripts/train_generative_detector.py \
+ --code-data-path data/code_hallucination/code_hallucination_data.json \
+ --model-name Qwen/Qwen3.5-2B \
+ --output-dir output/generative_detector \
+ --lora-r 16 \
+ --batch-size 2 \
+ --epochs 3 \
+ --max-length 4096
+"""
+
+import argparse
+import json
+import random
+from pathlib import Path
+
+import torch
+from peft import LoraConfig, TaskType, get_peft_model
+from torch.utils.data import DataLoader, Dataset
+from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
+
+SYSTEM_PROMPT = (
+ "You are a code hallucination detector. Given source code context and a code answer, "
+ "identify any hallucinated spans β code that is factually wrong, uses non-existent APIs, "
+ "has incorrect logic, or doesn't match the source context.\n\n"
+ "Respond with JSON only:\n"
+ '{"hallucinated_spans": [{"text": "exact hallucinated text", "explanation": "why it is wrong"}]}\n\n'
+ "If the answer is correct, respond with:\n"
+ '{"hallucinated_spans": []}'
+)
+
+
+def build_training_pairs(data_path: str) -> list[dict]:
+ """Build (input, output) pairs from the code hallucination dataset.
+
+ For hallucinated samples: output lists the hallucinated spans with explanations.
+ For clean samples: output is {"hallucinated_spans": []}.
+ """
+ data = json.loads(Path(data_path).read_text())
+ pairs = []
+
+ for item in data:
+ prompt = item["prompt"]
+ answer = item["answer"]
+ labels = item.get("labels", [])
+ split = item.get("split", "train")
+
+ user_msg = f"Context:\n{prompt}\n\nAnswer to check:\n{answer}"
+
+ if labels:
+ spans = []
+ for label in labels:
+ start = label["start"]
+ end = label["end"]
+ text = answer[start:end]
+ if text.strip():
+ spans.append(
+ {
+ "text": text,
+ "explanation": f"{label.get('label', 'hallucination')} error in code",
+ }
+ )
+ assistant_msg = json.dumps({"hallucinated_spans": spans})
+ else:
+ assistant_msg = json.dumps({"hallucinated_spans": []})
+
+ pairs.append(
+ {
+ "system": SYSTEM_PROMPT,
+ "user": user_msg,
+ "assistant": assistant_msg,
+ "split": split,
+ }
+ )
+
+ return pairs
+
+
+class SFTDataset(Dataset):
+ """Dataset for supervised fine-tuning with chat-formatted inputs."""
+
+ def __init__(self, pairs: list[dict], tokenizer, max_length: int = 4096):
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.examples = []
+
+ for pair in pairs:
+ messages = [
+ {"role": "system", "content": pair["system"]},
+ {"role": "user", "content": pair["user"]},
+ {"role": "assistant", "content": pair["assistant"]},
+ ]
+
+ # Tokenize full conversation
+ full_text = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=False
+ )
+ full_ids = tokenizer(
+ full_text, truncation=True, max_length=max_length, return_tensors="pt"
+ )
+
+ # Tokenize without assistant response to find where labels start
+ messages_no_assistant = [
+ {"role": "system", "content": pair["system"]},
+ {"role": "user", "content": pair["user"]},
+ ]
+ prompt_text = tokenizer.apply_chat_template(
+ messages_no_assistant, tokenize=False, add_generation_prompt=True
+ )
+ prompt_ids = tokenizer(
+ prompt_text, truncation=True, max_length=max_length, return_tensors="pt"
+ )
+ prompt_len = prompt_ids["input_ids"].shape[1]
+
+ input_ids = full_ids["input_ids"].squeeze(0)
+ attention_mask = full_ids["attention_mask"].squeeze(0)
+
+ # Labels: -100 for prompt tokens (masked), actual ids for assistant tokens
+ labels = input_ids.clone()
+ labels[:prompt_len] = -100
+
+ self.examples.append(
+ {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ }
+ )
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, idx):
+ return self.examples[idx]
+
+
+def collate_fn(batch):
+ """Pad batch to same length."""
+ max_len = max(ex["input_ids"].shape[0] for ex in batch)
+
+ input_ids = []
+ attention_mask = []
+ labels = []
+
+ for ex in batch:
+ pad_len = max_len - ex["input_ids"].shape[0]
+ input_ids.append(torch.cat([ex["input_ids"], torch.zeros(pad_len, dtype=torch.long)]))
+ attention_mask.append(
+ torch.cat([ex["attention_mask"], torch.zeros(pad_len, dtype=torch.long)])
+ )
+ labels.append(torch.cat([ex["labels"], torch.full((pad_len,), -100, dtype=torch.long)]))
+
+ return {
+ "input_ids": torch.stack(input_ids),
+ "attention_mask": torch.stack(attention_mask),
+ "labels": torch.stack(labels),
+ }
+
+
+def evaluate(model, dataloader, device):
+ """Compute average loss on validation set."""
+ model.eval()
+ total_loss = 0
+ total_steps = 0
+
+ with torch.no_grad():
+ for batch in dataloader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+ outputs = model(**batch)
+ total_loss += outputs.loss.item()
+ total_steps += 1
+
+ return total_loss / max(total_steps, 1)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Train generative hallucination span detector")
+ parser.add_argument("--code-data-path", type=str, required=True)
+ parser.add_argument("--model-name", type=str, default="Qwen/Qwen3.5-2B")
+ parser.add_argument("--output-dir", type=str, default="output/generative_detector")
+ parser.add_argument("--batch-size", type=int, default=2)
+ parser.add_argument("--epochs", type=int, default=3)
+ parser.add_argument("--learning-rate", type=float, default=2e-4)
+ parser.add_argument("--max-length", type=int, default=4096)
+ parser.add_argument("--lora-r", type=int, default=16)
+ parser.add_argument("--lora-alpha", type=int, default=32)
+ parser.add_argument("--lora-dropout", type=float, default=0.05)
+ parser.add_argument("--warmup-ratio", type=float, default=0.05)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+
+ # Build training pairs
+ print(f"Loading data from {args.code_data_path}")
+ pairs = build_training_pairs(args.code_data_path)
+
+ train_pairs = [p for p in pairs if p["split"] == "train"]
+ dev_pairs = [p for p in pairs if p["split"] == "dev"]
+ test_pairs = [p for p in pairs if p["split"] == "test"]
+
+ print(f"Train: {len(train_pairs)}, Dev: {len(dev_pairs)}, Test (held out): {len(test_pairs)}")
+
+ random.shuffle(train_pairs)
+
+ # Load tokenizer and model
+ print(f"Loading model: {args.model_name}")
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_name,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ )
+
+ # Apply LoRA
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ ],
+ )
+ model = get_peft_model(model, lora_config)
+ model.print_trainable_parameters()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device)
+
+ # Build datasets
+ print("Tokenizing datasets...")
+ train_dataset = SFTDataset(train_pairs, tokenizer, max_length=args.max_length)
+ dev_dataset = SFTDataset(dev_pairs, tokenizer, max_length=args.max_length)
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
+ )
+ dev_loader = DataLoader(
+ dev_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
+ )
+
+ print(f"Train examples: {len(train_dataset)}, Dev examples: {len(dev_dataset)}")
+
+ # Optimizer and scheduler
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.01)
+ total_steps = (len(train_loader) // args.gradient_accumulation_steps) * args.epochs
+ warmup_steps = int(total_steps * args.warmup_ratio)
+ scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
+
+ # Training loop
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ best_dev_loss = float("inf")
+
+ print(f"\nTraining for {args.epochs} epochs ({total_steps} steps)...")
+
+ for epoch in range(args.epochs):
+ model.train()
+ epoch_loss = 0
+ optimizer.zero_grad()
+
+ for step, batch in enumerate(train_loader):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ outputs = model(**batch)
+ loss = outputs.loss / args.gradient_accumulation_steps
+ loss.backward()
+ epoch_loss += outputs.loss.item()
+
+ if (step + 1) % args.gradient_accumulation_steps == 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ if (step + 1) % 100 == 0:
+ avg_loss = epoch_loss / (step + 1)
+ print(
+ f" Epoch {epoch + 1} step {step + 1}/{len(train_loader)} loss={avg_loss:.4f}"
+ )
+
+ # Evaluate
+ train_avg = epoch_loss / len(train_loader)
+ dev_loss = evaluate(model, dev_loader, device)
+ print(f"Epoch {epoch + 1}: train_loss={train_avg:.4f}, dev_loss={dev_loss:.4f}")
+
+ # Save best
+ if dev_loss < best_dev_loss:
+ best_dev_loss = dev_loss
+ model.save_pretrained(output_dir / "best")
+ tokenizer.save_pretrained(output_dir / "best")
+ print(f" Saved best model (dev_loss={dev_loss:.4f})")
+
+ # Save final
+ model.save_pretrained(output_dir / "final")
+ tokenizer.save_pretrained(output_dir / "final")
+ print(f"\nTraining complete. Best dev loss: {best_dev_loss:.4f}")
+ print(f"Models saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/conftest.py b/tests/conftest.py
index a0cecaf..4b5c624 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,43 +1 @@
"""Shared fixtures for pytest tests."""
-
-from unittest.mock import MagicMock
-
-import pytest
-import torch
-
-
-@pytest.fixture
-def mock_tokenizer():
- """Create a mock tokenizer for testing."""
- tokenizer = MagicMock()
- tokenizer.encode.return_value = [101, 102, 103, 104, 105]
-
- # Mock tokenizer call to return encoding
- tokenizer.return_value = {
- "input_ids": torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108]]),
- "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]),
- "offset_mapping": torch.tensor(
- [
- [0, 0], # [CLS]
- [0, 4], # "This"
- [5, 7], # "is"
- [8, 9], # "a"
- [10, 16], # "prompt"
- [0, 0], # [SEP]
- [0, 4], # "This"
- [5, 12], # "answer"
- ]
- ),
- }
-
- return tokenizer
-
-
-@pytest.fixture
-def mock_model():
- """Create a mock model for testing."""
- model = MagicMock()
- mock_output = MagicMock()
- mock_output.logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]])
- model.return_value = mock_output
- return model
diff --git a/tests/run_pytest.py b/tests/run_pytest.py
old mode 100644
new mode 100755
diff --git a/tests/test_inference_pytest.py b/tests/test_inference_pytest.py
index 546852b..e6f7c34 100644
--- a/tests/test_inference_pytest.py
+++ b/tests/test_inference_pytest.py
@@ -5,6 +5,7 @@
import pytest
import torch
+from lettucedetect.datasets.hallucination_dataset import HallucinationDataset
from lettucedetect.detectors.prompt_utils import PromptUtils
from lettucedetect.detectors.transformer import TransformerDetector
from lettucedetect.models.inference import HallucinationDetector
@@ -129,19 +130,15 @@ def test_init(self):
def test_predict(self):
"""Test predict method."""
-
- # Create a proper mock encoding with input_ids as a tensor attribute
- class MockEncoding:
- def __init__(self):
- self.input_ids = torch.tensor([[101, 102, 103]])
-
- mock_encoding = MockEncoding()
- mock_labels = torch.tensor([0, 0, 0])
- mock_offsets = torch.tensor([[0, 0], [0, 1], [1, 2]])
- mock_answer_start = 1
-
- # Patch the _predict method to avoid the actual implementation
- with patch.object(TransformerDetector, "_predict", return_value=[]):
+ # Patch internals to avoid actual model inference
+ with (
+ patch.object(
+ TransformerDetector,
+ "_group_passages_into_chunks",
+ return_value=[["This is a test context."]],
+ ),
+ patch.object(TransformerDetector, "_predict_single", return_value=[]),
+ ):
detector = TransformerDetector(model_path="dummy_path")
context = ["This is a test context."]
answer = "This is a test answer."
@@ -175,3 +172,191 @@ def test_form_prompt_without_question(self):
# Check that the prompt contains the text to summarize
assert "This is a text to summarize." in prompt
assert "Summarize" in prompt
+
+
+class TestChunking:
+ """Tests for automatic context chunking when input exceeds max_length."""
+
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ """Set up a TransformerDetector with mocked model/tokenizer."""
+ with (
+ patch(
+ "lettucedetect.detectors.transformer.AutoTokenizer.from_pretrained",
+ ) as tok_cls,
+ patch(
+ "lettucedetect.detectors.transformer.AutoModelForTokenClassification.from_pretrained",
+ ) as model_cls,
+ ):
+ # Set up a realistic tokenizer mock
+ self.mock_tokenizer = MagicMock()
+ self.mock_model = MagicMock()
+ tok_cls.return_value = self.mock_tokenizer
+ model_cls.return_value = self.mock_model
+
+ self.detector = TransformerDetector(model_path="dummy", max_length=32)
+ yield
+
+ def test_aggregation_uses_max(self):
+ """Max aggregation: highest hallucination prob across chunks wins."""
+ # Simulate two chunks with different probabilities for the same tokens
+ chunk1_tokens = [
+ {"token": "The", "pred": 0, "prob": 0.1},
+ {"token": "answer", "pred": 0, "prob": 0.3},
+ {"token": "is", "pred": 1, "prob": 0.8},
+ ]
+ chunk2_tokens = [
+ {"token": "The", "pred": 0, "prob": 0.2},
+ {"token": "answer", "pred": 1, "prob": 0.6},
+ {"token": "is", "pred": 0, "prob": 0.4},
+ ]
+
+ with patch.object(self.detector, "_predict_single") as mock_single:
+ mock_single.side_effect = [chunk1_tokens, chunk2_tokens]
+
+ # Mock _build_spans_from_tokens for the spans path
+ with patch.object(self.detector, "_build_spans_from_tokens", return_value=[]):
+ result = self.detector._predict_chunked(
+ ["chunk1", "chunk2"], "The answer is", output_format="tokens"
+ )
+
+ assert len(result) == 3
+ # Token 0: max(0.1, 0.2) = 0.2 β pred 0
+ assert result[0]["prob"] == 0.2
+ assert result[0]["pred"] == 0
+ # Token 1: max(0.3, 0.6) = 0.6 β pred 1 (β₯ 0.5)
+ assert result[1]["prob"] == 0.6
+ assert result[1]["pred"] == 1
+ # Token 2: max(0.8, 0.4) = 0.8 β pred 1
+ assert result[2]["prob"] == 0.8
+ assert result[2]["pred"] == 1
+
+ def test_group_passages_single_group(self):
+ """When all passages fit, _group_passages_into_chunks returns one group."""
+
+ def tokenizer_side_effect(text, **kwargs):
+ # Approximate: 1 token per word
+ words = len(text.split()) if text else 1
+ return {"input_ids": torch.zeros(1, words, dtype=torch.long)}
+
+ self.mock_tokenizer.side_effect = tokenizer_side_effect
+
+ with patch(
+ "lettucedetect.detectors.transformer.PromptUtils.format_context",
+ return_value="short prompt with passages",
+ ):
+ groups = self.detector._group_passages_into_chunks(
+ ["passage one", "passage two"], "What?", "short answer"
+ )
+ assert len(groups) == 1
+ assert groups[0] == ["passage one", "passage two"]
+
+ def test_group_passages_multiple_groups(self):
+ """When passages exceed budget, they should be split into multiple groups."""
+ # max_length=32, answer=5 tokens β total_budget = 32-5-3 = 24
+ # instruction_overhead = 10 tokens β passage_budget = 24-10 = 14
+ # Each passage = 10 tokens β 2 passages won't fit (10+1+10=21 > 14), so 1 per group
+
+ def tokenizer_side_effect(text, **kwargs):
+ if text == "short answer":
+ return {"input_ids": torch.zeros(1, 5, dtype=torch.long)}
+ if text.startswith("passage"):
+ return {"input_ids": torch.zeros(1, 10, dtype=torch.long)}
+ # "full_prompt" β 50 tokens (exceeds budget, triggers chunking)
+ if "full_prompt" in text:
+ return {"input_ids": torch.zeros(1, 50, dtype=torch.long)}
+ # "minimal" β 10 tokens (instruction overhead)
+ return {"input_ids": torch.zeros(1, 10, dtype=torch.long)}
+
+ self.mock_tokenizer.side_effect = tokenizer_side_effect
+
+ format_calls = [0]
+
+ def format_side_effect(ctx, q, lang):
+ format_calls[0] += 1
+ if format_calls[0] == 1:
+ return "full_prompt_long" # First call: all passages β triggers chunking
+ return "minimal" # Second call: [""] β measures instruction overhead
+
+ with patch(
+ "lettucedetect.detectors.transformer.PromptUtils.format_context",
+ side_effect=format_side_effect,
+ ):
+ groups = self.detector._group_passages_into_chunks(
+ ["passage A", "passage B", "passage C"], "What?", "short answer"
+ )
+
+ # With passage_budget = 14 and each passage = 10 tokens,
+ # each group holds 1 passage (10 < 14, but 10+1+10 = 21 > 14)
+ assert len(groups) == 3
+ assert groups[0] == ["passage A"]
+ assert groups[1] == ["passage B"]
+ assert groups[2] == ["passage C"]
+
+ def test_predict_uses_passage_chunking(self):
+ """predict() should use _group_passages_into_chunks for chunking."""
+ with (
+ patch.object(
+ self.detector,
+ "_group_passages_into_chunks",
+ return_value=[["p1"], ["p2"]],
+ ) as mock_group,
+ patch.object(
+ self.detector,
+ "_predict_chunked",
+ return_value=[{"token": "x", "pred": 0, "prob": 0.1}],
+ ) as mock_chunked,
+ patch(
+ "lettucedetect.detectors.transformer.PromptUtils.format_context",
+ side_effect=lambda ctx, q, lang: f"prompt_{ctx[0]}",
+ ),
+ ):
+ self.detector.predict(["p1", "p2"], "answer", "What?", "tokens")
+ mock_group.assert_called_once()
+ mock_chunked.assert_called_once_with(["prompt_p1", "prompt_p2"], "answer", "tokens")
+
+
+class TestAnswerStartToken:
+ """Tests for the answer_start_token fix in prepare_tokenized_input."""
+
+ def test_answer_start_token_basic(self):
+ """answer_start_token should point to the first answer token."""
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+ context = "The capital of France is Paris."
+ answer = "Paris is the capital."
+
+ _, _, offsets, answer_start = HallucinationDataset.prepare_tokenized_input(
+ tokenizer, context, answer, max_length=512
+ )
+
+ # answer_start should be within bounds
+ assert answer_start > 0
+ assert answer_start < offsets.size(0)
+ # The offset at answer_start should be non-zero (actual token, not special)
+ assert offsets[answer_start][1].item() > offsets[answer_start][0].item()
+
+ def test_answer_start_token_with_truncation(self):
+ """answer_start_token should be correct even when context is truncated."""
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+ # Create a very long context that will be truncated at max_length=32
+ context = "word " * 200 # ~200 tokens
+ answer = "short answer"
+
+ encoding, _, offsets, answer_start = HallucinationDataset.prepare_tokenized_input(
+ tokenizer, context, answer, max_length=32
+ )
+
+ total_len = encoding["input_ids"].shape[1]
+ assert total_len == 32 # Should be truncated to max_length
+
+ # answer_start should be within the actual sequence
+ assert answer_start > 0
+ assert answer_start < total_len
+
+ # The answer tokens should be at the end (before trailing [SEP])
+ # Verify by checking that the text at answer_start offset is non-empty
+ assert offsets[answer_start][1].item() > offsets[answer_start][0].item()