Skip to main content
This guide gives instructions on how to add Definitions, Solutions, capture Workloads, and record Evaluations by walking through each component of the Trace, with an end-to-end “apply at runtime” flow. A Trace is an atomic, immutable record of a single benchmark run. It links a specific Solution to a specific Definition, fixes the exact workload (input shapes + input data), and stores the complete evaluation. A folder of Definitions, Solutions, and Traces is your benchmark database.

Trace Schema (top level)

FieldTypeRequiredDescription
definitionstringYesThe name of the Definition used in this run.
solutionstringYesThe name of the Solution tested.
workloadobjectYesConcrete shapes and input data used for this run.
evaluationobjectYesResults, logs, and environment snapshot.
More details about schema are in FlashInfer Trace Schema.

Component 1: definition

What it is: The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) reference. Identity rule: Two kernels are under the same Definition iff:
  • They have the same axes,
  • Each axis has the same role (const vs var),
  • All const axes have the same values.
How to add a new kernel Definition:
  1. Refer to schema, choose a name (<type>_<stage>_<axis tokens>) and type; write a clear description and helpful tags.
  2. Specify axes with type: const|var (+ value for const).
  3. Add constraints that relate axes to inputs (e.g., CSR shapes).
  4. Specify inputs/outputs (names, shapes by axes, dtypes, optional layouts).
  5. Provide a correct Python reference returning a tuple of outputs.
  6. (Optional) Provide minimal tests that run the reference on tiny shapes.

Component 2: solution

What it is: A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM). Interface: Your function must take the Definition’s inputs and return the tuple of outputs. How to add a Solution:
  1. Add the implementation of the kernel (matching signature).
  2. Provide metadata co-located with the code, according to schema.
  3. Add unit tests vs reference across representative shapes.

Component 3: workload

What it is: The concrete axes + input data that instantiate a Definition for one run.
FieldDescription
axesMap of var axis → concrete int value.
inputsMap of input name → actual input.
How to capture workloads:

Env-vars (zero-code)

  1. Choose an output dataset root (optional):
export FIB_DATASET_PATH=/root/flashinfer-trace
# defaults to `~/.cache/flashinfer_bench/dataset` if unset
  1. Enable tracing and run your engine or script:
export FIB_ENABLE_TRACING=1
python run_engine.py  # your serving or batch script
By default, all kernels specified with its tracing config with a matching Definition are traced.
  1. What gets saved & where (default layout):
$FIB_DATASET_PATH/
├── workloads/
│   └── <op_type>/
│       └── <definition_name>.jsonl   # workload records (FlashInfer Trace format)
└── blob/
    └── workloads/            # tensor payloads (safetensors, when dumped)
Writing tensors to file is async (background thread) to reduce runtime overhead.

Tracing in code (fine-grained control)

If you want to target a subset of kernels / customize policies:
import flashinfer_bench as fib

# 1) Pick which kernels to trace and how
from flashinfer_bench import TracingConfig

gqa_paged_prefill_config = TracingConfig(
    input_dump_policy="dump_non_float",   # keep scalar and int tensors; skip large float payloads
    filter_policy="shape_only",             # save first occurrence per input-shape signature
)

configs = {
    "gqa_paged_prefill_causal_h32_kv4_d128_ps1": gqa_paged_prefill_config,
    # more tracing config mappings...
}

# 2) Enable, run, then finalize
with fib.enable_tracing(dataset_path="/root/flashinfer-trace", tracing_configs=configs):
    run_engine()  # your inference loop
Policies you can use right away:
  • input_dump_policy: "dump_all", "dump_none", "dump_int32", or a list of input names to dump, like input_dump_policy=["qo_indptr", "kv_indptr", "kv_indices", "sm_scale"].
  • filter_policy: "keep_all", "keep_first" (e.g., first k calls), "keep_first_by_axes", "keep_none", or a custom callable Workload -> key. These reduce disk/time while keeping representative samples.

Component 4: evaluation

What it is: The result bundle for one (definition, solution, workload) run. How to benchmark to produce Evaluations: Run the benchmarker over your (definition, solution, workload) triples in the dataset: Using CLI:
flashinfer-bench run --local /path/to/flashinfer-trace
Using Python API:
from flashinfer_bench.data import TraceSet
from flashinfer_bench.bench import Benchmark

# 1) Build TraceSet (definitions, solutions, workloads)
trace_set = TraceSet(root="./flashinfer-trace")  # scans for definitions, solutions, workloads

# 2) Run the benchmark
benchmark = Benchmark(trace_set, config)
benchmark.run_all(save_results=True)
  • Device pool: One MultiProcessRunner is created per CUDA device.
  • Concurrency: For each definition and workload, the benchmark:
    • Picks up to K = min(#devices, #solutions) runners (round-robin).
    • Reference phase: in parallel, calls runner.run_ref(defn, wl, config) to build a baseline on each selected runner.
      • If a runner fails during reference, it is removed from the pool and the workload on that runner is skipped.
    • Solutions phase: distributes solutions round-robin across the runners that succeeded in the reference phase, calling runner.run_solution(sol, baseline_handle, config) in parallel.
  • Status mapping:
    • Successful run with numerics in tolerance → PASSED.
    • Output shape/dtype mismatch → INCORRECT_SHAPE / INCORRECT_DTYPE.
    • Numeric check fails → INCORRECT_NUMERICAL.
    • Runtime fault → RUNTIME_ERROR.
    • Build/compile fails → COMPILE_ERROR.
Each solution run returns an Evaluation; the benchmark immediately stages a Trace(def_name, workload, sol_name, evaluation) in memory. After benchmarking is done, the results can be used to rank solutions, visualize leaderboards, and drive apply at runtime.

Reproducibility

  • BenchmarkConfig controls iteration counts, warmup, tolerances, and timeouts (use your project’s defaults or tune per kernel).
  • Environment snapshot: runners capture hardware and library versions into evaluation.environment.
  • Dead runner handling: any runner failing the reference is dropped for subsequent work; if all runners fail, a RuntimeError is raised.

Putting it together: Trace lifecycle

  1. Add the Definition
    • Finalize axes (const vs var), constraints, I/O shapes, and reference.
    • Identity is locked by the axes set/roles/const values.
  2. Add one or more Solutions
    • Implement the exact interface; return {output_name: tensor}.
    • Provide metadata and unit tests vs reference.
  3. Capture Workloads
    • Run with tracing (env-vars or code) over real requests to collect shapes and, when helpful, actual inputs (esp. ragged index tensors).
    • Curate a small but representative set (use shape_only or keep_first_k).
  4. Benchmark → Emit Traces
    • For each (definition, solution, workload) triple, run the benchmarker to produce one Trace JSON with evaluation.
    • Store logs and the environment snapshot alongside.
  5. Apply at runtime (end-to-end)
    • Use runtime substitution to dispatch to the best ranked Solution for the current shapes.

End-to-end “apply”

With apply, we can dynamically replace the kernels in the FlashInfer API with the best-performing ones from our traces. With adapters already written for FlashInfer, you can enable integration with minimal code changes.
export FIB_ENABLE_APPLY=1
export FIB_DATASET_PATH=/path/to/flashinfer-trace
python serve_or_benchmark.py
At call time, apply looks up the Definition, matches the current workload (axes and input data properties), and dispatches to the best Solution according to our Traces (with correctness constraints and numeric tolerances enforced).

Supporting kernels that don’t align with the Definition with adapters

Sometimes your production call site can’t be decorated directly—e.g., wrappers that keep internal state across plan()/run() like BatchPrefillWithPagedKVCacheWrapper. FlashInfer-Bench provides built-in adapters for common FlashInfer kernels, and you can also use the imperative apply() API for custom integration patterns. FlashInfer-Bench automatically patches common FlashInfer kernels when you enable apply. No manual decoration needed: How it works: When you call enable_apply(), FlashInfer-Bench automatically installs lightweight adapters that:
  1. Intercept FlashInfer wrapper methods (plan and run)
  2. Extract runtime parameters and match them to definitions
  3. Dispatch to the best-performing solution from your traces
  4. Fall back to the original FlashInfer implementation if no suitable solution exists
Supported kernels:
  • flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper (page_size=1)
  • flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper (causal=True, page_size=1)
  • flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper (causal=True)
  • flashinfer.norm.fused_add_rmsnorm
See flashinfer_bench/integration/flashinfer/adapters/ for the complete list and implementation details.

Imperative apply() API (Custom Integration)

For custom kernels or integration patterns not covered by the built-in adapters, use the function form of apply:
from flashinfer_bench import apply

result = apply(
    def_name_or_resolver: Union[str, Callable[..., str]],
    runtime_kwargs: Dict[str, Any],              # All arguments must follow the **kernel definition's interface
    fallback: Optional[Callable[..., Any]] = None,
)
Parameters:
  • def_name_or_resolver: The kernel definition name (e.g., "gemm_bf16") or a resolver function that maps runtime arguments to a definition name.
  • runtime_kwargs: Dictionary of keyword arguments to pass to the selected kernel. Must match the kernel definition’s interface.
  • fallback: Optional fallback function to invoke when no matching kernel is found in the Trace database.

Example: Creating custom adapters (advanced)

If you want to create reusable adapters similar to the built-in FlashInfer integrations, study the real implementations:
  • flashinfer_bench/integration/flashinfer/adapters/gqa_paged_decode.py
  • flashinfer_bench/integration/flashinfer/adapters/rmsnorm.py
Key pattern:
  1. Use ContextStore to preserve state across plan()/run() calls
  2. Extract parameters in the plan wrapper and store them in context
  3. In the run wrapper, retrieve stored params and call apply() with runtime_kwargs
  4. Provide a fallback lambda that calls the original implementation
  5. Register your adapter with the PatchManager
Example structure from the RMSNorm adapter:
from flashinfer_bench.apply import apply
from flashinfer_bench.integration.patch_manager import PatchSpec
from flashinfer_bench.integration.utils import ArgBinder

def _def_name_resolver(weight):
    return f"fused_add_rmsnorm_h{weight.shape[0]}"

class RMSNormAdapter:
    def targets(self):
        return [
            PatchSpec(
                path="flashinfer.norm.fused_add_rmsnorm",
                kind="function",
                name="fused_add_rmsnorm",
                ctx_key="rmsnorm",
            )
        ]
    
    def make_wrapper(self, spec, orig):
        binder = ArgBinder.from_callable(orig)
        
        def wrapper(*args, **kwargs):
            bound = binder.bind(args, kwargs)
            
            # Compatibility checks
            if bound["input"].dtype != torch.bfloat16:
                return orig(*args, **kwargs)
            
            rk = {
                "hidden_states": bound["input"],
                "residual": bound["residual"],
                "weight": bound["weight"],
            }
            
            return apply(_def_name_resolver, runtime_kwargs=rk, fallback=lambda **_: orig(*args, **kwargs))
        
        return wrapper
I