Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3.2 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant. Works for both prefill and decode stages.
Axes
num_tokens
varnum_qo_heads
16head_dim_ckv
512head_dim_kpe
64page_size
64topk
2048num_pages
varSignature
Inputs
| Name | Type | Shape |
|---|---|---|
q_nope | bfloat16 | [num_tokens, num_qo_heads, head_dim_ckv] |
q_pe | bfloat16 | [num_tokens, num_qo_heads, head_dim_kpe] |
ckv_cache | bfloat16 | [num_pages, page_size, head_dim_ckv] |
kpe_cache | bfloat16 | [num_pages, page_size, head_dim_kpe] |
sparse_indices | int32 | [num_tokens, topk] |
sm_scale | float32 | Scalar |
Outputs
| Name | Type | Shape |
|---|---|---|
output | bfloat16 | [num_tokens, num_qo_heads, head_dim_ckv] |
lse | float32 | [num_tokens, num_qo_heads] |
Constraints
- • sparse_indices.shape[0] == num_tokens
- • sparse_indices.shape[-1] == topk
- • ckv_cache.shape[1] == page_size
Reference Implementation
Loading editor...
Loading solutions…
