Skip to content

feat(nn): Add Retention / RetNet mechanism #172

@m96-chan

Description

@m96-chan

Summary

Implement Retention mechanism from RetNet (Retentive Network) as an alternative to attention.

Background

RetNet (Microsoft, 2023) introduces "retention" - a mechanism that combines benefits of Transformers and RNNs:

  • Parallel mode: O(n) training like attention (but linear complexity)
  • Recurrent mode: O(1) inference per token like RNN
  • Chunkwise mode: Hybrid for long sequences

Key advantages:

  • Linear complexity O(n) vs attention's O(n²)
  • Constant memory during inference
  • No KV cache needed (uses recurrent state)
  • Competitive performance with Transformers

Formula

Parallel Form (Training)

Retention(X) = (QK^T ⊙ D) V

where:
  Q = XW_Q,  K = XW_K,  V = XW_V
  D[i,j] = γ^(i-j) if i >= j else 0  (decay mask)
  γ = 1 - 2^(-5 - arange(num_heads))  (per-head decay)

Recurrent Form (Inference)

S_n = γ * S_{n-1} + K_n^T V_n   (state update)
O_n = Q_n S_n                    (output)

Chunkwise Form (Hybrid)

Combines parallel within chunks, recurrent across chunks.

Proposed Implementation

Native Kernels

native/ops/nn/retention/
├── retention_parallel.inl    # Parallel mode (training)
├── retention_recurrent.inl   # Recurrent mode (inference)  
├── retention_chunkwise.inl   # Hybrid mode
└── retention_kernels.cuh     # CUDA kernels

Python API

from pygpukit.ops.nn import (
    retention_init_decay,
    retention_parallel,
    retention_recurrent,
    retention_chunkwise,
)

# Initialize decay rates per head
gamma = retention_init_decay(num_heads)  # [num_heads]

# Parallel mode (training/prefill)
output = retention_parallel(Q, K, V, gamma)

# Recurrent mode (decode) 
# state: [num_heads, head_dim, head_dim]
output, new_state = retention_recurrent(Q, K, V, gamma, state)

# Chunkwise mode (long sequences)
output = retention_chunkwise(Q, K, V, gamma, chunk_size=512)

State Management

# Initialize recurrent state
state = retention_init_state(num_heads, head_dim, dtype="bfloat16")

# Decode loop
for token in tokens:
    q, k, v = compute_qkv(token)
    output, state = retention_recurrent(q, k, v, gamma, state)

Comparison with Attention

Aspect Attention Retention
Training complexity O(n²) O(n)
Inference memory O(n) KV cache O(d²) state
Parallelizable Yes Yes (parallel mode)
Recurrent No Yes

Tasks

  • Implement retention_init_decay()
  • Implement retention_parallel() kernel
  • Implement retention_recurrent() kernel
  • Implement retention_chunkwise() kernel
  • Implement state management utilities
  • Add Multi-Scale Retention (MSR) wrapper
  • Add Python bindings
  • Add Python API in ops/nn.py
  • Add tests
  • Add benchmark vs attention

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions