-
Notifications
You must be signed in to change notification settings - Fork 29
Introduce Searching Component (Protocol with Sample Implementation) #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Searching infrastructure for Optimization Kernel.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """History module for tracking optimization attempts. | ||
|
|
||
| Provides persistent storage for kernel optimization attempts, enabling: | ||
| - Resume: Continue runs after interruption | ||
| - History: Track what was tried and outcomes | ||
| - Learning: Use past attempts to guide exploration | ||
| """ | ||
|
|
||
| from .records import AttemptRecord, Outcome | ||
| from .store import AttemptStore, JsonAttemptStore | ||
|
|
||
| __all__ = ["AttemptRecord", "Outcome", "AttemptStore", "JsonAttemptStore"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Data records for tracking optimization attempts.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timezone | ||
| from enum import Enum | ||
|
|
||
|
|
||
| class Outcome(Enum): | ||
| """Result of an optimization attempt.""" | ||
|
|
||
| IMPROVED = "improved" | ||
| REGRESSED = "regressed" | ||
| FAILED = "failed" | ||
|
|
||
|
|
||
| @dataclass | ||
| class AttemptRecord: | ||
| """A single optimization attempt.""" | ||
|
|
||
| id: str | ||
| kernel_code: str | ||
| time_ms: float | ||
| outcome: Outcome | ||
| created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) | ||
| parent_id: str | None = None | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"AttemptRecord(id={self.id}, time_ms={self.time_ms:.4f}, outcome={self.outcome.value})" | ||
|
|
||
| def to_dict(self) -> dict: | ||
| """Serialize to dictionary for JSON storage.""" | ||
| return { | ||
| "id": self.id, | ||
| "kernel_code": self.kernel_code, | ||
| "time_ms": self.time_ms, | ||
| "outcome": self.outcome.value, | ||
| "created_at": self.created_at.isoformat(), | ||
| "parent_id": self.parent_id, | ||
| } | ||
|
|
||
| @staticmethod | ||
| def from_dict(data: dict) -> AttemptRecord: | ||
| """Deserialize from dictionary.""" | ||
| return AttemptRecord( | ||
| id=data["id"], | ||
| kernel_code=data["kernel_code"], | ||
| time_ms=data["time_ms"], | ||
| outcome=Outcome(data["outcome"]), | ||
| created_at=datetime.fromisoformat(data["created_at"]), | ||
| parent_id=data.get("parent_id"), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Storage interface and implementations for optimization attempts. | ||
|
|
||
| The attempt store provides persistent storage for kernel optimization | ||
| attempts discovered during the search process. This enables: | ||
|
|
||
| - Resume: Continue optimization runs after interruption | ||
| - History: Track what was tried and what worked/failed | ||
| - Learning: Use past attempts to guide future exploration | ||
| - Analysis: Understand optimization trajectories post-hoc | ||
|
|
||
| Thread/process safety: | ||
| - Only the main optimization loop should write to the store | ||
| - Workers return results via queue; manager calls add() | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| from pathlib import Path | ||
| from typing import Protocol | ||
|
|
||
| from .records import AttemptRecord, Outcome | ||
|
|
||
|
|
||
| class AttemptStore(Protocol): | ||
| """Interface for storing and querying optimization attempts. | ||
|
|
||
| Implementations must provide: | ||
| - add(): Store a new attempt | ||
| - get(): Retrieve by ID | ||
| - get_recent(): Get recent attempts for history context | ||
| - get_top_k(): Get best performers for parent selection | ||
| - get_best(): Get single best attempt | ||
| - count(): Count total attempts | ||
| """ | ||
|
|
||
| def add(self, attempt: AttemptRecord) -> None: | ||
| """Store an attempt.""" | ||
| ... | ||
|
|
||
| def get(self, attempt_id: str) -> AttemptRecord | None: | ||
| """Get an attempt by ID.""" | ||
| ... | ||
|
|
||
| def get_recent(self, n: int) -> list[AttemptRecord]: | ||
| """Get the n most recent attempts (newest first).""" | ||
| ... | ||
|
|
||
| def get_top_k(self, k: int) -> list[AttemptRecord]: | ||
| """Get the k best attempts by time_ms (fastest first).""" | ||
| ... | ||
|
|
||
| def get_best(self) -> AttemptRecord | None: | ||
| """Get the attempt with the lowest time_ms.""" | ||
| ... | ||
|
|
||
| def count(self) -> int: | ||
| """Count total attempts in the store.""" | ||
| ... | ||
|
|
||
|
|
||
| class JsonAttemptStore: | ||
| """JSON file-based implementation of AttemptStore.""" | ||
|
|
||
| def __init__(self, path: Path | str) -> None: | ||
| self.path = Path(path) | ||
| self._attempts: list[AttemptRecord] = [] | ||
| self._id_index: dict[str, int] = {} | ||
| self._load() | ||
|
|
||
| def _load(self) -> None: | ||
| """Load attempts from JSON file if it exists. | ||
|
|
||
| Falls back to empty store if the file is corrupted (e.g., partial write). | ||
| """ | ||
| if self.path.exists(): | ||
| try: | ||
| with open(self.path) as f: | ||
| data = json.load(f) | ||
| self._attempts = [AttemptRecord.from_dict(d) for d in data] | ||
| self._id_index = {a.id: i for i, a in enumerate(self._attempts)} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not seeing why we need |
||
| except (json.JSONDecodeError, KeyError) as e: | ||
| import warnings | ||
|
|
||
| warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}") | ||
| self._attempts = [] | ||
| self._id_index = {} | ||
|
|
||
| def _save(self) -> None: | ||
| """Save attempts to JSON file.""" | ||
| self.path.parent.mkdir(parents=True, exist_ok=True) | ||
| with open(self.path, "w") as f: | ||
| json.dump([a.to_dict() for a in self._attempts], f, indent=2) | ||
|
|
||
| def add(self, attempt: AttemptRecord) -> None: | ||
| """Store an attempt and persist to disk.""" | ||
| self._attempts.append(attempt) | ||
| self._id_index[attempt.id] = len(self._attempts) - 1 | ||
| self._save() | ||
|
|
||
| def get(self, attempt_id: str) -> AttemptRecord | None: | ||
| """Get an attempt by ID.""" | ||
| idx = self._id_index.get(attempt_id) | ||
| if idx is not None: | ||
| return self._attempts[idx] | ||
| return None | ||
|
|
||
| def get_recent(self, n: int) -> list[AttemptRecord]: | ||
| """Get the n most recent attempts (newest first).""" | ||
| return list(reversed(self._attempts[-n:])) | ||
|
|
||
| def get_top_k(self, k: int) -> list[AttemptRecord]: | ||
| """Get the k best attempts by time_ms (fastest first). | ||
|
|
||
| Ties are broken by created_at (oldest first) for deterministic ordering. | ||
| """ | ||
| valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] | ||
| sorted_by_time = sorted(valid, key=lambda a: (a.time_ms, a.created_at)) | ||
| return sorted_by_time[:k] | ||
|
|
||
| def get_best(self) -> AttemptRecord | None: | ||
| """Get the attempt with the lowest time_ms (excluding failed).""" | ||
| valid = [a for a in self._attempts if a.outcome != Outcome.FAILED] | ||
| if not valid: | ||
| return None | ||
| return min(valid, key=lambda a: a.time_ms) | ||
|
|
||
| def count(self) -> int: | ||
| """Count total attempts in the store.""" | ||
| return len(self._attempts) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Mutation module for building kernel optimization prompts.""" | ||
|
|
||
| from .mutator import Mutator, SimpleMutator | ||
|
|
||
| __all__ = ["Mutator", "SimpleMutator"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Mutation strategies for generating kernel optimization prompts. | ||
|
|
||
| Mutators build prompts for the LLM to optimize kernels, including: | ||
| - The parent kernel to improve | ||
| - History of what was tried before | ||
| - Any additional context (bottleneck analysis, inspirations, etc.) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Protocol | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ..history import AttemptRecord | ||
|
Comment on lines
+27
to
+28
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just import without checking? |
||
|
|
||
|
|
||
| class Mutator(Protocol): | ||
| """Interface for building optimization prompts.""" | ||
|
|
||
| def build_prompt( | ||
| self, | ||
| parent: AttemptRecord, | ||
| history: list[AttemptRecord] | None = None, | ||
| ) -> str: | ||
| """Build a prompt for the LLM to optimize the kernel. | ||
|
|
||
| Args: | ||
| parent: The kernel to optimize. | ||
| history: Previous attempts, ordered oldest-first. | ||
| """ | ||
| ... | ||
|
|
||
|
|
||
| class SimpleMutator: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If callers pass Suggestion: either define a clear convention (history always oldest→newest), or have SimpleMutator treat history as newest-first and use history[:3], or sort by created_at inside the mutator.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Update it to ordered as oldest-first
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to just have the mutator take in the store? We can be opinionated on how the mutator chooses history to start. Customization can be added later (e.g. n-history) |
||
| """Minimal mutator: basic prompt with kernel and history.""" | ||
|
|
||
| def build_prompt( | ||
| self, | ||
| parent: AttemptRecord, | ||
| history: list[AttemptRecord] | None = None, | ||
| ) -> str: | ||
| lines = [ | ||
| "# Optimize this Triton kernel\n", | ||
| f"Current performance: {parent.time_ms:.4f}ms\n", | ||
| ] | ||
|
|
||
| if history: | ||
| lines.append("\n## Recent attempts:\n") | ||
| for a in history[-3:]: | ||
| lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n") | ||
|
|
||
| lines.append("\n## Kernel:\n```python\n") | ||
| lines.append(parent.kernel_code) | ||
| lines.append("\n```\n") | ||
|
|
||
| return "".join(lines) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Sampling module for selecting parents from optimization history.""" | ||
|
|
||
| from .sampler import Sampler, SimpleSampler | ||
|
|
||
| __all__ = ["Sampler", "SimpleSampler"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need future