Skip to content

othakkar/tinformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

58 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tinformer

A tiny decoder-only transformer in JAX where every component, from scaled dot-product attention to autoregressive generation, is implemented from scratch. Supports Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) via a configurable num_kv_heads parameter. Built to trace every tensor shape through every layer, not to produce coherent text.

Tags

Tag Description Article
tinformer-from-scratch Base transformer implementation Tinformer: Building a Tiny Transformer from Scratch in JAX
kv-caching KV caching with benchmarks KV Caching from Scratch in JAX
mqa-gqa Multi-Query & Grouped-Query Attention MQA and GQA: Shrinking the KV Cache in Tinformer

Repo structure

.gitignore
LICENSE
README.md
generate.py                          # Naive + cached autoregressive generation
requirements.txt
src/
  config.py                          # TinformerConfig dataclass (incl. num_kv_heads for MQA/GQA)
  attention.py                       # Scaled dot-product attention + MHA/MQA/GQA (with KV cache)
  layernorm.py                       # Layer normalization
  ffn.py                             # Feed-forward network
  decoder.py                         # Decoder block (LN → MHA → residual → LN → FFN → residual)
  tinformer.py                       # Full model with KV cache support
benchmarks/
  benchmark_kv_caching.py            # Naive vs cached generation benchmarks
tests/
  test_shapes.py                     # Tensor shape verification
  test_causal_masking.py             # Causal mask validation
  test_attention_stability.py        # Numerical stability tests
  test_cached_generate.py            # KV cache correctness tests
  test_mqa_gqa.py                    # MQA/GQA configuration tests

Quickstart

git clone https://github.com/othakkar/tinformer.git
cd tinformer
pip install -r requirements.txt
python -m generate

Run benchmarks

python -m benchmarks.benchmark_kv_caching

Run tests

pytest tests/

License

Apache 2.0

About

Tiny decoder-only transformer in JAX. Every component built from scratch.

Resources

License

Stars

Watchers

Forks

Contributors

Languages