Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions triton_kernel_agent/opt_worker_component/searching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Searching infrastructure for Optimization Kernel."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""History module for tracking optimization attempts.

Provides persistent storage for kernel optimization attempts, enabling:
- Resume: Continue runs after interruption
- History: Track what was tried and outcomes
- Learning: Use past attempts to guide exploration
"""

from .records import AttemptRecord, Outcome
from .store import AttemptStore, JsonAttemptStore

__all__ = ["AttemptRecord", "Outcome", "AttemptStore", "JsonAttemptStore"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data records for tracking optimization attempts."""

from __future__ import annotations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need future


from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum


class Outcome(Enum):
"""Result of an optimization attempt."""

IMPROVED = "improved"
REGRESSED = "regressed"
FAILED = "failed"


@dataclass
class AttemptRecord:
"""A single optimization attempt."""

id: str
kernel_code: str
time_ms: float
outcome: Outcome
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
parent_id: str | None = None

def __repr__(self) -> str:
return f"AttemptRecord(id={self.id}, time_ms={self.time_ms:.4f}, outcome={self.outcome.value})"

def to_dict(self) -> dict:
"""Serialize to dictionary for JSON storage."""
return {
"id": self.id,
"kernel_code": self.kernel_code,
"time_ms": self.time_ms,
"outcome": self.outcome.value,
"created_at": self.created_at.isoformat(),
"parent_id": self.parent_id,
}

@staticmethod
def from_dict(data: dict) -> AttemptRecord:
"""Deserialize from dictionary."""
return AttemptRecord(
id=data["id"],
kernel_code=data["kernel_code"],
time_ms=data["time_ms"],
outcome=Outcome(data["outcome"]),
created_at=datetime.fromisoformat(data["created_at"]),
parent_id=data.get("parent_id"),
)
144 changes: 144 additions & 0 deletions triton_kernel_agent/opt_worker_component/searching/history/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Storage interface and implementations for optimization attempts.

The attempt store provides persistent storage for kernel optimization
attempts discovered during the search process. This enables:

- Resume: Continue optimization runs after interruption
- History: Track what was tried and what worked/failed
- Learning: Use past attempts to guide future exploration
- Analysis: Understand optimization trajectories post-hoc

Thread/process safety:
- Only the main optimization loop should write to the store
- Workers return results via queue; manager calls add()
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Protocol

from .records import AttemptRecord, Outcome


class AttemptStore(Protocol):
"""Interface for storing and querying optimization attempts.

Implementations must provide:
- add(): Store a new attempt
- get(): Retrieve by ID
- get_recent(): Get recent attempts for history context
- get_top_k(): Get best performers for parent selection
- get_best(): Get single best attempt
- count(): Count total attempts
"""

def add(self, attempt: AttemptRecord) -> None:
"""Store an attempt."""
...

def get(self, attempt_id: str) -> AttemptRecord | None:
"""Get an attempt by ID."""
...

def get_recent(self, n: int) -> list[AttemptRecord]:
"""Get the n most recent attempts (newest first)."""
...

def get_top_k(self, k: int) -> list[AttemptRecord]:
"""Get the k best attempts by time_ms (fastest first)."""
...

def get_best(self) -> AttemptRecord | None:
"""Get the attempt with the lowest time_ms."""
...

def count(self) -> int:
"""Count total attempts in the store."""
...


class JsonAttemptStore:
"""JSON file-based implementation of AttemptStore."""

def __init__(self, path: Path | str) -> None:
self.path = Path(path)
self._attempts: list[AttemptRecord] = []
self._id_index: dict[str, int] = {}
self._load()

def _load(self) -> None:
"""Load attempts from JSON file if it exists.

Falls back to empty store if the file is corrupted (e.g., partial write).
"""
if self.path.exists():
try:
with open(self.path) as f:
data = json.load(f)
self._attempts = [AttemptRecord.from_dict(d) for d in data]
self._id_index = {a.id: i for i, a in enumerate(self._attempts)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing why we need _id_index

except (json.JSONDecodeError, KeyError) as e:
import warnings

warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}")
self._attempts = []
self._id_index = {}

def _save(self) -> None:
"""Save attempts to JSON file."""
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as f:
json.dump([a.to_dict() for a in self._attempts], f, indent=2)

def add(self, attempt: AttemptRecord) -> None:
"""Store an attempt and persist to disk."""
self._attempts.append(attempt)
self._id_index[attempt.id] = len(self._attempts) - 1
self._save()

def get(self, attempt_id: str) -> AttemptRecord | None:
"""Get an attempt by ID."""
idx = self._id_index.get(attempt_id)
if idx is not None:
return self._attempts[idx]
return None

def get_recent(self, n: int) -> list[AttemptRecord]:
"""Get the n most recent attempts (newest first)."""
return list(reversed(self._attempts[-n:]))

def get_top_k(self, k: int) -> list[AttemptRecord]:
"""Get the k best attempts by time_ms (fastest first).

Ties are broken by created_at (oldest first) for deterministic ordering.
"""
valid = [a for a in self._attempts if a.outcome != Outcome.FAILED]
sorted_by_time = sorted(valid, key=lambda a: (a.time_ms, a.created_at))
return sorted_by_time[:k]

def get_best(self) -> AttemptRecord | None:
"""Get the attempt with the lowest time_ms (excluding failed)."""
valid = [a for a in self._attempts if a.outcome != Outcome.FAILED]
if not valid:
return None
return min(valid, key=lambda a: a.time_ms)

def count(self) -> int:
"""Count total attempts in the store."""
return len(self._attempts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mutation module for building kernel optimization prompts."""

from .mutator import Mutator, SimpleMutator

__all__ = ["Mutator", "SimpleMutator"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mutation strategies for generating kernel optimization prompts.

Mutators build prompts for the LLM to optimize kernels, including:
- The parent kernel to improve
- History of what was tried before
- Any additional context (bottleneck analysis, inspirations, etc.)
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Protocol

if TYPE_CHECKING:
from ..history import AttemptRecord
Comment on lines +27 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just import without checking?



class Mutator(Protocol):
"""Interface for building optimization prompts."""

def build_prompt(
self,
parent: AttemptRecord,
history: list[AttemptRecord] | None = None,
) -> str:
"""Build a prompt for the LLM to optimize the kernel.

Args:
parent: The kernel to optimize.
history: Previous attempts, ordered oldest-first.
"""
...


class SimpleMutator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JsonAttemptStore.get_recent() returns newest first (reversed(self._attempts[-n:])).

SimpleMutator.build_prompt() prints history using history[-3:].

If callers pass store.get_recent(...) directly into the mutator, then history[-3:] will actually take the oldest of the “recent” slice (because it’s already newest-first). That’s subtle and will confuse prompt context.

Suggestion: either define a clear convention (history always oldest→newest), or have SimpleMutator treat history as newest-first and use history[:3], or sort by created_at inside the mutator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Update it to ordered as oldest-first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to just have the mutator take in the store?

We can be opinionated on how the mutator chooses history to start. Customization can be added later (e.g. n-history)

"""Minimal mutator: basic prompt with kernel and history."""

def build_prompt(
self,
parent: AttemptRecord,
history: list[AttemptRecord] | None = None,
) -> str:
lines = [
"# Optimize this Triton kernel\n",
f"Current performance: {parent.time_ms:.4f}ms\n",
]

if history:
lines.append("\n## Recent attempts:\n")
for a in history[-3:]:
lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n")

lines.append("\n## Kernel:\n```python\n")
lines.append(parent.kernel_code)
lines.append("\n```\n")

return "".join(lines)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sampling module for selecting parents from optimization history."""

from .sampler import Sampler, SimpleSampler

__all__ = ["Sampler", "SimpleSampler"]
Loading