k << n.
Variants:
- indexer
- sparse_attention
indexer
Computes sparse attention scores using ReLU activation and learned weights, then selects top-K KV cache indices. Uses FP8 quantization with deep_gemm format. Axes (9 dimensions):batch_size,max_num_pages,num_pages: variablenum_index_heads,index_head_dim,page_size,topk,kv_cache_num_heads,head_dim_with_scale: constant
q_index_fp8: FP8 query for indexing [batch_size, num_index_heads, index_head_dim]k_index_cache_fp8: FP8 key index cache with scales [num_pages, page_size, kv_cache_num_heads, head_dim_with_scale]weights: learned head weights [batch_size, num_index_heads]seq_lens: sequence lengths [batch_size]block_table: page mapping [batch_size, max_num_pages]
topk_indices: selected token indices [batch_size, topk], -1 indicates padding
topk <= max_num_pages * page_sizenum_index_heads == 64,index_head_dim == 128(deep_gemm requirement)head_dim_with_scale == 132(128 + 4 scale bytes)
sparse_attention
Performs MLA-style attention on top-K selected KV entries. Works for both prefill (multiple tokens) and decode (one token per sequence) - the computation is identical, only the first dimension differs. Axes (7 dimensions):num_tokens,num_pages: variablenum_qo_heads,head_dim_ckv,head_dim_kpe,page_size,topk: constant
q_nope: query without positional encoding [num_tokens, num_qo_heads, head_dim_ckv]q_pe: query positional encoding [num_tokens, num_qo_heads, head_dim_kpe]ckv_cache: compressed KV cache [num_pages, page_size, head_dim_ckv]kpe_cache: key positional encoding cache [num_pages, page_size, head_dim_kpe]sparse_indices: top-K indices per token [num_tokens, topk], -1 indicates paddingsm_scale: softmax scale (scalar)
output: attention output [num_tokens, num_qo_heads, head_dim_ckv]lse: 2-based log-sum-exp [num_tokens, num_qo_heads]
sparse_indices.shape[0] == num_tokenssparse_indices.shape[-1] == topkckv_cache.shape[1] == page_size

