- prefill (chunked computation for variable-length sequences)
- decode (single-token generation with recurrent state update)
prefill
Axes (6 dimensions):total_seq_len,num_seqs: variablenum_q_heads,num_k_heads,num_v_heads,head_size: constant
q: query tensor [total_seq_len, num_q_heads, head_size]k: key tensor [total_seq_len, num_k_heads, head_size]v: value tensor [total_seq_len, num_v_heads, head_size]g: forget gate (alpha) [total_seq_len, num_sab_heads], float32, optional (defaults to ones)beta: update gate [total_seq_len, num_sab_heads], float32, optional (defaults to ones)cu_seqlens: cumulative sequence lengths [num_seqs + 1], int64initial_state: initial KV state [num_seqs, num_sab_heads, head_size, head_size], float32, optionalscale: softmax scale (scalar), optional (defaults to 1/sqrt(head_size))
output: attention output [total_seq_len, num_o_heads, head_size]final_state: final KV state [num_seqs, num_sab_heads, head_size, head_size], float32
num_sab_heads = max(num_q_heads, num_v_heads)(state and beta heads)num_o_heads = max(num_q_heads, num_v_heads)(output heads)
total_seq_len == cu_seqlens[-1]num_seqs == len(cu_seqlens) - 1- For GQA:
num_q_heads >= num_k_headsandnum_q_heads % num_k_heads == 0 - For GVA:
num_v_heads >= num_q_headsandnum_v_heads % num_q_heads == 0 num_k_heads == num_v_heads(keys and values must have same number of heads)
- The final state is in k-last layout [N, H, V, K]
- Gate tensors (g, beta) are in float32 for numerical stability
decode
Single-token decoding with recurrent state update. Uses gating parameters (A_log, a, dt_bias, b) to compute decay and update gates. Optionally applies L2 normalization to q and k for numerical stability. Axes (6 dimensions):batch_size: variable (number of sequences being decoded)num_q_heads,num_k_heads,num_v_heads,head_size: constant
q: query tensor [batch_size, 1, num_q_heads, head_size], bfloat16k: key tensor [batch_size, 1, num_k_heads, head_size], bfloat16v: value tensor [batch_size, 1, num_v_heads, head_size], bfloat16state: recurrent state [batch_size, num_sab_heads, head_size, head_size], float32A_log: log decay parameter [num_sab_heads], float32a: input-dependent decay [batch_size, 1, num_sab_heads], bfloat16dt_bias: decay bias [num_sab_heads], bfloat16b: update gate input [batch_size, 1, num_sab_heads], bfloat16scale: scale factor (scalar), float32 (default: 1.0 or 1/sqrt(head_size))use_qk_l2norm: whether to apply L2 normalization to q and k, bool (default: true)
output: attention output [batch_size, 1, num_o_heads, head_size], bfloat16new_state: updated recurrent state [batch_size, num_sab_heads, head_size, head_size], float32
num_sab_heads = max(num_q_heads, num_v_heads)(state and beta heads)num_o_heads = num_v_heads(output heads follow value heads)
- For GVA:
num_v_heads >= num_q_headsandnum_v_heads % num_q_heads == 0 num_k_heads == num_q_heads(keys and queries must have same number of heads)
- k-last: [B, H, V, K] - V dimension before K dimension, faster for decode
- k-first: [B, H, K, V] - K dimension before V dimension
- L2 normalization helps with numerical stability when head_size is large

