Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions kernel_perf_agent/kernel_opt/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Database package."""

__all__ = []
183 changes: 183 additions & 0 deletions kernel_perf_agent/kernel_opt/database/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from pathlib import Path

from kernel_perf_agent.kernel_opt.database.docs import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t see a kernel_perf_agent/kernel_opt/database/docs/__init__.py added in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, updated this in the c57e13c

on_device_tma,
on_host_tma,
persistence,
pid_swizzle,
)


class OptNode:
def __init__(self, level: int, dsl: str, opt_desc: str) -> None:
"""Initialize the optimization node with the given level, description, and DSL.
Args:
level: Level in the tree (0=root, 1=bottleneck type, 2=technique, 3=code example)
dsl: DSL used in the node (e.g., "text", "triton")
opt_desc: Description of the optimization or code example
"""
self.level = level
self.dsl = dsl
self.opt_desc = opt_desc
self.opt_parents: list[OptNode] = []
self.opt_children: list[OptNode] = []

def add_children(self, child_nodes: list[OptNode]) -> None:
"""Adds child nodes to this node."""
self.opt_children.extend(child_nodes)

def add_parents(self, parent_nodes: list[OptNode]) -> None:
"""Adds parent nodes to this node."""
self.opt_parents.extend(parent_nodes)

def __repr__(self) -> str:
"""String representation of the node for easy printing."""
return f"OptNode at level {self.level}: ({self.opt_desc})"


def add_relation(parent: OptNode, children: list[OptNode]) -> None:
"""Add parent-child relationship symmetrically.
Args:
parent: The parent node
children: List of child nodes to add
"""
parent.add_children(children)
for child in children:
child.add_parents([parent])


class OptHierarchy:
def __init__(self) -> None:
"""Initialize the optimization hierarchy with the root node."""
self.root = OptNode(level=0, dsl="text", opt_desc="root")

def get_root(self) -> OptNode:
return self.root

def hard_initialize(self, common_path: Path) -> None:
"""Hard initialize the hierarchy with pre-programmed database.
Args:
common_path: Path to the code_samples directory
"""
# Level 1 nodes - Latency, Memory, Utilization bottlenecks
optnode_latency = OptNode(
level=1,
dsl="text",
opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including:
- Persistent programming style to minimize kernel launch overhead
- Software pipelining to improve instruction-level parallelism and reduce execution time
""",
)
optnode_memory = OptNode(
level=1,
dsl="text",
opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including:
- PID swizzling to enhance L2 cache locality
- Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers
with compute operations
""",
)
optnode_utilization = OptNode(
level=1,
dsl="text",
opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques
to increase resource utilization and occupancy rates, including:
- Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations
- Enabling warp specializations to improve instruction-level parallelism and reduce register pressure
- Autotuning to identify and apply optimal kernel configurations that maximize resource usage
""",
)
add_relation(self.root, [optnode_latency, optnode_memory, optnode_utilization])

# Level 2 nodes - TMA, PID swizzling, persistent programming style
optnode_host_TMA = OptNode(
level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA
)
optnode_device_TMA = OptNode(
level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA
)
optnode_PID_swizzling = OptNode(
level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE
)
optnode_persistence = OptNode(
level=2, dsl="text", opt_desc=persistence.PERSISTENCE
)

add_relation(optnode_latency, [optnode_persistence])
add_relation(
optnode_memory,
[
optnode_host_TMA,
optnode_device_TMA,
optnode_PID_swizzling,
optnode_persistence,
],
)
add_relation(optnode_utilization, [optnode_persistence])

# Level 3 nodes - code examples
optnode_matmul = OptNode(
level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text()
)
optnode_matmul_pid_swizzling = OptNode(
level=3,
dsl="triton",
opt_desc=Path(common_path / "matmul_sw.py").read_text(),
)
optnode_matmul_tma_host = OptNode(
level=3,
dsl="triton",
opt_desc=Path(common_path / "matmul_tma_host.py").read_text(),
)
optnode_matadd = OptNode(
level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text()
)
optnode_matadd_persistence = OptNode(
level=3,
dsl="triton",
opt_desc=Path(common_path / "matadd_perst.py").read_text(),
)
optnode_matadd_tma_host = OptNode(
level=3,
dsl="triton",
opt_desc=Path(common_path / "matadd_tma_host.py").read_text(),
)
optnode_matadd_tma_device = OptNode(
level=3,
dsl="triton",
opt_desc=Path(common_path / "matadd_tma_device.py").read_text(),
)

add_relation(
optnode_host_TMA,
[
optnode_matmul,
optnode_matmul_tma_host,
optnode_matadd,
optnode_matadd_tma_host,
],
)
add_relation(optnode_device_TMA, [optnode_matadd, optnode_matadd_tma_device])
add_relation(
optnode_PID_swizzling, [optnode_matmul, optnode_matmul_pid_swizzling]
)
add_relation(optnode_persistence, [optnode_matadd, optnode_matadd_persistence])
77 changes: 77 additions & 0 deletions kernel_perf_agent/kernel_opt/database/code_samples/matadd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ============================ unoptimized matadd =================================
import torch
import triton
import triton.language as tl

BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
M,
N,
stride_m,
stride_n,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(axis=0)
_num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

# Range of pointers for loading the block of A and B.
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

x = tl.load(x_ptrs, mask=data_mask, other=0.0)
y = tl.load(y_ptrs, mask=data_mask, other=0.0)
output = x + y
tl.store(output_ptrs, output, mask=data_mask)


def add(x: torch.Tensor, y: torch.Tensor):
M, N = x.shape
output = torch.empty((M, N), device=x.device, dtype=torch.float16)

def grid(meta):
return (
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

add_kernel[grid](
x,
y,
output,
M,
N,
x.stride(0),
x.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return output
93 changes: 93 additions & 0 deletions kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ===================== matadd with persistent programming style ==================
import torch
import triton
import triton.language as tl

BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
M,
N,
stride_m,
stride_n,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
NUM_SMS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

num_tiles = num_pid_m * num_pid_n

# iterate over the program id with a stride of the total number of blocks
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n

# Range of pointers for loading the block of A and B.
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n)
output_ptrs = output_ptr + (
offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
)
data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

x = tl.load(x_ptrs, mask=data_mask, other=0.0)
y = tl.load(y_ptrs, mask=data_mask, other=0.0)
output = x + y
tl.store(output_ptrs, output, mask=data_mask)


def add(x: torch.Tensor, y: torch.Tensor):
M, N = x.shape
output = torch.empty((M, N), device=x.device, dtype=torch.float16)

# Get the number of streaming multiprocessors and use it to launch a fixed number of blocks
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

def grid(meta):
return (
min(
NUM_SMS,
triton.cdiv(M, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
),
)

add_kernel[grid](
x,
y,
output,
M,
N,
x.stride(0),
x.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
NUM_SMS=NUM_SMS,
)
return output
Loading