Skip to content

feat(nn): Add Sliding Window / Local Attention for long context #170

@m96-chan

Description

@m96-chan

Summary

Implement Sliding Window Attention (SWA) for efficient long-context inference. Each token attends only to a fixed window of previous tokens instead of all previous tokens.

Background

Full causal attention has O(n^2) complexity, making long sequences expensive. Sliding Window Attention limits each token to attend only to the last w tokens:

Token at position i attends to: [max(0, i - w + 1), i]

Used in:

  • Mistral 7B (window_size=4096)
  • Longformer (local + global attention)
  • BigBird (sparse attention patterns)

Benefits:

  • O(n * w) memory instead of O(n^2)
  • Constant memory per token during generation
  • KV cache size bounded by window_size

Attention Pattern

Full Causal:          Sliding Window (w=3):
[1 0 0 0 0]           [1 0 0 0 0]
[1 1 0 0 0]           [1 1 0 0 0]
[1 1 1 0 0]           [1 1 1 0 0]
[1 1 1 1 0]           [0 1 1 1 0]  <- only last 3
[1 1 1 1 1]           [0 0 1 1 1]  <- only last 3

Proposed Implementation

Native Kernels

native/ops/nn/attention/
├── sdpa_causal.inl           # Existing
├── sdpa_sliding_window.inl   # NEW
└── sdpa_sliding_window.cuh   # NEW: optimized kernel

Python API

# SDPA with sliding window
sdpa_sliding_window(
    Q, K, V,
    window_size=4096,
    scale=0.0,
    out=None,
)

# With fixed KV cache (for decode)
sdpa_sliding_window_fixed_cache(
    Q, K, V, out,
    context_len,
    window_size=4096,
    scale=0.0,
)

KV Cache Strategy

Two options for KV cache with sliding window:

Option A: Ring Buffer

# KV cache is a ring buffer of size window_size
# New KV written at position % window_size
kv_cache_pos = context_len % window_size
k_cache[kv_cache_pos] = k_new
v_cache[kv_cache_pos] = v_new

Option B: Full Cache + Masking

# Keep full cache, apply window mask in attention
# Simpler but uses more memory
attention_mask[i, j] = 1 if (i - j) < window_size else 0

Tasks

  • Implement sdpa_sliding_window() kernel
  • Implement sdpa_sliding_window_fixed_cache() for decode
  • Add ring buffer KV cache option
  • Add Python bindings
  • Add Python API in ops/nn.py
  • Add tests
  • Benchmark vs full attention
  • Update Mistral model support

Memory Comparison (seq_len=32K, window=4K)

Method KV Cache Size Attention Memory
Full Causal 32K * dim O(32K^2)
Sliding Window 4K * dim O(32K * 4K)
Savings 8x 8x

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions