diff --git a/triton_kernel_agent/opt_worker_component/searching/__init__.py b/triton_kernel_agent/opt_worker_component/searching/__init__.py new file mode 100644 index 0000000..179bf78 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Searching infrastructure for Optimization Kernel.""" diff --git a/triton_kernel_agent/opt_worker_component/searching/history/__init__.py b/triton_kernel_agent/opt_worker_component/searching/history/__init__.py new file mode 100644 index 0000000..2b4eb24 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""History module for tracking optimization attempts. + +Provides persistent storage for kernel optimization attempts, enabling: +- Resume: Continue runs after interruption +- History: Track what was tried and outcomes +- Learning: Use past attempts to guide exploration +""" + +from .records import AttemptRecord, Outcome +from .store import AttemptStore, JsonAttemptStore + +__all__ = ["AttemptRecord", "Outcome", "AttemptStore", "JsonAttemptStore"] diff --git a/triton_kernel_agent/opt_worker_component/searching/history/records.py b/triton_kernel_agent/opt_worker_component/searching/history/records.py new file mode 100644 index 0000000..a74344c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/records.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data records for tracking optimization attempts.""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + + +class Outcome(Enum): + """Result of an optimization attempt.""" + + IMPROVED = "improved" + REGRESSED = "regressed" + FAILED = "failed" + + +@dataclass +class AttemptRecord: + """A single optimization attempt.""" + + id: str + kernel_code: str + time_ms: float + outcome: Outcome + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + parent_id: str | None = None + + def __repr__(self) -> str: + return f"AttemptRecord(id={self.id}, time_ms={self.time_ms:.4f}, outcome={self.outcome.value})" + + def to_dict(self) -> dict: + """Serialize to dictionary for JSON storage.""" + return { + "id": self.id, + "kernel_code": self.kernel_code, + "time_ms": self.time_ms, + "outcome": self.outcome.value, + "created_at": self.created_at.isoformat(), + "parent_id": self.parent_id, + } + + @staticmethod + def from_dict(data: dict) -> "AttemptRecord": + """Deserialize from dictionary.""" + return AttemptRecord( + id=data["id"], + kernel_code=data["kernel_code"], + time_ms=data["time_ms"], + outcome=Outcome(data["outcome"]), + created_at=datetime.fromisoformat(data["created_at"]), + parent_id=data.get("parent_id"), + ) diff --git a/triton_kernel_agent/opt_worker_component/searching/history/store.py b/triton_kernel_agent/opt_worker_component/searching/history/store.py new file mode 100644 index 0000000..237ced7 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/store.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Storage interface and implementations for optimization attempts. + +The attempt store provides persistent storage for kernel optimization +attempts discovered during the search process. This enables: + +- Resume: Continue optimization runs after interruption +- History: Track what was tried and what worked/failed +- Learning: Use past attempts to guide future exploration +- Analysis: Understand optimization trajectories post-hoc + +Thread/process safety: +- Only the main optimization loop should write to the store +- Workers return results via queue; manager calls add() +""" + +import json +from pathlib import Path +from typing import Protocol + +from .records import AttemptRecord, Outcome + + +class AttemptStore(Protocol): + """Interface for storing and querying optimization attempts. + + Implementations must provide: + - add(): Store a new attempt + - get_recent(): Get recent attempts for history context + - get_top_k(): Get best performers for parent selection + - get_best(): Get single best attempt + - count(): Count total attempts + """ + + def add(self, attempt: AttemptRecord) -> None: + """Store an attempt.""" + ... + + def get_recent(self, n: int) -> list[AttemptRecord]: + """Get the n most recent attempts (oldest first).""" + ... + + def get_top_k(self, k: int) -> list[AttemptRecord]: + """Get the k best attempts by time_ms (fastest first).""" + ... + + def get_best(self) -> AttemptRecord | None: + """Get the attempt with the lowest time_ms.""" + ... + + def count(self) -> int: + """Count total attempts in the store.""" + ... + + +class JsonAttemptStore: + """JSON file-based implementation of AttemptStore.""" + + def __init__(self, path: Path | str) -> None: + self.path = Path(path) + self._attempts: list[AttemptRecord] = [] + self._load() + + def _load(self) -> None: + """Load attempts from JSON file if it exists. + + Falls back to empty store if the file is corrupted (e.g., partial write). + """ + if self.path.exists(): + try: + with open(self.path) as f: + data = json.load(f) + self._attempts = [AttemptRecord.from_dict(d) for d in data] + except (json.JSONDecodeError, KeyError) as e: + import warnings + + warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}") + self._attempts = [] + + def _save(self) -> None: + """Save attempts to JSON file.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + with open(self.path, "w") as f: + json.dump([a.to_dict() for a in self._attempts], f, indent=2) + + def add(self, attempt: AttemptRecord) -> None: + """Store an attempt and persist to disk.""" + self._attempts.append(attempt) + self._save() + + def get_recent(self, n: int) -> list[AttemptRecord]: + """Get the n most recent attempts (oldest first).""" + return self._attempts[-n:] + + def get_top_k(self, k: int) -> list[AttemptRecord]: + """Get the k best attempts by time_ms (fastest first). + + Ties are broken by created_at (oldest first) for deterministic ordering. + """ + valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] + sorted_by_time = sorted(valid, key=lambda a: (a.time_ms, a.created_at)) + return sorted_by_time[:k] + + def get_best(self) -> AttemptRecord | None: + """Get the attempt with the lowest time_ms (excluding failed).""" + valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] + if not valid: + return None + return min(valid, key=lambda a: a.time_ms) + + def count(self) -> int: + """Count total attempts in the store.""" + return len(self._attempts) diff --git a/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py b/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py new file mode 100644 index 0000000..5c78328 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mutation module for building kernel optimization prompts.""" + +from .mutator import Mutator, SimpleMutator + +__all__ = ["Mutator", "SimpleMutator"] diff --git a/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py new file mode 100644 index 0000000..e77990f --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mutation strategies for generating kernel optimization prompts. + +Mutators build prompts for the LLM to optimize kernels, including: +- The parent kernel to improve +- History of what was tried before +- Any additional context (bottleneck analysis, inspirations, etc.) +""" + +from typing import Protocol + +from ..history import AttemptRecord, AttemptStore + + +class Mutator(Protocol): + """Interface for building optimization prompts.""" + + def build_prompt(self, parent: AttemptRecord) -> str: + """Build a prompt for the LLM to optimize the kernel. + + Args: + parent: The kernel to optimize. + """ + ... + + +class SimpleMutator: + """Minimal mutator: basic prompt with kernel and history.""" + + def __init__(self, store: AttemptStore) -> None: + self.store = store + + def build_prompt(self, parent: AttemptRecord) -> str: + lines = [ + "# Optimize this Triton kernel\n", + f"Current performance: {parent.time_ms:.4f}ms\n", + ] + + history = self.store.get_recent(3) + if history: + lines.append("\n## Recent attempts:\n") + for a in history: + lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n") + + lines.append("\n## Kernel:\n```python\n") + lines.append(parent.kernel_code) + lines.append("\n```\n") + + return "".join(lines) diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py new file mode 100644 index 0000000..bcfa475 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling module for selecting parents from optimization history.""" + +from .sampler import BestSampler, Sampler + +__all__ = ["Sampler", "BestSampler"] diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py new file mode 100644 index 0000000..792317f --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling strategies for selecting parents and inspirations from history. + +Samplers control how we select: +- Parents: Which kernel to optimize next +- Inspirations: Which kernels to show as few-shot examples to the LLM +""" + +from typing import Protocol + +from ..history import AttemptRecord, AttemptStore + + +class Sampler(Protocol): + """Interface for sampling from optimization history.""" + + def sample_parent(self) -> AttemptRecord | None: + """Select a parent for the next optimization attempt.""" + ... + + def get_top_inspirations( + self, + n: int, + ) -> list[AttemptRecord]: + """Get top-performing attempts for few-shot prompting.""" + ... + + +class BestSampler: + """Sampler that always returns the best parent and top-k inspirations.""" + + def __init__(self, store: AttemptStore) -> None: + self.store = store + + def sample_parent(self) -> AttemptRecord | None: + return self.store.get_best() + + def get_top_inspirations( + self, + n: int, + ) -> list[AttemptRecord]: + return self.store.get_top_k(n) diff --git a/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py b/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py new file mode 100644 index 0000000..dbe9402 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategy module for controlling the optimization loop.""" + +from .strategy import SimpleStrategy, Strategy + +__all__ = ["Strategy", "SimpleStrategy"] diff --git a/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py new file mode 100644 index 0000000..90181e1 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategies for controlling the optimization search loop. + +Strategies decide: +- Which parent to optimize next +- When to stop (convergence, plateau, max rounds) +- How to select the next generation of candidates +""" + +from typing import Protocol + +from ..history import AttemptRecord, AttemptStore +from ..sampling import Sampler + + +class Strategy(Protocol): + """Interface for optimization loop control.""" + + def next_parent(self) -> AttemptRecord | None: + """Select the next parent to optimize from.""" + ... + + def record_result(self, attempt: AttemptRecord) -> None: + """Record an optimization attempt result.""" + ... + + def get_best(self) -> AttemptRecord | None: + """Get the best attempt so far.""" + ... + + def should_stop(self, round_num: int, max_rounds: int) -> bool: + """Check if optimization should terminate early.""" + ... + + +class SimpleStrategy: + """Example Implementation: always pick best, stop at max rounds.""" + + def __init__(self, store: AttemptStore, sampler: Sampler) -> None: + self.store = store + self.sampler = sampler + + def next_parent(self) -> AttemptRecord | None: + return self.sampler.sample_parent() + + def record_result(self, attempt: AttemptRecord) -> None: + self.store.add(attempt) + + def get_best(self) -> AttemptRecord | None: + return self.store.get_best() + + def should_stop(self, round_num: int, max_rounds: int) -> bool: + return round_num >= max_rounds