-
Notifications
You must be signed in to change notification settings - Fork 758
Dev/database store #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mydmdm
wants to merge
23
commits into
main
Choose a base branch
from
dev/database-store
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Dev/database store #276
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
47b0a4e
initial version with all dbstore interface implemented, but not all t…
mydmdm b8940fe
interface implemented, but still have some issues for the periodical …
mydmdm 96a0d58
configurable retry added
mydmdm a0b6833
support periodic background tasks for attempt timeout checking
mydmdm cbd6498
update error messages
mydmdm a477bec
only corner cases for status propagation
mydmdm f7fe24a
all tests passed with some FIXME
mydmdm 1857e39
reuse test_memory.py
mydmdm 5950276
fix typo
mydmdm 70838ee
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm 93814a9
Enhance rollout and attempt handling with type unification and new re…
mydmdm 294a1cc
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm 6131c1f
rename to SqlLightningStore
mydmdm b6312db
fix pre-commit warnings
mydmdm c21e065
Update agentlightning/store/database/orm/rollout.py
mydmdm 5983eec
Update agentlightning/store/database/retry_helper.py
mydmdm 3162ed6
fix lint issue
mydmdm c1b7827
fix lint error
mydmdm fe2d052
To break the big transaction inside timeout healthy checking
mydmdm 5fc8ee0
fix lint errors
mydmdm de8df23
Merge branch 'dev/database-store' of https://github.com/microsoft/age…
mydmdm 2eb207e
fix lint warning and update uv.lock
mydmdm dc2c5e5
Merge branch 'main' into dev/database-store
mydmdm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| from .sqlite import SqlLightningStore | ||
|
|
||
| __all__ = [ | ||
| "SqlLightningStore", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| from .attempt import AttemptInDB, SpanSeqIdInDB | ||
| from .base import ( | ||
| AttemptStatusUpdateMessage, | ||
| SqlAlchemyBase, | ||
| ) | ||
| from .resources import ResourcesUpdateInDB | ||
| from .rollout import RolloutInDB | ||
| from .span import SpanInDB | ||
|
|
||
| __all__ = [ | ||
| "SqlAlchemyBase", | ||
| "AttemptStatusUpdateMessage", | ||
| "RolloutInDB", | ||
| "AttemptInDB", | ||
| "ResourcesUpdateInDB", | ||
| "SpanSeqIdInDB", | ||
| "SpanInDB", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,251 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import logging | ||
| import time | ||
| import uuid | ||
| from dataclasses import InitVar | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from sqlalchemy import JSON, Float, Integer, String, select | ||
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker | ||
| from sqlalchemy.orm import Mapped, mapped_column | ||
|
|
||
| from agentlightning.types import Attempt | ||
|
|
||
| from .base import AttemptStatusUpdateMessage, SqlAlchemyBase | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _generate_attempt_id() -> str: | ||
| """We don't need that long because attempts are limited to rollouts.""" | ||
| short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:8] | ||
| return "at-" + short_id | ||
|
|
||
|
|
||
| class AttemptInDB(SqlAlchemyBase): | ||
| __tablename__ = "attempts" | ||
|
|
||
| rollout_id: Mapped[str] = mapped_column(String, nullable=False) | ||
| attempt_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_attempt_id) | ||
| sequence_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0) | ||
| start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False) | ||
| end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None) | ||
| status: Mapped[str] = mapped_column(String, default="preparing", nullable=False) | ||
| worker_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None) | ||
| last_heartbeat_time: Mapped[Optional[float]] = mapped_column(Float, nullable=False, default_factory=time.time) | ||
| attempt_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None) | ||
|
|
||
| # addition columns for processing | ||
| max_duration: Mapped[Optional[float]] = mapped_column( | ||
| Float, nullable=True, default=None | ||
| ) # maximum duration allowed for this attempt in seconds | ||
| max_heartbeat_interval: Mapped[Optional[float]] = mapped_column( | ||
| Float, nullable=True, default=None | ||
| ) # maximum allowed heartbeat interval in seconds | ||
|
|
||
| version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1) | ||
| __mapper_args__ = { | ||
| "version_id_col": version_id, | ||
| } | ||
|
|
||
| def is_unresponsive(self, current_time: float) -> bool: | ||
| """Check if the attempt is unresponsive based on the last heartbeat time and max_heartbeat_interval.""" | ||
| if self.max_heartbeat_interval is None: | ||
| return False | ||
| if self.last_heartbeat_time is None: | ||
| return False | ||
| return (current_time - self.last_heartbeat_time) > self.max_heartbeat_interval | ||
|
|
||
| def is_timed_out(self, current_time: float) -> bool: | ||
| """Check if the attempt has timed out based on the start time and max_duration.""" | ||
| if self.max_duration is None: | ||
| return False | ||
| return (current_time - self.start_time) > self.max_duration | ||
|
|
||
| def as_attempt(self) -> Attempt: | ||
| return Attempt( | ||
| **self.model_dump( | ||
| exclude={"max_duration", "max_heartbeat_interval", "version_id"}, | ||
| mapper={"metadata": lambda obj: obj.attempt_metadata}, # type: ignore | ||
| ) | ||
| ) | ||
|
|
||
| def _validate_status_message(self, msg: Dict[str, Any]) -> None: | ||
| """This function validates the status update message from caller. | ||
| Raises ValueError if the message is invalid. | ||
| """ | ||
| if "event" not in msg: | ||
| raise ValueError("Status update message must contain 'event' field.") | ||
| if "timestamp" not in msg: | ||
| msg["timestamp"] = time.time() | ||
| if msg["event"] not in [ | ||
| "user_update", # user update attempt status via dbstore.update_attempt() | ||
| "span_received", # new span received | ||
| "single_step_timeout", # single step timeout detected (from last span heartbeat) | ||
| "overall_timeout", # overall timeout detected | ||
| ]: | ||
| raise ValueError(f"Unsupported event type: {msg['event']}") | ||
| if msg["event"] == "user_update" and "new_status" not in msg: | ||
| raise ValueError("User update event must contain 'new_status' field.") | ||
|
|
||
| def get_finished_statuses(self) -> List[str]: | ||
| """This function returns the list of statuses that are considered finished.""" | ||
| return [ | ||
| "succeeded", | ||
| "failed", | ||
| "timeout", | ||
| ] | ||
|
|
||
| def update_status(self, msg: Dict[str, Any]) -> Optional[AttemptStatusUpdateMessage]: | ||
|
||
| """This function updates the status of the attempt based on the event. | ||
| Args: | ||
| msg: A dictionary containing the status update message. It must contain an "event" field, and optionally a "new_status" field. | ||
| More details about the message format can be found in the `_validate_status_message`() method. | ||
| current_time: The current time to use for updating timestamps. If None, uses time.time(). | ||
| Returns: | ||
| A dictionary containing the status update message: {"event": "attempt_status_updated", "old_status": old_status, "new_status": new_status}. | ||
| IF no meaningful status update is performed, returns None. | ||
| Raises: | ||
| ValueError: If the event is not recognized or the status transition is invalid. | ||
| NotImplementedError: If the event handling is not implemented for the current status. | ||
| RuntimeError: If the new status is not set after processing the event. | ||
| """ | ||
| self._validate_status_message(msg) | ||
| event = msg["event"] | ||
| current_time = msg.get("timestamp", time.time()) | ||
| old_status = self.status | ||
| new_status = msg.get("new_status", None) | ||
|
|
||
| # Step 1: Determine the new status based on the event and current status | ||
| if event == "user_update": | ||
| if not new_status: | ||
| raise ValueError("new_status must be provided for user_update event.") | ||
| elif event == "span_received": | ||
| self.last_heartbeat_time = current_time | ||
| if old_status in ["preparing", "unresponsive", "running"]: | ||
| new_status = "running" | ||
| elif old_status in self.get_finished_statuses(): | ||
| logger.warning( | ||
| f"Span received after attempt is already in status {self.status}. No status update performed." | ||
| ) | ||
| return # no further status update needed | ||
| else: | ||
| raise NotImplementedError(f"Event {event} is not implemented for status {old_status}.") | ||
| elif event == "single_step_timeout": | ||
| if old_status in [ | ||
| "preparing", | ||
| "running", | ||
| ]: | ||
| new_status = "unresponsive" | ||
| else: | ||
| logger.warning( | ||
| f"Single step timeout detected but attempt is in status {self.status}. No status update performed." | ||
| ) | ||
| return # no further status update needed | ||
| elif event == "overall_timeout": | ||
| if old_status not in self.get_finished_statuses(): | ||
| new_status = "timeout" | ||
| else: | ||
| logger.warning( | ||
| f"Overall timeout detected but attempt is in status {self.status}. No status update performed." | ||
| ) | ||
| return # no further status update needed | ||
| else: | ||
| raise NotImplementedError(f"Event {event} is not implemented for status update.") | ||
|
|
||
| # Step 2: Update the status | ||
| if not new_status: | ||
| raise RuntimeError( | ||
| f"new_status should not be {new_status} after processing event for {event} on status {old_status}." | ||
| ) | ||
| if new_status == old_status: | ||
| return # no status change | ||
| if new_status in self.get_finished_statuses(): | ||
| # when attempt is finished, set end_time | ||
| self.end_time = current_time | ||
| self.status = new_status | ||
|
|
||
| # Step 3: Return the status update info for further processing | ||
| return AttemptStatusUpdateMessage( | ||
| attempt_id=self.attempt_id, | ||
| rollout_id=self.rollout_id, | ||
| timestamp=current_time, | ||
| old_status=old_status, | ||
| new_status=new_status, | ||
| ) | ||
|
|
||
| @classmethod | ||
| async def get_latest_attempt_for_rollout( | ||
| cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str | ||
| ) -> Optional[Attempt]: | ||
| async with session_factory() as session: | ||
| async with session.begin(): | ||
| result = await session.scalars( | ||
| select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.desc()).limit(1) | ||
| ) | ||
| attempt_obj = result.one_or_none() | ||
| if attempt_obj is None: | ||
| return None | ||
| return attempt_obj.as_attempt() | ||
|
|
||
| @classmethod | ||
| async def get_attempts_for_rollout( | ||
| cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str | ||
| ) -> List[Attempt]: | ||
| async with session_factory() as session: | ||
| async with session.begin(): | ||
| result = await session.scalars( | ||
| select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.asc()) | ||
| ) | ||
| return [attempt.as_attempt() for attempt in result.all()] | ||
|
|
||
|
|
||
| class SpanSeqIdInDB(SqlAlchemyBase): | ||
| __tablename__ = "span_sequence" | ||
|
|
||
| rollout_id: Mapped[str] = mapped_column(nullable=False, primary_key=True) | ||
|
|
||
| # FIXME InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting | ||
| # attempt_id: Mapped[str] = mapped_column(nullable=False) | ||
| attempt_id: InitVar[str] # not mapped column, just for type hinting | ||
|
|
||
| current_sequence: Mapped[int] = mapped_column(default=1, nullable=False) | ||
|
|
||
| # Versioning for optimistic concurrency control | ||
| version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1) | ||
| __mapper_args__ = { | ||
| "version_id_col": version_id, | ||
| # "primary_key": [rollout_id, attempt_id], | ||
| # "primary_key": [rollout_id], | ||
| } | ||
|
|
||
| @classmethod | ||
| async def get_next_sequence_id( | ||
| cls: type[SpanSeqIdInDB], | ||
| session_factory: async_sessionmaker[AsyncSession], | ||
| rollout_id: str, | ||
| attempt_id: str, | ||
| external_seq_id: Optional[int] = None, | ||
| ) -> int: | ||
| """Get the next sequence ID with retries to handle race conditions. | ||
| IF external_seq_id is provided and is greater than current_sequence, set current_sequence to external_seq_id. | ||
| """ | ||
| async with session_factory() as session: | ||
| async with session.begin(): | ||
| seq_obj = await session.get(cls, rollout_id) | ||
| # seq_obj = await session.get(cls, [rollout_id, attempt_id]) | ||
| if seq_obj is None: | ||
| raise ValueError(f"Rollout {rollout_id} not found") | ||
| else: | ||
| current_seq = ( | ||
| external_seq_id | ||
| if external_seq_id is not None and external_seq_id > seq_obj.current_sequence | ||
| else seq_obj.current_sequence | ||
| ) | ||
| seq_obj.current_sequence = current_seq + 1 | ||
| await session.flush() | ||
| return current_seq | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.