From 48c9b0c21988a50506b521b1a961818dff2f525e Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 27 Jan 2026 00:03:51 -0800 Subject: [PATCH 1/4] Introducing Database module to kernel_opt --- .../kernel_opt/database/__init__.py | 18 ++ kernel_perf_agent/kernel_opt/database/base.py | 199 ++++++++++++++++++ .../database/code_samples/matadd.py | 75 +++++++ .../database/code_samples/matadd_perst.py | 89 ++++++++ .../code_samples/matadd_tma_device.py | 95 +++++++++ .../database/code_samples/matadd_tma_host.py | 71 +++++++ .../database/code_samples/matmul.py | 98 +++++++++ .../database/code_samples/matmul_sw.py | 105 +++++++++ .../database/code_samples/matmul_tma_host.py | 79 +++++++ .../database/docs/experimental_tma.md | 169 +++++++++++++++ .../kernel_opt/database/docs/on_device_tma.py | 56 +++++ .../kernel_opt/database/docs/on_host_tma.py | 49 +++++ .../kernel_opt/database/docs/persistence.py | 43 ++++ .../kernel_opt/database/docs/pid_swizzle.py | 37 ++++ .../kernel_opt/database/docs/tma.md | 150 +++++++++++++ 15 files changed, 1333 insertions(+) create mode 100644 kernel_perf_agent/kernel_opt/database/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/database/base.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/persistence.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/tma.md diff --git a/kernel_perf_agent/kernel_opt/database/__init__.py b/kernel_perf_agent/kernel_opt/database/__init__.py new file mode 100644 index 0000000..b214283 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database package.""" + +# Database package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/database/base.py b/kernel_perf_agent/kernel_opt/database/base.py new file mode 100644 index 0000000..ffb4a0c --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/base.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from kernel_perf_agent.kernel_opt.database.docs import ( + on_device_tma, + on_host_tma, + persistence, + pid_swizzle, +) + + +class OptNode: + def __init__(self, level: int, dsl: str, opt_desc: str) -> None: + """Initialize the optimization node with the given level, description, and DSL. + :param level: int, Level in the tree + :param dsl: str, DSL used in the node + :param opt_desc: str, Description of the optimization + :param opt_parents: List[str], Parent nodes description + :param opt_children: List[OptNode], Children nodes + """ + + self.level = level # int, Level in the tree + self.dsl = dsl + self.opt_desc = opt_desc # str, Root node description + self.opt_parents = [] # List[str], Parent nodes description + self.opt_children = [] # List[OptNode], Children nodes + + def add_children(self, child_nodes): + """Adds a child node to the current node.""" + self.opt_children.extend(child_nodes) + + def remove_children(self, child_nodes): + """Removes a child node from the current node.""" + for child in child_nodes: + if child in self.opt_children: + self.opt_children.remove(child) + + def add_parents(self, parent_nodes): + """Adds a child node to the current node.""" + self.opt_parents.extend(parent_nodes) + + def remove_parents(self, parent_nodes): + """Removes a child node from the current node.""" + for parent in parent_nodes: + if parent in self.opt_parents: + self.opt_parents.remove(parent) + + def __repr__(self): + """String representation of the node for easy printing.""" + return f"OptNode at level {self.level}: ({self.opt_desc})" + + +class OptHierarchy: + def __init__(self) -> None: + """Initialize the optimization hierarchy with the root node.""" + self.root = OptNode(level=0, dsl="text", opt_desc="root") + + def get_root(self): + return self.root + + def hard_initialize(self, common_path) -> None: + """Hard initialize the hierarchy with pre-programmed database.""" + + # Level 1 nodes - Latency, Memory, Utilization bottlenecks + optnode_latency = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including: + - Persistent programming style to minimize kernel launch overhead + - Software pipelining to improve instruction-level parallelism and reduce execution time + """, + ) + optnode_memory = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including: + - PID swizzling to enhance L2 cache locality + - Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers + with compute operations + """, + ) + optnode_utilization = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques + to increase resource utilization and occupancy rates, including: + - Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations + - Enabling warp specializations to improve instruction-level parallelism and reduce register pressure + - Autotuning to identify and apply optimal kernel configurations that maximize resource usage + """, + ) + level_1_opts = [optnode_latency, optnode_memory, optnode_utilization] + self.root.add_children(level_1_opts) + optnode_latency.add_parents([self.root]) + optnode_memory.add_parents([self.root]) + optnode_utilization.add_parents([self.root]) + + # Level 2 nodes - TMA, PID swizzling, persistent programming style + optnode_host_TMA = OptNode( + level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA + ) + optnode_device_TMA = OptNode( + level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA + ) + optnode_PID_swizzling = OptNode( + level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE + ) + optnode_persistence = OptNode( + level=2, dsl="text", opt_desc=persistence.PERSISTENCE + ) + + optnode_latency.add_children([optnode_persistence]) + optnode_memory.add_children( + [ + optnode_host_TMA, + optnode_device_TMA, + optnode_PID_swizzling, + optnode_persistence, + ] + ) + optnode_utilization.add_children([optnode_persistence]) + + optnode_host_TMA.add_parents([optnode_memory]) + optnode_device_TMA.add_parents([optnode_memory]) + optnode_PID_swizzling.add_parents([optnode_memory]) + optnode_persistence.add_parents( + [optnode_latency, optnode_memory, optnode_utilization] + ) + + # Level 3 nodes - code example of each kernel + # common_path="../kernel_opt/database/code_samples/" + optnode_matmul = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() + ) + optnode_matmul_pid_swizzling = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_sw.py").read_text(), + ) + optnode_matmul_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_tma_host.py").read_text(), + ) + optnode_matadd = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text() + ) + optnode_matadd_persistence = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_perst.py").read_text(), + ) + optnode_matadd_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_host.py").read_text(), + ) + optnode_matadd_tma_device = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), + ) + + optnode_host_TMA.add_children( + [ + optnode_matmul, + optnode_matmul_tma_host, + optnode_matadd, + optnode_matadd_tma_host, + ] + ) + optnode_device_TMA.add_children([optnode_matadd, optnode_matadd_tma_device]) + optnode_PID_swizzling.add_children( + [optnode_matmul, optnode_matmul_pid_swizzling] + ) + optnode_persistence.add_children([optnode_matadd, optnode_matadd_persistence]) + + optnode_matmul.add_parents([optnode_host_TMA, optnode_PID_swizzling]) + optnode_matmul_pid_swizzling.add_parents([optnode_PID_swizzling]) + optnode_matmul_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd.add_parents( + [optnode_host_TMA, optnode_device_TMA, optnode_persistence] + ) + optnode_matadd_persistence.add_parents([optnode_persistence]) + optnode_matadd_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd_tma_device.add_parents([optnode_device_TMA]) diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py new file mode 100644 index 0000000..ec71868 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================ unoptimized matadd ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py new file mode 100644 index 0000000..99b4c79 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================== matadd with persistent programming style ================== +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + num_tiles = num_pid_m * num_pid_n + + # iterate over the program id with a stride of the total number of blocks + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + ( + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + ) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid = lambda meta: ( + min( + NUM_SMS, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + NUM_SMS=NUM_SMS, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py new file mode 100644 index 0000000..274e172 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matadd with on-device Tensor Memory Accelerator (TMA) integration ========== +from typing import Optional + +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # device TMA + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + y_desc = tl.make_tensor_descriptor( + y_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + output_desc = tl.make_tensor_descriptor( + output_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py new file mode 100644 index 0000000..7b65f1b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matadd with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + x_desc = TensorDescriptor(x, x.shape, x.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + y_desc = TensorDescriptor(y, y.shape, y.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + output_desc = TensorDescriptor( + output, output.shape, output.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py new file mode 100644 index 0000000..08f1225 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================ unoptimized matmul ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py new file mode 100644 index 0000000..61ca813 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ==================== matmul with PID swizzling ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +GROUP_SIZE_M = 8 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py new file mode 100644 index 0000000..0e2b0af --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matmul with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K]) # TMA load of A + b = b_desc.load([k * BLOCK_SIZE_K, pid_n * BLOCK_SIZE_N]) # TMA load of B + accumulator = tl.dot(a, b, accumulator) + c = accumulator.to(tl.float16) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + a_desc = TensorDescriptor(a, a.shape, a.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor(b, b.shape, b.stride(), [BLOCK_SIZE_K, BLOCK_SIZE_N]) + c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md new file mode 100644 index 0000000..7edbad0 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md @@ -0,0 +1,169 @@ + + +# Triton Tutorial: How to integrate NV TMA into kernels +## Background +TMA is a hardware unit introduced by the NV Hopper GPU. It takes over some of the data transfer work from softwares and thus improves the performance by freeing up warps or reducing register pressures etc. In practice, Triton kernel authors can update the kernel code by simply replacing `tl.load` and `tl.store` with TMA API calls to get this performance boost. + +## TMA APIs +TMA API is going through changes (from experimental to official) on upstream Triton. While we’re working out a plan to migrate, we’ll support the “old” experimental API that’s currently being used in our fbsource codebase. This tutorial will be based on the experimental API. + +TMA data load/store needs a TMA tensor descriptor object. The descriptor will describe the tensor address, strides, shapes etc. of the tensor to be copied (treat it as the CUDA `TensorMap` object). The descriptor itself needs to be stored somewhere. Depending on where we initialize the descriptor, we have two types of descriptors: on-host and on-device. The former allocates the memory on host memory, initializes descriptors there and then copies them by value to GMEM. The latter will allocate a big chunk of memory on GMEM, and then have each program to find their own offset and initialize descriptors there. + +To leverage TMA, we need to decide between on-host and on-device descriptors. That decision could be yet another topic. Here we quickly highlight a few key differences: +- Only on-device descriptors can handle dynamic shapes where not all programs are handling the same box size, which is typical in kernels like Jagged Flash Attention or HSTU. The reason is that on-host descriptors are initialized before kernel launch while on-device ones are initialized in the kernel where the box size is known. +- Torch Inductor, especially AOTI, currently only supports on-device descriptors +- On-device descriptors are initialized by every kernel program in the grid while on-host ones are initialized by host code so likely on-device descriptors take more compute resources +- Current on-device descriptors implementation (experimental API) might take more global memory because the number of programs is not necessarily known when allocating memory chunk for descriptors (e.g. depending on auto tuned BLOCK_SIZE_M), so we need to be conservative and allocate more memory + +Note: neither of these two types of TMA is necessarily faster than the other. It depends on actual use cases. + +Now for the sake of this tutorial we’ll start with on-device descriptors. And also we’ll use the example of copying 2d tensors as it’s the most common. + +With those premises, here’re the APIs to call: + +- Allocate memory chunk to store descriptors on host: +``` +TMA_DESC_SIZE = 128 # size in bytes used by a single descriptor, tunable +NUM_DESC_PER_PROGRAM = ... # how many different tensors to load/store by each program. e.g. 3 for GEMM `C=AB`, 4 for HSTU Q,K,V,O tensors +NUM_OF_PROGRAMS = ... # same as specified in kernel `grid`. If grid size is related to auto tune config, use a reasonable upper bound by hard coding "minimal block M size" etc. for now. +workspace = torch.empty( + TMA_DESC_SIZE * NUM_DESC_PER_PROGRAM * NUM_OF_PROGRAMS, + dtype=torch.uint8, + device="cuda",) +# then pass `workspace` to kernel +``` +- Initialize descriptor object: +``` +desc_ptr = workspace + TMA_DESC_SIZE * + TMA_DESC_SIZE * # in program offset in range [0,NUM_DESC_PER_PROGRAM) + + +tl.extra.cuda.experimental_device_tensormap_create2d( +desc_ptr=desc_ptr, +global_address=, # tensor to load into or store from +load_size=[BOX_SIZE_0, BOX_SIZE_1], # size of the 2D box to copy +global_size=[GLOBAL_SIZE_0, GLOBAL_SIZE_1], # this defines a "global box" in GMEM. TMA load/store won't go over this boundary if load_size is not divisble by global_size. e.g. Assuming GLOBAL_SIZE_0 == 1.5 * BLOCK_SIZE_0 and GLOBAL_SIZE_1 == BLOCK_SIZE_1, then: for TMA load, the second box will return a tensor of size (BLOCK_SIZE_0, BLOCK_SIZE_1) but the second half of the tensor is all 0; for TMA store, the second box will only have its first half written to GMEM. +element_ty= # usually tensor_ptr.dtype.element_ty +) +``` +- Acquire fence on a TensorMap/descriptor object: +``` +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire() +``` +- Load data from GMEM to SMEM: +``` +x = tl._experimental_descriptor_load( + , #initialized, and acquired fence above + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from + [BOX_SIZE_0, BOX_SIZE_1], # keep the same as descriptor's `load_size` + ,) +``` +- Store data from SMEM to GMEM: +``` +tl._experimental_descriptor_store( + , #initialized, and acquired fence above + , #the tensor to be stored on GMEM + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from +) +``` + +## Example +### Store +Let’s assume we have the following non TMA store code now: + +``` +start_m = pid * BLOCK_M +offs_m = start_m + tl.arange(0, BLOCK_M) +offs_v_d = tl.arange(0, BLOCK_D_V) +off_o = Out + seq_start * stride_om + off_h * stride_oh # TMA will use Out as global address, and include seq_start * stride_om + off_h * stride_oh as part of offsets +out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] +tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + +# Essentially, it tries to store the tensor `acc` into this box: +# Out[ +# (seq_start + pid * BLOCK_M : seq_start + (pid+1) * BLOCK_M), +# (off_h * stride_oh : off_h * stride_oh + BLOCK_D_V) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_V) starting at [seq_start + pid * BLOCK_M, off_h * stride_oh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than (seq_start + seq_len) will be masked. Note that (seq_start + seq_len) == seq_end, which we'll use in TMA store below +``` +The equivalent TMA store code would be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_o, + global_address=Out, # Out is of shape (L, H, DimV) + load_size=[BLOCK_M, BLOCK_D_V], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimV], # this eliminates the need for `mask`, TMA automatically take care of boundaries. + element_ty=Out.dtype.element_ty, +) +# pyre-ignore [20] +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_o) +tl._experimental_descriptor_store( + device_desc_o, + acc, # acc needs to be casted to the right dtype + [ #offset as explained in comments above (where the box starts at) + (seq_start + pid * BLOCK_M).to(tl.int32), + (off_h * stride_oh).to(tl.int32), + ], + ) +``` +### Load +Assume we have this non TMA load code: +``` +Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + seq_start * stride_qm, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) +q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + + +# Essentially this tries to load this box into q: +# Q[ +# (seq_start + start_m : seq_start + start_m + BLOCK_M), +# (off_h * stride_qh : off_h * stride_qh + BLOCK_D_Q) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_Q) starting at [seq_start + start_m, off_h * stride_qh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than seq_len will be filled with zero, with shape of q always being (BLOCK_M, BLOCK_D_Q). +``` +The equivalent TMA load code will be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_q, + global_address=Q, # shape (L, H, DimQ) + load_size=[BLOCK_M,BLOCK_D_Q], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimQ], # seq_end == seq_start + seq_len + element_ty=Q.dtype.element_ty, + ) +# pyre-ignore [20] + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_q) + + +q = tl._experimental_descriptor_load( + device_desc_q, + [ #offset as explained in comments above (where the box starts at) + (seq_start + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + [BLOCK_M,BLOCK_D_Q], + Q.dtype.element_ty, + ) +``` diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py new file mode 100644 index 0000000..45d1deb --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ON_DEVICE_TMA = """ +============================= On-Device Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Device TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the device. On-device descriptors +allocate a large chunk of memory in GMEM, and each program have to find its own offset +and initialize descriptors there. + +## How to integrate on-device TMA into a Triton program? +To enable on-device TMA in a Triton program, we need to add support from both the host and kernel programs. +In the host program, a global memory allocation is needed by adding the following function: +``` +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +triton.set_allocator(alloc_fn) +``` +In addition, we need to import the method `from typing import Optional`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we declare a TMA descriptor for each tensor and then use the descriptor to load and store the tensor in blocks. +An example of a TMA descriptor declaration is +``` +x_desc = tl.make_tensor_descriptor( + x_ptr, # the pointer to the tensor + shape=[M, N], # the shape of the tensor + strides=[stride_m, stride_n], # the stride of the tensor + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], # the block size of each TMA load/store +) +``` +An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py new file mode 100644 index 0000000..31b169d --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ON_HOST_TMA = """ +============================= On-Host Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Host TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the host. On-host descriptors +allocate memory in the host memory, initialize the descriptors there, and then copy +them by value to GMEM. + +## How to integrate on-host TMA into a Triton program? +To enable on-host TMA in a Triton program, we need to add support on both the host and kernel programs. +In the host program, we allocate a TMA descriptor for each tensor and pass the descriptor as an argument to the kernel. +An example of a TMA descriptor declaration is +``` +x_desc = TensorDescriptor( + x, # the pointer to the tensor + x.shape, # the shape of the tensor + x.stride(), # the stride of the tensor + [BLOCK_SIZE_M, BLOCK_SIZE_N] # the block size of each TMA load/store +) +``` +And in addition, we need to import the method `from triton.tools.tensor_descriptor import TensorDescriptor`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we use the TMA descriptor to load and store the tensor in blocks. An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/persistence.py b/kernel_perf_agent/kernel_opt/database/docs/persistence.py new file mode 100644 index 0000000..755743e --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/persistence.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PERSISTENCE = """ +================================ Persistent Programming Style ======================================= +## What it is: +The persistent programming style in GPU is a kernel design pattern where a fixed number of +blocks is launched, typically equal to the number of streaming multiprocessors (SMs), +instead of launching blocks proportional to the problem size. This pattern is particularly effective +for large-scale computations where the problem size exceeds the GPU's parallel capacity. + +## Traditional Approach: +In an unoptimized Triton GPU kernel, the number of blocks launched is dependent on the input size, +typically calculated as `triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]` +in the grid argument. +Each block processes exactly one tile of work, and the number of blocks can be much larger +than the available hardware resources. + +## Persistent Approach: +In a persistent style implementation, a fixed number of blocks is launched, which can be the number +of streaming multiprocessors (SMs) on the GPU by calling `torch.cuda.get_device_properties("cuda").multi_processor_count`. +In the kernel code, each block iterates over the program ID with a stride equal to the total number of blocks, +ensuring that the computation is completed by a fixed number of blocks. +These blocks "persist" and loop until all work is completed. + +## Advantages: +* Better resource utilization: Matches hardware capabilities exactly +* Reduced launch overhead: Fewer kernel launches for large problems +* Improved occupancy: Keeps all SMs busy throughout execution +* Better cache locality: Blocks can reuse data across multiple iterations +* Load balancing: Work is distributed more evenly across SMs +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py new file mode 100644 index 0000000..acb14ed --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PID_SWIZZLE = """ +===================================== PID Swizzling =========================================== +## What it is: +PID swizzling is a GPU optimization technique used in Triton programming that remaps +program identifiers (`pid_m` and `pid_n`) to create better memory access patterns, +specifically for L2 cache locality. This technique is commonly used in high-performance GPU kernels, +particularly for GEMM (General Matrix Multiply) operations in frameworks like Triton. + +## Traditional Approach: +The program launch order matters as it affects the L2 cache hit rate. +In an unoptimized GPU kernel, each program instance computes a [BLOCK_SIZE_M, BLOCK_SIZE_N] +block of the output tensor, and the program identifiers are arranged in a simple row-major ordering +by `pid_m = pid // num_pid_n` and `pid_n = pid % num_pid_n`. +This creates poor cache locality because adjacent programs access memory locations that are far apart. + +## PID Swizzling Approach: +PID swizzling forms "super-grouping" of programs with a fixed row size `GROUP_SIZE_M`. +The number of programs in a group is `GROUP_SIZE_M * num_pid_n`. +The `group_id` is calculated by dividing the program id by the number of programs in a group. +If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the row size of the last group is smaller +and can be calculated by subtracting `GROUP_SIZE_M * group_id` from `num_pid_m`. +The programs within a group are arranged in a column-major order. +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/tma.md b/kernel_perf_agent/kernel_opt/database/docs/tma.md new file mode 100644 index 0000000..89087d9 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/tma.md @@ -0,0 +1,150 @@ +**TMA (Tensor Memory Accelerator)** is a hardware feature in NVIDIA GPUs that accelerates memory transfers for tensor operations by providing more efficient block-based memory access patterns. + +What is TMA? +------------ + +TMA replaces traditional pointer-based memory access with **tensor descriptors** that describe the entire tensor layout, enabling the GPU hardware to optimize memory transfers automatically. + +Benefits of TMA: +---------------- + +* **Hardware-accelerated memory transfers** +* **Better memory coalescing** +* **Reduced memory access overhead** +* **Simplified memory access patterns** + +How to Add TMA to Triton Code +----------------------------- + +There are two approaches: **Host-side TMA** and **Device-side TMA**. + +### 1. Host-side TMA Implementation + +**Host-side setup:** + +``` +from triton.tools.tensor_descriptor import TensorDescriptor + +def matmul_with_tma(a, b): + # Create TMA descriptors on host + a_desc = TensorDescriptor( + a, # the tensor + a.shape, # tensor shape + a.stride(), # tensor strides + [BLOCK_SIZE_M, BLOCK_SIZE_K] # block size for TMA operations + ) + + b_desc = TensorDescriptor( + b, + b.shape, + b.stride(), + [BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = TensorDescriptor( + c, + c.shape, + c.stride(), + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Pass descriptors to kernel + kernel[grid](a_desc, b_desc, c_desc, ...) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_desc, b_desc, c_desc, ...): + pid = tl.program_id(axis=0) + # Calculate tile positions + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load using TMA descriptors + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) # offset coordinates + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute + accumulator = tl.dot(a, b) + + # Store using TMA descriptor + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator) +``` + +### 2. Device-side TMA Implementation + +**Host-side setup:** + +``` +from typing import Optional + +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +# Set custom allocator for TMA +triton.set_allocator(alloc_fn) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, ...): + # Create TMA descriptors in kernel + a_desc = tl.make_tensor_descriptor( + a_ptr, # pointer to tensor + shape=[M, K], # tensor shape + strides=[stride_am, stride_ak], # tensor strides + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K] # TMA block size + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[stride_bk, stride_bn], + block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Use descriptors for memory operations + pid = tl.program_id(axis=0) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load blocks using TMA + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute and store + result = tl.dot(a, b) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], result) +``` + +Key Differences from Traditional Approach: +------------------------------------------ + +**Traditional:** + +``` +# Manual pointer arithmetic +offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak +a = tl.load(a_ptrs, mask=...) +``` + +**TMA:** + +``` +# Descriptor-based access +a = a_desc.load([pid_m * BLOCK_SIZE_M, k_offset]) +``` + +TMA simplifies memory access patterns and leverages hardware acceleration for better performance in tensor operations. From 84708fd8bca0ea81dc6e321f5f21d8f27ce4bb0f Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 27 Jan 2026 00:37:17 -0800 Subject: [PATCH 2/4] fix ruff --- .../kernel_opt/database/code_samples/matadd.py | 10 ++++++---- .../database/code_samples/matadd_perst.py | 16 ++++++++++------ .../database/code_samples/matadd_tma_device.py | 10 ++++++---- .../database/code_samples/matadd_tma_host.py | 10 ++++++---- .../kernel_opt/database/code_samples/matmul.py | 16 +++++++++------- .../database/code_samples/matmul_sw.py | 8 +++++--- .../database/code_samples/matmul_tma_host.py | 10 ++++++---- 7 files changed, 48 insertions(+), 32 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py index ec71868..20cf664 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py @@ -35,7 +35,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -58,9 +58,11 @@ def add(x: torch.Tensor, y: torch.Tensor): M, N = x.shape output = torch.empty((M, N), device=x.device, dtype=torch.float16) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py index 99b4c79..33015ea 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py @@ -68,12 +68,16 @@ def add(x: torch.Tensor, y: torch.Tensor): # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - grid = lambda meta: ( - min( - NUM_SMS, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ), - ) + + def grid(meta): + return ( + min( + NUM_SMS, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py index 274e172..3d4d6c1 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py @@ -37,7 +37,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -78,9 +78,11 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): triton.set_allocator(alloc_fn) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py index 7b65f1b..81b4761 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py @@ -34,7 +34,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -56,9 +56,11 @@ def add(x: torch.Tensor, y: torch.Tensor): output, output.shape, output.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x_desc, y_desc, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py index 08f1225..7f78258 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -42,7 +42,7 @@ def matmul_kernel( BLOCK_SIZE_K: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -75,9 +75,11 @@ def matmul(a, b): K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a, b, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py index 61ca813..1e2e468 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py @@ -81,9 +81,11 @@ def matmul(a, b): K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a, b, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py index 0e2b0af..dc4f414 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py @@ -37,7 +37,7 @@ def matmul_kernel( BLOCK_SIZE_K: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -62,9 +62,11 @@ def matmul(a, b): b_desc = TensorDescriptor(b, b.shape, b.stride(), [BLOCK_SIZE_K, BLOCK_SIZE_N]) c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a_desc, b_desc, From 069a80debb3952dfef940c7580cc14f8562b26d7 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 28 Jan 2026 21:14:39 -0800 Subject: [PATCH 3/4] change base and add integration example --- kernel_perf_agent/kernel_opt/database/base.py | 112 +++++----- pyproject.toml | 2 + .../prescribing/RAG_based_prescriber.py | 197 ++++++++++++++++++ 3 files changed, 247 insertions(+), 64 deletions(-) create mode 100644 triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py diff --git a/kernel_perf_agent/kernel_opt/database/base.py b/kernel_perf_agent/kernel_opt/database/base.py index ffb4a0c..2ee543b 100644 --- a/kernel_perf_agent/kernel_opt/database/base.py +++ b/kernel_perf_agent/kernel_opt/database/base.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from pathlib import Path from kernel_perf_agent.kernel_opt.database.docs import ( @@ -25,55 +27,57 @@ class OptNode: def __init__(self, level: int, dsl: str, opt_desc: str) -> None: """Initialize the optimization node with the given level, description, and DSL. - :param level: int, Level in the tree - :param dsl: str, DSL used in the node - :param opt_desc: str, Description of the optimization - :param opt_parents: List[str], Parent nodes description - :param opt_children: List[OptNode], Children nodes - """ - self.level = level # int, Level in the tree + Args: + level: Level in the tree (0=root, 1=bottleneck type, 2=technique, 3=code example) + dsl: DSL used in the node (e.g., "text", "triton") + opt_desc: Description of the optimization or code example + """ + self.level = level self.dsl = dsl - self.opt_desc = opt_desc # str, Root node description - self.opt_parents = [] # List[str], Parent nodes description - self.opt_children = [] # List[OptNode], Children nodes + self.opt_desc = opt_desc + self.opt_parents: list[OptNode] = [] + self.opt_children: list[OptNode] = [] - def add_children(self, child_nodes): - """Adds a child node to the current node.""" + def add_children(self, child_nodes: list[OptNode]) -> None: + """Adds child nodes to this node.""" self.opt_children.extend(child_nodes) - def remove_children(self, child_nodes): - """Removes a child node from the current node.""" - for child in child_nodes: - if child in self.opt_children: - self.opt_children.remove(child) - - def add_parents(self, parent_nodes): - """Adds a child node to the current node.""" + def add_parents(self, parent_nodes: list[OptNode]) -> None: + """Adds parent nodes to this node.""" self.opt_parents.extend(parent_nodes) - def remove_parents(self, parent_nodes): - """Removes a child node from the current node.""" - for parent in parent_nodes: - if parent in self.opt_parents: - self.opt_parents.remove(parent) - - def __repr__(self): + def __repr__(self) -> str: """String representation of the node for easy printing.""" return f"OptNode at level {self.level}: ({self.opt_desc})" +def add_relation(parent: OptNode, children: list[OptNode]) -> None: + """Add parent-child relationship symmetrically. + + Args: + parent: The parent node + children: List of child nodes to add + """ + parent.add_children(children) + for child in children: + child.add_parents([parent]) + + class OptHierarchy: def __init__(self) -> None: """Initialize the optimization hierarchy with the root node.""" self.root = OptNode(level=0, dsl="text", opt_desc="root") - def get_root(self): + def get_root(self) -> OptNode: return self.root - def hard_initialize(self, common_path) -> None: - """Hard initialize the hierarchy with pre-programmed database.""" + def hard_initialize(self, common_path: Path) -> None: + """Hard initialize the hierarchy with pre-programmed database. + Args: + common_path: Path to the code_samples directory + """ # Level 1 nodes - Latency, Memory, Utilization bottlenecks optnode_latency = OptNode( level=1, @@ -102,11 +106,7 @@ def hard_initialize(self, common_path) -> None: - Autotuning to identify and apply optimal kernel configurations that maximize resource usage """, ) - level_1_opts = [optnode_latency, optnode_memory, optnode_utilization] - self.root.add_children(level_1_opts) - optnode_latency.add_parents([self.root]) - optnode_memory.add_parents([self.root]) - optnode_utilization.add_parents([self.root]) + add_relation(self.root, [optnode_latency, optnode_memory, optnode_utilization]) # Level 2 nodes - TMA, PID swizzling, persistent programming style optnode_host_TMA = OptNode( @@ -122,26 +122,19 @@ def hard_initialize(self, common_path) -> None: level=2, dsl="text", opt_desc=persistence.PERSISTENCE ) - optnode_latency.add_children([optnode_persistence]) - optnode_memory.add_children( + add_relation(optnode_latency, [optnode_persistence]) + add_relation( + optnode_memory, [ optnode_host_TMA, optnode_device_TMA, optnode_PID_swizzling, optnode_persistence, - ] - ) - optnode_utilization.add_children([optnode_persistence]) - - optnode_host_TMA.add_parents([optnode_memory]) - optnode_device_TMA.add_parents([optnode_memory]) - optnode_PID_swizzling.add_parents([optnode_memory]) - optnode_persistence.add_parents( - [optnode_latency, optnode_memory, optnode_utilization] + ], ) + add_relation(optnode_utilization, [optnode_persistence]) - # Level 3 nodes - code example of each kernel - # common_path="../kernel_opt/database/code_samples/" + # Level 3 nodes - code examples optnode_matmul = OptNode( level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() ) @@ -174,26 +167,17 @@ def hard_initialize(self, common_path) -> None: opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), ) - optnode_host_TMA.add_children( + add_relation( + optnode_host_TMA, [ optnode_matmul, optnode_matmul_tma_host, optnode_matadd, optnode_matadd_tma_host, - ] - ) - optnode_device_TMA.add_children([optnode_matadd, optnode_matadd_tma_device]) - optnode_PID_swizzling.add_children( - [optnode_matmul, optnode_matmul_pid_swizzling] + ], ) - optnode_persistence.add_children([optnode_matadd, optnode_matadd_persistence]) - - optnode_matmul.add_parents([optnode_host_TMA, optnode_PID_swizzling]) - optnode_matmul_pid_swizzling.add_parents([optnode_PID_swizzling]) - optnode_matmul_tma_host.add_parents([optnode_host_TMA]) - optnode_matadd.add_parents( - [optnode_host_TMA, optnode_device_TMA, optnode_persistence] + add_relation(optnode_device_TMA, [optnode_matadd, optnode_matadd_tma_device]) + add_relation( + optnode_PID_swizzling, [optnode_matmul, optnode_matmul_pid_swizzling] ) - optnode_matadd_persistence.add_parents([optnode_persistence]) - optnode_matadd_tma_host.add_parents([optnode_host_TMA]) - optnode_matadd_tma_device.add_parents([optnode_device_TMA]) + add_relation(optnode_persistence, [optnode_matadd, optnode_matadd_persistence]) diff --git a/pyproject.toml b/pyproject.toml index a31dc3a..82ef2de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "python-dotenv", "gradio>=5.5.0", "requests", + "langchain-openai", + "numpy", ] [project.optional-dependencies] diff --git a/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py b/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py new file mode 100644 index 0000000..7800609 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG-based prescriber for retrieving optimization patterns from the database. + +This module provides embedding-based retrieval of optimization patterns from +a hierarchical knowledge base. The database is organized in 3 levels: +- L1: Bottleneck types (Latency, Memory, Utilization) +- L2: Optimization techniques (TMA, PID swizzling, Persistence) +- L3: Code examples (matmul.py, matmul_sw.py, etc.) + +Usage: + prescriber = RAGPrescriber(database_path=Path("...")) + + # Retrieve best matching pattern for an optimization hint + opt_node, similarity_scores = prescriber.retrieve("use TMA for memory optimization") + + # Build context from the retrieved node (traverses to leaf code examples) + context = prescriber.build_context(opt_node) +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np + +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy, OptNode + + +class RAGPrescriber: + """ + Embedding-based retriever for optimization patterns. + + Uses OpenAI embeddings to find the most similar optimization node + based on the optimization prompt, then traverses to leaf nodes + to collect code examples. + """ + + def __init__( + self, + database_path: Path | None = None, + logger: logging.Logger | None = None, + ): + """ + Initialize the RAG prescriber. + + Args: + database_path: Path to code_samples directory. Defaults to + kernel_perf_agent/kernel_opt/database/code_samples + logger: Optional logger instance + """ + self.logger = logger or logging.getLogger(__name__) + self.opt_hierarchy: OptHierarchy | None = None + self._embeddings = None # Lazy-loaded + + # Default path + if database_path is None: + database_path = ( + Path(__file__).parent.parent.parent.parent + / "kernel_perf_agent" + / "kernel_opt" + / "database" + / "code_samples" + ) + + self._initialize_database(database_path) + + def _initialize_database(self, db_path: Path) -> None: + """Initialize the optimization hierarchy from disk.""" + if not db_path.exists(): + self.logger.warning(f"RAG database path not found: {db_path}") + return + + self.opt_hierarchy = OptHierarchy() + self.opt_hierarchy.hard_initialize(db_path) + self.logger.info(f"Initialized RAG database from {db_path}") + + def _get_embeddings(self): + """Lazy-load OpenAI embeddings.""" + if self._embeddings is None: + from langchain_openai import OpenAIEmbeddings + + self._embeddings = OpenAIEmbeddings() + return self._embeddings + + def _cosine_similarity(self, vec1, vec2) -> float: + """Compute cosine similarity between two vectors.""" + dot_product = np.dot(vec1, vec2) + norm_vec1 = np.linalg.norm(vec1) + norm_vec2 = np.linalg.norm(vec2) + if norm_vec1 == 0 or norm_vec2 == 0: + return 0.0 + return dot_product / (norm_vec1 * norm_vec2) + + def retrieve(self, opt_prompt: str) -> tuple[OptNode | None, dict[OptNode, float]]: + """ + Retrieve the most relevant optimization node using embedding similarity. + + Traverses the entire database tree, computes embedding similarity for + each node's description, and returns the node with highest similarity. + + Args: + opt_prompt: Description of the desired optimization + (e.g., "use TMA for memory optimization") + + Returns: + Tuple of (best_matching_node, similarity_scores_dict) + """ + if not self.opt_hierarchy: + self.logger.warning("Database not initialized") + return None, {} + + embeddings = self._get_embeddings() + key_embedding = embeddings.embed_query(opt_prompt) + + # Traverse tree and compute similarity for all nodes + root = self.opt_hierarchy.get_root() + cur_level = list(root.opt_children) + opt_similarity: dict[OptNode, float] = {} + + while cur_level: + child_level = [] + for node in cur_level: + opt_embedding = embeddings.embed_query(node.opt_desc) + opt_similarity[node] = self._cosine_similarity( + key_embedding, opt_embedding + ) + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + cur_level = child_level + + if not opt_similarity: + return None, {} + + # Get node with highest similarity + opt_similarity_sorted = sorted( + opt_similarity.items(), key=lambda item: item[1], reverse=True + ) + best_node = opt_similarity_sorted[0][0] + + self.logger.info( + f"Retrieved optimization pattern (similarity={opt_similarity_sorted[0][1]:.3f}): " + f"{best_node.opt_desc[:80]}..." + ) + + return best_node, opt_similarity + + def build_context(self, opt_node: OptNode) -> str: + """ + Build context by traversing from opt_node to all leaf nodes. + + Collects optimization descriptions and code examples from the + node and all its descendants. + + Args: + opt_node: Starting node (typically from retrieve()) + + Returns: + Context string with technique descriptions and code examples + """ + context = "" + leaf_reached = False + cur_level = [opt_node] + + while cur_level: + child_level = [] + for node in cur_level: + # Mark when we reach leaf nodes (code examples) + if not leaf_reached and not node.opt_children: + leaf_reached = True + context += ( + "\nHere are code examples before and after the optimization:\n" + ) + + context += node.opt_desc + + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + + cur_level = child_level + + return context From c57e13ca72e5a4b77808ba64a77d85473dd88efd Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 29 Jan 2026 15:57:45 -0800 Subject: [PATCH 4/4] update rag integration --- .../kernel_opt/database/__init__.py | 1 - .../kernel_opt/database/docs/__init__.py | 17 ++ pyproject.toml | 1 - .../prescribing/RAG_based_prescriber.py | 153 ++++++++++++------ 4 files changed, 123 insertions(+), 49 deletions(-) create mode 100644 kernel_perf_agent/kernel_opt/database/docs/__init__.py diff --git a/kernel_perf_agent/kernel_opt/database/__init__.py b/kernel_perf_agent/kernel_opt/database/__init__.py index b214283..fdfd396 100644 --- a/kernel_perf_agent/kernel_opt/database/__init__.py +++ b/kernel_perf_agent/kernel_opt/database/__init__.py @@ -14,5 +14,4 @@ """Database package.""" -# Database package __all__ = [] diff --git a/kernel_perf_agent/kernel_opt/database/docs/__init__.py b/kernel_perf_agent/kernel_opt/database/docs/__init__.py new file mode 100644 index 0000000..29f3159 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database Document package.""" + +__all__ = [] diff --git a/pyproject.toml b/pyproject.toml index 82ef2de..b34a309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "python-dotenv", "gradio>=5.5.0", "requests", - "langchain-openai", "numpy", ] diff --git a/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py b/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py index 7800609..6a6095d 100644 --- a/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py +++ b/triton_kernel_agent/opt_worker_component/prescribing/RAG_based_prescriber.py @@ -16,7 +16,7 @@ This module provides embedding-based retrieval of optimization patterns from a hierarchical knowledge base. The database is organized in 3 levels: -- L1: Bottleneck types (Latency, Memory, Utilization) +- L1: Bottleneck types (Compute, Memory, Under-utilization) - L2: Optimization techniques (TMA, PID swizzling, Persistence) - L3: Code examples (matmul.py, matmul_sw.py, etc.) @@ -64,12 +64,20 @@ def __init__( """ self.logger = logger or logging.getLogger(__name__) self.opt_hierarchy: OptHierarchy | None = None - self._embeddings = None # Lazy-loaded + self._openai_client = None # Lazy-loaded + self._node_embeddings: dict[ + OptNode, list[float] + ] = {} # Precomputed L1/L2 embeddings - # Default path + # Default path: navigate to project root (where pyproject.toml is) if database_path is None: + project_root = Path(__file__).resolve() + while project_root.parent != project_root: + if (project_root / "pyproject.toml").exists(): + break + project_root = project_root.parent database_path = ( - Path(__file__).parent.parent.parent.parent + project_root / "kernel_perf_agent" / "kernel_opt" / "database" @@ -86,15 +94,48 @@ def _initialize_database(self, db_path: Path) -> None: self.opt_hierarchy = OptHierarchy() self.opt_hierarchy.hard_initialize(db_path) + self._precompute_embeddings() self.logger.info(f"Initialized RAG database from {db_path}") - def _get_embeddings(self): - """Lazy-load OpenAI embeddings.""" - if self._embeddings is None: - from langchain_openai import OpenAIEmbeddings + def _precompute_embeddings(self) -> None: + """Precompute embeddings for L1/L2 nodes.""" + if not self.opt_hierarchy: + return + + root = self.opt_hierarchy.get_root() + cur_level = list(root.opt_children) - self._embeddings = OpenAIEmbeddings() - return self._embeddings + while cur_level: + child_level = [] + for node in cur_level: + # Only embed L1/L2 nodes (nodes with children), skip L3 code examples + if node.opt_children: + self._node_embeddings[node] = self._embed_query(node.opt_desc) + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + cur_level = child_level + + self.logger.info( + f"Precomputed embeddings for {len(self._node_embeddings)} L1/L2 nodes" + ) + + def _get_openai_client(self): + """Lazy-load OpenAI client.""" + if self._openai_client is None: + from openai import OpenAI + + self._openai_client = OpenAI() + return self._openai_client + + def _embed_query(self, text: str) -> list[float]: + """Get embedding for a text query.""" + client = self._get_openai_client() + response = client.embeddings.create( + input=text, + model="text-embedding-3-large", # text-embedding-3-small for lower cost + ) + return response.data[0].embedding def _cosine_similarity(self, vec1, vec2) -> float: """Compute cosine similarity between two vectors.""" @@ -109,8 +150,9 @@ def retrieve(self, opt_prompt: str) -> tuple[OptNode | None, dict[OptNode, float """ Retrieve the most relevant optimization node using embedding similarity. - Traverses the entire database tree, computes embedding similarity for - each node's description, and returns the node with highest similarity. + Computes similarity against precomputed L1/L2 node embeddings and returns + the node with highest similarity. Use build_context() to traverse down + to L3 code examples. Args: opt_prompt: Description of the desired optimization @@ -123,25 +165,14 @@ def retrieve(self, opt_prompt: str) -> tuple[OptNode | None, dict[OptNode, float self.logger.warning("Database not initialized") return None, {} - embeddings = self._get_embeddings() - key_embedding = embeddings.embed_query(opt_prompt) + key_embedding = self._embed_query(opt_prompt) - # Traverse tree and compute similarity for all nodes - root = self.opt_hierarchy.get_root() - cur_level = list(root.opt_children) + # Compute similarity against precomputed L1/L2 node embeddings opt_similarity: dict[OptNode, float] = {} - - while cur_level: - child_level = [] - for node in cur_level: - opt_embedding = embeddings.embed_query(node.opt_desc) - opt_similarity[node] = self._cosine_similarity( - key_embedding, opt_embedding - ) - for child in node.opt_children: - if child not in child_level: - child_level.append(child) - cur_level = child_level + for node, node_embedding in self._node_embeddings.items(): + opt_similarity[node] = self._cosine_similarity( + key_embedding, node_embedding + ) if not opt_similarity: return None, {} @@ -159,39 +190,67 @@ def retrieve(self, opt_prompt: str) -> tuple[OptNode | None, dict[OptNode, float return best_node, opt_similarity - def build_context(self, opt_node: OptNode) -> str: + def build_context( + self, + opt_node: OptNode, + max_chars: int = 8192, + max_code_examples: int = 2, + ) -> str: """ - Build context by traversing from opt_node to all leaf nodes. - - Collects optimization descriptions and code examples from the - node and all its descendants. + Build context from opt_node with technique description and code examples. Args: opt_node: Starting node (typically from retrieve()) + max_chars: Maximum character budget for context + max_code_examples: Maximum number of leaf code examples to include Returns: - Context string with technique descriptions and code examples + Context string with technique description and limited code examples """ - context = "" - leaf_reached = False - cur_level = [opt_node] + parts: list[str] = [] + code_examples: list[str] = [] + # BFS to collect technique descriptions (non-leaf) and code examples (leaf) + cur_level = [opt_node] while cur_level: child_level = [] for node in cur_level: - # Mark when we reach leaf nodes (code examples) - if not leaf_reached and not node.opt_children: - leaf_reached = True - context += ( - "\nHere are code examples before and after the optimization:\n" - ) - - context += node.opt_desc + if node.opt_children: + # Non-leaf: technique description + parts.append(node.opt_desc.strip()) + else: + # Leaf: code example + code_examples.append(node.opt_desc.strip()) for child in node.opt_children: if child not in child_level: child_level.append(child) - cur_level = child_level + # Build context with separators + context = "" + + # Add technique descriptions + if parts: + context += "## Optimization Technique\n\n" + context += "\n\n---\n\n".join(parts) + + # Add limited code examples + if code_examples: + selected_examples = code_examples[:max_code_examples] + if parts: + context += "\n\n---\n\n" # Separator between sections + context += "## Code Examples\n\n" + context += "\n\n---\n\n".join(selected_examples) + + if len(code_examples) > max_code_examples: + context += f"\n\n(Showing {max_code_examples} of {len(code_examples)} examples)" + + # Truncate if over budget + if len(context) > max_chars: + context = context[:max_chars] + "\n\n... (truncated)" + self.logger.warning( + f"Context truncated from {len(context)} to {max_chars} chars" + ) + return context