Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 211 additions & 2 deletions daft_lance/lance_data_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import warnings
from itertools import chain
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

import lance
from lance.fragment import FragmentMetadata
Expand All @@ -25,9 +25,12 @@
resolve_storage_version,
)

DAFT_IDEMPOTENCE_KEY = "daft.idempotence-key"

if TYPE_CHECKING:
from collections.abc import Iterator

from daft.checkpoint import IdempotentCommit
from daft.daft import IOConfig


Expand All @@ -49,6 +52,7 @@ def __init__(
use_legacy_format: bool | None = None,
enable_stable_row_ids: bool = False,
storage_options: dict[str, str] | None = None,
checkpoint: IdempotentCommit | None = None,
) -> None:
self._reject_unsupported_modes(mode, use_legacy_format)
if not isinstance(uri, (str, pathlib.Path)):
Expand All @@ -62,6 +66,10 @@ def __init__(
if storage_options is not None
else io_config_to_storage_options(self._io_config, self._table_uri)
)
if checkpoint is not None and mode != "append":
raise NotImplementedError("Lance checkpoint writes currently support mode='append' only.")
self._checkpoint = checkpoint
self._checkpoint_enabled = checkpoint is not None
self._init_lance_knobs(
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
Expand Down Expand Up @@ -197,6 +205,28 @@ def name(self) -> str:
def schema(self) -> Schema:
return self._schema

def checkpoint_file_format(self) -> str | None:
"""Return the checkpoint metadata format for this sink.

Daft core calls this while building the DataSink pipeline. Returning
"lance" means Daft should stage this sink's per-input write_results as
Lance checkpoint metadata. Returning None keeps the normal DataSink path.
"""
return "lance" if self._checkpoint_enabled else None

def __getstate__(self) -> dict[str, Any]:
"""Customize how this sink is pickled for Ray workers.

Ray workers call write(), but they do not call finalize(). They only
need _checkpoint_enabled so Daft still stages write_results. The
checkpoint store stays on the driver because it cannot be pickled and
is only needed later by finalize().
"""
state = self.__dict__.copy()
if self._checkpoint_enabled:
state["_checkpoint"] = None
return state

def _prepare_arrow_table(self, input_table: pa.Table) -> pa.Table:
target_schema = self._table_schema if self._table_schema is not None else self._pyarrow_schema
target_schema = self._blob.cast_target_schema(target_schema)
Expand Down Expand Up @@ -251,7 +281,114 @@ def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResu
if buffer.has_rows():
yield self._write_arrow_table(buffer.drain())

def finalize(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
@staticmethod
def _enum_name(value: Any) -> str:
"""Return the name for checkpoint enum values."""
if isinstance(value, str):
return value
name = getattr(value, "name", None)
if isinstance(name, str):
return name
text = str(value)
if "." in text:
return text.rsplit(".", 1)[-1]
raise TypeError(f"expected checkpoint enum or string, got {value!r}")

def _pending_checkpoint_ids(self) -> list[str]:
"""Return checkpoint ids that are sealed but not marked committed.

These ids point to staged Lance write_results. After the Lance append
transaction lands, we mark these ids committed so the store stops
returning their write_results as pending work.
"""
assert self._checkpoint is not None
return [
checkpoint.id
for checkpoint in self._checkpoint.store.list_checkpoints()
if self._enum_name(checkpoint.status) == "Checkpointed"
]

def _checkpointed_write_results(self) -> list[WriteResult[list[FragmentMetadata]]]:
"""Read staged Lance write_results from the checkpoint store.

Daft core stores DataSink output as Arrow IPC MicroPartitions. This
decodes those blobs and returns the original WriteResult objects, whose
result fields contain Lance fragments.
"""
assert self._checkpoint is not None

write_results: list[WriteResult[list[FragmentMetadata]]] = []
for file_metadata in self._checkpoint.store.get_checkpointed_files():
if self._enum_name(file_metadata.format) != "Lance":
raise RuntimeError(
"unexpected checkpoint metadata format for Lance write; "
f"expected Lance, got {file_metadata.format!r}"
)

try:
micropartition = MicroPartition.from_ipc_stream(file_metadata.data)
payload = micropartition.to_pydict()
write_results.extend(payload["write_results"])
except Exception as e:
raise RuntimeError(
"failed to decode Lance write_results from checkpoint store; "
"expected an Arrow IPC MicroPartition with a python `write_results` column"
) from e
return write_results

@staticmethod
def _dataset_version(dataset: Any) -> int:
"""Return the current Lance dataset version number."""
return int(getattr(dataset, "version", getattr(dataset, "latest_version", 0)))

def _dataset_stats_result(self, dataset: Any) -> MicroPartition:
"""Build the MicroPartition returned by write_lance."""
stats = dataset.stats.dataset_stats()
return MicroPartition.from_pydict(
{
"num_fragments": pa.array([stats["num_fragments"]], type=pa.int64()),
"num_deleted_rows": pa.array([stats["num_deleted_rows"]], type=pa.int64()),
"num_small_files": pa.array([stats["num_small_files"]], type=pa.int64()),
"version": pa.array([self._dataset_version(dataset)], type=pa.int64()),
}
)

def _idempotence_key_exists(self, dataset: Any) -> bool:
"""Return True if Lance already has a transaction for this commit key.

Checkpointed Lance commits write DAFT_IDEMPOTENCE_KEY into transaction
properties. Retries scan the history for the same value to avoid
appending the same fragments twice.
"""
assert self._checkpoint is not None

version = max(self._dataset_version(dataset), 1)
for transaction in dataset.get_transactions(version):
if transaction is None:
continue
properties = getattr(transaction, "transaction_properties", None) or {}
if properties.get(DAFT_IDEMPOTENCE_KEY) == self._checkpoint.idempotence_key:
return True
return False

def checkpoint_commit_exists(self) -> bool:
"""Driver-side pre-check used before running the write pipeline.

DataFrame.write_lance calls this before write_sink(). If the Lance
transaction history already has this idempotence key, the logical commit
has already landed. In that case, this retry does not run the pipeline,
so there are no current-run write_results. DataFrame.write_lance still
calls finalize([]), which lets the sink mark pending checkpoint ids from
the previous attempt committed.
"""
if not self._checkpoint_enabled:
return False
if self._checkpoint is None:
raise RuntimeError("checking a Lance checkpoint commit requires the driver-side CheckpointStore")
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
return self._idempotence_key_exists(dataset)

def _finalize_uncheckpointed(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
fragments = list(chain.from_iterable(write_result.result for write_result in write_results))

Expand All @@ -278,6 +415,78 @@ def finalize(self, write_results: list[WriteResult[list[FragmentMetadata]]]) ->
)
return stats_dict

def _finalize_checkpointed(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
"""Finalize a checkpointed append using write_results from the store.

The current run's write_results are not the source of truth. A retry may
skip inputs that were already checkpointed, so pending fragments must be
recovered from the checkpoint store.
"""
assert self._checkpoint is not None

dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)
pending_ids = self._pending_checkpoint_ids()

# Check the Lance history again at finalize time. The pre-check in
# DataFrame.write_lance runs before the pipeline starts, so it can miss:
# 1. A concurrent retry that commits the same idempotence key while this
# retry is still running.
# 2. A previous attempt that passed the pre-check, committed to Lance,
# and then crashed before mark_committed().
if self._idempotence_key_exists(dataset):
if pending_ids:
self._checkpoint.store.mark_committed(pending_ids)
return self._dataset_stats_result(dataset)

checkpointed_write_results = self._checkpointed_write_results()
if not checkpointed_write_results:
if write_results:
raise RuntimeError(
"write_lance checkpoint did not stage any Lance write_results. "
"Read the source with daft.CheckpointConfig using the same CheckpointStore."
)
return self._dataset_stats_result(dataset)

fragments = list(chain.from_iterable(write_result.result for write_result in checkpointed_write_results))
# It is possible to recover WriteResult objects whose result lists contain no
# Lance fragments. For example, an input can reach the sink after all
# rows were filtered out, so Daft still has a completed checkpoint
# boundary but Lance did not create any physical fragment files.
#
# In that case there is no data change to append. We avoid creating a
# no-op Lance transaction because it would add a new table version only
# to record that nothing was appended. The useful recovery action is to
# mark the sealed checkpoint ids committed so these inputs are not
# returned as pending work again.
if not fragments:
if pending_ids:
self._checkpoint.store.mark_committed(pending_ids)
return self._dataset_stats_result(dataset)

operation = lance.LanceOperation.Append(fragments)
transaction = lance.Transaction(
self._dataset_version(dataset),
operation,
transaction_properties={DAFT_IDEMPOTENCE_KEY: self._checkpoint.idempotence_key},
)
committed_dataset = lance.LanceDataset.commit(
self._table_uri,
transaction,
storage_options=self._storage_options,
)

if pending_ids:
self._checkpoint.store.mark_committed(pending_ids)
return self._dataset_stats_result(committed_dataset)

def finalize(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
if self._checkpoint_enabled:
if self._checkpoint is None:
raise RuntimeError("checkpointed Lance finalization requires the driver-side CheckpointStore")
return self._finalize_checkpointed(write_results)
return self._finalize_uncheckpointed(write_results)


class _LanceFragmentBuffer:
"""Accumulates pyarrow tables until a row-count or byte-size threshold is hit."""
Expand Down
Loading
Loading