Overview
This document describes the JSON schema for a kernel Definition. TheDefinition
provides a formal, machine-readable specification for a computational workload found in a model’s forward pass. It is designed to be the single source of truth that guides both human and agent-based kernel development. Specifically, this schema defines:
- Tensor Formats: The shape, data type (
dtype
). - Dimension Semantics: The distinction between
constant
dimensions (fixed at compile time) andvariable
dimensions (determined at runtime). - Computational Logic: A clear, step-by-step reference implementation in plain PyTorch, which serves as the official mathematical specification of the kernel.
Definition
does not contain specific input data for its variable axes. That data is provided by the workload
field of each Trace
, which is used for benchmarking Solution
s.
JSON Schema Description
Top-Level Object Structure
Field | Type | Required | Description |
---|---|---|---|
name | string | Yes | A unique, human-readable name for the kernel, should include concrete problem information. Naming convention: {op_type}_{props}_{constants} (e.g. gqa_paged_decode_h32_kv8_d128_ps1 ). |
op_type | string | Yes | The general compute category. |
tags | array[string] (Default: null ) | No | The string tags associated with this definition. Used for grouping and filtering. |
description | string (Default: null ) | No | A brief, human-readable description of the definition and its purpose. |
axes | Dict[string, Union[AxisConst, AxisVar]] | Yes | An object mapping symbolic dimension names (e.g., "M" , "N" , "K" ) to their definitions. The value is either a constant or a variable axis. The axes will be bound to the input tensor dimensions at runtime. |
inputs | Dict[string, TensorSpec] | Yes | Named input tensors (e.g.,"A" ,"B" ). |
outputs | Dict[string, TensorSpec] | Yes | Named output tensors (e.g.,"C" ). |
reference | string | Yes | The reference implementation in PyTorch, serving as the mathematical specification. |
constraints | array[string] (Default: null ) | No | An optional list of assertions describing relationships between axes. |
op_type
: Compute Category
op_type
is a string
field used for grouping and filtering kernels. It represents the genral compute characteristic.
Current supported op_type
s are:
- Attention:
gqa_ragged
,gqa_paged
,mla_ragged
,mla_paged
- GEMM:
gemm
- Misc:
rmsnorm
,fused_add_rmsnorm
tags
: Additional Attributes
tags
is an array of strings that attaches searchable attributes to a definition. Tags use namespaced keys to keep meanings clear and filterable.
Each tag is either:
- a namespaced key–value string:
"<namespace>:<value>"
, or - a flag without a value (e.g.,
"fused"
).
-
stage: *
— Which computation stage this definition fits to. Examples:stage: prefill
,stage: decode
. -
model:*
— Models known to use this definition (ideally system-derived from references/traces). Examples:model:llama-3.1-8b
,model:deepseek-v3
. -
quantization:*
— Indicates quantization characteristics. For the simple case, encode the effective dtype. Examples:quantization:float8_e4m3fn
,quantization:int8
. -
status:*
— Community/validation status. Examples:status:verified
,status:draft
,status:deprecated
. -
fused
— Flag tag indicating the definition represents a fused kernel.
axes
: Dimension Definitions
The axes
object contains any number of keys, where each key is a symbolic dimension name (e.g., "M"
, "N"
, "K"
), and the value is an object describing its type.
type
: const
Represents a constant dimension.
Field | Type | Required | Description |
---|---|---|---|
type | string | Yes | Must be "const" |
value | integer | Yes | Constant value of the axis |
description | string | No | Brief description. |
type
: var
Represents a variable axis whose value will be determined by the input data.
Field | Type | Required | Description | Default |
---|---|---|---|---|
type | string | Yes | Must be "var" | — |
description | string | No | Brief description |
inputs
, outputs
: Tensor Definitions
These fields describe the input and output tensors of the kernel. They contain any number of key-value pairs, where each key is the name of a tensor (e.g., "A"
, "B"
, "C"
). The value is a tensor description:
Field | Type | Required | Description |
---|---|---|---|
shape | array or null | Yes | List of axis names (strings). Represents a scalar if null . |
dtype | string | Yes | Data type of the tensor |
description | string | No | Brief description. |
dtype
: Data Types
The following values are allowed for dtype
:
float32
float16
bfloat16
float8_e4m3fn
float8_e5m2
float4_e2m1
int64
int32
int16
int8
bool
Scalar Values and 0-D Tensors
Specifically, a tensor with a shape[]
(empty array) represents a 0-D tensor.
To represent a scalar value, we use shape null
. The scalar input must receive a python scalar data (int, float, bool). The scalar output will return a python scalar value.
Example:
reference
: Reference Implementation
The reference
field is a string that contains the reference implementation of the kernel in plain PyTorch.
- It must contain a global function named
run
as the entry point. - This code defines the official mathematical specification of the kernel.
- It should avoid high-level packagings (e.g.,
torch.nn.functional
) in favor of explicit, step-by-step computations to ensure maximum clarity for all consumers (human or agent).