flashinfer_bench.apply¶
flashinfer_bench.apply provides a tool that meets two needs:
Apply best-performing one from FlashInfer Trace database to the LLM engine
Trace the kernel in the LLM engine and dump its input as FlashInfer Trace’s workload format
- flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str]) Callable[[Callable[[...], Any]], Callable[[...], Any]]¶
- flashinfer_bench.apply(def_name_or_resolver: str | Callable[[...], str], *, runtime_kwargs: Dict[str, Any], fallback: Callable[[...], Any] | None) Any
Decorator/function for routing to the best-performing kernel recorded in the FlashInfer Trace database.
This API can be used in two modes:
Decorator mode (only
def_name_or_resolverprovided): returns a decorator that wraps a kernel function with a router. The router selects the best-performing candidate according to the function’s runtime arguments.Function mode (
runtime_kwargsprovided, optionallyfallback): immediately resolves and calls the best-performing kernel and returns its result.
- Parameters:
def_name_or_resolver (Union[str, Callable[..., str]]) – The kernel name, or a resolver
fn(*args, **kwargs) -> strthat maps runtime arguments to a kernel name (definition name).runtime_kwargs (Dict[str, Any], optional) – Only used in function mode. The runtime arguments to feed into the selected kernel. Use this to call the kernel immediately instead of returning a decorator.
fallback (Optional[Callable[..., Any]], optional) – Only used in function mode. A fallback function to invoke when no matching kernel is found in the Trace database.
- Returns:
Decorator mode: a decorator that transforms the target kernel function into a routed version.
Function mode: the return value produced by the selected (or fallback) kernel.
- Return type:
Union[Callable[[Callable[…, Any]], Callable[…, Any]], Any]
Examples
Decorator mode with a fixed name¶
>>> @apply("gemm_bf16") ... def gemm_bf16(A, B): ... return torch.nn.functional.linear(A, B)
Decorator mode with a resolver¶
>>> @apply(lambda A, B: f"gemm_n{B.shape[0]}_k{B.shape[1]}") ... def gemm_bf16(A, B): ... return torch.nn.functional.linear(A, B)
Function mode¶
>>> out = apply( ... "gemm_bf16", ... runtime_kwargs={"A": A, "B": B, "bias": None}, ... fallback=lambda **kw: torch.nn.functional.linear(**kw), ... )
- flashinfer_bench.enable_apply(dataset_path: str | None = None, apply_config: ApplyConfig | None = None) ApplyRuntime¶
Enable apply functionality globally and return a ApplyRuntime instance that manages the apply functionality.
There is only one global ApplyRuntime instance. This function must be called in the main thread.
- Parameters:
dataset_path (str, optional) – Path to the dataset/traceset directory
apply_config (ApplyConfig, optional) – Configuration for the apply runtime
- Returns:
The global ApplyRuntime instance managing the apply functionality.
- Return type:
ApplyRuntime
Examples
>>> # Direct usage >>> enable_apply("/path/to/traceset", cfg) >>> # Apply is now enabled >>> out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn) >>> disable_apply() >>> # Apply is now disabled.
>>> # Context manager usage >>> with enable_apply("/path/to/traceset", cfg): ... out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn) >>> # Apply is now disabled.
- flashinfer_bench.disable_apply() None¶
Disable global apply functionality.
This function silently disables the global apply runtime by setting it to None. After calling this function, any subsequent calls to apply() will use fallback functions instead of the apply runtime.
Check out the enable_apply function for examples.
- Return type:
None