{
"name": "gemm_triton_h100_v1",
"definition": "gemm",
"description": "A high-performance GEMM implementation (C = A @ B.T) using Triton. Generated by one-shot inquiry with Gemini-2.5-Pro.",
"author": "gemini-2.5-pro-mystery-agent",
"spec": {
"language": "triton",
"target_hardware": [
"NVIDIA_H100"
],
"dependencies": [
"triton >= 2.3",
"torch"
],
"entry_point": "main.py::run"
},
"sources": [
{
"path": "main.py",
"content": "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8)\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef _gemm_kernel(\n A, B, C, M, N, K, stride_am, stride_ak, stride_bn, stride_bk, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr\n):\n # ... (Triton kernel logic as before)\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None]\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :]\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = A + (offs_am * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = B + (offs_bn * stride_bn + offs_k[:, None] * stride_bk)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n c = accumulator.to(C.dtype.element_ty)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef run(A, B):\n M, K = A.shape\n N, _ = B.shape\n C = torch.empty((M, N), device=A.device, dtype=A.dtype)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n _gemm_kernel[grid](A, B, C, M, N, K, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1))\n return C"
}
]
}