Definition

pydantic model flashinfer_bench.data.Definition

Complete definition of a computational workload.

A Definition provides a formal, machine-readable specification for a computational workload. It defines the tensor formats, dimension semantics, and computational logic through a reference implementation. This serves as the single source of truth for kernel development and optimization.

Show JSON schema
{
   "title": "Definition",
   "description": "Complete definition of a computational workload.\n\nA Definition provides a formal, machine-readable specification for a computational\nworkload. It defines the tensor formats, dimension semantics, and computational\nlogic through a reference implementation. This serves as the single source of\ntruth for kernel development and optimization.",
   "type": "object",
   "properties": {
      "name": {
         "description": "A unique, human-readable name for the kernel definition.",
         "minLength": 1,
         "title": "Name",
         "type": "string"
      },
      "op_type": {
         "description": "The general compute category (e.g., 'gemm', 'gqa_ragged', 'mla_paged', 'moe').",
         "minLength": 1,
         "title": "Op Type",
         "type": "string"
      },
      "axes": {
         "additionalProperties": {
            "anyOf": [
               {
                  "$ref": "#/$defs/AxisConst"
               },
               {
                  "$ref": "#/$defs/AxisVar"
               }
            ]
         },
         "description": "Dictionary of symbolic dimensions used in tensor shapes. The axes will be bound to the\ninput tensor dimensions at runtime.",
         "propertyNames": {
            "minLength": 1
         },
         "title": "Axes",
         "type": "object"
      },
      "inputs": {
         "additionalProperties": {
            "$ref": "#/$defs/TensorSpec"
         },
         "description": "Named input tensors required by this kernel.",
         "propertyNames": {
            "minLength": 1
         },
         "title": "Inputs",
         "type": "object"
      },
      "outputs": {
         "additionalProperties": {
            "$ref": "#/$defs/TensorSpec"
         },
         "description": "Named output tensors produced by this kernel.",
         "propertyNames": {
            "minLength": 1
         },
         "title": "Outputs",
         "type": "object"
      },
      "reference": {
         "description": "Reference implementation code. It defines the compute logic of the kernel. Must be a valid\nPython code with a 'run' function that takes the input tensors and returns the output tensors.",
         "minLength": 1,
         "title": "Reference",
         "type": "string"
      },
      "tags": {
         "anyOf": [
            {
               "items": {
                  "minLength": 1,
                  "type": "string"
               },
               "type": "array"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "Optional list of tags for grouping and filtering kernels. It's used in the FlashInfer-Bench\nwebsite.",
         "title": "Tags"
      },
      "description": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "Optional human-readable description of the kernel's purpose.",
         "title": "Description"
      },
      "constraints": {
         "anyOf": [
            {
               "items": {
                  "minLength": 1,
                  "type": "string"
               },
               "type": "array"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "Optional list of constraint expressions describing relationships between axes.",
         "title": "Constraints"
      }
   },
   "$defs": {
      "AxisConst": {
         "description": "Constant axis with a fixed value.\n\nA constant axis represents a dimension that has a fixed, compile-time known value.\nThis is useful for dimensions that don't vary across different instances of the\nsame kernel definition, such as embedding dimensions or hidden layer sizes.",
         "properties": {
            "type": {
               "const": "const",
               "default": "const",
               "description": "The type identifier for constant axes.",
               "title": "Type",
               "type": "string"
            },
            "value": {
               "description": "The constant integer value of this axis dimension.",
               "minimum": 0,
               "title": "Value",
               "type": "integer"
            },
            "description": {
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "An optional human-readable description explaining the purpose of this axis.",
               "title": "Description"
            }
         },
         "required": [
            "value"
         ],
         "title": "AxisConst",
         "type": "object"
      },
      "AxisVar": {
         "description": "Variable axis that can be specified at runtime.\n\nA variable axis represents a dimension whose value is determined at runtime\nbased on the actual input data. Its value will be bound to the input tensor\ndimension at runtime.",
         "properties": {
            "type": {
               "const": "var",
               "default": "var",
               "title": "Type",
               "type": "string"
            },
            "description": {
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "title": "Description"
            }
         },
         "title": "AxisVar",
         "type": "object"
      },
      "DType": {
         "description": "Supported data types for tensors.\n\nEnumeration of all data types that can be used in tensor specifications.\nIncludes both floating-point and integer types commonly used in machine\nlearning and high-performance computing applications.",
         "enum": [
            "float32",
            "float16",
            "bfloat16",
            "float8_e4m3fn",
            "float8_e5m2",
            "float4_e2m1",
            "int64",
            "int32",
            "int16",
            "int8",
            "bool"
         ],
         "title": "DType",
         "type": "string"
      },
      "TensorSpec": {
         "description": "Specification for a tensor including shape and data type, to use as input or output of a\nkernel.\n\nThis includes the symbolic shape (referencing defined axes) and the data type.\nScalars are represented with a None shape.",
         "properties": {
            "shape": {
               "anyOf": [
                  {
                     "items": {
                        "minLength": 1,
                        "type": "string"
                     },
                     "type": "array"
                  },
                  {
                     "type": "null"
                  }
               ],
               "description": "List of axis names defining the tensor shape. None for scalar values.",
               "title": "Shape"
            },
            "dtype": {
               "$ref": "#/$defs/DType",
               "description": "The data type of all elements in this tensor."
            },
            "description": {
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "An optional human-readable description of this tensor's purpose and usage.",
               "title": "Description"
            }
         },
         "required": [
            "shape",
            "dtype"
         ],
         "title": "TensorSpec",
         "type": "object"
      }
   },
   "required": [
      "name",
      "op_type",
      "axes",
      "inputs",
      "outputs",
      "reference"
   ]
}

Fields:
  • name (str)

  • op_type (str)

  • axes (Dict[str, flashinfer_bench.data.definition.AxisConst | flashinfer_bench.data.definition.AxisVar])

  • inputs (Dict[str, flashinfer_bench.data.definition.TensorSpec])

  • outputs (Dict[str, flashinfer_bench.data.definition.TensorSpec])

  • reference (str)

  • tags (List[str] | None)

  • description (str | None)

  • constraints (List[str] | None)

field name: Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])] [Required]

A unique, human-readable name for the kernel definition.

Constraints:
  • min_length = 1

field op_type: Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])] [Required]

The general compute category (e.g., ‘gemm’, ‘gqa_ragged’, ‘mla_paged’, ‘moe’).

Constraints:
  • min_length = 1

field axes: Dict[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])], AxisConst | AxisVar] [Required]

Dictionary of symbolic dimensions used in tensor shapes. The axes will be bound to the input tensor dimensions at runtime.

Dictionary of symbolic dimensions used in tensor shapes. The axes will be bound to the input tensor dimensions at runtime.

field inputs: Dict[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])], TensorSpec] [Required]

Named input tensors required by this kernel.

field outputs: Dict[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])], TensorSpec] [Required]

Named output tensors produced by this kernel.

field reference: Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])] [Required]

Reference implementation code. It defines the compute logic of the kernel. Must be a valid Python code with a ‘run’ function that takes the input tensors and returns the output tensors.

Reference implementation code. It defines the compute logic of the kernel. Must be a valid Python code with a ‘run’ function that takes the input tensors and returns the output tensors.

Constraints:
  • min_length = 1

field tags: List[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])]] | None = None

Optional list of tags for grouping and filtering kernels. It’s used in the FlashInfer-Bench website.

Optional list of tags for grouping and filtering kernels. It’s used in the FlashInfer-Bench website.

field description: str | None = None

Optional human-readable description of the kernel’s purpose.

field constraints: List[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])]] | None = None

Optional list of constraint expressions describing relationships between axes.

get_const_axes() Dict[str, int]

Get all constant axes and their values.

Returns:

Dictionary mapping constant axis names to their fixed values.

Return type:

Dict[str, int]

get_var_axes() List[str]

Get all variable axis names.

Returns:

List of all variable axis names defined in this Definition.

Return type:

List[str]

property get_var_axes_bindings: Dict[str, Tuple[str, int]]

Get the bindings of variable axes to input tensor dimensions.

Determines which input tensor and dimension index corresponds to each variable axis. If multiple input tensors share the same axis, the binding will be to the first tensor encountered.

Returns:

Dictionary mapping axis names to tuples of (input_tensor_name, dimension_index). Only includes variable axes that appear in input tensor shapes.

Return type:

Dict[str, Tuple[str, int]]

get_input_shapes(var_values: Dict[str, int] | None = None) Dict[str, List[int]]

Get concrete input shapes given variable axis values.

Parameters:

var_values (Optional[Dict[str, int]], default=None) – Values for variable axes. If None, defaults to empty dictionary.

Returns:

Dictionary mapping input tensor names to their concrete shapes.

Return type:

Dict[str, List[int]]

Raises:

ValueError – If a required variable axis value is missing from var_values.

get_output_shapes(var_values: Dict[str, int] | None = None) Dict[str, List[int]]

Get concrete output shapes given variable axis values.

Parameters:

var_values (Optional[Dict[str, int]], default=None) – Values for variable axes. If None, defaults to empty dictionary.

Returns:

Dictionary mapping output tensor names to their concrete shapes.

Return type:

Dict[str, List[int]]

Raises:

ValueError – If a required variable axis value is missing from var_values.

pydantic model flashinfer_bench.data.AxisConst

Constant axis with a fixed value.

A constant axis represents a dimension that has a fixed, compile-time known value. This is useful for dimensions that don’t vary across different instances of the same kernel definition, such as embedding dimensions or hidden layer sizes.

Show JSON schema
{
   "title": "AxisConst",
   "description": "Constant axis with a fixed value.\n\nA constant axis represents a dimension that has a fixed, compile-time known value.\nThis is useful for dimensions that don't vary across different instances of the\nsame kernel definition, such as embedding dimensions or hidden layer sizes.",
   "type": "object",
   "properties": {
      "type": {
         "const": "const",
         "default": "const",
         "description": "The type identifier for constant axes.",
         "title": "Type",
         "type": "string"
      },
      "value": {
         "description": "The constant integer value of this axis dimension.",
         "minimum": 0,
         "title": "Value",
         "type": "integer"
      },
      "description": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "An optional human-readable description explaining the purpose of this axis.",
         "title": "Description"
      }
   },
   "required": [
      "value"
   ]
}

Fields:
  • type (Literal['const'])

  • value (int)

  • description (str | None)

field type: Literal['const'] = 'const'

The type identifier for constant axes.

field value: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Ge(ge=0)])] [Required]

The constant integer value of this axis dimension.

Constraints:
  • ge = 0

field description: str | None = None

An optional human-readable description explaining the purpose of this axis.

pydantic model flashinfer_bench.data.AxisVar

Variable axis that can be specified at runtime.

A variable axis represents a dimension whose value is determined at runtime based on the actual input data. Its value will be bound to the input tensor dimension at runtime.

Show JSON schema
{
   "title": "AxisVar",
   "description": "Variable axis that can be specified at runtime.\n\nA variable axis represents a dimension whose value is determined at runtime\nbased on the actual input data. Its value will be bound to the input tensor\ndimension at runtime.",
   "type": "object",
   "properties": {
      "type": {
         "const": "var",
         "default": "var",
         "title": "Type",
         "type": "string"
      },
      "description": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "title": "Description"
      }
   }
}

Fields:
  • type (Literal['var'])

  • description (str | None)

field type: Literal['var'] = 'var'

The type identifier for variable axes.

field description: str | None = None

An optional human-readable description explaining the purpose of this axis.

pydantic model flashinfer_bench.data.TensorSpec

Specification for a tensor including shape and data type, to use as input or output of a kernel.

This includes the symbolic shape (referencing defined axes) and the data type. Scalars are represented with a None shape.

Show JSON schema
{
   "title": "TensorSpec",
   "description": "Specification for a tensor including shape and data type, to use as input or output of a\nkernel.\n\nThis includes the symbolic shape (referencing defined axes) and the data type.\nScalars are represented with a None shape.",
   "type": "object",
   "properties": {
      "shape": {
         "anyOf": [
            {
               "items": {
                  "minLength": 1,
                  "type": "string"
               },
               "type": "array"
            },
            {
               "type": "null"
            }
         ],
         "description": "List of axis names defining the tensor shape. None for scalar values.",
         "title": "Shape"
      },
      "dtype": {
         "$ref": "#/$defs/DType",
         "description": "The data type of all elements in this tensor."
      },
      "description": {
         "anyOf": [
            {
               "type": "string"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "An optional human-readable description of this tensor's purpose and usage.",
         "title": "Description"
      }
   },
   "$defs": {
      "DType": {
         "description": "Supported data types for tensors.\n\nEnumeration of all data types that can be used in tensor specifications.\nIncludes both floating-point and integer types commonly used in machine\nlearning and high-performance computing applications.",
         "enum": [
            "float32",
            "float16",
            "bfloat16",
            "float8_e4m3fn",
            "float8_e5m2",
            "float4_e2m1",
            "int64",
            "int32",
            "int16",
            "int8",
            "bool"
         ],
         "title": "DType",
         "type": "string"
      }
   },
   "required": [
      "shape",
      "dtype"
   ]
}

Fields:
  • shape (List[str] | None)

  • dtype (flashinfer_bench.data.definition.DType)

  • description (str | None)

field shape: List[Annotated[str, FieldInfo(annotation=NoneType, required=True, metadata=[MinLen(min_length=1)])]] | None [Required]

List of axis names defining the tensor shape. None for scalar values.

field dtype: DType [Required]

The data type of all elements in this tensor.

field description: str | None = None

An optional human-readable description of this tensor’s purpose and usage.

class flashinfer_bench.data.definition.DType

Supported data types for tensors.

Enumeration of all data types that can be used in tensor specifications. Includes both floating-point and integer types commonly used in machine learning and high-performance computing applications.

FLOAT32 = 'float32'

32-bit IEEE 754 floating point.

FLOAT16 = 'float16'

16-bit IEEE 754 half-precision floating point.

BFLOAT16 = 'bfloat16'

16-bit Brain Floating Point format.

FLOAT8_E4M3FN = 'float8_e4m3fn'

8-bit floating point with 4 exponent bits and 3 mantissa bits.

FLOAT8_E5M2 = 'float8_e5m2'

8-bit floating point with 5 exponent bits and 2 mantissa bits.

FLOAT4_E2M1 = 'float4_e2m1'

4-bit floating point with 2 exponent bits and 1 mantissa bit.

INT64 = 'int64'

64-bit signed integer.

INT32 = 'int32'

32-bit signed integer.

INT16 = 'int16'

16-bit signed integer.

INT8 = 'int8'

8-bit signed integer.

BOOL = 'bool'

Boolean type.

__new__(value)