-
Notifications
You must be signed in to change notification settings - Fork 29
Add Knowledge Database to Kernel optimization #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
48c9b0c
Introducing Database module to kernel_opt
84708fd
fix ruff
5d01d11
Merge branch 'main' into kaiming/opt_component_7_clean
kaiming-cheng 069a80d
change base and add integration example
c57e13c
update rag integration
83ed713
Merge branch 'main' into kaiming/opt_component_7_clean
kaiming-cheng 2f10c69
Merge branch 'main' into kaiming/opt_component_7_clean
kaiming-cheng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Database package.""" | ||
|
|
||
| __all__ = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| from kernel_perf_agent.kernel_opt.database.docs import ( | ||
| on_device_tma, | ||
| on_host_tma, | ||
| persistence, | ||
| pid_swizzle, | ||
| ) | ||
|
|
||
|
|
||
| class OptNode: | ||
| def __init__(self, level: int, dsl: str, opt_desc: str) -> None: | ||
| """Initialize the optimization node with the given level, description, and DSL. | ||
| Args: | ||
| level: Level in the tree (0=root, 1=bottleneck type, 2=technique, 3=code example) | ||
| dsl: DSL used in the node (e.g., "text", "triton") | ||
| opt_desc: Description of the optimization or code example | ||
| """ | ||
| self.level = level | ||
| self.dsl = dsl | ||
| self.opt_desc = opt_desc | ||
| self.opt_parents: list[OptNode] = [] | ||
| self.opt_children: list[OptNode] = [] | ||
|
|
||
| def add_children(self, child_nodes: list[OptNode]) -> None: | ||
| """Adds child nodes to this node.""" | ||
| self.opt_children.extend(child_nodes) | ||
|
|
||
| def add_parents(self, parent_nodes: list[OptNode]) -> None: | ||
| """Adds parent nodes to this node.""" | ||
| self.opt_parents.extend(parent_nodes) | ||
|
|
||
| def __repr__(self) -> str: | ||
| """String representation of the node for easy printing.""" | ||
| return f"OptNode at level {self.level}: ({self.opt_desc})" | ||
|
|
||
|
|
||
| def add_relation(parent: OptNode, children: list[OptNode]) -> None: | ||
| """Add parent-child relationship symmetrically. | ||
| Args: | ||
| parent: The parent node | ||
| children: List of child nodes to add | ||
| """ | ||
| parent.add_children(children) | ||
| for child in children: | ||
| child.add_parents([parent]) | ||
|
|
||
|
|
||
| class OptHierarchy: | ||
| def __init__(self) -> None: | ||
| """Initialize the optimization hierarchy with the root node.""" | ||
| self.root = OptNode(level=0, dsl="text", opt_desc="root") | ||
|
|
||
| def get_root(self) -> OptNode: | ||
| return self.root | ||
|
|
||
| def hard_initialize(self, common_path: Path) -> None: | ||
| """Hard initialize the hierarchy with pre-programmed database. | ||
| Args: | ||
| common_path: Path to the code_samples directory | ||
| """ | ||
| # Level 1 nodes - Latency, Memory, Utilization bottlenecks | ||
| optnode_latency = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including: | ||
| - Persistent programming style to minimize kernel launch overhead | ||
| - Software pipelining to improve instruction-level parallelism and reduce execution time | ||
| """, | ||
| ) | ||
| optnode_memory = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including: | ||
| - PID swizzling to enhance L2 cache locality | ||
| - Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers | ||
| with compute operations | ||
| """, | ||
| ) | ||
| optnode_utilization = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques | ||
| to increase resource utilization and occupancy rates, including: | ||
| - Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations | ||
| - Enabling warp specializations to improve instruction-level parallelism and reduce register pressure | ||
| - Autotuning to identify and apply optimal kernel configurations that maximize resource usage | ||
| """, | ||
| ) | ||
| add_relation(self.root, [optnode_latency, optnode_memory, optnode_utilization]) | ||
|
|
||
| # Level 2 nodes - TMA, PID swizzling, persistent programming style | ||
| optnode_host_TMA = OptNode( | ||
| level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA | ||
| ) | ||
| optnode_device_TMA = OptNode( | ||
| level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA | ||
| ) | ||
| optnode_PID_swizzling = OptNode( | ||
| level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE | ||
| ) | ||
| optnode_persistence = OptNode( | ||
| level=2, dsl="text", opt_desc=persistence.PERSISTENCE | ||
| ) | ||
|
|
||
| add_relation(optnode_latency, [optnode_persistence]) | ||
| add_relation( | ||
| optnode_memory, | ||
| [ | ||
| optnode_host_TMA, | ||
| optnode_device_TMA, | ||
| optnode_PID_swizzling, | ||
| optnode_persistence, | ||
| ], | ||
| ) | ||
| add_relation(optnode_utilization, [optnode_persistence]) | ||
|
|
||
| # Level 3 nodes - code examples | ||
| optnode_matmul = OptNode( | ||
| level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() | ||
| ) | ||
| optnode_matmul_pid_swizzling = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matmul_sw.py").read_text(), | ||
| ) | ||
| optnode_matmul_tma_host = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matmul_tma_host.py").read_text(), | ||
| ) | ||
| optnode_matadd = OptNode( | ||
| level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text() | ||
| ) | ||
| optnode_matadd_persistence = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_perst.py").read_text(), | ||
| ) | ||
| optnode_matadd_tma_host = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_tma_host.py").read_text(), | ||
| ) | ||
| optnode_matadd_tma_device = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), | ||
| ) | ||
|
|
||
| add_relation( | ||
| optnode_host_TMA, | ||
| [ | ||
| optnode_matmul, | ||
| optnode_matmul_tma_host, | ||
| optnode_matadd, | ||
| optnode_matadd_tma_host, | ||
| ], | ||
| ) | ||
| add_relation(optnode_device_TMA, [optnode_matadd, optnode_matadd_tma_device]) | ||
| add_relation( | ||
| optnode_PID_swizzling, [optnode_matmul, optnode_matmul_pid_swizzling] | ||
| ) | ||
| add_relation(optnode_persistence, [optnode_matadd, optnode_matadd_persistence]) | ||
77 changes: 77 additions & 0 deletions
77
kernel_perf_agent/kernel_opt/database/code_samples/matadd.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ============================ unoptimized matadd ================================= | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| BLOCK_SIZE_M = 128 | ||
| BLOCK_SIZE_N = 128 | ||
| DEVICE = triton.runtime.driver.active.get_active_torch_device() | ||
|
|
||
|
|
||
| @triton.jit | ||
| def add_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| output_ptr, | ||
| M, | ||
| N, | ||
| stride_m, | ||
| stride_n, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
| pid_m = pid // num_pid_n | ||
| pid_n = pid % num_pid_n | ||
|
|
||
| # Range of pointers for loading the block of A and B. | ||
| offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
| x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
|
|
||
| x = tl.load(x_ptrs, mask=data_mask, other=0.0) | ||
| y = tl.load(y_ptrs, mask=data_mask, other=0.0) | ||
| output = x + y | ||
| tl.store(output_ptrs, output, mask=data_mask) | ||
|
|
||
|
|
||
| def add(x: torch.Tensor, y: torch.Tensor): | ||
| M, N = x.shape | ||
| output = torch.empty((M, N), device=x.device, dtype=torch.float16) | ||
|
|
||
| def grid(meta): | ||
| return ( | ||
| triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), | ||
| ) | ||
|
|
||
| add_kernel[grid]( | ||
| x, | ||
| y, | ||
| output, | ||
| M, | ||
| N, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| ) | ||
| return output |
93 changes: 93 additions & 0 deletions
93
kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ===================== matadd with persistent programming style ================== | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| BLOCK_SIZE_M = 128 | ||
| BLOCK_SIZE_N = 128 | ||
| DEVICE = triton.runtime.driver.active.get_active_torch_device() | ||
|
|
||
|
|
||
| @triton.jit | ||
| def add_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| output_ptr, | ||
| M, | ||
| N, | ||
| stride_m, | ||
| stride_n, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| NUM_SMS: tl.constexpr, | ||
| ): | ||
| start_pid = tl.program_id(axis=0) | ||
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
|
|
||
| num_tiles = num_pid_m * num_pid_n | ||
|
|
||
| # iterate over the program id with a stride of the total number of blocks | ||
| for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): | ||
| pid_m = tile_id // num_pid_n | ||
| pid_n = tile_id % num_pid_n | ||
|
|
||
| # Range of pointers for loading the block of A and B. | ||
| offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
| x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| output_ptrs = output_ptr + ( | ||
| offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | ||
| ) | ||
| data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
|
|
||
| x = tl.load(x_ptrs, mask=data_mask, other=0.0) | ||
| y = tl.load(y_ptrs, mask=data_mask, other=0.0) | ||
| output = x + y | ||
| tl.store(output_ptrs, output, mask=data_mask) | ||
|
|
||
|
|
||
| def add(x: torch.Tensor, y: torch.Tensor): | ||
| M, N = x.shape | ||
| output = torch.empty((M, N), device=x.device, dtype=torch.float16) | ||
|
|
||
| # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks | ||
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | ||
|
|
||
| def grid(meta): | ||
| return ( | ||
| min( | ||
| NUM_SMS, | ||
| triton.cdiv(M, meta["BLOCK_SIZE_M"]) | ||
| * triton.cdiv(N, meta["BLOCK_SIZE_N"]), | ||
| ), | ||
| ) | ||
|
|
||
| add_kernel[grid]( | ||
| x, | ||
| y, | ||
| output, | ||
| M, | ||
| N, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| NUM_SMS=NUM_SMS, | ||
| ) | ||
| return output |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t see a
kernel_perf_agent/kernel_opt/database/docs/__init__.pyadded in this PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, updated this in the c57e13c