Skip to main contentGrouped Query Attention (GQA) with paged memory layout. GQA is an attention mechanism that reduces memory usage by grouping multiple query heads to share the same key-value heads, while using a paged memory system for efficient KV cache management. This allows for variable-length sequences and better memory utilization compared to traditional attention.
Variants:
prefill
Axes (8 dimensions):
total_q, num_pages, len_indptr, num_kv_indices: variable
num_qo_heads, num_kv_heads, head_dim, page_size: constant
Inputs (6 tensors + 1 scalar):
q: query tensor [total_q, num_qo_heads, head_dim]
k_cache, v_cache: paged KV cache [num_pages, page_size, num_kv_heads, head_dim]
qo_indptr, kv_indptr, kv_indices: paging indices
sm_scale: softmax scale (scalar)
Outputs (2 tensors):
output: attention output [total_q, num_qo_heads, head_dim]
lse: log-sum-exp values [total_q, num_qo_heads]
Constraints:
total_q == qo_indptr[-1]
num_kv_indices = kv_indptr[-1]
decode
Axes (8 dimensions):
total_q, num_pages, len_indptr, num_kv_indices: variable
num_qo_heads, num_kv_heads, head_dim, page_size: constant
Inputs (5 tensors + 1 scalar):
q: query tensor [total_q, num_qo_heads, head_dim]
k_cache, v_cache: paged KV cache [num_pages, page_size, num_kv_heads, head_dim]
kv_indptr, kv_indices: paging indices
sm_scale: softmax scale (scalar)
Outputs (2 tensors):
output: attention output [total_q, num_qo_heads, head_dim]
lse: log-sum-exp values [total_q, num_qo_heads]
Constraints:
len_indptr = num_pages + 1
num_kv_indices = kv_indptr[-1]