gdn_prefill_qk4_v8_d128_k_last

gdn
·Solutions (2)

Gated Delta Net prefill with GVA configuration and k-last state layout. The state is in k-last layout [N, H, V, K]. Captured from Qwen3 Next linear attention layers (TP=4).

stage:prefill
status:verified
model:qwen3-next
layout:k-last
fi_api:flashinfer.gdn.chunk_gated_delta_rule
tp:4

Axes

total_seq_len
var
num_seqs
var
num_q_heads
4
num_k_heads
4
num_v_heads
8
head_size
128
len_cu_seqlens
var

Signature

Inputs

NameTypeShape
q
bfloat16[total_seq_len, num_q_heads, head_size]
k
bfloat16[total_seq_len, num_k_heads, head_size]
v
bfloat16[total_seq_len, num_v_heads, head_size]
state
float32[num_seqs, num_v_heads, head_size, head_size]
A_log
float32[num_v_heads]
a
bfloat16[total_seq_len, num_v_heads]
dt_bias
float32[num_v_heads]
b
bfloat16[total_seq_len, num_v_heads]
cu_seqlens
int64[len_cu_seqlens]
scale
float32Scalar

Outputs

NameTypeShape
output
bfloat16[total_seq_len, num_v_heads, head_size]
new_state
float32[num_seqs, num_v_heads, head_size, head_size]

Constraints

  • len_cu_seqlens == num_seqs + 1
  • total_seq_len == cu_seqlens[-1].item()

Reference Implementation

Loading editor...
Loading solutions…