Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions daft_lance/lance_data_sink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
import pathlib
import uuid
import warnings
from itertools import chain
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -30,6 +32,8 @@

from daft.daft import IOConfig

logger = logging.getLogger(__name__)


class LanceDataSink(DataSink[list[FragmentMetadata]]):
"""WriteSink for writing data to a Lance dataset."""
Expand All @@ -49,6 +53,8 @@ def __init__(
use_legacy_format: bool | None = None,
enable_stable_row_ids: bool = False,
storage_options: dict[str, str] | None = None,
use_mem_wal: bool = False,
compact_after_write: bool = True,
) -> None:
self._reject_unsupported_modes(mode, use_legacy_format)
if not isinstance(uri, (str, pathlib.Path)):
Expand All @@ -72,6 +78,11 @@ def __init__(
self._pyarrow_schema = self._normalize_schema(schema)
self._init_blob_policy(blob_columns)

self._use_mem_wal = use_mem_wal
self._compact_after_write = compact_after_write
self._mem_wal_total_rows: int = 0
self._mem_wal_total_bytes: int = 0

self._version: int = 0
self._table_schema: pa.Schema | None = None
existing = self._absorb_existing_dataset()
Expand Down Expand Up @@ -227,8 +238,49 @@ def _write_arrow_table(self, table: pa.Table) -> WriteResult[list[FragmentMetada
)
return WriteResult(result=fragments, bytes_written=bytes_written, rows_written=wrapped.num_rows)

def _ensure_mem_wal_dataset(self) -> lance.LanceDataset:
try:
ds = lance.dataset(self._table_uri, storage_options=self._storage_options)
except (ValueError, FileNotFoundError, OSError):
ds = None

if ds is None:
ds = lance.write_dataset(
pa.table(
{f.name: pa.array([], type=f.type) for f in self._effective_pyarrow_schema},
schema=self._effective_pyarrow_schema,
),
self._table_uri,
mode="create",
storage_options=self._storage_options,
data_storage_version=self._data_storage_version,
use_legacy_format=self._use_legacy_format,
)

details = ds.mem_wal_index_details()
if details is None or details.get("num_shards", -1) < 0:
ds.initialize_mem_wal(unsharded=True)

return ds

def _write_arrow_table_mem_wal(
self, table: pa.Table, ds: lance.LanceDataset
) -> WriteResult[list[FragmentMetadata]]:
shard_id = str(uuid.uuid4())
with ds.mem_wal_writer(shard_id) as writer:
writer.put(table)
stats = writer.stats()
bytes_written = stats.get("wal_flush_bytes", 0)
return WriteResult(result=[], bytes_written=bytes_written, rows_written=table.num_rows)

def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResult[list[FragmentMetadata]]]:
"""Writes fragments from the given micropartitions."""
if self._use_mem_wal:
yield from self._write_mem_wal(micropartitions)
else:
yield from self._write_cow(micropartitions)

def _write_cow(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResult[list[FragmentMetadata]]]:
buffer = _LanceFragmentBuffer(
max_rows=self._max_rows_per_file,
max_bytes=self._max_bytes_per_file,
Expand All @@ -251,8 +303,25 @@ def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResu
if buffer.has_rows():
yield self._write_arrow_table(buffer.drain())

def _write_mem_wal(
self, micropartitions: Iterator[MicroPartition]
) -> Iterator[WriteResult[list[FragmentMetadata]]]:
ds = self._ensure_mem_wal_dataset()
for micropartition in micropartitions:
arrow_table = self._prepare_arrow_table(micropartition.to_arrow())
wrapped = self._blob.wrap_table(arrow_table)
result = self._write_arrow_table_mem_wal(wrapped, ds)
self._mem_wal_total_rows += wrapped.num_rows
self._mem_wal_total_bytes += result.bytes_written
yield result

def finalize(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
if self._use_mem_wal:
return self._finalize_mem_wal(write_results)
return self._finalize_cow(write_results)

def _finalize_cow(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
fragments = list(chain.from_iterable(write_result.result for write_result in write_results))

operation: lance.LanceOperation.BaseOperation
Expand All @@ -278,6 +347,30 @@ def finalize(self, write_results: list[WriteResult[list[FragmentMetadata]]]) ->
)
return stats_dict

def _finalize_mem_wal(self, write_results: list[WriteResult[list[FragmentMetadata]]]) -> MicroPartition:
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)

if self._compact_after_write:
logger.info(
"MemWAL write complete (%d rows, %d bytes). Running compaction.",
self._mem_wal_total_rows,
self._mem_wal_total_bytes,
)
from daft_lance.lance_compaction import compact_files_internal

compact_files_internal(dataset)
dataset = lance.dataset(self._table_uri, storage_options=self._storage_options)

stats = dataset.stats.dataset_stats()
return MicroPartition.from_pydict(
{
"num_fragments": pa.array([stats["num_fragments"]], type=pa.int64()),
"num_deleted_rows": pa.array([stats["num_deleted_rows"]], type=pa.int64()),
"num_small_files": pa.array([stats["num_small_files"]], type=pa.int64()),
"version": pa.array([dataset.version], type=pa.int64()),
}
)


class _LanceFragmentBuffer:
"""Accumulates pyarrow tables until a row-count or byte-size threshold is hit."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ readme = "README.md"
dependencies = [
"lance-namespace>=0.6.0",
"lance-namespace-urllib3-client>=0.6.0",
"pylance>=6.0.0"
"pylance>=7.0.0"
]

[dependency-groups]
Expand Down
Loading
Loading