Skip to main content
DeepSeek Sparse Attention (DSA) with paged memory layout. DSA is a two-stage sparse attention mechanism: first an indexer selects top-K relevant KV cache entries using ReLU scoring, then MLA-style attention is performed only on selected entries. This reduces attention computation from O(n) to O(k) where k << n. Variants:
  • indexer
  • sparse_attention

indexer

Computes sparse attention scores using ReLU activation and learned weights, then selects top-K KV cache indices. Uses FP8 quantization with deep_gemm format. Axes (9 dimensions):
  • batch_size, max_num_pages, num_pages: variable
  • num_index_heads, index_head_dim, page_size, topk, kv_cache_num_heads, head_dim_with_scale: constant
Inputs (5 tensors):
  • q_index_fp8: FP8 query for indexing [batch_size, num_index_heads, index_head_dim]
  • k_index_cache_fp8: FP8 key index cache with scales [num_pages, page_size, kv_cache_num_heads, head_dim_with_scale]
  • weights: learned head weights [batch_size, num_index_heads]
  • seq_lens: sequence lengths [batch_size]
  • block_table: page mapping [batch_size, max_num_pages]
Outputs (1 tensor):
  • topk_indices: selected token indices [batch_size, topk], -1 indicates padding
Constraints:
  • topk <= max_num_pages * page_size
  • num_index_heads == 64, index_head_dim == 128 (deep_gemm requirement)
  • head_dim_with_scale == 132 (128 + 4 scale bytes)

sparse_attention

Performs MLA-style attention on top-K selected KV entries. Works for both prefill (multiple tokens) and decode (one token per sequence) - the computation is identical, only the first dimension differs. Axes (7 dimensions):
  • num_tokens, num_pages: variable
  • num_qo_heads, head_dim_ckv, head_dim_kpe, page_size, topk: constant
Inputs (5 tensors + 1 scalar):
  • q_nope: query without positional encoding [num_tokens, num_qo_heads, head_dim_ckv]
  • q_pe: query positional encoding [num_tokens, num_qo_heads, head_dim_kpe]
  • ckv_cache: compressed KV cache [num_pages, page_size, head_dim_ckv]
  • kpe_cache: key positional encoding cache [num_pages, page_size, head_dim_kpe]
  • sparse_indices: top-K indices per token [num_tokens, topk], -1 indicates padding
  • sm_scale: softmax scale (scalar)
Outputs (2 tensors):
  • output: attention output [num_tokens, num_qo_heads, head_dim_ckv]
  • lse: 2-based log-sum-exp [num_tokens, num_qo_heads]
Constraints:
  • sparse_indices.shape[0] == num_tokens
  • sparse_indices.shape[-1] == topk
  • ckv_cache.shape[1] == page_size