FP8 block scale MoE operation. Routing and two grouped-GEMM included.
Axes
seq_len
varnum_experts
256num_local_experts
32hidden_size
7168intermediate_size
2048gemm1_out_size
4096num_hidden_blocks
56num_intermediate_blocks
16num_gemm1_out_blocks
32Signature
Inputs
| Name | Type | Shape |
|---|---|---|
routing_logits | float32 | [seq_len, num_experts] |
routing_bias | bfloat16 | [num_experts] |
hidden_states | float8_e4m3fn | [seq_len, hidden_size] |
hidden_states_scale | float32 | [num_hidden_blocks, seq_len] |
gemm1_weights | float8_e4m3fn | [num_local_experts, gemm1_out_size, hidden_size] |
gemm1_weights_scale | float32 | [num_local_experts, num_gemm1_out_blocks, num_hidden_blocks] |
gemm2_weights | float8_e4m3fn | [num_local_experts, hidden_size, intermediate_size] |
gemm2_weights_scale | float32 | [num_local_experts, num_hidden_blocks, num_intermediate_blocks] |
local_expert_offset | int32 | Scalar |
routed_scaling_factor | float32 | Scalar |
Outputs
| Name | Type | Shape |
|---|---|---|
output | bfloat16 | [seq_len, hidden_size] |
Reference Implementation
Loading editor...
Loading solutions…
