From c7ea222f554405d3fa0ae36533e239bdb6f81ece Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 29 Jan 2026 17:45:01 -0800 Subject: [PATCH 1/3] Introduce Searching Component --- .../searching/__init__.py | 15 ++ .../searching/history/__init__.py | 26 ++++ .../searching/history/records.py | 67 ++++++++ .../searching/history/store.py | 143 ++++++++++++++++++ .../searching/mutation/__init__.py | 19 +++ .../searching/mutation/mutator.py | 70 +++++++++ .../searching/sampling/__init__.py | 19 +++ .../searching/sampling/sampler.py | 58 +++++++ .../searching/strategy/__init__.py | 19 +++ .../searching/strategy/strategy.py | 69 +++++++++ 10 files changed, 505 insertions(+) create mode 100644 triton_kernel_agent/opt_worker_component/searching/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/history/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/history/records.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/history/store.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py diff --git a/triton_kernel_agent/opt_worker_component/searching/__init__.py b/triton_kernel_agent/opt_worker_component/searching/__init__.py new file mode 100644 index 0000000..179bf78 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Searching infrastructure for Optimization Kernel.""" diff --git a/triton_kernel_agent/opt_worker_component/searching/history/__init__.py b/triton_kernel_agent/opt_worker_component/searching/history/__init__.py new file mode 100644 index 0000000..2b4eb24 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""History module for tracking optimization attempts. + +Provides persistent storage for kernel optimization attempts, enabling: +- Resume: Continue runs after interruption +- History: Track what was tried and outcomes +- Learning: Use past attempts to guide exploration +""" + +from .records import AttemptRecord, Outcome +from .store import AttemptStore, JsonAttemptStore + +__all__ = ["AttemptRecord", "Outcome", "AttemptStore", "JsonAttemptStore"] diff --git a/triton_kernel_agent/opt_worker_component/searching/history/records.py b/triton_kernel_agent/opt_worker_component/searching/history/records.py new file mode 100644 index 0000000..8da6565 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/records.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data records for tracking optimization attempts.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + + +class Outcome(Enum): + """Result of an optimization attempt.""" + + IMPROVED = "improved" + REGRESSED = "regressed" + FAILED = "failed" + + +@dataclass +class AttemptRecord: + """A single optimization attempt.""" + + id: str + kernel_code: str + time_ms: float + outcome: Outcome + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + parent_id: str | None = None + + def __repr__(self) -> str: + return f"AttemptRecord(id={self.id}, time_ms={self.time_ms:.4f}, outcome={self.outcome.value})" + + def to_dict(self) -> dict: + """Serialize to dictionary for JSON storage.""" + return { + "id": self.id, + "kernel_code": self.kernel_code, + "time_ms": self.time_ms, + "outcome": self.outcome.value, + "created_at": self.created_at.isoformat(), + "parent_id": self.parent_id, + } + + @staticmethod + def from_dict(data: dict) -> AttemptRecord: + """Deserialize from dictionary.""" + return AttemptRecord( + id=data["id"], + kernel_code=data["kernel_code"], + time_ms=data["time_ms"], + outcome=Outcome(data["outcome"]), + created_at=datetime.fromisoformat(data["created_at"]), + parent_id=data.get("parent_id"), + ) diff --git a/triton_kernel_agent/opt_worker_component/searching/history/store.py b/triton_kernel_agent/opt_worker_component/searching/history/store.py new file mode 100644 index 0000000..8dd17ca --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/history/store.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Storage interface and implementations for optimization attempts. + +The attempt store provides persistent storage for kernel optimization +attempts discovered during the search process. This enables: + +- Resume: Continue optimization runs after interruption +- History: Track what was tried and what worked/failed +- Learning: Use past attempts to guide future exploration +- Analysis: Understand optimization trajectories post-hoc + +Thread/process safety: +- Only the main optimization loop should write to the store +- Workers return results via queue; manager calls add() +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Protocol + +from .records import AttemptRecord, Outcome + + +class AttemptStore(Protocol): + """Interface for storing and querying optimization attempts. + + Implementations must provide: + - add(): Store a new attempt + - get(): Retrieve by ID + - get_recent(): Get recent attempts for history context + - get_top_k(): Get best performers for parent selection + - get_best(): Get single best attempt + - count(): Count total attempts + """ + + def add(self, attempt: AttemptRecord) -> None: + """Store an attempt.""" + ... + + def get(self, attempt_id: str) -> AttemptRecord | None: + """Get an attempt by ID.""" + ... + + def get_recent(self, n: int) -> list[AttemptRecord]: + """Get the n most recent attempts (newest first).""" + ... + + def get_top_k(self, k: int) -> list[AttemptRecord]: + """Get the k best attempts by time_ms (fastest first).""" + ... + + def get_best(self) -> AttemptRecord | None: + """Get the attempt with the lowest time_ms.""" + ... + + def count(self) -> int: + """Count total attempts in the store.""" + ... + + +class JsonAttemptStore: + """JSON file-based implementation of AttemptStore.""" + + def __init__(self, path: Path | str) -> None: + self.path = Path(path) + self._attempts: list[AttemptRecord] = [] + self._id_index: dict[str, int] = {} + self._load() + + def _load(self) -> None: + """Load attempts from JSON file if it exists. + + Falls back to empty store if the file is corrupted (e.g., partial write). + """ + if self.path.exists(): + try: + with open(self.path) as f: + data = json.load(f) + self._attempts = [AttemptRecord.from_dict(d) for d in data] + self._id_index = {a.id: i for i, a in enumerate(self._attempts)} + except (json.JSONDecodeError, KeyError) as e: + import warnings + warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}") + self._attempts = [] + self._id_index = {} + + def _save(self) -> None: + """Save attempts to JSON file.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + with open(self.path, "w") as f: + json.dump([a.to_dict() for a in self._attempts], f, indent=2) + + def add(self, attempt: AttemptRecord) -> None: + """Store an attempt and persist to disk.""" + self._attempts.append(attempt) + self._id_index[attempt.id] = len(self._attempts) - 1 + self._save() + + def get(self, attempt_id: str) -> AttemptRecord | None: + """Get an attempt by ID.""" + idx = self._id_index.get(attempt_id) + if idx is not None: + return self._attempts[idx] + return None + + def get_recent(self, n: int) -> list[AttemptRecord]: + """Get the n most recent attempts (newest first).""" + return list(reversed(self._attempts[-n:])) + + def get_top_k(self, k: int) -> list[AttemptRecord]: + """Get the k best attempts by time_ms (fastest first). + + Ties are broken by created_at (oldest first) for deterministic ordering. + """ + valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] + sorted_by_time = sorted(valid, key=lambda a: (a.time_ms, a.created_at)) + return sorted_by_time[:k] + + def get_best(self) -> AttemptRecord | None: + """Get the attempt with the lowest time_ms (excluding failed).""" + valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] + if not valid: + return None + return min(valid, key=lambda a: a.time_ms) + + def count(self) -> int: + """Count total attempts in the store.""" + return len(self._attempts) diff --git a/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py b/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py new file mode 100644 index 0000000..5c78328 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/mutation/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mutation module for building kernel optimization prompts.""" + +from .mutator import Mutator, SimpleMutator + +__all__ = ["Mutator", "SimpleMutator"] diff --git a/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py new file mode 100644 index 0000000..00e89d6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mutation strategies for generating kernel optimization prompts. + +Mutators build prompts for the LLM to optimize kernels, including: +- The parent kernel to improve +- History of what was tried before +- Any additional context (bottleneck analysis, inspirations, etc.) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from ..history import AttemptRecord + + +class Mutator(Protocol): + """Interface for building optimization prompts.""" + + def build_prompt( + self, + parent: AttemptRecord, + history: list[AttemptRecord] | None = None, + ) -> str: + """Build a prompt for the LLM to optimize the kernel. + + Args: + parent: The kernel to optimize. + history: Previous attempts, ordered oldest-first. + """ + ... + + +class SimpleMutator: + """Minimal mutator: basic prompt with kernel and history.""" + + def build_prompt( + self, + parent: AttemptRecord, + history: list[AttemptRecord] | None = None, + ) -> str: + lines = [ + "# Optimize this Triton kernel\n", + f"Current performance: {parent.time_ms:.4f}ms\n", + ] + + if history: + lines.append("\n## Recent attempts:\n") + for a in history[-3:]: + lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n") + + lines.append("\n## Kernel:\n```python\n") + lines.append(parent.kernel_code) + lines.append("\n```\n") + + return "".join(lines) diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py new file mode 100644 index 0000000..db292fc --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling module for selecting parents from optimization history.""" + +from .sampler import Sampler, SimpleSampler + +__all__ = ["Sampler", "SimpleSampler"] diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py new file mode 100644 index 0000000..2913f1b --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sampling strategies for selecting parents and inspirations from history. + +Samplers control how we select: +- Parents: Which kernel to optimize next +- Inspirations: Which kernels to show as few-shot examples to the LLM +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from ..history import AttemptRecord, AttemptStore + + +class Sampler(Protocol): + """Interface for sampling from optimization history.""" + + def sample_parent(self) -> AttemptRecord | None: + """Select a parent for the next optimization attempt.""" + ... + + def get_top_inspirations( + self, + n: int, + ) -> list[AttemptRecord]: + """Get top-performing attempts for few-shot prompting.""" + ... + + +class SimpleSampler: + """Example implementation: returns best parent, top-k inspirations.""" + + def __init__(self, store: AttemptStore) -> None: + self.store = store + + def sample_parent(self) -> AttemptRecord | None: + return self.store.get_best() + + def get_top_inspirations( + self, + n: int, + ) -> list[AttemptRecord]: + return self.store.get_top_k(n) diff --git a/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py b/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py new file mode 100644 index 0000000..dbe9402 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/strategy/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategy module for controlling the optimization loop.""" + +from .strategy import SimpleStrategy, Strategy + +__all__ = ["Strategy", "SimpleStrategy"] diff --git a/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py new file mode 100644 index 0000000..8b99ed1 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategies for controlling the optimization search loop. + +Strategies decide: +- Which parent to optimize next +- When to stop (convergence, plateau, max rounds) +- How to select the next generation of candidates +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from ..history import AttemptRecord, AttemptStore + from ..sampling import Sampler + + +class Strategy(Protocol): + """Interface for optimization loop control.""" + + def next_parent(self) -> AttemptRecord | None: + """Select the next parent to optimize from.""" + ... + + def record_result(self, attempt: AttemptRecord) -> None: + """Record an optimization attempt result.""" + ... + + def get_best(self) -> AttemptRecord | None: + """Get the best attempt so far.""" + ... + + def should_stop(self, round_num: int, max_rounds: int) -> bool: + """Check if optimization should terminate early.""" + ... + + +class SimpleStrategy: + """Example Implementation: always pick best, stop at max rounds.""" + + def __init__(self, store: AttemptStore, sampler: Sampler) -> None: + self.store = store + self.sampler = sampler + + def next_parent(self) -> AttemptRecord | None: + return self.sampler.sample_parent() + + def record_result(self, attempt: AttemptRecord) -> None: + self.store.add(attempt) + + def get_best(self) -> AttemptRecord | None: + return self.store.get_best() + + def should_stop(self, round_num: int, max_rounds: int) -> bool: + return round_num >= max_rounds From e61c48612b1c0e7d5115eca7beefbcf74e921c5e Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Sun, 1 Feb 2026 23:55:15 -0800 Subject: [PATCH 2/3] fix ruff --- .../opt_worker_component/searching/history/store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/triton_kernel_agent/opt_worker_component/searching/history/store.py b/triton_kernel_agent/opt_worker_component/searching/history/store.py index 8dd17ca..a9d4153 100644 --- a/triton_kernel_agent/opt_worker_component/searching/history/store.py +++ b/triton_kernel_agent/opt_worker_component/searching/history/store.py @@ -95,6 +95,7 @@ def _load(self) -> None: self._id_index = {a.id: i for i, a in enumerate(self._attempts)} except (json.JSONDecodeError, KeyError) as e: import warnings + warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}") self._attempts = [] self._id_index = {} From 52f3ec0fe8464ffd68d2f9965b736742aebb9112 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 5 Feb 2026 15:33:47 -0800 Subject: [PATCH 3/3] fix issue --- .../searching/history/records.py | 4 +-- .../searching/history/store.py | 24 +++-------------- .../searching/mutation/mutator.py | 26 +++++++------------ .../searching/sampling/__init__.py | 4 +-- .../searching/sampling/sampler.py | 11 +++----- .../searching/strategy/strategy.py | 9 +++---- 6 files changed, 22 insertions(+), 56 deletions(-) diff --git a/triton_kernel_agent/opt_worker_component/searching/history/records.py b/triton_kernel_agent/opt_worker_component/searching/history/records.py index 8da6565..a74344c 100644 --- a/triton_kernel_agent/opt_worker_component/searching/history/records.py +++ b/triton_kernel_agent/opt_worker_component/searching/history/records.py @@ -14,8 +14,6 @@ """Data records for tracking optimization attempts.""" -from __future__ import annotations - from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum @@ -55,7 +53,7 @@ def to_dict(self) -> dict: } @staticmethod - def from_dict(data: dict) -> AttemptRecord: + def from_dict(data: dict) -> "AttemptRecord": """Deserialize from dictionary.""" return AttemptRecord( id=data["id"], diff --git a/triton_kernel_agent/opt_worker_component/searching/history/store.py b/triton_kernel_agent/opt_worker_component/searching/history/store.py index a9d4153..237ced7 100644 --- a/triton_kernel_agent/opt_worker_component/searching/history/store.py +++ b/triton_kernel_agent/opt_worker_component/searching/history/store.py @@ -27,8 +27,6 @@ - Workers return results via queue; manager calls add() """ -from __future__ import annotations - import json from pathlib import Path from typing import Protocol @@ -41,7 +39,6 @@ class AttemptStore(Protocol): Implementations must provide: - add(): Store a new attempt - - get(): Retrieve by ID - get_recent(): Get recent attempts for history context - get_top_k(): Get best performers for parent selection - get_best(): Get single best attempt @@ -52,12 +49,8 @@ def add(self, attempt: AttemptRecord) -> None: """Store an attempt.""" ... - def get(self, attempt_id: str) -> AttemptRecord | None: - """Get an attempt by ID.""" - ... - def get_recent(self, n: int) -> list[AttemptRecord]: - """Get the n most recent attempts (newest first).""" + """Get the n most recent attempts (oldest first).""" ... def get_top_k(self, k: int) -> list[AttemptRecord]: @@ -79,7 +72,6 @@ class JsonAttemptStore: def __init__(self, path: Path | str) -> None: self.path = Path(path) self._attempts: list[AttemptRecord] = [] - self._id_index: dict[str, int] = {} self._load() def _load(self) -> None: @@ -92,13 +84,11 @@ def _load(self) -> None: with open(self.path) as f: data = json.load(f) self._attempts = [AttemptRecord.from_dict(d) for d in data] - self._id_index = {a.id: i for i, a in enumerate(self._attempts)} except (json.JSONDecodeError, KeyError) as e: import warnings warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}") self._attempts = [] - self._id_index = {} def _save(self) -> None: """Save attempts to JSON file.""" @@ -109,19 +99,11 @@ def _save(self) -> None: def add(self, attempt: AttemptRecord) -> None: """Store an attempt and persist to disk.""" self._attempts.append(attempt) - self._id_index[attempt.id] = len(self._attempts) - 1 self._save() - def get(self, attempt_id: str) -> AttemptRecord | None: - """Get an attempt by ID.""" - idx = self._id_index.get(attempt_id) - if idx is not None: - return self._attempts[idx] - return None - def get_recent(self, n: int) -> list[AttemptRecord]: - """Get the n most recent attempts (newest first).""" - return list(reversed(self._attempts[-n:])) + """Get the n most recent attempts (oldest first).""" + return self._attempts[-n:] def get_top_k(self, k: int) -> list[AttemptRecord]: """Get the k best attempts by time_ms (fastest first). diff --git a/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py index 00e89d6..e77990f 100644 --- a/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py +++ b/triton_kernel_agent/opt_worker_component/searching/mutation/mutator.py @@ -20,27 +20,19 @@ - Any additional context (bottleneck analysis, inspirations, etc.) """ -from __future__ import annotations +from typing import Protocol -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from ..history import AttemptRecord +from ..history import AttemptRecord, AttemptStore class Mutator(Protocol): """Interface for building optimization prompts.""" - def build_prompt( - self, - parent: AttemptRecord, - history: list[AttemptRecord] | None = None, - ) -> str: + def build_prompt(self, parent: AttemptRecord) -> str: """Build a prompt for the LLM to optimize the kernel. Args: parent: The kernel to optimize. - history: Previous attempts, ordered oldest-first. """ ... @@ -48,19 +40,19 @@ def build_prompt( class SimpleMutator: """Minimal mutator: basic prompt with kernel and history.""" - def build_prompt( - self, - parent: AttemptRecord, - history: list[AttemptRecord] | None = None, - ) -> str: + def __init__(self, store: AttemptStore) -> None: + self.store = store + + def build_prompt(self, parent: AttemptRecord) -> str: lines = [ "# Optimize this Triton kernel\n", f"Current performance: {parent.time_ms:.4f}ms\n", ] + history = self.store.get_recent(3) if history: lines.append("\n## Recent attempts:\n") - for a in history[-3:]: + for a in history: lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n") lines.append("\n## Kernel:\n```python\n") diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py index db292fc..bcfa475 100644 --- a/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/__init__.py @@ -14,6 +14,6 @@ """Sampling module for selecting parents from optimization history.""" -from .sampler import Sampler, SimpleSampler +from .sampler import BestSampler, Sampler -__all__ = ["Sampler", "SimpleSampler"] +__all__ = ["Sampler", "BestSampler"] diff --git a/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py index 2913f1b..792317f 100644 --- a/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py +++ b/triton_kernel_agent/opt_worker_component/searching/sampling/sampler.py @@ -19,12 +19,9 @@ - Inspirations: Which kernels to show as few-shot examples to the LLM """ -from __future__ import annotations +from typing import Protocol -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from ..history import AttemptRecord, AttemptStore +from ..history import AttemptRecord, AttemptStore class Sampler(Protocol): @@ -42,8 +39,8 @@ def get_top_inspirations( ... -class SimpleSampler: - """Example implementation: returns best parent, top-k inspirations.""" +class BestSampler: + """Sampler that always returns the best parent and top-k inspirations.""" def __init__(self, store: AttemptStore) -> None: self.store = store diff --git a/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py index 8b99ed1..90181e1 100644 --- a/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py +++ b/triton_kernel_agent/opt_worker_component/searching/strategy/strategy.py @@ -20,13 +20,10 @@ - How to select the next generation of candidates """ -from __future__ import annotations +from typing import Protocol -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from ..history import AttemptRecord, AttemptStore - from ..sampling import Sampler +from ..history import AttemptRecord, AttemptStore +from ..sampling import Sampler class Strategy(Protocol):