-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Summary
Implement Retention mechanism from RetNet (Retentive Network) as an alternative to attention.
Background
RetNet (Microsoft, 2023) introduces "retention" - a mechanism that combines benefits of Transformers and RNNs:
- Parallel mode: O(n) training like attention (but linear complexity)
- Recurrent mode: O(1) inference per token like RNN
- Chunkwise mode: Hybrid for long sequences
Key advantages:
- Linear complexity O(n) vs attention's O(n²)
- Constant memory during inference
- No KV cache needed (uses recurrent state)
- Competitive performance with Transformers
Formula
Parallel Form (Training)
Retention(X) = (QK^T ⊙ D) V
where:
Q = XW_Q, K = XW_K, V = XW_V
D[i,j] = γ^(i-j) if i >= j else 0 (decay mask)
γ = 1 - 2^(-5 - arange(num_heads)) (per-head decay)
Recurrent Form (Inference)
S_n = γ * S_{n-1} + K_n^T V_n (state update)
O_n = Q_n S_n (output)
Chunkwise Form (Hybrid)
Combines parallel within chunks, recurrent across chunks.
Proposed Implementation
Native Kernels
native/ops/nn/retention/
├── retention_parallel.inl # Parallel mode (training)
├── retention_recurrent.inl # Recurrent mode (inference)
├── retention_chunkwise.inl # Hybrid mode
└── retention_kernels.cuh # CUDA kernels
Python API
from pygpukit.ops.nn import (
retention_init_decay,
retention_parallel,
retention_recurrent,
retention_chunkwise,
)
# Initialize decay rates per head
gamma = retention_init_decay(num_heads) # [num_heads]
# Parallel mode (training/prefill)
output = retention_parallel(Q, K, V, gamma)
# Recurrent mode (decode)
# state: [num_heads, head_dim, head_dim]
output, new_state = retention_recurrent(Q, K, V, gamma, state)
# Chunkwise mode (long sequences)
output = retention_chunkwise(Q, K, V, gamma, chunk_size=512)State Management
# Initialize recurrent state
state = retention_init_state(num_heads, head_dim, dtype="bfloat16")
# Decode loop
for token in tokens:
q, k, v = compute_qkv(token)
output, state = retention_recurrent(q, k, v, gamma, state)Comparison with Attention
| Aspect | Attention | Retention |
|---|---|---|
| Training complexity | O(n²) | O(n) |
| Inference memory | O(n) KV cache | O(d²) state |
| Parallelizable | Yes | Yes (parallel mode) |
| Recurrent | No | Yes |
Tasks
- Implement
retention_init_decay() - Implement
retention_parallel()kernel - Implement
retention_recurrent()kernel - Implement
retention_chunkwise()kernel - Implement state management utilities
- Add Multi-Scale Retention (MSR) wrapper
- Add Python bindings
- Add Python API in
ops/nn.py - Add tests
- Add benchmark vs attention
References
- RetNet paper: https://arxiv.org/abs/2307.08621
- Official implementation: https://github.com/microsoft/torchscale
- "Retentive Network: A Successor to Transformer for Large Language Models"
Metadata
Metadata
Assignees
Labels
No labels