flashinfer_bench.apply¶
flashinfer_bench.apply provides a tool that meets two needs:
Apply best-performing one from FlashInfer Trace database to the LLM engine
Trace the kernel in the LLM engine and dump its input as FlashInfer Trace’s workload format
- flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str]) Callable[[Callable[[...], Any]], Callable[[...], Any]]¶
- flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str], *, args: Tuple[Any, ...], kwargs: Dict[str, Any] | None = None, fallback: Callable[[...], Any] | None = None) Any
Decorator/function for routing to the best-performing kernel recorded in the FlashInfer Trace database.
This API can be used in two modes:
Decorator mode (only
def_name_or_resolverprovided): returns a decorator that wraps a kernel function with a router. The router selects the best-performing candidate according to the function’s runtime arguments.Function mode (
argsorkwargsprovided, optionallyfallback): immediately resolves and calls the best-performing kernel and returns its result.
The calling convention (value-returning vs destination-passing) is determined by the number of arguments: - If len(args) == len(inputs): value-returning style, solution returns outputs - If len(args) == len(inputs) + len(outputs): destination-passing style, outputs are pre-allocated and passed as arguments
- Parameters:
def_name_or_resolver (Union[str, Callable[..., str]]) – The kernel name, or a resolver
fn(*args, **kwargs) -> strthat maps runtime arguments to a kernel name (definition name).args (Tuple[Any, ...], optional) – Only used in function mode. The positional runtime arguments to feed into the selected kernel. The number of arguments determines the calling convention.
kwargs (Dict[str, Any], optional) – Only used in function mode. The keyword runtime arguments to feed into the selected kernel. The number of arguments determines the calling convention.
fallback (Optional[Callable[..., Any]], optional) – Only used in function mode. A fallback function to invoke when no matching kernel is found in the Trace database.
- Returns:
Decorator mode: a decorator that transforms the target kernel function into a routed version.
Function mode: the return value produced by the selected (or fallback) kernel. For destination-passing style, returns None.
- Return type:
Union[Callable[[Callable[…, Any]], Callable[…, Any]], Any]
Examples
Decorator mode with a fixed name¶
>>> @apply("gemm_bf16") ... def gemm_bf16(A, B): ... return A @ B.T
Decorator mode with a resolver¶
>>> @apply(lambda A, B: f"gemm_n{B.shape[0]}_k{B.shape[1]}") ... def gemm_bf16(A, B): ... return A @ B.T
Function mode (value-returning)¶
>>> out = apply( ... "gemm_bf16", ... args=(A, B), ... fallback=lambda A, B: A @ B.T, ... )
Function mode (destination-passing)¶
>>> C = torch.empty(M, N, device=A.device, dtype=A.dtype) >>> apply( ... "gemm_bf16", ... args=(A, B, C), # C is pre-allocated output ... fallback=lambda *args: my_gemm_dps(*args), ... )
Function mode with kwargs¶
>>> out = apply( ... "gemm_bf16", ... kwargs={"A": A, "B": B}, ... fallback=lambda A, B: A @ B.T, ... )
- flashinfer_bench.enable_apply(dataset_path: str | None = None, apply_config: ApplyConfig | ApplyConfigRegistry | None = None) ApplyRuntime¶
Enable apply functionality globally and return a ApplyRuntime instance that manages the apply functionality.
The apply runtime is process-level and supports nesting. This function is recommended to be called in the main thread.
- Parameters:
dataset_path (str, optional) – Path to the dataset/trace_set directory
apply_config (Union[ApplyConfig, ApplyConfigRegistry], optional) – Configuration for the apply runtime. Can be: - ApplyConfig: A single config used as the default for all definitions - ApplyConfigRegistry: A registry with per-definition configs If None, uses default ApplyConfigRegistry.
- Returns:
The newly created ApplyRuntime instance that has been pushed onto the global stack.
- Return type:
ApplyRuntime
Examples
>>> # Direct usage with single config >>> enable_apply("/path/to/trace_set", ApplyConfig(max_atol=1e-3)) >>> out = apply("rmsnorm_d4096", args=(...), kwargs={...}, fallback=ref_fn) >>> disable_apply()
>>> # Usage with per-definition configs >>> registry = get_default_registry() >>> registry.register("mla_paged", ApplyConfig(max_atol=1e-3, on_miss_policy="use_def_best")) >>> registry.register("gemm_bf16", ApplyConfig(aot_ratio=0.8)) >>> enable_apply("/path/to/trace_set", registry)
>>> # Context manager usage >>> with enable_apply("/path/to/trace_set", cfg): ... out = apply("rmsnorm_d4096", args=(...), kwargs={...}, fallback=ref_fn) >>> # Apply is now disabled.
- flashinfer_bench.disable_apply() None¶
Disable current apply runtime and restore the previous one (if any).
Pops the top runtime from the global stack and restores the previous runtime (if any) as the active instance. Safe to call even if no apply runtime is active.
Check out the enable_apply function for examples.
- Return type:
None