diff --git a/kernel_perf_agent/__init__.py b/kernel_perf_agent/__init__.py index 1f49766..47bb96e 100644 --- a/kernel_perf_agent/__init__.py +++ b/kernel_perf_agent/__init__.py @@ -14,5 +14,4 @@ """Kernel Performance Agent package.""" -# "Kernel Performance Agent package __all__ = [] diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py new file mode 100644 index 0000000..d8db477 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diagnose Prompt Module for Hardware Bottleneck Analysis. + +""" + +__all__: list[str] = [] diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py new file mode 100644 index 0000000..48351a6 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import logging +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import ( + GPU_SPECS_DATABASE, +) + +__all__ = ["GPU_SPECS_DATABASE", "get_gpu_specs"] + +logger = logging.getLogger(__name__) + + +def get_gpu_specs(gpu_name: str) -> dict[str, Any] | None: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name. Must exactly match a key in GPU_SPECS_DATABASE. + + Returns: + Dictionary with GPU specifications, or None if GPU is not in the database. + When successful, contains: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs("NVIDIA A100") + >>> if specs: + ... print(f"SM Count: {specs['sm_count']}") + """ + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + logger.warning( + "Unknown GPU: '%s'. Disable Optimization. Available GPUs: %s", + gpu_name, + ", ".join(GPU_SPECS_DATABASE.keys()), + ) + return None + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Show all available GPUs + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") + + # Example usage + print(f"\n{'=' * 60}") + example_gpu = "NVIDIA A100" + specs = get_gpu_specs(example_gpu) + if specs: + print(f"\nExample specs for {example_gpu}:") + print(f" - Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s") + print(f" - Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS") + print(f" - SM Count: {specs['sm_count']}") diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py new file mode 100644 index 0000000..cbc616d --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database - Updated with Specific SKUs + +This module contains the GPU hardware specifications database used for +performance analysis and bottleneck identification. Updated to include +specific SKU variants for multi-SKU GPUs like A100 and H100. + +Sources: +- NVIDIA official specifications and datasheets +- TechPowerUp GPU Database +- Manufacturer datasheets + +Last Updated: January 2026 +""" + +GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { + # NVIDIA A100 SKUs - SXM4 Variants + "NVIDIA A100 SXM4 40GB": { + "name": "NVIDIA A100 SXM4 40GB", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, # Without sparsity + "peak_bf16_tflops": 312.0, # Without sparsity + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + "form_factor": "SXM4", + "tdp_w": 400, + }, + "NVIDIA A100 SXM4 80GB": { + "name": "NVIDIA A100 SXM4 80GB", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, # Without sparsity + "peak_bf16_tflops": 312.0, # Without sparsity + "peak_memory_bw_gbps": 2039, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 80, + "memory_type": "HBM2e", + "form_factor": "SXM4", + "tdp_w": 400, + }, + # NVIDIA A100 SKUs - PCIe Variants + "NVIDIA A100 PCIe 40GB": { + "name": "NVIDIA A100 PCIe 40GB", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, # Without sparsity + "peak_bf16_tflops": 312.0, # Without sparsity + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + "form_factor": "PCIe", + "tdp_w": 250, + }, + "NVIDIA A100 PCIe 80GB": { + "name": "NVIDIA A100 PCIe 80GB", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, # Without sparsity + "peak_bf16_tflops": 312.0, # Without sparsity + "peak_memory_bw_gbps": 1935, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 80, + "memory_type": "HBM2e", + "form_factor": "PCIe", + "tdp_w": 300, + }, + # NVIDIA H100 SKUs - SXM5 Variant + "NVIDIA H100 SXM5 80GB": { + "name": "NVIDIA H100 SXM5 80GB", + "architecture": "Hopper", + "peak_fp32_tflops": 67.0, + "peak_fp16_tflops": 1979.0, # Without sparsity + "peak_bf16_tflops": 1979.0, # Without sparsity + "peak_memory_bw_gbps": 3350, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + "form_factor": "SXM5", + "tdp_w": 700, + }, + # NVIDIA H100 SKUs - PCIe Variant + "NVIDIA H100 PCIe 80GB": { + "name": "NVIDIA H100 PCIe 80GB", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 1513.0, # Without sparsity + "peak_bf16_tflops": 1513.0, # Without sparsity + "peak_memory_bw_gbps": 2000, + "sm_count": 114, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM2e", + "form_factor": "PCIe", + "tdp_w": 350, + }, + # NVIDIA H100 SKUs - NVL Variant (for LLM inference) + "NVIDIA H100 NVL 94GB": { + "name": "NVIDIA H100 NVL 94GB", + "architecture": "Hopper", + "peak_fp32_tflops": 60.0, + "peak_fp16_tflops": 1671.0, # Without sparsity + "peak_bf16_tflops": 1671.0, # Without sparsity + "peak_memory_bw_gbps": 3900, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 94, + "memory_type": "HBM3", + "form_factor": "PCIe", + "tdp_w": 400, + }, + # NVIDIA RTX 4090 + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.58, + "peak_fp16_tflops": 82.58, + "peak_bf16_tflops": 82.58, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + "form_factor": "PCIe", + "tdp_w": 450, + }, + # NVIDIA RTX 5080 + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 56.28, + "peak_fp16_tflops": 56.28, + "peak_bf16_tflops": 56.28, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + "form_factor": "PCIe", + "tdp_w": 360, + }, +} diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompt.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompt.py new file mode 100644 index 0000000..67fbb13 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompt.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Bottleneck Analysis Prompt Builder + +Provides prompt templates and parsing utilities for LLM-based bottleneck analysis +of NCU profiling metrics. + +Bottleneck Categories: +- memory: Memory bandwidth is the limiting factor +- compute: Compute throughput is the limiting factor +- underutilized: Neither saturated (<60% both), indicating stalls/occupancy issues + +Metric definitions are in metric_schema.py. +""" + +import json +import re +from dataclasses import dataclass, field +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.metric_schema import ( + GPU_MEMORY_FIELDS, + GPU_SPEC_FIELDS, + NCU_METRIC_SECTIONS, +) +from kernel_perf_agent.kernel_opt.roofline.ncu_roofline import RooflineResult + +BOTTLENECK_CATEGORIES = {"memory", "compute", "underutilized"} + + +# ============================================================================= +# Data Classes +# ============================================================================= + + +@dataclass +class BottleneckResult: + """A single bottleneck analysis.""" + + category: str + summary: str + reasoning: str + root_causes: list[dict[str, Any]] = field(default_factory=list) + recommended_fixes: list[dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "category": self.category, + "summary": self.summary, + "reasoning": self.reasoning, + "root_causes": self.root_causes, + "recommended_fixes": self.recommended_fixes, + } + + +# ============================================================================= +# Prompt Template +# ============================================================================= + + +BOTTLENECK_PROMPT = """\ +You are a GPU performance expert analyzing Triton kernel profiling data. + +## Task +Analyze the NCU metrics and identify {num_bottlenecks} performance bottleneck(s). For each, classify as: +- **memory**: Memory bandwidth is the limiting factor +- **compute**: Compute throughput is the limiting factor +- **underutilized**: Neither saturated (<60% both), indicating stalls/occupancy issues + +## GPU Specifications +{gpu_specs} + +## Roofline Analysis +- Bottleneck: {roofline_bottleneck} +- Compute SOL: {compute_sol:.1f}% +- Memory SOL: {memory_sol:.1f}% +- Efficiency: {efficiency:.1f}% +- Headroom: {headroom:.1f}% +- At Roofline: {at_roofline} +- Tensor Cores: {uses_tc} +- Warnings: {roofline_warnings} + +## NCU Metrics +{ncu_metrics} + +## Kernel Code +```python +{kernel_code} +``` + +## Output (JSON array, no markdown fence) +[ + {{ + "category": "memory" | "compute" | "underutilized", + "summary": "One-line summary", + "reasoning": "Explanation citing metrics", + "root_causes": [ + {{ + "cause": "Description", + "evidence": [{{"metric": "name", "value": 0.0, "interpretation": "meaning"}}] + }} + ], + "recommended_fixes": [ + {{"fix": "Actionable instruction", "rationale": "Why"}} + ] + }} +] + +Requirements: +- Provide exactly {num_bottlenecks} bottleneck analysis object(s) in the array. +- Order by importance (most critical first). +- Each bottleneck should have exactly {num_causes} root cause(s) and {num_fixes} corresponding fix(es). +- Keep summaries and reasoning concise and grounded in the provided metrics. +""" + + +# ============================================================================= +# Prompt Building +# ============================================================================= + + +def _fmt_value(v: Any) -> str: + """Format a value for display in prompts.""" + if isinstance(v, float): + return f"{v:.3g}" + if isinstance(v, int): + return str(v) + return str(v) + + +def _format_gpu_specs(gpu_specs: dict[str, Any]) -> str: + """Format GPU specifications using metric_schema definitions.""" + lines = [] + + for label, key, unit in GPU_SPEC_FIELDS: + value = gpu_specs.get(key) + if value is not None: + lines.append(f"- {label}: {_fmt_value(value)}{unit}") + + for label, size_key, type_key, unit in GPU_MEMORY_FIELDS: + size = gpu_specs.get(size_key) + mem_type = gpu_specs.get(type_key, "") + if size is not None: + type_str = f" {mem_type}" if mem_type else "" + lines.append(f"- {label}: {_fmt_value(size)}{unit}{type_str}") + + return "\n".join(lines) if lines else "N/A" + + +def _format_ncu_metrics(ncu_metrics: dict[str, Any]) -> str: + """Format NCU metrics grouped by section using metric_schema definitions.""" + lines = [] + + for section_name, metric_defs in NCU_METRIC_SECTIONS.items(): + section_lines = [] + for label, key, unit in metric_defs: + value = ncu_metrics.get(key) + if value is not None: + section_lines.append(f" - {label}: {_fmt_value(value)}{unit}") + + if section_lines: + lines.append(f"### {section_name}") + lines.extend(section_lines) + + schema_keys = {key for _, key, _ in sum(NCU_METRIC_SECTIONS.values(), [])} + other_keys = sorted(set(ncu_metrics.keys()) - schema_keys) + if other_keys: + lines.append("### Other Metrics") + for key in other_keys: + value = ncu_metrics[key] + lines.append(f" - {key}: {_fmt_value(value)}") + + return "\n".join(lines) if lines else "N/A" + + +def build_bottleneck_prompt( + kernel_code: str, + ncu_metrics: dict[str, Any], + roofline: RooflineResult, + gpu_specs: dict[str, Any], + num_bottlenecks: int = 2, + num_causes: int = 2, + num_fixes: int = 1, +) -> str: + """Build the bottleneck analysis prompt for the LLM. + + Args: + kernel_code: The Triton kernel source code. + ncu_metrics: NCU profiling metrics dictionary. + roofline: Roofline analysis result. + gpu_specs: GPU hardware specifications. + num_bottlenecks: Number of bottlenecks to request. + num_causes: Number of root causes per bottleneck. + num_fixes: Number of recommended fixes per bottleneck. + + Returns: + Formatted prompt string for the LLM. + """ + return BOTTLENECK_PROMPT.format( + num_bottlenecks=num_bottlenecks, + num_causes=num_causes, + num_fixes=num_fixes, + gpu_specs=_format_gpu_specs(gpu_specs), + roofline_bottleneck=roofline.bottleneck, + compute_sol=roofline.compute_sol_pct, + memory_sol=roofline.memory_sol_pct, + efficiency=roofline.efficiency_pct, + headroom=roofline.headroom_pct, + at_roofline="Yes" if roofline.at_roofline else "No", + uses_tc="Yes" if roofline.uses_tensor_cores else "No", + roofline_warnings="; ".join(roofline.warnings) or "None", + ncu_metrics=_format_ncu_metrics(ncu_metrics), + kernel_code=kernel_code, + ) + + +# ============================================================================= +# Response Parsing +# ============================================================================= + + +def parse_bottleneck_response( + response: str, + fallback_category: str = "underutilized", +) -> list[BottleneckResult]: + """Parse LLM response into a list of BottleneckResult. + + Args: + response: Raw LLM response text. + fallback_category: Category to use if parsing fails. + + Returns: + List of BottleneckResult. Empty list if parsing fails completely. + """ + # Try to find JSON array + array_match = re.search(r"\[[\s\S]*\]", response) + if array_match: + try: + data = json.loads(array_match.group()) + if isinstance(data, list): + return _parse_bottleneck_list(data, fallback_category) + except json.JSONDecodeError: + pass + + # Fall back to single object + obj_match = re.search(r"\{[\s\S]*\}", response) + if obj_match: + try: + data = json.loads(obj_match.group()) + if isinstance(data, dict): + return _parse_bottleneck_list([data], fallback_category) + except json.JSONDecodeError: + pass + + return [] + + +def _parse_bottleneck_list( + items: list[dict[str, Any]], + fallback_category: str, +) -> list[BottleneckResult]: + """Parse a list of bottleneck dicts into BottleneckResult objects.""" + results = [] + for item in items: + category = item.get("category", fallback_category) + if category not in BOTTLENECK_CATEGORIES: + category = fallback_category + + root_causes = [ + {"cause": rc.get("cause", "Unknown"), "evidence": rc.get("evidence", [])} + for rc in item.get("root_causes", []) + ] + + fixes = [ + {"fix": f.get("fix", ""), "rationale": f.get("rationale", "")} + for f in item.get("recommended_fixes", []) + ] + + results.append( + BottleneckResult( + category=category, + summary=item.get("summary", f"{category}-bound"), + reasoning=item.get("reasoning", ""), + root_causes=root_causes, + recommended_fixes=fixes, + ) + ) + + return results diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py new file mode 100644 index 0000000..64d1d67 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Metric Schema Definitions for NCU Profiling and GPU Specifications. + +This module provides the single source of truth for: +- NCU profiling metric definitions (keys, labels, units) +- GPU specification field definitions + +Schema Format: List of tuples (display_label, key, unit_suffix) +- display_label: Human-readable name shown in prompts +- key: NCU metric key or GPU spec dictionary key +- unit_suffix: Unit to append after value (e.g., "%", " GB/s", " bytes") +""" + +from typing import Dict, List, Tuple + +# Type alias for metric definition: (label, key, unit) +MetricDef = Tuple[str, str, str] + +# ============================================================================= +# GPU Specification Fields +# ============================================================================= + +GPU_SPEC_FIELDS: List[MetricDef] = [ + ("Name", "name", ""), + ("Architecture", "architecture", ""), + ("Peak Memory Bandwidth", "peak_memory_bw_gbps", " GB/s"), + ("Peak FP32 Performance", "peak_fp32_tflops", " TFLOPS"), + ("Peak FP16 Performance", "peak_fp16_tflops", " TFLOPS"), + ("SM Count", "sm_count", ""), + ("Max Threads per SM", "max_threads_per_sm", ""), + ("L1 Cache per SM", "l1_cache_kb", " KB"), + ("L2 Cache (Total)", "l2_cache_mb", " MB"), +] + +# Special case: Memory Size has two fields combined +GPU_MEMORY_FIELDS: List[Tuple[str, str, str, str]] = [ + # (label, size_key, type_key, size_unit) + ("Memory Size", "memory_gb", "memory_type", " GB"), +] + +# ============================================================================= +# NCU Profiling Metric Sections +# ============================================================================= + +NCU_METRIC_SECTIONS: Dict[str, List[MetricDef]] = { + "SM & Compute Utilization": [ + ("SM Cycles Active", "sm__cycles_active.avg", ""), + ("Warp Active", "sm__warps_active.avg.pct_of_peak_sustained_active", "%"), + ("Total Instructions Executed", "sm__inst_executed.sum", ""), + ( + "Tensor Core Utilization", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "%", + ), + ( + "Tensor Core Pipeline Active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ], + "Memory Bandwidth & Cache": [ + ( + "DRAM Throughput", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bandwidth", "dram__bytes.sum.per_second", " bytes/sec"), + ( + "GPU DRAM Throughput", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bytes Read", "dram__bytes_read.sum", " bytes"), + ("DRAM Bytes Write", "dram__bytes_write.sum", " bytes"), + ("L1 Cache Hit Rate", "l1tex__t_sector_hit_rate.pct", "%"), + ( + "L1 Throughput", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ("L2 Cache Hit Rate", "lts__t_sector_hit_rate.pct", "%"), + ( + "L2 Throughput", + "lts__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ], + "Memory Access Patterns": [ + ( + "Memory Coalescing", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "%", + ), + ( + "Branch Uniformity", + "smsp__sass_average_branch_targets_threads_uniform.pct", + "%", + ), + ], + "Occupancy & Resources": [ + ("Occupancy Limited By Blocks", "launch__occupancy_limit_blocks", ""), + ("Occupancy Limited By Registers", "launch__occupancy_limit_registers", ""), + ( + "Occupancy Limited By Shared Memory", + "launch__occupancy_limit_shared_mem", + "", + ), + ("Registers per Thread", "launch__registers_per_thread", ""), + ( + "Shared Memory per Block", + "launch__shared_mem_per_block_allocated", + " bytes", + ), + ], + "Stall Metrics (Warp Issue Stalls)": [ + ( + "Short Scoreboard Stalls", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Long Scoreboard Stalls", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Barrier Stalls", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "%", + ), + ( + "Branch Resolving Stalls", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "%", + ), + ], +} diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py index d177194..d0a028c 100644 --- a/kernel_perf_agent/kernel_opt/profiler/__init__.py +++ b/kernel_perf_agent/kernel_opt/profiler/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""kernel_perf_agent package.""" +"""NCU profiling module for kernel performance analysis.""" -# Kernel Perf Agent package __all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 4ce8568..4b1bf83 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -255,12 +255,8 @@ def _apply_selection_policy( Returns: DataFrame with a single row based on the policy """ - if df.empty: + if len(df) <= 1: return df - - if len(df) == 1: - return df - if policy == MetricSelectionPolicy.FIRST: return df.iloc[[0]] elif policy == MetricSelectionPolicy.LAST: @@ -317,7 +313,7 @@ def load_ncu_metrics( extra_keep: Optional[Sequence[str]] = ("Kernel Name",), coerce_numeric: bool = True, name_list: Optional[Sequence[str]] = None, - select: Union[str, MetricSelectionPolicy] = MetricSelectionPolicy.LAST, + select: MetricSelectionPolicy = MetricSelectionPolicy.LAST, ) -> pd.DataFrame: """ Load and parse NCU metrics from CSV file. @@ -328,32 +324,19 @@ def load_ncu_metrics( extra_keep: Additional columns to keep (e.g., "Kernel Name") coerce_numeric: Convert metric values to numeric name_list: Filter by kernel names (substring match) - select: Selection policy when multiple rows per name. - Can be MetricSelectionPolicy enum or string ("first", "last", "max_cycles") + select: Selection policy when multiple rows per name Returns: DataFrame with parsed metrics Raises: FileNotFoundError: If CSV file not found - ValueError: If no requested columns found in CSV or invalid select value + ValueError: If no requested columns found in CSV """ csv_path = Path(csv_path) if not csv_path.exists(): raise FileNotFoundError(f"CSV not found: {csv_path}") - # Convert string to enum if needed - if isinstance(select, str): - try: - policy = MetricSelectionPolicy(select) - except ValueError: - raise ValueError( - f"Invalid select value: {select}. " - f"Must be one of: {[p.value for p in MetricSelectionPolicy]}" - ) - else: - policy = select - df = pd.read_csv(csv_path, comment="=", low_memory=False) metric_cols = list(columns) if columns is not None else METRIC_COLUMNS @@ -383,14 +366,11 @@ def load_ncu_metrics( .apply(pd.to_numeric, errors="coerce") ) - # Filter by kernel name list if provided - if name_list: - sub = _filter_by_kernel_names(sub, name_list, policy, keep_cols) - else: - # Apply selection to all rows if no name filter - sub = _apply_selection_policy(sub, policy) - - return sub + return ( + _filter_by_kernel_names(sub, name_list, select, keep_cols) + if name_list + else _apply_selection_policy(sub, select) + ) def metrics_to_prompt( diff --git a/kernel_perf_agent/kernel_opt/roofline/__init__.py b/kernel_perf_agent/kernel_opt/roofline/__init__.py new file mode 100644 index 0000000..f3d8afe --- /dev/null +++ b/kernel_perf_agent/kernel_opt/roofline/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Roofline analysis module for kernel performance optimization.""" + +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/roofline/ncu_roofline.py b/kernel_perf_agent/kernel_opt/roofline/ncu_roofline.py new file mode 100644 index 0000000..6e2131c --- /dev/null +++ b/kernel_perf_agent/kernel_opt/roofline/ncu_roofline.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Roofline Analysis Module using NCU SOL (Speed of Light) Metrics. + +This module uses NCU's built-in SOL metrics to determine kernel efficiency +relative to hardware limits + +NCU SOL metrics directly measure how close performance is to peak: +- Compute SOL: SM throughput as % of peak +- Memory SOL: DRAM throughput as % of peak + +Updated in January 2026 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + + +# NCU metrics needed for roofline analysis +# Note: The profiler (ncu_profiler.py) collects these and more metrics. +# This list documents the minimum required for roofline decisions. + +NCU_ROOFLINE_METRICS = [ + # Primary SOL metrics (Speed of Light) + "gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed", # Memory SOL + "sm__throughput.avg.pct_of_peak_sustained_elapsed", # Compute SOL + # Tensor core detection + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active", +] + + +@dataclass +class RooflineConfig: + """Configuration for roofline analysis.""" + + threshold_pct: float = 90.0 # SOL % to consider at roofline + early_stop: bool = True # Stop optimization when at roofline + convergence_rounds: int = 5 # Rounds without improvement to trigger stop + min_improvement_pct: float = 0.1 # Minimum improvement to continue + tensor_core_threshold: float = 5.0 # Min TC activity % to consider TC usage + underutilized_threshold: float = 60.0 # Both SOL < this % = underutilized + + +@dataclass +class RooflineResult: + """Result of roofline analysis using NCU SOL metrics.""" + + # SOL metrics from NCU (primary) + compute_sol_pct: float # SM throughput as % of peak + memory_sol_pct: float # DRAM throughput as % of peak + + # Derived efficiency (max of compute/memory SOL) + efficiency_pct: float # Primary efficiency metric for decisions + at_roofline: bool # True if efficiency >= threshold_pct + headroom_pct: float # 100 - efficiency + + # Classification + bottleneck: str # "memory" | "compute" | "underutilized" + uses_tensor_cores: bool # Whether TC is active + + # Data quality + warnings: list[str] = field(default_factory=list) + + +class RooflineAnalyzer: + """Analyzes kernel performance using NCU SOL metrics.""" + + def __init__( + self, + config: RooflineConfig | None = None, + logger: logging.Logger | None = None, + ): + """ + Initialize the roofline analyzer. + + Args: + config: Roofline configuration (defaults to RooflineConfig()) + logger: Logger instance + """ + self.config = config or RooflineConfig() + self.logger = logger or logging.getLogger(__name__) + self._efficiency_history: list[float] = [] + + def _is_using_tensor_cores(self, ncu_metrics: dict[str, Any]) -> bool: + """Detect tensor core usage from NCU metrics.""" + tc_cycles = ncu_metrics.get( + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active", 0 + ) + return tc_cycles > self.config.tensor_core_threshold + + def _classify_bottleneck(self, compute_sol: float, memory_sol: float) -> str: + """ + Classify bottleneck based on SOL metrics. + + The LOWER SOL value indicates the bottleneck. + If both are lower than threshold, the kernel is underutilized (could be occupancy, + instruction mix, launch config, dependency stalls, etc.). + """ + threshold = self.config.underutilized_threshold + + # Both low = underutilized (neither resource is saturated) + if memory_sol < threshold and compute_sol < threshold: + return "underutilized" + + # Return whichever is lower + if memory_sol <= compute_sol: + return "memory" + else: + return "compute" + + def analyze( + self, + ncu_metrics: dict[str, Any], + ) -> RooflineResult: + """ + Analyze kernel performance using NCU SOL metrics. + + Args: + ncu_metrics: NCU profiling metrics dictionary + + Returns: + RooflineResult with SOL-based efficiency analysis + """ + warnings: list[str] = [] + + # Extract SOL metrics with missing-key detection + compute_key = "sm__throughput.avg.pct_of_peak_sustained_elapsed" + memory_key = "gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed" + + compute_missing = compute_key not in ncu_metrics + memory_missing = memory_key not in ncu_metrics + + if compute_missing: + self.logger.warning("Compute SOL metric missing from NCU data") + warnings.append("Compute SOL metric missing") + if memory_missing: + self.logger.warning("Memory SOL metric missing from NCU data") + warnings.append("Memory SOL metric missing") + + # Fail only if both keys are absent + if compute_missing and memory_missing: + return RooflineResult( + compute_sol_pct=0, + memory_sol_pct=0, + efficiency_pct=0, + at_roofline=False, + headroom_pct=100, + bottleneck="unknown", + uses_tensor_cores=False, + warnings=["Analysis failed - no SOL metrics in NCU data"], + ) + + compute_sol = ncu_metrics.get(compute_key, 0) + memory_sol = ncu_metrics.get(memory_key, 0) + + # Primary efficiency: use max of compute/memory + efficiency = max(compute_sol, memory_sol) + + # Tensor core detection + uses_tc = self._is_using_tensor_cores(ncu_metrics) + + # Classify bottleneck + bottleneck = self._classify_bottleneck(compute_sol, memory_sol) + + # Check if at roofline + at_roofline = efficiency >= self.config.threshold_pct + + return RooflineResult( + compute_sol_pct=compute_sol, + memory_sol_pct=memory_sol, + efficiency_pct=efficiency, + at_roofline=at_roofline, + headroom_pct=max(0, 100 - efficiency), + bottleneck=bottleneck, + uses_tensor_cores=uses_tc, + warnings=warnings, + ) + + def should_stop(self, result: RooflineResult) -> tuple[bool, str]: + """ + Check if optimization should stop based on SOL efficiency and convergence. + + Args: + result: RooflineResult from analyze() + + Returns: + Tuple of (should_stop, reason) + """ + self._efficiency_history.append(result.efficiency_pct) + + # Condition 1: At roofline threshold (if early_stop enabled) + if self.config.early_stop and result.at_roofline: + return ( + True, + f"At roofline ({result.efficiency_pct:.1f}% SOL >= " + f"{self.config.threshold_pct}%)", + ) + + # Condition 2: Efficiency converged (no improvement for N rounds) + if len(self._efficiency_history) >= self.config.convergence_rounds: + recent = self._efficiency_history[-self.config.convergence_rounds :] + improvement = max(recent) - min(recent) + if improvement < self.config.min_improvement_pct: + return ( + True, + f"Converged (improvement {improvement:.2f}% < " + f"{self.config.min_improvement_pct}%)", + ) + + return False, "" + + def reset_history(self) -> None: + """Reset efficiency history for a new optimization run.""" + self._efficiency_history = [] + + +def format_roofline_summary(result: RooflineResult) -> str: + """Format a human-readable summary of roofline analysis.""" + lines = [ + "=== Roofline Analysis ===", + f"SOL Efficiency: {result.efficiency_pct:.1f}%", + f" Compute SOL: {result.compute_sol_pct:.1f}%", + f" Memory SOL: {result.memory_sol_pct:.1f}%", + f" Bottleneck: {result.bottleneck}", + f" Tensor Cores: {'Yes' if result.uses_tensor_cores else 'No'}", + "", + ] + + if result.at_roofline: + lines.append("Status: AT ROOFLINE") + else: + lines.append(f"Headroom: {result.headroom_pct:.1f}%") + + if result.warnings: + lines.append(f"Warnings: {'; '.join(result.warnings)}") + + return "\n".join(lines) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py index 733216c..fe1e84f 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -288,12 +288,11 @@ def main(): args = _parse_args() device = torch.device(args.device) - dtype_map = { + dtype = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, - } - dtype = dtype_map[args.dtype] + }[args.dtype] if not args.quiet: print("=" * 80) diff --git a/triton_kernel_agent/opt_worker_component/orchestrator/__init__.py b/triton_kernel_agent/opt_worker_component/orchestrator/__init__.py new file mode 100644 index 0000000..86bd3ce --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/orchestrator/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimization orchestration components.""" + +from .optimization_orchestrator import OptimizationOrchestrator + +__all__ = ["OptimizationOrchestrator"] diff --git a/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py new file mode 100644 index 0000000..1d5d9f8 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/orchestrator/optimization_orchestrator.py @@ -0,0 +1,591 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Main optimization orchestration logic.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.judger_prompt import BottleneckResult +from kernel_perf_agent.kernel_opt.roofline.ncu_roofline import RooflineAnalyzer +from triton_kernel_agent.prompt_manager import PromptManager +from triton_kernel_agent.worker import VerificationWorker +from triton_kernel_agent.worker_util import ( + _call_llm, + _extract_code_from_response, + _write_kernel_file, +) +from utils.providers.base import BaseProvider + + +def _get_triton_kernel_metrics(ncu_metrics: dict[str, Any]) -> dict[str, Any]: + """ + Extract metrics for the Triton kernel, filtering out PyTorch kernels. + + NCU profiles all CUDA kernels including PyTorch internals (at::*). + This function finds the actual Triton kernel metrics. + + Args: + ncu_metrics: Dict keyed by kernel name with metric dicts as values + + Returns: + Flat metrics dict for the Triton kernel, or first non-PyTorch kernel + """ + if not ncu_metrics: + return {} + + # Filter out PyTorch kernels (they start with "at::" or "void at::") + triton_kernels = { + name: metrics + for name, metrics in ncu_metrics.items() + if not name.startswith("at::") and not name.startswith("void at::") + } + + if triton_kernels: + # Return the first Triton kernel's metrics + return next(iter(triton_kernels.values())) + + # Fallback: return first kernel if no Triton kernel found + return next(iter(ncu_metrics.values()), {}) + + +class OptimizationOrchestrator: + """Orchestrates the main optimization loop.""" + + def __init__( + self, + profiler: Any, + benchmarker: Any, + bottleneck_analyzer: Any, + verification_worker: VerificationWorker, + prompt_manager: PromptManager, + provider: BaseProvider, + model: str, + high_reasoning_effort: bool, + kernel_file: Path, + gpu_specs: dict[str, Any] | None, + pytorch_baseline_time: float | None, + artifact_dir: Path, + output_dir: Path, + logger: logging.Logger, + roofline_analyzer: RooflineAnalyzer, + divergence_threshold: float = 50.0, + ): + """ + Initialize optimization orchestrator. + + Args: + profiler: KernelProfiler instance + benchmarker: Benchmark instance (handles both kernel and PyTorch benchmarking) + bottleneck_analyzer: BottleneckAnalyzer instance + verification_worker: VerificationWorker for verify + refine + prompt_manager: PromptManager for building optimization prompts + provider: LLM provider instance + model: Model name for LLM calls + high_reasoning_effort: Whether to use high reasoning effort + kernel_file: Path to kernel file for writing + gpu_specs: GPU specifications for optimization prompt + pytorch_baseline_time: Pre-computed PyTorch baseline + artifact_dir: Directory for optimization artifacts + output_dir: Directory for final output (best_kernel.py) + logger: Logger instance + roofline_analyzer: RooflineAnalyzer for optimization guidance and early termination + divergence_threshold: Max % worse performance before reverting to best kernel + """ + # Components + self.profiler = profiler + self.benchmarker = benchmarker + self.bottleneck_analyzer = bottleneck_analyzer + self.verification_worker = verification_worker + self.prompt_manager = prompt_manager + + # LLM configuration + self.provider = provider + self.model = model + self.high_reasoning_effort = high_reasoning_effort + + # File configuration + self.kernel_file = kernel_file + + # Configuration + self.gpu_specs = gpu_specs + self.pytorch_baseline_time = pytorch_baseline_time + self.divergence_threshold = divergence_threshold + + # Paths + self.artifact_dir = artifact_dir + self.output_dir = output_dir + + # Logger + self.logger = logger + + # Optional roofline analyzer + self.roofline_analyzer = roofline_analyzer + + def optimize_kernel( + self, + kernel_code: str, + problem_file: Path, + test_code: str, + known_kernel_time: float | None = None, + max_opt_rounds: int = 5, + ) -> tuple[bool, str, dict[str, Any]]: + """ + Main optimization loop. + + Args: + kernel_code: Initial kernel code + problem_file: Path to problem file + test_code: Test code for verification + known_kernel_time: Known performance of kernel_code in ms + max_opt_rounds: Maximum optimization rounds + + Returns: + Tuple of (success, best_kernel_code, performance_metrics) + """ + self.logger.info("=" * 80) + self.logger.info("Starting hardware-guided optimization") + self.logger.info("=" * 80) + + # Initialize state + current_kernel = kernel_code + best_kernel = kernel_code + best_time = float("inf") + error_feedback = "" + best_ncu_metrics: dict[str, Any] | None = None + best_bottleneck_category: str | None = None + best_round_num: int = 0 + early_stop_reason = "" + + # Reset roofline history for new optimization run + self.roofline_analyzer.reset_history() + + # Extract problem description + problem_description = problem_file.read_text() + self.logger.info(f"Problem: {problem_description[:100]}...") + + # Benchmark baseline and PyTorch + best_time, baseline_results, pytorch_baseline_time = self._benchmark_baseline( + kernel_code, problem_file, known_kernel_time + ) + + # Optimization rounds + for round_num in range(1, max_opt_rounds + 1): + self.logger.info("") + self.logger.info("=" * 80) + self.logger.info(f"ROUND {round_num}/{max_opt_rounds}") + self.logger.info("=" * 80) + + # Profile and analyze bottleneck + bottleneck_results, roofline_result, ncu_metrics = ( + self._profile_and_analyze(current_kernel, problem_file, round_num) + ) + + # Log roofline for the kernel we just profiled + roofline_check = None + if ncu_metrics: + flat_metrics = _get_triton_kernel_metrics(ncu_metrics) + roofline_check = self.roofline_analyzer.analyze( + ncu_metrics=flat_metrics, + ) + self.logger.info( + f"[{round_num}] Roofline (kernel_round_{round_num - 1}): " + f"{roofline_check.bottleneck}-bound, {roofline_check.efficiency_pct:.1f}% SOL " + f"(Compute: {roofline_check.compute_sol_pct:.1f}%, " + f"Memory: {roofline_check.memory_sol_pct:.1f}%)" + ) + + if not bottleneck_results: + self.logger.warning( + f"[{round_num}] No analysis available, skipping round" + ) + continue + + # Build optimization prompt using PromptManager with correct API + primary = bottleneck_results[0] + opt_prompt = self.prompt_manager.render_kernel_optimization_prompt( + problem_description=problem_description, + kernel_code=current_kernel, + gpu_specs=self.gpu_specs, + roofline=roofline_result.to_dict() if roofline_result else {}, + category=primary.category, + summary=primary.summary, + reasoning=primary.reasoning, + root_cause=primary.root_causes[0] if primary.root_causes else {}, + recommended_fix=primary.recommended_fixes[0] + if primary.recommended_fixes + else {}, + pytorch_baseline_ms=pytorch_baseline_time, + current_best_ms=best_time, + error_feedback=error_feedback if error_feedback else None, + ) + + # Save prompt + prompt_file = self.artifact_dir / f"round{round_num:03d}_opt_prompt.txt" + with open(prompt_file, "w") as f: + f.write(opt_prompt) + + # Generate optimized kernel + optimized_kernel = self._generate_optimized_kernel(opt_prompt, round_num) + if not optimized_kernel: + error_feedback = "Failed to extract valid kernel code. Please provide complete kernel wrapped in ```python blocks." + continue + + # Verify and refine + success, optimized_kernel, verify_error = self._verify_and_refine( + optimized_kernel, test_code, problem_description, round_num + ) + if not success: + error_feedback = ( + verify_error or "Previous attempt failed correctness check." + ) + continue + + error_feedback = "" + + # Save and benchmark + kernel_file_round = self.artifact_dir / f"kernel_round_{round_num}.py" + kernel_file_round.write_text(optimized_kernel) + + bench_results = self.benchmarker.benchmark_kernel( + kernel_file_round, problem_file + ) + new_time = bench_results["time_ms"] + + # Update kernels based on performance + old_best_time = best_time + current_kernel, best_kernel, best_time = self._update_kernels( + optimized_kernel, + new_time, + current_kernel, + best_kernel, + best_time, + round_num, + ) + + # Track metadata when new best is found + if best_time < old_best_time: + best_round_num = round_num + best_bottleneck_category = primary.category + if ncu_metrics: + best_ncu_metrics = ncu_metrics + + # Early termination check (using roofline computed at start of round) + if ncu_metrics and roofline_check: + should_stop, stop_reason = self.roofline_analyzer.should_stop( + roofline_check + ) + if should_stop and self.roofline_analyzer.config.early_stop: + self.logger.info( + f"[{round_num}] 🎯 Early termination: {stop_reason}" + ) + early_stop_reason = stop_reason + break + + # Profile the final kernel to get its roofline + if best_round_num > 0: + final_kernel_file = self.artifact_dir / f"kernel_round_{best_round_num}.py" + if final_kernel_file.exists(): + self.logger.info(f"Profiling final best kernel (round {best_round_num})...") + final_profiler_results = self.profiler.profile_kernel( + final_kernel_file, problem_file, best_round_num + ) + if final_profiler_results and final_profiler_results.metrics: + best_ncu_metrics = final_profiler_results.metrics + final_flat_metrics = _get_triton_kernel_metrics(best_ncu_metrics) + final_roofline = self.roofline_analyzer.analyze( + ncu_metrics=final_flat_metrics, + ) + self.logger.info( + f"Final roofline (kernel_round_{best_round_num}): " + f"{final_roofline.bottleneck}-bound, {final_roofline.efficiency_pct:.1f}% SOL " + f"(Compute: {final_roofline.compute_sol_pct:.1f}%, " + f"Memory: {final_roofline.memory_sol_pct:.1f}%)" + ) + + # Final results + return self._finalize_results( + best_kernel, + best_time, + baseline_results, + pytorch_baseline_time, + max_opt_rounds, + best_ncu_metrics, + best_bottleneck_category, + best_round_num, + early_stop_reason, + ) + + def _benchmark_baseline( + self, kernel_code: str, problem_file: Path, known_kernel_time: float | None + ) -> tuple[float, dict[str, float], float | None]: + """Benchmark baseline kernel and PyTorch.""" + if known_kernel_time and known_kernel_time != float("inf"): + best_time = known_kernel_time + baseline_results = {"time_ms": known_kernel_time, "speedup": 1.0} + self.logger.info(f"📊 Using known kernel time: {best_time:.4f} ms") + else: + _write_kernel_file(self.kernel_file, kernel_code, self.logger) + kernel_file_round = self.artifact_dir / "kernel_round_0.py" + kernel_file_round.write_text(kernel_code) + + baseline_results = self.benchmarker.benchmark_kernel( + kernel_file_round, problem_file + ) + best_time = baseline_results["time_ms"] + self.logger.info(f"📊 Baseline time: {best_time:.4f} ms") + + # PyTorch baseline + if self.pytorch_baseline_time is not None: + pytorch_baseline_time = self.pytorch_baseline_time + if pytorch_baseline_time != float("inf"): + self.logger.info( + f"📊 PyTorch baseline: {pytorch_baseline_time:.4f} ms (pre-computed)" + ) + else: + pytorch_baseline_time = None + else: + pytorch_results = self.benchmarker.benchmark_pytorch(problem_file) + pytorch_baseline_time = pytorch_results.get("time_ms", float("inf")) + if pytorch_baseline_time != float("inf"): + self.logger.info(f"📊 PyTorch baseline: {pytorch_baseline_time:.4f} ms") + else: + pytorch_baseline_time = None + + return best_time, baseline_results, pytorch_baseline_time + + def _profile_and_analyze( + self, + current_kernel: str, + problem_file: Path, + round_num: int, + ) -> tuple[list[BottleneckResult] | None, Any | None, dict[str, Any] | None]: + """Profile kernel and analyze bottlenecks. + + Returns: + Tuple of (bottleneck_results, roofline_result, ncu_metrics). + All can be None if profiling fails. + """ + self.logger.info(f"[{round_num}] Profiling current kernel with NCU...") + kernel_file_round = self.artifact_dir / f"kernel_round_{round_num - 1}.py" + kernel_file_round.write_text(current_kernel) + + profiler_results = self.profiler.profile_kernel( + kernel_file_round, problem_file, round_num + ) + + if profiler_results is None: + self.logger.warning(f"[{round_num}] Profiling failed") + return None, None, None + + ncu_metrics = profiler_results.metrics + + if not ncu_metrics: + return None, None, ncu_metrics + + # Run roofline analysis + flat_metrics = next(iter(ncu_metrics.values()), {}) if ncu_metrics else {} + roofline_result = self.bottleneck_analyzer.roofline.analyze(flat_metrics) + + # Run bottleneck analysis + self.logger.info(f"[{round_num}] Analyzing bottleneck...") + bottleneck_results = self.bottleneck_analyzer.analyze( + current_kernel, ncu_metrics, round_num, roofline_result + ) + + if bottleneck_results: + strategy_file = self.artifact_dir / f"round{round_num:03d}_strategy.json" + with open(strategy_file, "w") as f: + json.dump([r.to_dict() for r in bottleneck_results], f, indent=2) + return bottleneck_results, roofline_result, ncu_metrics + + return None, roofline_result, ncu_metrics + + def _generate_optimized_kernel(self, opt_prompt: str, round_num: int) -> str | None: + """Generate optimized kernel from LLM.""" + self.logger.info(f"[{round_num}] Generating optimized kernel...") + try: + messages = [{"role": "user", "content": opt_prompt}] + response_text = _call_llm( + provider=self.provider, + model=self.model, + messages=messages, + high_reasoning_effort=self.high_reasoning_effort, + logger=self.logger, + max_tokens=24576, + ) + + # Save response + response_file = self.artifact_dir / f"round{round_num:03d}_opt_reply.txt" + with open(response_file, "w") as f: + f.write(response_text) + + # Extract code + optimized_kernel = _extract_code_from_response( + response_text=response_text, + logger=self.logger, + ) + + if not optimized_kernel or len(optimized_kernel) < 100: + self.logger.warning( + f"[{round_num}] Failed to extract valid kernel code" + ) + return None + + return optimized_kernel + + except Exception as e: + self.logger.error(f"[{round_num}] LLM call failed: {e}") + return None + + def _verify_and_refine( + self, + optimized_kernel: str, + test_code: str, + problem_description: str, + round_num: int, + ) -> tuple[bool, str, str]: + """ + Verify kernel correctness with refinement attempts. + + Returns: + Tuple of (success, final_kernel, error_feedback) + """ + self.logger.info(f"[{round_num}] Verifying correctness...") + success, final_kernel, error_feedback = ( + self.verification_worker.verify_with_refinement( + kernel_code=optimized_kernel, + test_code=test_code, + problem_description=problem_description, + ) + ) + + if success: + self.logger.info(f"[{round_num}] ✅ Correctness check passed") + else: + self.logger.warning(f"[{round_num}] ❌ Correctness check failed") + + return success, final_kernel, error_feedback + + def _update_kernels( + self, + optimized_kernel: str, + new_time: float, + current_kernel: str, + best_kernel: str, + best_time: float, + round_num: int, + ) -> tuple[str, str, float]: + """Update current and best kernels based on performance.""" + if new_time < best_time: + # New best found + speedup = best_time / new_time + improvement = (best_time - new_time) / best_time * 100 + self.logger.info( + f"[{round_num}] 🎉 NEW BEST! {new_time:.4f} ms (speedup: {speedup:.2f}x, improvement: {improvement:.1f}%)" + ) + return optimized_kernel, optimized_kernel, new_time + else: + # Check for excessive divergence + divergence = (new_time - best_time) / best_time * 100 + + if divergence > self.divergence_threshold: + self.logger.warning( + f"[{round_num}] ⚠️ EXCESSIVE DIVERGENCE: {new_time:.4f} ms is {divergence:.1f}% worse" + ) + self.logger.warning(f"[{round_num}] 🔄 REVERTING to best kernel") + return best_kernel, best_kernel, best_time + else: + self.logger.info( + f"[{round_num}] No improvement: {new_time:.4f} ms vs best {best_time:.4f} ms" + ) + return optimized_kernel, best_kernel, best_time + + def _finalize_results( + self, + best_kernel: str, + best_time: float, + baseline_results: dict[str, float], + pytorch_baseline_time: float | None, + rounds: int, + ncu_metrics: dict[str, Any] | None = None, + bottleneck_category: str | None = None, + best_round: int = 0, + early_stop_reason: str = "", + ) -> tuple[bool, str, dict[str, Any]]: + """Finalize and log optimization results.""" + self.logger.info("") + self.logger.info("=" * 80) + self.logger.info("OPTIMIZATION COMPLETE") + if early_stop_reason: + self.logger.info(f" (Early termination: {early_stop_reason})") + self.logger.info("=" * 80) + + baseline_speedup = baseline_results["time_ms"] / best_time + improvement_percent = ( + (baseline_results["time_ms"] - best_time) + / baseline_results["time_ms"] + * 100 + ) + + self.logger.info("📊 Final Results:") + self.logger.info(f" Best time: {best_time:.4f} ms") + self.logger.info(f" Baseline time: {baseline_results['time_ms']:.4f} ms") + self.logger.info(f" Speedup vs baseline: {baseline_speedup:.2f}x") + + if pytorch_baseline_time and pytorch_baseline_time != float("inf"): + pytorch_speedup = pytorch_baseline_time / best_time + self.logger.info(f" PyTorch baseline: {pytorch_baseline_time:.4f} ms") + self.logger.info(f" Speedup vs PyTorch: {pytorch_speedup:.2f}x") + + self.logger.info(f" Improvement: {improvement_percent:.1f}%") + self.logger.info("") + + # Save best kernel + best_kernel_file = self.output_dir / "best_kernel.py" + best_kernel_file.write_text(best_kernel) + + perf_metrics = { + "baseline_time_ms": baseline_results["time_ms"], + "best_time_ms": best_time, + "speedup": baseline_speedup, + "rounds": rounds, + } + + if bottleneck_category: + perf_metrics["bottleneck_addressed"] = bottleneck_category + + # Add NCU metrics if available + if ncu_metrics: + kernel_metrics = next(iter(ncu_metrics.values()), {}) + perf_metrics["memory_throughput"] = kernel_metrics.get( + "dram__throughput.avg.pct_of_peak_sustained_elapsed" + ) + perf_metrics["compute_throughput"] = kernel_metrics.get( + "sm__throughput.avg.pct_of_peak_sustained_elapsed" + ) + + if bottleneck_category: + perf_metrics["bottleneck_category"] = bottleneck_category + + if early_stop_reason: + perf_metrics["early_stop_reason"] = early_stop_reason + + success = best_time != float("inf") + return success, best_kernel, perf_metrics diff --git a/triton_kernel_agent/opt_worker_component/prescribing/__init__.py b/triton_kernel_agent/opt_worker_component/prescribing/__init__.py new file mode 100644 index 0000000..c515b2a --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/prescribing/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prescribing module for kernel optimization.""" + +from .bottleneck_analyzer import BottleneckAnalyzer + +__all__ = ["BottleneckAnalyzer"] diff --git a/triton_kernel_agent/opt_worker_component/prescribing/bottleneck_analyzer.py b/triton_kernel_agent/opt_worker_component/prescribing/bottleneck_analyzer.py new file mode 100644 index 0000000..b774f0a --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/prescribing/bottleneck_analyzer.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Bottleneck Analyzer - LLM-based NCU profiling analysis. + +This module orchestrates LLM calls for bottleneck analysis using: +- judger_prompt.py: Prompt template, parsing, BottleneckResult dataclass +- ncu_roofline.py: Roofline analysis using NCU SOL metrics + +Bottleneck Categories: +- memory: Memory bandwidth is the limiting factor +- compute: Compute throughput is the limiting factor +- underutilized: Neither saturated (<60% both), indicating stalls/occupancy issues +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.judger_prompt import ( + BottleneckResult, + build_bottleneck_prompt, + parse_bottleneck_response, +) +from kernel_perf_agent.kernel_opt.roofline.ncu_roofline import RooflineAnalyzer +from triton_kernel_agent.worker_util import _call_llm, _save_debug_file +from utils.providers.base import BaseProvider + + +class BottleneckAnalyzer: + """LLM-based bottleneck analyzer using NCU metrics.""" + + def __init__( + self, + provider: BaseProvider, + model: str, + gpu_specs: dict[str, Any], + logs_dir: Path | None = None, + logger: logging.Logger | None = None, + num_bottlenecks: int = 1, + num_causes: int = 2, + num_fixes: int = 1, + enable_debug: bool = True, + ): + """ + Initialize bottleneck analyzer. + + Args: + provider: LLM provider instance + model: Model name for LLM calls + gpu_specs: GPU hardware specifications + logs_dir: Directory for saving debug files + logger: Logger instance + num_bottlenecks: Number of bottlenecks to request from LLM + num_causes: Number of root causes per bottleneck + num_fixes: Number of recommended fixes per bottleneck + enable_debug: Whether to save debug files (prompts/responses) + """ + self.provider = provider + self.model = model + self.gpu_specs = gpu_specs + self.logs_dir = logs_dir + self.logger = logger or logging.getLogger(__name__) + self.num_bottlenecks = num_bottlenecks + self.num_causes = num_causes + self.num_fixes = num_fixes + self.enable_debug = enable_debug + self.roofline = RooflineAnalyzer(logger=logger) + + def analyze( + self, + kernel_code: str, + ncu_metrics: dict[str, Any], + round_num: int = 0, + roofline_result: Any = None, + ) -> list[BottleneckResult]: + """ + Analyze kernel bottlenecks using LLM. + + Args: + kernel_code: The Triton kernel source code + ncu_metrics: NCU profiling metrics dictionary + round_num: Current optimization round (for logging) + roofline_result: Pre-computed RooflineResult (if None, computed internally) + + Returns: + List of BottleneckResult (ordered by importance). + Empty list if analysis fails. + """ + if roofline_result is None: + # Filter out PyTorch kernels (at::*) and get Triton kernel metrics + if ncu_metrics: + triton_kernels = { + name: metrics + for name, metrics in ncu_metrics.items() + if not name.startswith("at::") and not name.startswith("void at::") + } + flat_metrics = ( + next(iter(triton_kernels.values())) + if triton_kernels + else next(iter(ncu_metrics.values()), {}) + ) + else: + flat_metrics = {} + roofline_result = self.roofline.analyze(flat_metrics) + + prompt = build_bottleneck_prompt( + kernel_code=kernel_code, + ncu_metrics=ncu_metrics, + roofline=roofline_result, + gpu_specs=self.gpu_specs, + num_bottlenecks=self.num_bottlenecks, + num_causes=self.num_causes, + num_fixes=self.num_fixes, + ) + + response = _call_llm( + provider=self.provider, + model=self.model, + messages=[{"role": "user", "content": prompt}], + logger=self.logger, + max_tokens=16384, + ) + + if self.enable_debug and self.logs_dir: + _save_debug_file( + self.logs_dir / f"round{round_num:03d}_bottleneck_prompt.txt", + prompt, + self.logger, + ) + _save_debug_file( + self.logs_dir / f"round{round_num:03d}_bottleneck_response.txt", + response, + self.logger, + ) + + results = parse_bottleneck_response(response) + + if results: + categories = [r.category for r in results] + self.logger.info(f"[{round_num}] Bottlenecks: {', '.join(categories)}") + else: + self.logger.warning(f"[{round_num}] Failed to parse bottleneck response") + + return results diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 9534fc9..7252aea 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -88,6 +88,7 @@ def _load_templates(self): "test_generation": "test_generation.j2", "kernel_generation": "kernel_generation.j2", "kernel_refinement": "kernel_refinement.j2", + "kernel_optimization": "kernel_optimization.j2", "triton_guidelines": "triton_guidelines.j2", } @@ -194,6 +195,64 @@ def render_kernel_refinement_prompt( no_cusolver=no_cusolver, ) + def render_kernel_optimization_prompt( + self, + problem_description: str, + kernel_code: str, + gpu_specs: dict, + roofline: dict, + category: str, + summary: str, + reasoning: str, + root_cause: dict, + recommended_fix: dict, + pytorch_baseline_ms: float | None = None, + current_best_ms: float | None = None, + error_feedback: str | None = None, + ) -> str: + """ + Render the kernel optimization prompt. + + Args: + problem_description: Description of the problem + kernel_code: Current kernel implementation + gpu_specs: GPU hardware specifications dict + roofline: Roofline analysis result dict with keys: + bottleneck, compute_sol_pct, memory_sol_pct, efficiency_pct, + headroom_pct, at_roofline, uses_tensor_cores, warnings + category: Bottleneck category ("memory", "compute", "underutilized") + summary: One-line bottleneck summary + reasoning: Explanation citing metrics + root_cause: Single root cause dict {"cause": "...", "evidence": [...]} + recommended_fix: Single fix dict {"fix": "...", "rationale": "..."} + pytorch_baseline_ms: PyTorch Eager baseline time in ms + current_best_ms: Current best kernel time in ms (for iterative opt) + error_feedback: Error message from previous failed attempt + + Returns: + Rendered prompt string + """ + template = self.templates["kernel_optimization"] + + bottleneck = { + "category": category, + "summary": summary, + "reasoning": reasoning, + "root_cause": root_cause, + "recommended_fix": recommended_fix, + } + + return template.render( + problem_description=problem_description, + kernel_code=kernel_code, + gpu_specs=gpu_specs, + roofline=roofline, + bottleneck=bottleneck, + pytorch_baseline_ms=pytorch_baseline_ms, + current_best_ms=current_best_ms, + error_feedback=error_feedback, + ) + def render_triton_guidelines(self) -> str: """ Render the Triton guidelines. diff --git a/triton_kernel_agent/templates/kernel_optimization.j2 b/triton_kernel_agent/templates/kernel_optimization.j2 new file mode 100644 index 0000000..92fe699 --- /dev/null +++ b/triton_kernel_agent/templates/kernel_optimization.j2 @@ -0,0 +1,101 @@ +{# +Copyright (c) Meta Platforms, Inc. and affiliates. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +TASK: Optimize the following Triton kernel based on hardware profiling analysis to achieve better performance. + +{% if gpu_specs %} +## TARGET GPU +- GPU: {{ gpu_specs.name }} +- Architecture: {{ gpu_specs.architecture }} +- Peak Memory Bandwidth: {{ gpu_specs.peak_memory_bw_gbps }} GB/s +- Peak FP32: {{ gpu_specs.peak_fp32_tflops }} TFLOPS +- Peak FP16: {{ gpu_specs.peak_fp16_tflops }} TFLOPS +- Peak BF16: {{ gpu_specs.peak_bf16_tflops }} TFLOPS +- SM Count: {{ gpu_specs.sm_count }} +- Max Threads per SM: {{ gpu_specs.max_threads_per_sm }} +- L1 Cache per SM: {{ gpu_specs.l1_cache_kb }} KB +- L2 Cache: {{ gpu_specs.l2_cache_mb }} MB +- Memory: {{ gpu_specs.memory_gb }} GB {{ gpu_specs.memory_type }} +{% endif %} + +## PROBLEM DESCRIPTION +{{ problem_description }} +{% if pytorch_baseline_ms %} +PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms +{% endif %} + +## CURRENT KERNEL +```python +{{ kernel_code }} +``` + +{% if roofline %} +## ROOFLINE ANALYSIS +- Primary Bottleneck: {{ roofline.bottleneck | upper }} +- Compute SOL: {{ "%.1f"|format(roofline.compute_sol_pct) }}% +- Memory SOL: {{ "%.1f"|format(roofline.memory_sol_pct) }}% +- Efficiency: {{ "%.1f"|format(roofline.efficiency_pct) }}% (headroom: {{ "%.1f"|format(roofline.headroom_pct) }}%) +- At Roofline: {{ "Yes" if roofline.at_roofline else "No" }} +- Tensor Cores: {{ "Active" if roofline.uses_tensor_cores else "Inactive" }} +{%- if roofline.warnings %} +- Warnings: {{ roofline.warnings | join("; ") }} +{%- endif %} +{% endif %} + +## BOTTLENECK ANALYSIS +### Category: {{ bottleneck.category | upper }} +{{ bottleneck.summary }} + +**Reasoning:** {{ bottleneck.reasoning }} + +**Root Cause:** {{ bottleneck.root_cause.cause }} +{%- if bottleneck.root_cause.evidence %} + Evidence: {% for e in bottleneck.root_cause.evidence %}{{ e.metric }}={{ e.value }}{% if not loop.last %}, {% endif %}{% endfor %} +{%- endif %} + +**Recommended Fix:** {{ bottleneck.recommended_fix.fix }} +{%- if bottleneck.recommended_fix.rationale %} ({{ bottleneck.recommended_fix.rationale }}){% endif %} + +{% if error_feedback %} +## PREVIOUS ATTEMPT FAILED +{{ error_feedback }} +{% endif %} + +## PERFORMANCE TARGET +{% if pytorch_baseline_ms %} +- PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms +{% endif %} +{% if current_best_ms %} +- Current best kernel: {{ "%.4f"|format(current_best_ms) }} ms +- Target: Improve by at least 10% (< {{ "%.4f"|format(current_best_ms * 0.9) }} ms) +{% else %} +- Target: Improve by at least 10% over Eager (< {{ "%.4f"|format(pytorch_baseline_ms * 0.9) }} ms) +{% endif %} +- Maintain numerical correctness (atol=1e-4 or rtol=1e-4) +- Preserve public API (same inputs/outputs, shapes, dtypes) + +## REQUIREMENTS +1. Apply the recommended fixes above to address the {{ bottleneck.category | upper }} bottleneck +2. The implementation must be a complete, valid Python file +3. Main function must be named 'kernel_function' wrapping the Triton kernel +4. Keep the wrapper free of PyTorch compute primitives + +## OUTPUT FORMAT +Output complete optimized kernel code in ```python blocks. +Include only: imports, Triton kernel (@triton.jit), wrapper function (kernel_function). +No testing code, benchmarks, or explanatory comments. + +Generate the complete optimized kernel implementation: