diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/MergingSortedRowDataReader.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/MergingSortedRowDataReader.java new file mode 100644 index 000000000000..228abb50e667 --- /dev/null +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/MergingSortedRowDataReader.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import org.apache.iceberg.BaseScanTaskGroup; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.ScanTaskGroup; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortField; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.io.CloseableGroup; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; +import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SortedMerge; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.ProjectingInternalRow; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; + +/** + * A {@link PartitionReader} that reads multiple sorted files and merges them into a single sorted + * stream using a k-way heap merge ({@link SortedMerge}). + * + *

This reader is used when {@code preserve-data-ordering} is enabled and the task group contains + * multiple files that all have the same sort order. + * + *

Sort key columns absent from the requested projection are temporarily added to the read schema + * so that {@link SortOrderComparators} can access them during the merge. The extra columns are + * stripped from each row before it is returned to Spark. + */ +class MergingSortedRowDataReader implements PartitionReader { + private static final Logger LOG = LoggerFactory.getLogger(MergingSortedRowDataReader.class); + + private final CloseableGroup resources; + private final CloseableIterator mergedIterator; + private final List fileReaders; + // non-null only when sort key columns were added to the read schema beyond what Spark projected + private final ProjectingInternalRow projectingRow; + private InternalRow current; + + MergingSortedRowDataReader(SparkInputPartition partition) { + Table table = partition.table(); + ScanTaskGroup taskGroup = partition.taskGroup(); + Schema projection = partition.projection(); + SortOrder sortOrder = table.sortOrder(); + + int numFiles = taskGroup.tasks().size(); + + Preconditions.checkState( + sortOrder.isSorted(), "Cannot create merging reader for unsorted table %s", table.name()); + Preconditions.checkState( + numFiles > 1, "Merging reader requires multiple files, got %s", numFiles); + + LOG.info( + "Creating merging reader for {} files with sort order {} in table {}", + numFiles, + sortOrder.orderId(), + table.name()); + + // Augment the projected schema with any sort key columns Spark did not request so that + // SortOrderComparators can access every sort key field during the merge. + Schema mergeReadSchema = mergeReadSchema(projection, sortOrder, table); + this.projectingRow = buildProjectingRow(projection, mergeReadSchema); + + this.resources = new CloseableGroup(); + this.fileReaders = + taskGroup.tasks().stream() + .map( + task -> + new RowDataReader( + table, + partition.io(), + new BaseScanTaskGroup<>(ImmutableList.of(task)), + mergeReadSchema, + partition.isCaseSensitive(), + partition.cacheDeleteFilesOnExecutors())) + .toList(); + fileReaders.forEach(resources::addCloseable); + // Wrap each reader as a CloseableIterable and feed into SortedMerge. + List> fileIterables = + fileReaders.stream().map(this::readerToIterable).toList(); + SortedMerge sortedMerge = + new SortedMerge<>(buildComparator(mergeReadSchema, sortOrder), fileIterables); + resources.addCloseable(sortedMerge); + this.mergedIterator = sortedMerge.iterator(); + } + + /** + * Adapts a {@link RowDataReader} to a {@link CloseableIterable} for use with {@link SortedMerge}. + * Each row is copied before it enters the priority queue because Spark's Parquet/ORC readers + * reuse {@link InternalRow} instances for performance. + */ + private CloseableIterable readerToIterable(RowDataReader reader) { + return CloseableIterable.withNoopClose( + () -> + new CloseableIterator<>() { + private boolean advanced = false; + private boolean hasNext = false; + + @Override + public boolean hasNext() { + if (!advanced) { + try { + hasNext = reader.next(); + advanced = true; + } catch (IOException e) { + throw new UncheckedIOException("Failed to advance reader", e); + } + } + return hasNext; + } + + @Override + public InternalRow next() { + if (!advanced) { + hasNext(); + } + advanced = false; + return reader.get().copy(); + } + + @Override + public void close() throws IOException { + reader.close(); + } + }); + } + + @Override + public boolean next() throws IOException { + if (!mergedIterator.hasNext()) { + return false; + } + + InternalRow merged = mergedIterator.next(); + if (projectingRow == null) { + this.current = merged; + } else { + projectingRow.project(merged); + this.current = projectingRow; + } + + return true; + } + + @Override + public InternalRow get() { + return current; + } + + @Override + public void close() throws IOException { + resources.close(); + } + + @Override + public CustomTaskMetric[] currentMetricsValues() { + long totalDeletes = + fileReaders.stream() + .flatMap(reader -> Arrays.stream(reader.currentMetricsValues())) + .filter(metric -> metric instanceof TaskNumDeletes) + .mapToLong(CustomTaskMetric::value) + .sum(); + return new CustomTaskMetric[] { + new TaskNumSplits(fileReaders.size()), new TaskNumDeletes(totalDeletes) + }; + } + + /** + * Builds a comparator for merging {@link InternalRow}s by the given sort order. Uses {@link + * SortOrderComparators} which handles all transform types (identity, bucket, truncate), ASC/DESC + * directions, and null ordering. The two {@link InternalRowWrapper} instances are allocated once + * and reused — {@code wrap()} just updates an internal reference. + */ + private static Comparator buildComparator( + Schema mergeReadSchema, SortOrder sortOrder) { + StructType sparkSchema = SparkSchemaUtil.convert(mergeReadSchema); + Comparator keyComparator = + SortOrderComparators.forSchema(mergeReadSchema, sortOrder); + InternalRowWrapper left = new InternalRowWrapper(sparkSchema, mergeReadSchema.asStruct()); + InternalRowWrapper right = new InternalRowWrapper(sparkSchema, mergeReadSchema.asStruct()); + return (r1, r2) -> keyComparator.compare(left.wrap(r1), right.wrap(r2)); + } + + /** + * Returns a {@link ProjectingInternalRow} that remaps columns from the wider merge schema back to + * the requested projection, or {@code null} if no extra columns were added. + */ + private static ProjectingInternalRow buildProjectingRow(Schema projection, Schema mergeSchema) { + if (projection.columns().size() == mergeSchema.columns().size()) { + return null; + } + + List mergeColumns = mergeSchema.columns(); + List positions = Lists.newArrayListWithCapacity(projection.columns().size()); + + for (int i = 0; i < projection.columns().size(); i++) { + int fieldId = projection.columns().get(i).fieldId(); + boolean found = false; + for (int j = 0; j < mergeColumns.size(); j++) { + if (mergeColumns.get(j).fieldId() == fieldId) { + positions.add(j); + found = true; + break; + } + } + Preconditions.checkState( + found, "Projection field id=%s not found in merge read schema — this is a bug", fieldId); + } + + StructType sparkSchema = SparkSchemaUtil.convert(projection); + return new ProjectingInternalRow(sparkSchema, JavaConverters.asScala(positions).toIndexedSeq()); + } + + /** + * Returns the schema to use when reading each file. This is the requested {@code projection} + * augmented with any sort key columns that are not already present, so the merge comparator can + * access every sort key field regardless of what Spark projected. + */ + private static Schema mergeReadSchema(Schema projection, SortOrder sortOrder, Table table) { + Schema tableSchema = table.schema(); + List missingFields = Lists.newArrayList(); + + for (SortField sortField : sortOrder.fields()) { + int fieldId = sortField.sourceId(); + if (projection.findField(fieldId) == null) { + Types.NestedField tableField = tableSchema.findField(fieldId); + if (tableField != null) { + missingFields.add(tableField); + } + } + } + + if (missingFields.isEmpty()) { + return projection; + } + + return TypeUtil.join(projection, new Schema(missingFields)); + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestMergingSortedRowDataReader.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestMergingSortedRowDataReader.java new file mode 100644 index 000000000000..6ca114940397 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestMergingSortedRowDataReader.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.iceberg.types.Types.NestedField.required; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import org.apache.iceberg.BaseScanTaskGroup; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Files; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.data.FileHelpers; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.TestBase; +import org.apache.iceberg.types.Types; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class TestMergingSortedRowDataReader extends TestBase { + + private static final Schema SCHEMA = + new Schema( + required(1, "id", Types.IntegerType.get()), required(2, "data", Types.StringType.get())); + + private static final PartitionSpec SPEC = PartitionSpec.unpartitioned(); + + private Table table; + + @TempDir private Path temp; + + @BeforeEach + void before() { + table = catalog.createTable(TableIdentifier.of("default", "test_merging_reader"), SCHEMA, SPEC); + table.replaceSortOrder().asc("id").commit(); + } + + @AfterEach + void after() { + catalog.dropTable(TableIdentifier.of("default", "test_merging_reader")); + } + + @Test + void mergeTwoSortedFiles() throws IOException { + DataFile file1 = writeDataFile(record(1, "a"), record(3, "c"), record(5, "e")); + DataFile file2 = writeDataFile(record(2, "b"), record(4, "d"), record(6, "f")); + + table.newAppend().appendFile(file1).appendFile(file2).commit(); + + List rows = readMerged(table); + + assertThat(extractIds(rows)).containsExactly(1, 2, 3, 4, 5, 6); + } + + @Test + void mergeWithDuplicateKeys() throws IOException { + DataFile file1 = writeDataFile(record(1, "a"), record(2, "b")); + DataFile file2 = writeDataFile(record(1, "c"), record(2, "d")); + DataFile file3 = writeDataFile(record(1, "e"), record(3, "f")); + + table.newAppend().appendFile(file1).appendFile(file2).appendFile(file3).commit(); + + List rows = readMerged(table); + + assertThat(extractIds(rows)).containsExactly(1, 1, 1, 2, 2, 3); + } + + @Test + void mergeDescendingOrder() throws IOException { + catalog.dropTable(TableIdentifier.of("default", "test_merging_reader")); + table = catalog.createTable(TableIdentifier.of("default", "test_merging_reader"), SCHEMA, SPEC); + table.replaceSortOrder().desc("id").commit(); + + DataFile file1 = writeDataFile(record(6, "f"), record(4, "d")); + DataFile file2 = writeDataFile(record(5, "e"), record(3, "c"), record(1, "a")); + + table.newAppend().appendFile(file1).appendFile(file2).commit(); + + List rows = readMerged(table); + + assertThat(extractIds(rows)).containsExactly(6, 5, 4, 3, 1); + } + + @Test + void mergeWithNulls() throws IOException { + Schema nullableSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get())); + + catalog.dropTable(TableIdentifier.of("default", "test_merging_reader")); + table = + catalog.createTable( + TableIdentifier.of("default", "test_merging_reader"), nullableSchema, SPEC); + table.replaceSortOrder().asc("id").commit(); + + DataFile file1 = writeDataFile(nullRecord("x"), record(3, "c")); + DataFile file2 = writeDataFile(nullRecord("y"), record(1, "a"), record(2, "b")); + + table.newAppend().appendFile(file1).appendFile(file2).commit(); + + List rows = readMerged(table); + + assertThat(rows).hasSize(5); + assertThat(rows.get(0).isNullAt(0)).isTrue(); + assertThat(rows.get(1).isNullAt(0)).isTrue(); + assertThat(extractIds(rows.subList(2, 5))).containsExactly(1, 2, 3); + } + + @Test + void mergeThreeFiles() throws IOException { + DataFile file1 = writeDataFile(record(1, "a"), record(4, "d"), record(7, "g")); + DataFile file2 = writeDataFile(record(2, "b"), record(5, "e"), record(8, "h")); + DataFile file3 = writeDataFile(record(3, "c"), record(6, "f"), record(9, "i")); + + table.newAppend().appendFile(file1).appendFile(file2).appendFile(file3).commit(); + + List rows = readMerged(table); + + assertThat(extractIds(rows)).containsExactly(1, 2, 3, 4, 5, 6, 7, 8, 9); + } + + private List readMerged(Table tbl) throws IOException { + tbl.refresh(); + + List fileTasks = Lists.newArrayList(); + try (CloseableIterable tasks = tbl.newScan().planFiles()) { + tasks.forEach(fileTasks::add); + } + + assertThat(fileTasks).hasSizeGreaterThan(1); + + BaseScanTaskGroup taskGroup = new BaseScanTaskGroup<>(fileTasks); + + Broadcast tableBroadcast = sparkContext.broadcast(SerializableTableWithSize.copyOf(tbl)); + Broadcast fileIOBroadcast = + sparkContext.broadcast(SerializableFileIOWithSize.wrap(tbl.io())); + + SparkInputPartition partition = + new SparkInputPartition( + Types.StructType.of(), + taskGroup, + tableBroadcast, + fileIOBroadcast, + SchemaParser.toJson(tbl.schema()), + true, + new String[0], + false); + + List rows = Lists.newArrayList(); + try (MergingSortedRowDataReader reader = new MergingSortedRowDataReader(partition)) { + while (reader.next()) { + rows.add(reader.get().copy()); + } + } + + return rows; + } + + private List extractIds(List rows) { + return rows.stream().map(row -> row.isNullAt(0) ? null : row.getInt(0)).toList(); + } + + private Record record(int id, String data) { + GenericRecord record = GenericRecord.create(SCHEMA); + record.set(0, id); + record.set(1, data); + return record; + } + + private Record nullRecord(String data) { + Schema nullableSchema = + new Schema( + Types.NestedField.optional(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get())); + GenericRecord record = GenericRecord.create(nullableSchema); + record.set(0, null); + record.set(1, data); + return record; + } + + private DataFile writeDataFile(Record... records) throws IOException { + return FileHelpers.writeDataFile( + table, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + Lists.newArrayList(records)); + } +}