Pet project and research playground for LLM decoding acceleration.
This repository compares exact and approximate decoding strategies under a single benchmark harness, with reproducible configs, JSONL outputs, and report scripts.
sp_samp/ remains the historical benchmark stack. jointadaspec/ is the new thesis-focused implementation of JointAdaSpec: joint adaptive control of draft length and verification threshold through a tabular MDP.
- Side-by-side comparison of Baseline, Speculative Sampling, AutoJudge, Consensus AutoJudge, Top-K, SpecExec, and JointAdaSpec.
- Paper-aligned AutoJudge implementation (GSM8K label mining + LogisticRegression calibration).
- New
jointadaspec/stack for joint adaptive control of draft length and verification threshold via a tabular MDP. - Long-run friendly workflow (resume keys, checkpoints, strict result schema validation).
- Real benchmark reports are versioned in
reports/.
| Method | Exact vs target distribution | Main idea |
|---|---|---|
baseline |
exact | Target-only decoding |
speculative |
exact | Draft proposes, target verifies |
autojudge |
approximate | Judge can accept some mismatches |
consensus_autojudge |
approximate | Two drafts + consensus gate decide accept / escalate / fallback |
topk |
approximate | Accept mismatch if target token in top-k |
specexec |
exact | Parallel speculative branches + cache reuse |
jointadaspec |
approximate | Policy jointly chooses draft stop/continue and fuzzy verification threshold |
JointAdaSpec formulates speculative decoding control as a finite MDP over the state s = (H, K_prev, k) and solves the joint policy by sparse value iteration.
H: entropy of the draft distributionK_prev: delayedKL(q || p)from the latest available target-verification stepk: current position in the draft windowa_length ∈ {stop, continue}a_verif ∈ {1.00, 1.22, 1.49, 1.82, 2.22, 2.71, 3.30, 4.00}
The implementation lives in jointadaspec/ and is split into:
jointadaspec/core/: features, verification rules, decoder base classesjointadaspec/mdp/: discrete spaces, trace collection, MDP estimation, value iterationjointadaspec/inference/: learned policy lookup and online JointAdaSpec decoderjointadaspec/baselines/:target_only,speculative,fuzzy_sd,adaptive_length, and both cascade baselinesjointadaspec/analysis/: empirical checks for C1-C4 and N1-N2jointadaspec/metrics/: speed, quality, Pareto helpersjointadaspec/utils/: model loading, dataset loading, manifests, structured logging
The pipeline is stage-based:
scripts/01_collect_traces.pycollects exploratory SD traces on held-out prompts.scripts/02_solve_mdp.pyestimates transitions/rewards, solves the joint MDP, and saves both cascade baselines.scripts/03_benchmark.pybenchmarks JointAdaSpec against target-only, speculative, fuzzy-SD, adaptive-length, and the two cascades with seeded resume-safe JSONL output.scripts/04_verify_conditions.pychecks empirical monotonicity/supermodularity conditions and writes diagnostic plots.scripts/05_write_manifest.pywrites reproducibility manifests with git SHA, config YAML, seeds, and artifact hashes.
Minimal reproduction on the Qwen 14B / 0.5B setup:
make jointadaspec-full MODEL_PAIR=qwen14b_0p5b JOINTADA_MAX_TRACES=10 JOINTADA_MAX_SAMPLES=10 JOINTADA_N_SEEDS=1Outputs include traces, policy NPZ files, cascade policies, seeded benchmark JSONL, conditions JSON, manifests, Pareto plots, threshold-surface heatmaps, and ablation charts.
Conservative quality-aware local rerun on Qwen 7B / 1.5B:
make jointadaspec-full MODEL_PAIR=qwen7b_1p5b_quality JOINTADA_MAX_TRACES=500 JOINTADA_MAX_SAMPLES=100 JOINTADA_N_SEEDS=3 JOINTADA_MAX_NEW_TOKENS=256This config uses local checkpoints under models/ and adds state-dependent quality-risk reward shaping through quality_risk_K and quality_risk_k.
Held-out quality validation using the fixed 2026-05-05 policy:
make jointadaspec-quality-heldout-full JOINTADA_DATE=2026-05-08The validation uses GSM8K test_start_index=100, 500 prompts, 3 seeds, and pre-registers cascade_verif_then_length as the primary quality decoder.
Best substantive jointadaspec/ snapshot: 2026-04-14, Qwen2.5 7B -> 1.5B, RTX 5090.
Artifacts:
reports/jointadaspec_qwen_7b_1p5b_2026-04-14.mdreports/jointadaspec_qwen_runs_through_2026-04-28.mdreports/jointadaspec_quality_qwen7b_1p5b_quality_2026-05-05.mdoutputs/jointadaspec_qwen_2026-04-14/01_traces_gsm8k/outputs/jointadaspec_qwen_2026-04-14/02_solve/outputs/jointadaspec_qwen_2026-04-14/03_bench_gsm8k/outputs/jointadaspec_qwen_2026-04-14/04_bench_livecodebench/
Why this remains the headline benchmark:
3000collected traces479880one-step transition recordskappasweep:0.0, 0.5, 1.0, 2.0, 5.0100prompts on GSM8K and LiveCodeBench- value iteration converged for all saved policies
GSM8K throughput snapshot (100 prompts):
| Method | Speed (tok/s) | Acceptance | vs Vanilla |
|---|---|---|---|
| Vanilla AR | 29.38 | 0.000 | 1.000 |
| Fixed SD | 17.26 | 0.801 | 0.588 |
Fuzzy SD (T=4.0) |
19.09 | 0.915 | 0.650 |
| JointAdaSpec | 18.99 | 0.951 | 0.646 |
| SpecDecPP | 17.81 | 0.835 | 0.606 |
LiveCodeBench throughput snapshot (100 prompts):
| Method | Speed (tok/s) | Acceptance | vs Vanilla |
|---|---|---|---|
| Vanilla AR | 14.36 | 0.000 | 1.000 |
| Fixed SD | 8.05 | 0.659 | 0.560 |
Fuzzy SD (T=4.0) |
9.62 | 0.843 | 0.670 |
| JointAdaSpec | 10.34 | 0.890 | 0.720 |
| SpecDecPP | 9.04 | 0.763 | 0.629 |
Current interpretation:
JointAdaSpecis operational end to end on the Qwen7B -> 1.5Bprofile used for the first thesis traces.- On this single-GPU ratio it improves acceptance over the fixed-window baselines, but it does not yet outperform target-only decoding in throughput.
- The current
jointadaspec/benchmark path writes resumable seeded JSONL with task-level GSM8K exact match, bootstrap confidence intervals, plus legacyrun.jsonlprompt records for backward compatibility.
Latest tracked artifacts through 2026-05-05 include smoke, partial, and one complete quality-aware hypothesis run:
| Run | Scope | Best observed line | Status |
|---|---|---|---|
2026-04-21, Qwen 7B -> 1.5B |
schema/seeded smoke, 1 prompt, 1 seed, 8 new tokens |
cascade_verif_then_length: 18.47 tok/s, 1.204x vs target-only |
Complete smoke benchmark |
2026-04-28, Qwen 14B -> 0.5B |
smoke, 5 prompts, 1 seed, 64 new tokens |
cascade_verif_then_length: 13.15 tok/s, 0.725x; jointadaspec: 12.85 tok/s, 0.709x |
Complete smoke benchmark |
2026-04-28, Qwen 7B -> 1.5B |
full trace/solve attempt, 500 traces |
policy and condition artifacts saved | Benchmark stage failed on HuggingFace SSL EOF |
2026-05-05, Qwen 7B -> 1.5B quality-aware |
500 traces, 100 GSM8K prompts, 3 seeds |
cascade_verif_then_length: EM 0.6267; jointadaspec: EM 0.6167; target_only: EM 0.5700 |
Complete quality hypothesis run |
For the 2026-04-28 Qwen 14B -> 0.5B smoke, target_only reached 18.13 tok/s. Condition checks passed c2, c3, n1, and n2; c1 and c4 failed, which is recorded as an empirical result rather than a runtime failure.
Latest historical sp_samp/ full Llama run: 2026-03-28-llama-48h-cgrid8 on RTX 5090.
Source reports:
reports/yandex_llama3_8b_3b_2026-03-28-llama-48h-cgrid8-gsm8k.mdreports/yandex_llama3_8b_3b_2026-03-28-llama-48h-cgrid8-livecodebench.md
GSM8K highlights (k=4):
| Method | Accuracy (%) | Speed (tok/s) |
|---|---|---|
| Baseline | 70.89 | 72.68 |
| Speculative | 71.89 | 40.68 |
| AutoJudge (t=0.14) | 78.67 | 45.98 |
| Top-K (all) | 75.67 | 59.29 |
LiveCodeBench highlights (throughput only):
| Method | Speed (tok/s) |
|---|---|
| Baseline | 71.52 |
| Speculative | 34.80 |
| AutoJudge (t=1.0) | 29.27 |
| Top-K (all) | 36.53 |
More context and historical runs: docs/RESULTS.md.
make setup
make check
make test
make bench-toy OUT=/tmp/bench_toy.jsonlOptional tiny HF smoke:
make smoke-hf OUT=/tmp/smoke_hf.jsonlJointAdaSpec toy validation:
.venv/bin/python -m pytest tests/test_features.py tests/test_verification.py tests/test_mdp_solver.py tests/test_inference.py tests/test_end_to_end.py -qJointAdaSpec short smoke on the staged pipeline:
make jointadaspec-full MODEL_PAIR=qwen7b_1p5b JOINTADA_MAX_TRACES=10 JOINTADA_MAX_SAMPLES=10 JOINTADA_N_SEEDS=1Paper-style Qwen sweep:
make paper-evalLocal Qwen 7B/1.5B sweep:
make local-evalLocal Llama 8B/3B sweep:
bash scripts/run_llama3_8b_3b_eval.shJointAdaSpec staged long-run on Qwen2.5 7B/1.5B for both GSM8K and LiveCodeBench:
bash scripts/run_jointadaspec_qwen_longrun.shRecommended monitoring:
tmux new -s jointadaspec48h
bash scripts/run_jointadaspec_qwen_longrun.sh | tee -a logs/jointadaspec_$(date +%F).logValidate any JSONL output:
.venv/bin/python scripts/validate_results_jsonl.py --path datasets/results.jsonl --strictsp_samp/core implementations and HF adaptersjointadaspec/JointAdaSpec MDP training and inference stackbenchmarks/benchmark entrypoint and result loggingconfigs/model, method, and experiment presetsscripts/orchestration, validation, and report generationtests/unit testsreports/tracked benchmark artifactsdatasets/local datasets and run outputs (gitignored)
Method design notes:
docs/CONSENSUS_AUTOJUDGE.md- disagreement-aware two-draft approximate decoding design
- Draft and target must use tokenizer-compatible vocab mapping.
- AutoJudge paper C-grid policy is
1e-7..1e0(8 values). - Reusing the same output file enables automatic resume by
resume_key. - Llama checkpoints in this environment are gated on Hugging Face; use
experiments/qwen25_7b_1p5b_jointadaspecfor ungated JointAdaSpec smoke and long runs. - JointAdaSpec trace collection is intentionally exact and CPU-heavy at the control layer; long runs should be started in
tmux.
- Contribution guide:
CONTRIBUTING.md - Open issues and feature proposals: GitHub issue templates
- Current priorities:
docs/ROADMAP.md - Repository presentation checklist:
docs/GITHUB_SETUP.md
MIT. See LICENSE.