Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
47b0a4e
initial version with all dbstore interface implemented, but not all t…
mydmdm Oct 31, 2025
b8940fe
interface implemented, but still have some issues for the periodical …
mydmdm Nov 3, 2025
96a0d58
configurable retry added
mydmdm Nov 3, 2025
a0b6833
support periodic background tasks for attempt timeout checking
mydmdm Nov 4, 2025
cbd6498
update error messages
mydmdm Nov 4, 2025
a477bec
only corner cases for status propagation
mydmdm Nov 4, 2025
f7fe24a
all tests passed with some FIXME
mydmdm Nov 4, 2025
1857e39
reuse test_memory.py
mydmdm Nov 5, 2025
5950276
fix typo
mydmdm Nov 5, 2025
70838ee
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm Nov 5, 2025
93814a9
Enhance rollout and attempt handling with type unification and new re…
mydmdm Nov 6, 2025
294a1cc
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm Nov 6, 2025
6131c1f
rename to SqlLightningStore
mydmdm Nov 6, 2025
b6312db
fix pre-commit warnings
mydmdm Nov 6, 2025
c21e065
Update agentlightning/store/database/orm/rollout.py
mydmdm Nov 6, 2025
5983eec
Update agentlightning/store/database/retry_helper.py
mydmdm Nov 6, 2025
3162ed6
fix lint issue
mydmdm Nov 6, 2025
c1b7827
fix lint error
mydmdm Nov 6, 2025
fe2d052
To break the big transaction inside timeout healthy checking
mydmdm Nov 7, 2025
5fc8ee0
fix lint errors
mydmdm Nov 7, 2025
de8df23
Merge branch 'dev/database-store' of https://github.com/microsoft/age…
mydmdm Nov 7, 2025
2eb207e
fix lint warning and update uv.lock
mydmdm Nov 11, 2025
dc2c5e5
Merge branch 'main' into dev/database-store
mydmdm Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,9 @@ cython_debug/

# Claude
.claude/*.local.json

# Temporary and backup files
*.tmp
*.bak
*.backup

2 changes: 2 additions & 0 deletions agentlightning/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from .client_server import LightningStoreClient, LightningStoreServer
from .memory import InMemoryLightningStore
from .threading import LightningStoreThreaded
from .database import DatabaseLightningStore

__all__ = [
"LightningStore",
"LightningStoreClient",
"LightningStoreServer",
"InMemoryLightningStore",
"LightningStoreThreaded",
"DatabaseLightningStore",
]
5 changes: 5 additions & 0 deletions agentlightning/store/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dbstore import DatabaseLightningStore

__all__ = [
"DatabaseLightningStore",
]
615 changes: 615 additions & 0 deletions agentlightning/store/database/dbstore.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions agentlightning/store/database/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft. All rights reserved.

from .base import (
SqlAlchemyBase,
AttemptStatusUpdateMessage,
)

from .rollout import RolloutInDB
from .attempt import AttemptInDB, SpanSeqIdInDB
from .resources import ResourcesUpdateInDB
from .span import SpanInDB

__all__ = [
"SqlAlchemyBase",
"AttemptStatusUpdateMessage",
"RolloutInDB",
"AttemptInDB",
"ResourcesUpdateInDB",
"SpanSeqIdInDB",
"SpanInDB",
]
215 changes: 215 additions & 0 deletions agentlightning/store/database/orm/attempt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
from typing import Any, Dict, List, Optional
import time
import uuid
import hashlib
import logging
from dataclasses import InitVar
from sqlalchemy import String, Integer, Float, JSON
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select

from agentlightning.types import Attempt
from .base import SqlAlchemyBase, AttemptStatusUpdateMessage

logger = logging.getLogger(__name__)


def _generate_attempt_id() -> str:
"""We don't need that long because attempts are limited to rollouts."""
short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:8]
return "at-" + short_id


class AttemptInDB(SqlAlchemyBase):
__tablename__ = "attempts"

rollout_id: Mapped[str] = mapped_column(String, nullable=False)
attempt_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_attempt_id)
sequence_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False)
end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None)
status: Mapped[str] = mapped_column(String, default="preparing", nullable=False)
worker_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None)
last_heartbeat_time: Mapped[Optional[float]] = mapped_column(Float, nullable=False, default_factory=time.time)
attempt_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None)

# addition columns for processing
max_duration: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None) # maximum duration allowed for this attempt in seconds
max_heartbeat_interval: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None) # maximum allowed heartbeat interval in seconds

def as_attempt(self) -> Attempt:
return Attempt(
rollout_id=self.rollout_id,
attempt_id=self.attempt_id,
sequence_id=self.sequence_id,
start_time=self.start_time,
end_time=self.end_time,
status=self.status, # type: ignore
worker_id=self.worker_id,
last_heartbeat_time=self.last_heartbeat_time,
metadata=self.attempt_metadata if self.attempt_metadata is not None else {},
)

def _validate_status_message(self, msg: Dict[str, Any]) -> None:
"""This function validates the status update message from caller.
Raises ValueError if the message is invalid.
"""
if "event" not in msg:
raise ValueError("Status update message must contain 'event' field.")
if "timestamp" not in msg:
msg["timestamp"] = time.time()
if msg["event"] not in [
"user_update", # user update attempt status via dbstore.update_attempt()
"span_received", # new span received
"single_step_timeout", # single step timeout detected (from last span heartbeat)
"overall_timeout", # overall timeout detected
]:
raise ValueError(f"Unsupported event type: {msg['event']}")
if msg["event"] == "user_update" and "new_status" not in msg:
raise ValueError("User update event must contain 'new_status' field.")

def get_finished_statuses(self) -> List[str]:
"""This function returns the list of statuses that are considered finished.
"""
return [
"succeeded",
"failed",
"timeout",
]

def update_status(self, msg: Dict[str, Any]) -> Optional[AttemptStatusUpdateMessage]:
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
"""This function updates the status of the attempt based on the event.
Args:
msg: A dictionary containing the status update message. It must contain an "event" field, and optionally a "new_status" field.
More details about the message format can be found in the `_validate_status_message`() method.
current_time: The current time to use for updating timestamps. If None, uses time.time().
Returns:
A dictionary containing the status update message: {"event": "attempt_status_updated", "old_status": old_status, "new_status": new_status}.
IF no meaningful status update is performed, returns None.
Raises:
ValueError: If the event is not recognized or the status transition is invalid.
NotImplementedError: If the event handling is not implemented for the current status.
RuntimeError: If the new status is not set after processing the event.
"""
self._validate_status_message(msg)
event = msg["event"]
current_time = msg.get("timestamp", time.time())
old_status = self.status
new_status = msg.get("new_status", None)

# Step 1: Determine the new status based on the event and current status
if event == "user_update":
if not new_status:
raise ValueError("new_status must be provided for user_update event.")
elif event == "span_received":
self.last_heartbeat_time = current_time
if old_status in ["preparing", "unresponsive", "running"]:
new_status = "running"
elif old_status in self.get_finished_statuses():
logger.warning(f"Span received after attempt is already in status {self.status}. No status update performed.")
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status {old_status}.")
elif event == "single_step_timeout":
if old_status in ["preparing", "running", ]:
new_status = "unresponsive"
else:
logger.warning(f"Single step timeout detected but attempt is in status {self.status}. No status update performed.")
return # no further status update needed
elif event == "overall_timeout":
if old_status not in self.get_finished_statuses():
new_status = "timeout"
else:
logger.warning(f"Overall timeout detected but attempt is in status {self.status}. No status update performed.")
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status update.")

# Step 2: Update the status
if not new_status:
raise RuntimeError(f"new_status should not be {new_status} after processing event for {event} on status {old_status}.")
if new_status == old_status:
return # no status change
if new_status in self.get_finished_statuses():
# when attempt is finished, set end_time
self.end_time = current_time
self.status = new_status

# Step 3: Return the status update info for further processing
return AttemptStatusUpdateMessage(
attempt_id=self.attempt_id,
rollout_id=self.rollout_id,
timestamp=current_time,
old_status=old_status,
new_status=new_status,
)

@classmethod
async def get_latest_attempt_for_rollout(cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str) -> Optional[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls)
.where(cls.rollout_id == rollout_id)
.order_by(cls.sequence_id.desc())
.limit(1)
)
attempt_obj = result.one_or_none()
if attempt_obj is None:
return None
return attempt_obj.as_attempt()


@classmethod
async def get_attempts_for_rollout(cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str) -> List[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls)
.where(cls.rollout_id == rollout_id)
.order_by(cls.sequence_id.asc())
)
return [attempt.as_attempt() for attempt in result.all()]




class SpanSeqIdInDB(SqlAlchemyBase):
__tablename__ = "span_sequence"

rollout_id: Mapped[str] = mapped_column(nullable=False, primary_key=True)

# FIXME InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting
# attempt_id: Mapped[str] = mapped_column(nullable=False)
attempt_id: InitVar[str] # not mapped column, just for type hinting
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states 'not mapped column, just for type hinting', but InitVar is a dataclass field that passes initialization values without storing them. The comment could be clearer about why this field exists if it's not used in the class implementation.

Suggested change
attempt_id: InitVar[str] # not mapped column, just for type hinting

Copilot uses AI. Check for mistakes.

current_sequence: Mapped[int] = mapped_column(default=1, nullable=False)

# Versioning for optimistic concurrency control
version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
__mapper_args__ = {
"version_id_col": version_id,
# "primary_key": [rollout_id, attempt_id],
# "primary_key": [rollout_id],
}

@classmethod
async def get_next_sequence_id(cls: type[SpanSeqIdInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str, attempt_id: str, external_seq_id: Optional[int] = None) -> int:
"""Get the next sequence ID with retries to handle race conditions.
IF external_seq_id is provided and is greater than current_sequence, set current_sequence to external_seq_id.
"""
async with session_factory() as session:
async with session.begin():
seq_obj = await session.get(cls, rollout_id)
# seq_obj = await session.get(cls, [rollout_id, attempt_id])
if seq_obj is None:
raise ValueError(f"Rollout {rollout_id} not found")
else:
current_seq = external_seq_id if external_seq_id is not None and external_seq_id > seq_obj.current_sequence else seq_obj.current_sequence
seq_obj.current_sequence = current_seq + 1
await session.flush()
return current_seq
Loading