Solution
to a specific Definition
, fixes the exact workload
(input shapes + input data), and stores the complete evaluation
. A folder of Definitions, Solutions, and Traces is your benchmark database.
Trace Schema (top level)
Field | Type | Required | Description |
---|---|---|---|
definition | string | Yes | The name of the Definition used in this run. |
solution | string | Yes | The name of the Solution tested. |
workload | object | Yes | Concrete shapes and input data used for this run. |
evaluation | object | Yes | Results, logs, and environment snapshot. |
Component 1: definition
What it is: The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) reference
.
Identity rule: Two kernels are under the same Definition iff:
- They have the same axes,
- Each axis has the same role (
const
vsvar
), - All
const
axes have the same values.
- Refer to schema, choose a
name
(<type>_<stage>_<axis tokens>
) andtype
; write a cleardescription
and helpfultags
. - Specify
axes
withtype: const|var
(+value
for const). - Add
constraints
that relate axes to inputs (e.g., CSR shapes). - Specify
inputs
/outputs
(names, shapes by axes, dtypes, optional layouts). - Provide a correct Python
reference
returning a tuple of outputs. - (Optional) Provide minimal tests that run the
reference
on tiny shapes.
Component 2: solution
What it is: A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM).
Interface: Your function must take the Definition’s inputs
and return the tuple of outputs
.
How to add a Solution:
- Add the implementation of the kernel (matching signature).
- Provide metadata co-located with the code, according to schema.
- Add unit tests vs
reference
across representative shapes.
Component 3: workload
What it is: The concrete axes + input data that instantiate a Definition for one run.
Field | Description |
---|---|
axes | Map of var axis → concrete int value. |
inputs | Map of input name → actual input. |
Env-vars (zero-code)
- Choose an output dataset root (optional):
- Enable tracing and run your engine or script:
- What gets saved & where (default layout):
Tracing in code (fine-grained control)
If you want to target a subset of kernels / customize policies:input_dump_policy
:"dump_all"
,"dump_none"
,"dump_int32"
, or a list of input names to dump, likeinput_dump_policy=["qo_indptr", "kv_indptr", "kv_indices", "sm_scale"]
.filter_policy
:"keep_all"
,"keep_first"
(e.g., first k calls),"keep_first_by_axes"
,"keep_none"
, or a custom callableWorkload -> key
. These reduce disk/time while keeping representative samples.
Component 4: evaluation
What it is: The result bundle for one (definition, solution, workload)
run.
How to benchmark to produce Evaluations:
Run the benchmarker over your (definition, solution, workload)
triples in the dataset:
Using CLI:
-
Device pool: One
MultiProcessRunner
is created per CUDA device. -
Concurrency: For each definition and workload, the benchmark:
-
Picks up to
K = min(#devices, #solutions)
runners (round-robin). -
Reference phase: in parallel, calls
runner.run_ref(defn, wl, config)
to build a baseline on each selected runner.- If a runner fails during reference, it is removed from the pool and the workload on that runner is skipped.
-
Solutions phase: distributes solutions round-robin across the runners that succeeded in the reference phase, calling
runner.run_solution(sol, baseline_handle, config)
in parallel.
-
Picks up to
-
Status mapping:
- Successful run with numerics in tolerance →
PASSED
. - Output shape/dtype mismatch →
INCORRECT_SHAPE
/INCORRECT_DTYPE
. - Numeric check fails →
INCORRECT_NUMERICAL
. - Runtime fault →
RUNTIME_ERROR
. - Build/compile fails →
COMPILE_ERROR
.
- Successful run with numerics in tolerance →
Evaluation
; the benchmark immediately stages a Trace(def_name, workload, sol_name, evaluation)
in memory.
After benchmarking is done, the results can be used to rank solutions, visualize leaderboards, and drive apply
at runtime.
Reproducibility
BenchmarkConfig
controls iteration counts, warmup, tolerances, and timeouts (use your project’s defaults or tune per kernel).- Environment snapshot: runners capture hardware and library versions into
evaluation.environment
. - Dead runner handling: any runner failing the reference is dropped for subsequent work; if all runners fail, a
RuntimeError
is raised.
Putting it together: Trace lifecycle
-
Add the Definition
- Finalize axes (
const
vsvar
), constraints, I/O shapes, andreference
. - Identity is locked by the axes set/roles/const values.
- Finalize axes (
-
Add one or more Solutions
- Implement the exact interface; return
{output_name: tensor}
. - Provide metadata and unit tests vs
reference
.
- Implement the exact interface; return
-
Capture Workloads
- Run with tracing (env-vars or code) over real requests to collect shapes and, when helpful, actual inputs (esp. ragged index tensors).
- Curate a small but representative set (use
shape_only
orkeep_first_k
).
-
Benchmark → Emit Traces
- For each
(definition, solution, workload)
triple, run the benchmarker to produce one Trace JSON withevaluation
. - Store logs and the environment snapshot alongside.
- For each
-
Apply at runtime (end-to-end)
- Use runtime substitution to dispatch to the best ranked Solution for the current shapes.
End-to-end “apply”
Withapply
, we can dynamically replace the kernels in the FlashInfer API with the best-performing ones from our traces. With adapters already written for FlashInfer, you can enable integration with minimal code changes.
apply
looks up the Definition, matches the current workload (axes and input data properties), and dispatches to the best Solution according to our Traces (with correctness constraints and numeric tolerances enforced).
Supporting kernels that don’t align with the Definition with adapters
Sometimes your production call site can’t be decorated directly—e.g., wrappers that keep internal state acrossplan()
/run()
like BatchPrefillWithPagedKVCacheWrapper
. FlashInfer-Bench provides built-in adapters for common FlashInfer kernels, and you can also use the imperative apply()
API for custom integration patterns.
Built-in FlashInfer Integration (Recommended)
FlashInfer-Bench automatically patches common FlashInfer kernels when you enable apply. No manual decoration needed: How it works: When you callenable_apply()
, FlashInfer-Bench automatically installs lightweight adapters that:
- Intercept FlashInfer wrapper methods (
plan
andrun
) - Extract runtime parameters and match them to definitions
- Dispatch to the best-performing solution from your traces
- Fall back to the original FlashInfer implementation if no suitable solution exists
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper
(page_size=1)flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper
(causal=True, page_size=1)flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper
(causal=True)flashinfer.norm.fused_add_rmsnorm
flashinfer_bench/integration/flashinfer/adapters/
for the complete list and implementation details.
Imperative apply()
API (Custom Integration)
For custom kernels or integration patterns not covered by the built-in adapters, use the function form of apply
:
def_name_or_resolver
: The kernel definition name (e.g.,"gemm_bf16"
) or a resolver function that maps runtime arguments to a definition name.runtime_kwargs
: Dictionary of keyword arguments to pass to the selected kernel. Must match the kernel definition’s interface.fallback
: Optional fallback function to invoke when no matching kernel is found in the Trace database.
Example: Creating custom adapters (advanced)
If you want to create reusable adapters similar to the built-in FlashInfer integrations, study the real implementations:flashinfer_bench/integration/flashinfer/adapters/gqa_paged_decode.py
flashinfer_bench/integration/flashinfer/adapters/rmsnorm.py
- Use
ContextStore
to preserve state acrossplan()
/run()
calls - Extract parameters in the
plan
wrapper and store them in context - In the
run
wrapper, retrieve stored params and callapply()
withruntime_kwargs
- Provide a fallback lambda that calls the original implementation
- Register your adapter with the
PatchManager