Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
47b0a4e
initial version with all dbstore interface implemented, but not all t…
mydmdm Oct 31, 2025
b8940fe
interface implemented, but still have some issues for the periodical …
mydmdm Nov 3, 2025
96a0d58
configurable retry added
mydmdm Nov 3, 2025
a0b6833
support periodic background tasks for attempt timeout checking
mydmdm Nov 4, 2025
cbd6498
update error messages
mydmdm Nov 4, 2025
a477bec
only corner cases for status propagation
mydmdm Nov 4, 2025
f7fe24a
all tests passed with some FIXME
mydmdm Nov 4, 2025
1857e39
reuse test_memory.py
mydmdm Nov 5, 2025
5950276
fix typo
mydmdm Nov 5, 2025
70838ee
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm Nov 5, 2025
93814a9
Enhance rollout and attempt handling with type unification and new re…
mydmdm Nov 6, 2025
294a1cc
Merge remote-tracking branch 'origin/main' into dev/database-store
mydmdm Nov 6, 2025
6131c1f
rename to SqlLightningStore
mydmdm Nov 6, 2025
b6312db
fix pre-commit warnings
mydmdm Nov 6, 2025
c21e065
Update agentlightning/store/database/orm/rollout.py
mydmdm Nov 6, 2025
5983eec
Update agentlightning/store/database/retry_helper.py
mydmdm Nov 6, 2025
3162ed6
fix lint issue
mydmdm Nov 6, 2025
c1b7827
fix lint error
mydmdm Nov 6, 2025
fe2d052
To break the big transaction inside timeout healthy checking
mydmdm Nov 7, 2025
5fc8ee0
fix lint errors
mydmdm Nov 7, 2025
de8df23
Merge branch 'dev/database-store' of https://github.com/microsoft/age…
mydmdm Nov 7, 2025
2eb207e
fix lint warning and update uv.lock
mydmdm Nov 11, 2025
dc2c5e5
Merge branch 'main' into dev/database-store
mydmdm Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ cython_debug/
# Claude
.claude/*.local.json

# Temporary and backup files
*.tmp
*.bak
*.backup

# Dashboard generated files
agentlightning/dashboard/**/*.css
agentlightning/dashboard/**/*.js
Expand Down
2 changes: 2 additions & 0 deletions agentlightning/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .base import LightningStore, LightningStoreCapabilities
from .client_server import LightningStoreClient, LightningStoreServer
from .database import SqlLightningStore
from .memory import InMemoryLightningStore
from .threading import LightningStoreThreaded

Expand All @@ -12,4 +13,5 @@
"LightningStoreServer",
"InMemoryLightningStore",
"LightningStoreThreaded",
"SqlLightningStore",
]
9 changes: 5 additions & 4 deletions agentlightning/store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional, Sequence, TypedDict
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, TypedDict

from opentelemetry.sdk.trace import ReadableSpan

Expand Down Expand Up @@ -267,7 +267,7 @@ async def add_otel_span(

async def query_rollouts(
self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None
) -> List[Rollout]:
) -> List[Union[Rollout, AttemptedRollout]]:
"""Retrieve rollouts filtered by status and/or explicit identifiers.

Args:
Expand Down Expand Up @@ -297,7 +297,7 @@ async def query_attempts(self, rollout_id: str) -> List[Attempt]:
"""
raise NotImplementedError()

async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]:
async def get_rollout_by_id(self, rollout_id: str) -> Optional[Union[Rollout, AttemptedRollout]]:
"""Fetch a rollout by identifier without mutating its state.

Args:
Expand Down Expand Up @@ -457,6 +457,8 @@ async def update_resources(self, resources_id: str, resources: NamedResources) -
This API is typically used by algorithms that maintain mutable resources (e.g., model
checkpoints) under a stable identifier.

If `resources_id` does not exist, implementations should add it as a new snapshot.

Args:
resources_id: Identifier of the snapshot to replace.
resources: Updated mapping of resource names to payloads.
Expand All @@ -466,7 +468,6 @@ async def update_resources(self, resources_id: str, resources: NamedResources) -

Raises:
NotImplementedError: Subclasses must implement resource persistence.
ValueError: Implementations must raise when `resources_id` does not exist.
"""
raise NotImplementedError()

Expand Down
7 changes: 7 additions & 0 deletions agentlightning/store/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from .sqlite import SqlLightningStore

__all__ = [
"SqlLightningStore",
]
20 changes: 20 additions & 0 deletions agentlightning/store/database/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.

from .attempt import AttemptInDB, SpanSeqIdInDB
from .base import (
AttemptStatusUpdateMessage,
SqlAlchemyBase,
)
from .resources import ResourcesUpdateInDB
from .rollout import RolloutInDB
from .span import SpanInDB

__all__ = [
"SqlAlchemyBase",
"AttemptStatusUpdateMessage",
"RolloutInDB",
"AttemptInDB",
"ResourcesUpdateInDB",
"SpanSeqIdInDB",
"SpanInDB",
]
251 changes: 251 additions & 0 deletions agentlightning/store/database/orm/attempt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) Microsoft. All rights reserved.

from __future__ import annotations

import hashlib
import logging
import time
import uuid
from dataclasses import InitVar
from typing import Any, Dict, List, Optional

from sqlalchemy import JSON, Float, Integer, String, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapped, mapped_column

from agentlightning.types import Attempt

from .base import AttemptStatusUpdateMessage, SqlAlchemyBase

logger = logging.getLogger(__name__)


def _generate_attempt_id() -> str:
"""We don't need that long because attempts are limited to rollouts."""
short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:8]
return "at-" + short_id


class AttemptInDB(SqlAlchemyBase):
__tablename__ = "attempts"

rollout_id: Mapped[str] = mapped_column(String, nullable=False)
attempt_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_attempt_id)
sequence_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False)
end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None)
status: Mapped[str] = mapped_column(String, default="preparing", nullable=False)
worker_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None)
last_heartbeat_time: Mapped[Optional[float]] = mapped_column(Float, nullable=False, default_factory=time.time)
attempt_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None)

# addition columns for processing
max_duration: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, default=None
) # maximum duration allowed for this attempt in seconds
max_heartbeat_interval: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, default=None
) # maximum allowed heartbeat interval in seconds

version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
__mapper_args__ = {
"version_id_col": version_id,
}

def is_unresponsive(self, current_time: float) -> bool:
"""Check if the attempt is unresponsive based on the last heartbeat time and max_heartbeat_interval."""
if self.max_heartbeat_interval is None:
return False
if self.last_heartbeat_time is None:
return False
return (current_time - self.last_heartbeat_time) > self.max_heartbeat_interval

def is_timed_out(self, current_time: float) -> bool:
"""Check if the attempt has timed out based on the start time and max_duration."""
if self.max_duration is None:
return False
return (current_time - self.start_time) > self.max_duration

def as_attempt(self) -> Attempt:
return Attempt(
**self.model_dump(
exclude={"max_duration", "max_heartbeat_interval", "version_id"},
mapper={"metadata": lambda obj: obj.attempt_metadata}, # type: ignore
)
)

def _validate_status_message(self, msg: Dict[str, Any]) -> None:
"""This function validates the status update message from caller.
Raises ValueError if the message is invalid.
"""
if "event" not in msg:
raise ValueError("Status update message must contain 'event' field.")
if "timestamp" not in msg:
msg["timestamp"] = time.time()
if msg["event"] not in [
"user_update", # user update attempt status via dbstore.update_attempt()
"span_received", # new span received
"single_step_timeout", # single step timeout detected (from last span heartbeat)
"overall_timeout", # overall timeout detected
]:
raise ValueError(f"Unsupported event type: {msg['event']}")
if msg["event"] == "user_update" and "new_status" not in msg:
raise ValueError("User update event must contain 'new_status' field.")

def get_finished_statuses(self) -> List[str]:
"""This function returns the list of statuses that are considered finished."""
return [
"succeeded",
"failed",
"timeout",
]

def update_status(self, msg: Dict[str, Any]) -> Optional[AttemptStatusUpdateMessage]:
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
"""This function updates the status of the attempt based on the event.
Args:
msg: A dictionary containing the status update message. It must contain an "event" field, and optionally a "new_status" field.
More details about the message format can be found in the `_validate_status_message`() method.
current_time: The current time to use for updating timestamps. If None, uses time.time().
Returns:
A dictionary containing the status update message: {"event": "attempt_status_updated", "old_status": old_status, "new_status": new_status}.
IF no meaningful status update is performed, returns None.
Raises:
ValueError: If the event is not recognized or the status transition is invalid.
NotImplementedError: If the event handling is not implemented for the current status.
RuntimeError: If the new status is not set after processing the event.
"""
self._validate_status_message(msg)
event = msg["event"]
current_time = msg.get("timestamp", time.time())
old_status = self.status
new_status = msg.get("new_status", None)

# Step 1: Determine the new status based on the event and current status
if event == "user_update":
if not new_status:
raise ValueError("new_status must be provided for user_update event.")
elif event == "span_received":
self.last_heartbeat_time = current_time
if old_status in ["preparing", "unresponsive", "running"]:
new_status = "running"
elif old_status in self.get_finished_statuses():
logger.warning(
f"Span received after attempt is already in status {self.status}. No status update performed."
)
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status {old_status}.")
elif event == "single_step_timeout":
if old_status in [
"preparing",
"running",
]:
new_status = "unresponsive"
else:
logger.warning(
f"Single step timeout detected but attempt is in status {self.status}. No status update performed."
)
return # no further status update needed
elif event == "overall_timeout":
if old_status not in self.get_finished_statuses():
new_status = "timeout"
else:
logger.warning(
f"Overall timeout detected but attempt is in status {self.status}. No status update performed."
)
return # no further status update needed
else:
raise NotImplementedError(f"Event {event} is not implemented for status update.")

# Step 2: Update the status
if not new_status:
raise RuntimeError(
f"new_status should not be {new_status} after processing event for {event} on status {old_status}."
)
if new_status == old_status:
return # no status change
if new_status in self.get_finished_statuses():
# when attempt is finished, set end_time
self.end_time = current_time
self.status = new_status

# Step 3: Return the status update info for further processing
return AttemptStatusUpdateMessage(
attempt_id=self.attempt_id,
rollout_id=self.rollout_id,
timestamp=current_time,
old_status=old_status,
new_status=new_status,
)

@classmethod
async def get_latest_attempt_for_rollout(
cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str
) -> Optional[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.desc()).limit(1)
)
attempt_obj = result.one_or_none()
if attempt_obj is None:
return None
return attempt_obj.as_attempt()

@classmethod
async def get_attempts_for_rollout(
cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str
) -> List[Attempt]:
async with session_factory() as session:
async with session.begin():
result = await session.scalars(
select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.asc())
)
return [attempt.as_attempt() for attempt in result.all()]


class SpanSeqIdInDB(SqlAlchemyBase):
__tablename__ = "span_sequence"

rollout_id: Mapped[str] = mapped_column(nullable=False, primary_key=True)

# FIXME InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting
# attempt_id: Mapped[str] = mapped_column(nullable=False)
attempt_id: InitVar[str] # not mapped column, just for type hinting

current_sequence: Mapped[int] = mapped_column(default=1, nullable=False)

# Versioning for optimistic concurrency control
version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
__mapper_args__ = {
"version_id_col": version_id,
# "primary_key": [rollout_id, attempt_id],
# "primary_key": [rollout_id],
}

@classmethod
async def get_next_sequence_id(
cls: type[SpanSeqIdInDB],
session_factory: async_sessionmaker[AsyncSession],
rollout_id: str,
attempt_id: str,
external_seq_id: Optional[int] = None,
) -> int:
"""Get the next sequence ID with retries to handle race conditions.
IF external_seq_id is provided and is greater than current_sequence, set current_sequence to external_seq_id.
"""
async with session_factory() as session:
async with session.begin():
seq_obj = await session.get(cls, rollout_id)
# seq_obj = await session.get(cls, [rollout_id, attempt_id])
if seq_obj is None:
raise ValueError(f"Rollout {rollout_id} not found")
else:
current_seq = (
external_seq_id
if external_seq_id is not None and external_seq_id > seq_obj.current_sequence
else seq_obj.current_sequence
)
seq_obj.current_sequence = current_seq + 1
await session.flush()
return current_seq
Loading
Loading