flashinfer_bench.apply

flashinfer_bench.apply provides a tool that meets two needs:

  1. Apply best-performing one from FlashInfer Trace database to the LLM engine

  2. Trace the kernel in the LLM engine and dump its input as FlashInfer Trace’s workload format

flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str]) Callable[[Callable[[...], Any]], Callable[[...], Any]]
flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str], *, runtime_kwargs: Dict[str, Any], fallback: Callable[[...], Any] | None) Any

Decorator/function for routing to the best-performing kernel recorded in the FlashInfer Trace database.

This API can be used in two modes:

  1. Decorator mode (only def_name_or_resolver provided): returns a decorator that wraps a kernel function with a router. The router selects the best-performing candidate according to the function’s runtime arguments.

  2. Function mode (runtime_kwargs provided, optionally fallback): immediately resolves and calls the best-performing kernel and returns its result.

Parameters:
  • def_name_or_resolver (Union[str, Callable[..., str]]) – The kernel name, or a resolver fn(*args, **kwargs) -> str that maps runtime arguments to a kernel name (definition name).

  • runtime_kwargs (Dict[str, Any], optional) – Only used in function mode. The runtime arguments to feed into the selected kernel. Use this to call the kernel immediately instead of returning a decorator.

  • fallback (Optional[Callable[..., Any]], optional) – Only used in function mode. A fallback function to invoke when no matching kernel is found in the Trace database.

Returns:

  • Decorator mode: a decorator that transforms the target kernel function into a routed version.

  • Function mode: the return value produced by the selected (or fallback) kernel.

Return type:

Union[Callable[[Callable[…, Any]], Callable[…, Any]], Any]

Examples

Decorator mode with a fixed name

>>> @apply("gemm_bf16")
... def gemm_bf16(A, B):
...     return torch.nn.functional.linear(A, B)

Decorator mode with a resolver

>>> @apply(lambda A, B: f"gemm_n{B.shape[0]}_k{B.shape[1]}")
... def gemm_bf16(A, B):
...     return torch.nn.functional.linear(A, B)

Function mode

>>> out = apply(
...     "gemm_bf16",
...     runtime_kwargs={"A": A, "B": B, "bias": None},
...     fallback=lambda **kw: torch.nn.functional.linear(**kw),
... )
flashinfer_bench.enable_apply(dataset_path: str | None = None, apply_config: ApplyConfig | None = None) ApplyRuntime

Enable apply functionality globally and return a ApplyRuntime instance that manages the apply functionality.

There is only one global ApplyRuntime instance. This function must be called in the main thread.

Parameters:
  • dataset_path (str, optional) – Path to the dataset/traceset directory

  • apply_config (ApplyConfig, optional) – Configuration for the apply runtime

Returns:

The global ApplyRuntime instance managing the apply functionality.

Return type:

ApplyRuntime

Examples

>>> # Direct usage
>>> enable_apply("/path/to/traceset", cfg)
>>> # Apply is now enabled
>>> out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn)
>>> disable_apply()
>>> # Apply is now disabled.
>>> # Context manager usage
>>> with enable_apply("/path/to/traceset", cfg):
...     out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn)
>>> # Apply is now disabled.
flashinfer_bench.disable_apply() None

Disable global apply functionality.

This function silently disables the global apply runtime by setting it to None. After calling this function, any subsequent calls to apply() will use fallback functions instead of the apply runtime.

Check out the enable_apply function for examples.

Return type:

None