-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Summary
Implement Sliding Window Attention (SWA) for efficient long-context inference. Each token attends only to a fixed window of previous tokens instead of all previous tokens.
Background
Full causal attention has O(n^2) complexity, making long sequences expensive. Sliding Window Attention limits each token to attend only to the last w tokens:
Token at position i attends to: [max(0, i - w + 1), i]
Used in:
- Mistral 7B (window_size=4096)
- Longformer (local + global attention)
- BigBird (sparse attention patterns)
Benefits:
- O(n * w) memory instead of O(n^2)
- Constant memory per token during generation
- KV cache size bounded by window_size
Attention Pattern
Full Causal: Sliding Window (w=3):
[1 0 0 0 0] [1 0 0 0 0]
[1 1 0 0 0] [1 1 0 0 0]
[1 1 1 0 0] [1 1 1 0 0]
[1 1 1 1 0] [0 1 1 1 0] <- only last 3
[1 1 1 1 1] [0 0 1 1 1] <- only last 3
Proposed Implementation
Native Kernels
native/ops/nn/attention/
├── sdpa_causal.inl # Existing
├── sdpa_sliding_window.inl # NEW
└── sdpa_sliding_window.cuh # NEW: optimized kernel
Python API
# SDPA with sliding window
sdpa_sliding_window(
Q, K, V,
window_size=4096,
scale=0.0,
out=None,
)
# With fixed KV cache (for decode)
sdpa_sliding_window_fixed_cache(
Q, K, V, out,
context_len,
window_size=4096,
scale=0.0,
)KV Cache Strategy
Two options for KV cache with sliding window:
Option A: Ring Buffer
# KV cache is a ring buffer of size window_size
# New KV written at position % window_size
kv_cache_pos = context_len % window_size
k_cache[kv_cache_pos] = k_new
v_cache[kv_cache_pos] = v_newOption B: Full Cache + Masking
# Keep full cache, apply window mask in attention
# Simpler but uses more memory
attention_mask[i, j] = 1 if (i - j) < window_size else 0Tasks
- Implement
sdpa_sliding_window()kernel - Implement
sdpa_sliding_window_fixed_cache()for decode - Add ring buffer KV cache option
- Add Python bindings
- Add Python API in
ops/nn.py - Add tests
- Benchmark vs full attention
- Update Mistral model support
Memory Comparison (seq_len=32K, window=4K)
| Method | KV Cache Size | Attention Memory |
|---|---|---|
| Full Causal | 32K * dim | O(32K^2) |
| Sliding Window | 4K * dim | O(32K * 4K) |
| Savings | 8x | 8x |
References
- Mistral: https://arxiv.org/abs/2310.06825
- Longformer: https://arxiv.org/abs/2004.05150
- Sliding Window Attention explained: https://huggingface.co/blog/mistral
Metadata
Metadata
Assignees
Labels
No labels