Skip to content

levvius/adaptive-speculative-decoding

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Adaptive Speculative Decoding

CI License: MIT Python 3.11+

Pet project and research playground for LLM decoding acceleration.

This repository compares exact and approximate decoding strategies under a single benchmark harness, with reproducible configs, JSONL outputs, and report scripts.

sp_samp/ remains the historical benchmark stack. jointadaspec/ is the new thesis-focused implementation of JointAdaSpec: joint adaptive control of draft length and verification threshold through a tabular MDP.

Why This Project

  • Side-by-side comparison of Baseline, Speculative Sampling, AutoJudge, Consensus AutoJudge, Top-K, SpecExec, and JointAdaSpec.
  • Paper-aligned AutoJudge implementation (GSM8K label mining + LogisticRegression calibration).
  • New jointadaspec/ stack for joint adaptive control of draft length and verification threshold via a tabular MDP.
  • Long-run friendly workflow (resume keys, checkpoints, strict result schema validation).
  • Real benchmark reports are versioned in reports/.

Implemented Methods

Method Exact vs target distribution Main idea
baseline exact Target-only decoding
speculative exact Draft proposes, target verifies
autojudge approximate Judge can accept some mismatches
consensus_autojudge approximate Two drafts + consensus gate decide accept / escalate / fallback
topk approximate Accept mismatch if target token in top-k
specexec exact Parallel speculative branches + cache reuse
jointadaspec approximate Policy jointly chooses draft stop/continue and fuzzy verification threshold

JointAdaSpec: Joint Adaptive Speculative Decoding

JointAdaSpec formulates speculative decoding control as a finite MDP over the state s = (H, K_prev, k) and solves the joint policy by sparse value iteration.

  • H: entropy of the draft distribution
  • K_prev: delayed KL(q || p) from the latest available target-verification step
  • k: current position in the draft window
  • a_length ∈ {stop, continue}
  • a_verif ∈ {1.00, 1.22, 1.49, 1.82, 2.22, 2.71, 3.30, 4.00}

The implementation lives in jointadaspec/ and is split into:

  • jointadaspec/core/: features, verification rules, decoder base classes
  • jointadaspec/mdp/: discrete spaces, trace collection, MDP estimation, value iteration
  • jointadaspec/inference/: learned policy lookup and online JointAdaSpec decoder
  • jointadaspec/baselines/: target_only, speculative, fuzzy_sd, adaptive_length, and both cascade baselines
  • jointadaspec/analysis/: empirical checks for C1-C4 and N1-N2
  • jointadaspec/metrics/: speed, quality, Pareto helpers
  • jointadaspec/utils/: model loading, dataset loading, manifests, structured logging

The pipeline is stage-based:

  1. scripts/01_collect_traces.py collects exploratory SD traces on held-out prompts.
  2. scripts/02_solve_mdp.py estimates transitions/rewards, solves the joint MDP, and saves both cascade baselines.
  3. scripts/03_benchmark.py benchmarks JointAdaSpec against target-only, speculative, fuzzy-SD, adaptive-length, and the two cascades with seeded resume-safe JSONL output.
  4. scripts/04_verify_conditions.py checks empirical monotonicity/supermodularity conditions and writes diagnostic plots.
  5. scripts/05_write_manifest.py writes reproducibility manifests with git SHA, config YAML, seeds, and artifact hashes.

Minimal reproduction on the Qwen 14B / 0.5B setup:

make jointadaspec-full MODEL_PAIR=qwen14b_0p5b JOINTADA_MAX_TRACES=10 JOINTADA_MAX_SAMPLES=10 JOINTADA_N_SEEDS=1

Outputs include traces, policy NPZ files, cascade policies, seeded benchmark JSONL, conditions JSON, manifests, Pareto plots, threshold-surface heatmaps, and ablation charts.

Conservative quality-aware local rerun on Qwen 7B / 1.5B:

make jointadaspec-full MODEL_PAIR=qwen7b_1p5b_quality JOINTADA_MAX_TRACES=500 JOINTADA_MAX_SAMPLES=100 JOINTADA_N_SEEDS=3 JOINTADA_MAX_NEW_TOKENS=256

This config uses local checkpoints under models/ and adds state-dependent quality-risk reward shaping through quality_risk_K and quality_risk_k.

Held-out quality validation using the fixed 2026-05-05 policy:

make jointadaspec-quality-heldout-full JOINTADA_DATE=2026-05-08

The validation uses GSM8K test_start_index=100, 500 prompts, 3 seeds, and pre-registers cascade_verif_then_length as the primary quality decoder.

Latest Benchmark Snapshot

JointAdaSpec Result Status

Best substantive jointadaspec/ snapshot: 2026-04-14, Qwen2.5 7B -> 1.5B, RTX 5090.

Artifacts:

  • reports/jointadaspec_qwen_7b_1p5b_2026-04-14.md
  • reports/jointadaspec_qwen_runs_through_2026-04-28.md
  • reports/jointadaspec_quality_qwen7b_1p5b_quality_2026-05-05.md
  • outputs/jointadaspec_qwen_2026-04-14/01_traces_gsm8k/
  • outputs/jointadaspec_qwen_2026-04-14/02_solve/
  • outputs/jointadaspec_qwen_2026-04-14/03_bench_gsm8k/
  • outputs/jointadaspec_qwen_2026-04-14/04_bench_livecodebench/

Why this remains the headline benchmark:

  • 3000 collected traces
  • 479880 one-step transition records
  • kappa sweep: 0.0, 0.5, 1.0, 2.0, 5.0
  • 100 prompts on GSM8K and LiveCodeBench
  • value iteration converged for all saved policies

GSM8K throughput snapshot (100 prompts):

Method Speed (tok/s) Acceptance vs Vanilla
Vanilla AR 29.38 0.000 1.000
Fixed SD 17.26 0.801 0.588
Fuzzy SD (T=4.0) 19.09 0.915 0.650
JointAdaSpec 18.99 0.951 0.646
SpecDecPP 17.81 0.835 0.606

LiveCodeBench throughput snapshot (100 prompts):

Method Speed (tok/s) Acceptance vs Vanilla
Vanilla AR 14.36 0.000 1.000
Fixed SD 8.05 0.659 0.560
Fuzzy SD (T=4.0) 9.62 0.843 0.670
JointAdaSpec 10.34 0.890 0.720
SpecDecPP 9.04 0.763 0.629

Current interpretation:

  • JointAdaSpec is operational end to end on the Qwen 7B -> 1.5B profile used for the first thesis traces.
  • On this single-GPU ratio it improves acceptance over the fixed-window baselines, but it does not yet outperform target-only decoding in throughput.
  • The current jointadaspec/ benchmark path writes resumable seeded JSONL with task-level GSM8K exact match, bootstrap confidence intervals, plus legacy run.jsonl prompt records for backward compatibility.

Latest tracked artifacts through 2026-05-05 include smoke, partial, and one complete quality-aware hypothesis run:

Run Scope Best observed line Status
2026-04-21, Qwen 7B -> 1.5B schema/seeded smoke, 1 prompt, 1 seed, 8 new tokens cascade_verif_then_length: 18.47 tok/s, 1.204x vs target-only Complete smoke benchmark
2026-04-28, Qwen 14B -> 0.5B smoke, 5 prompts, 1 seed, 64 new tokens cascade_verif_then_length: 13.15 tok/s, 0.725x; jointadaspec: 12.85 tok/s, 0.709x Complete smoke benchmark
2026-04-28, Qwen 7B -> 1.5B full trace/solve attempt, 500 traces policy and condition artifacts saved Benchmark stage failed on HuggingFace SSL EOF
2026-05-05, Qwen 7B -> 1.5B quality-aware 500 traces, 100 GSM8K prompts, 3 seeds cascade_verif_then_length: EM 0.6267; jointadaspec: EM 0.6167; target_only: EM 0.5700 Complete quality hypothesis run

For the 2026-04-28 Qwen 14B -> 0.5B smoke, target_only reached 18.13 tok/s. Condition checks passed c2, c3, n1, and n2; c1 and c4 failed, which is recorded as an empirical result rather than a runtime failure.

Historical sp_samp Snapshot

Latest historical sp_samp/ full Llama run: 2026-03-28-llama-48h-cgrid8 on RTX 5090.

Source reports:

  • reports/yandex_llama3_8b_3b_2026-03-28-llama-48h-cgrid8-gsm8k.md
  • reports/yandex_llama3_8b_3b_2026-03-28-llama-48h-cgrid8-livecodebench.md

GSM8K highlights (k=4):

Method Accuracy (%) Speed (tok/s)
Baseline 70.89 72.68
Speculative 71.89 40.68
AutoJudge (t=0.14) 78.67 45.98
Top-K (all) 75.67 59.29

LiveCodeBench highlights (throughput only):

Method Speed (tok/s)
Baseline 71.52
Speculative 34.80
AutoJudge (t=1.0) 29.27
Top-K (all) 36.53

More context and historical runs: docs/RESULTS.md.

Quick Start (5 Minutes)

make setup
make check
make test
make bench-toy OUT=/tmp/bench_toy.jsonl

Optional tiny HF smoke:

make smoke-hf OUT=/tmp/smoke_hf.jsonl

JointAdaSpec toy validation:

.venv/bin/python -m pytest tests/test_features.py tests/test_verification.py tests/test_mdp_solver.py tests/test_inference.py tests/test_end_to_end.py -q

JointAdaSpec short smoke on the staged pipeline:

make jointadaspec-full MODEL_PAIR=qwen7b_1p5b JOINTADA_MAX_TRACES=10 JOINTADA_MAX_SAMPLES=10 JOINTADA_N_SEEDS=1

Reproduce Main Runs

Paper-style Qwen sweep:

make paper-eval

Local Qwen 7B/1.5B sweep:

make local-eval

Local Llama 8B/3B sweep:

bash scripts/run_llama3_8b_3b_eval.sh

JointAdaSpec staged long-run on Qwen2.5 7B/1.5B for both GSM8K and LiveCodeBench:

bash scripts/run_jointadaspec_qwen_longrun.sh

Recommended monitoring:

tmux new -s jointadaspec48h
bash scripts/run_jointadaspec_qwen_longrun.sh | tee -a logs/jointadaspec_$(date +%F).log

Validate any JSONL output:

.venv/bin/python scripts/validate_results_jsonl.py --path datasets/results.jsonl --strict

Project Structure

  • sp_samp/ core implementations and HF adapters
  • jointadaspec/ JointAdaSpec MDP training and inference stack
  • benchmarks/ benchmark entrypoint and result logging
  • configs/ model, method, and experiment presets
  • scripts/ orchestration, validation, and report generation
  • tests/ unit tests
  • reports/ tracked benchmark artifacts
  • datasets/ local datasets and run outputs (gitignored)

Method design notes:

  • docs/CONSENSUS_AUTOJUDGE.md - disagreement-aware two-draft approximate decoding design

Constraints and Repro Notes

  • Draft and target must use tokenizer-compatible vocab mapping.
  • AutoJudge paper C-grid policy is 1e-7..1e0 (8 values).
  • Reusing the same output file enables automatic resume by resume_key.
  • Llama checkpoints in this environment are gated on Hugging Face; use experiments/qwen25_7b_1p5b_jointadaspec for ungated JointAdaSpec smoke and long runs.
  • JointAdaSpec trace collection is intentionally exact and CPU-heavy at the control layer; long runs should be started in tmux.

For Reviewers and Contributors

  • Contribution guide: CONTRIBUTING.md
  • Open issues and feature proposals: GitHub issue templates
  • Current priorities: docs/ROADMAP.md
  • Repository presentation checklist: docs/GITHUB_SETUP.md

License

MIT. See LICENSE.

About

Adaptive speculative decoding for LLM inference latency optimization

Topics

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors