diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 225c21fb18b07..ea7af01121a1b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -620,6 +620,8 @@ private[serializer] object KryoSerializer { "org.apache.spark.sql.columnar.CachedBatchSerializer", "org.apache.spark.sql.columnar.SimpleMetricsCachedBatchSerializer", "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatch", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer", "org.apache.spark.ml.attribute.Attribute", "org.apache.spark.ml.attribute.AttributeGroup", diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index b28cf2076f544..c3f1f05ac117d 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -57,6 +57,8 @@ subitems: - text: Caching Data url: sql-performance-tuning.html#caching-data + - text: Arrow Cache Format + url: sql-arrow-cache-format.html - text: Tuning Partitions url: sql-performance-tuning.html#tuning-partitions - text: Leveraging Statistics diff --git a/docs/sql-arrow-cache-format.md b/docs/sql-arrow-cache-format.md new file mode 100644 index 0000000000000..6e8be90bbc5aa --- /dev/null +++ b/docs/sql-arrow-cache-format.md @@ -0,0 +1,312 @@ +# Apache Arrow Cache Format for Spark + +## Overview + +Apache Spark supports using Apache Arrow as an alternative cache format for in-memory Dataset caching. This format provides improved performance for certain workloads, especially when working with columnar data sources like Parquet and ORC. + +## Benefits + +The Arrow cache format offers several advantages over the default cache format: + +- **Zero-copy reads** when input is already in Arrow format (e.g., Arrow-based data sources, re-caching Arrow cached data) +- **Better filter pushdown** with min/max statistics for partition pruning +- **Off-heap memory management** via Arrow allocators +- **Efficient compression** with zstd and lz4 codecs +- **Arrow ecosystem interoperability** for data sharing + +**Note**: Spark's built-in Parquet/ORC readers use internal column vectors (`OnHeapColumnVector`/`OffHeapColumnVector`), not Arrow format, so they don't benefit from zero-copy optimization. + +## Configuration + +To enable Arrow cache format, set the static configuration: + +```scala +spark.conf.set("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") +``` + +**Note**: This is a static configuration that must be set before the SparkSession is created. + +```scala +val spark = SparkSession.builder() + .appName("MyApp") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() +``` + +## Usage + +Once configured, use cache operations as normal: + +```scala +// Cache a DataFrame +val df = spark.read.parquet("data.parquet") +df.cache() + +// Use cached data +df.filter("age > 30").count() + +// Uncache when done +df.unpersist() +``` + +## Compression + +Arrow cache supports multiple compression codecs. Configure compression with: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") +``` + +Available options: +- `none` - No compression (fastest, largest size) +- `lz4` - LZ4 compression (fast, good compression) +- `zstd` - Zstandard compression (slower, best compression, **default**) + +For zstd, you can also configure the compression level: + +```scala +spark.conf.set("spark.sql.execution.arrow.compression.level", "3") // Default: 3, Range: 1-22 +``` + +## Vectorized Reader + +Enable vectorized reading for better performance with primitive types: + +```scala +spark.conf.set("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") +``` + +When enabled, cached data is read as columnar batches instead of rows, which can significantly improve performance for columnar operations. + +## Performance Characteristics + +In our benchmarks, the Arrow cache format performs best on the following workloads. Actual +results depend on data types, compression settings, and hardware, and the default cache format +can be faster in some cases (for example, with higher compression levels): + +1. **Filter-Heavy Workloads**: Queries with selective filters benefit from min/max statistics. +2. **Columnar Operations**: Aggregations and projections on cached data benefit from the Arrow format. +3. **Parquet/ORC Caching**: Arrow's batch processing helps even without the zero-copy path. +4. **Re-caching with Column Projection**: Dropping columns from Arrow-cached data preserves the + `ArrowColumnVector` format, enabling true zero-copy extraction and the largest gains. + +### Benchmark Results + +The numbers below are illustrative results from one run on an Apple M4 Max (OpenJDK 21.0.8) and +will vary with hardware, JDK, and compression settings. They are not a guarantee. For the +authoritative, regularly regenerated numbers, see +`sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt` and the `ArrowCacheBenchmark` suite. + +| Workload | Default Cache | Arrow Cache | Speedup | +|----------|--------------|-------------|---------| +| Write + Read (5M rows, 3 primitive columns) | 153.7 ns/row | 74.2 ns/row | **~2X faster** | +| Filter with stats (5M rows) | 100.1 ns/row | 70.8 ns/row | **~1.4X faster** | +| Columnar input from Parquet (2M rows, 3 primitive columns) | 195.3 ns/row | 113.1 ns/row | **~1.7X faster** | +| Re-cache with zero-copy (2M rows, 2 columns) | 123.3 ns/row | 38.5 ns/row | **~3.2X faster** | + +**Notes**: +- **Write + Read**: Significant improvement from efficient Arrow serialization and vectorized operations +- **Filter improvement**: Comes from min/max statistics enabling batch skipping during partition pruning +- **Parquet caching**: Shows improvement despite Spark's Parquet reader producing `OnHeapColumnVector`/`OffHeapColumnVector` rather than `ArrowColumnVector`, due to Arrow's efficient batch processing +- **Re-cache with zero-copy**: When caching a subset of columns from Arrow-cached data (e.g., `df.drop("column")`), the remaining columns preserve their `ArrowColumnVector` format, enabling true zero-copy extraction and achieving the best performance +- **Zero-copy benefits** only apply when input is already `ArrowColumnVector` (e.g., Python Arrow sources, re-caching Arrow cached data with column projection) + +## Supported Data Types + +Arrow cache supports all Spark SQL data types: + +### Primitive Types +- BooleanType +- ByteType, ShortType, IntegerType, LongType +- FloatType, DoubleType +- DecimalType (all precision/scale combinations) + +### Temporal Types +- DateType +- TimestampType +- TimestampNTZType + +### String and Binary +- StringType +- BinaryType + +### Complex Types +- ArrayType +- StructType +- MapType +- Nested combinations of the above + +## Statistics and Filter Pushdown + +Arrow cache automatically collects min/max statistics for the following types: +- Boolean +- Numeric types (Byte, Short, Int, Long, Float, Double) +- Decimal +- Date, Timestamp, and Timestamp without time zone (TIMESTAMP_NTZ) +- Time +- Year-month and day-time intervals +- String (using collation-aware comparison for collated strings) + +Other types (Binary, Variant, calendar intervals, and complex types such as +Array/Struct/Map) are cached but do not contribute min/max bounds, so they only +record null counts and sizes. + +These statistics enable partition pruning when filtering: + +```scala +val df = spark.range(10000000).cache() + +// This filter can skip batches using min/max statistics +df.filter("id > 5000000").count() +``` + +## Memory Management + +Arrow cache uses off-heap memory managed by Apache Arrow allocators. This is a fundamental design choice in Apache Arrow and is not configurable for on-heap memory. + +**Memory Efficiency**: +- Despite requiring off-heap memory, Arrow cache is often **more memory-efficient** than default cache: + - Efficient compression with zstd/lz4 codecs + - Compact columnar format without Java object overhead + - Better compression ratios, especially for strings and complex types +- If you have limited off-heap memory, increase `spark.executor.memoryOverhead` to allocate more off-heap memory + +**Memory Cleanup**: +Arrow memory is automatically cleaned up when: +- Tasks complete +- DataFrames are unpersisted +- SparkSession is stopped + +You can monitor Arrow memory usage through Spark metrics and the Spark UI. + +## Limitations and Considerations + +1. **Static Configuration**: Cache serializer must be set before SparkSession creation +2. **Memory Overhead**: Arrow format has small per-batch overhead +3. **Compatibility**: Cannot mix cache formats - recache needed when switching +4. **Compression Trade-off**: Higher compression = lower memory but slower reads + +## Migration from Default Cache + +To migrate from default cache to Arrow cache: + +1. **Stop your SparkSession** +2. **Uncache all DataFrames** (optional but recommended) +3. **Update SparkSession configuration**: + ```scala + val spark = SparkSession.builder() + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .getOrCreate() + ``` +4. **Recache your DataFrames** + +**Note**: Existing cached data will be invalidated when changing cache format. + +## Troubleshooting + +### Out of Memory Errors + +If you encounter OOM errors with Arrow cache: + +1. Reduce batch size: + ```scala + spark.conf.set("spark.sql.arrow.maxRecordsPerBatch", "5000") // Default: 10000 + ``` + +2. Enable compression: + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + ``` + +3. Reduce compression level: + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + ``` + +### Slow Performance + +If Arrow cache is slower than expected: + +1. Enable vectorized reader: + ```scala + spark.conf.set("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") + ``` + +2. Try different compression codec: + ```scala + spark.conf.set("spark.sql.execution.arrow.compression.codec", "lz4") // Faster than zstd + ``` + +3. Increase batch size (if memory allows): + ```scala + spark.conf.set("spark.sql.arrow.maxRecordsPerBatch", "20000") + ``` + +## Configuration Reference + +| Configuration | Default | Description | +|---------------|---------|-------------| +| `spark.sql.cache.serializer` | DefaultCachedBatchSerializer | Cache format serializer class | +| `spark.sql.execution.arrow.compression.codec` | `zstd` | Compression codec (none, lz4, zstd) | +| `spark.sql.execution.arrow.compression.level` | `3` | Zstd compression level (1-22) | +| `spark.sql.arrow.maxRecordsPerBatch` | `10000` | Maximum rows per Arrow batch | +| `spark.sql.inMemoryColumnarStorage.enableVectorizedReader` | `true` | Enable vectorized cache reading | + +## Example: Complete Application + +```scala +import org.apache.spark.sql.SparkSession + +object ArrowCacheExample { + def main(args: Array[String]): Unit = { + // Create SparkSession with Arrow cache + val spark = SparkSession.builder() + .appName("ArrowCacheExample") + .master("local[*]") + .config("spark.sql.cache.serializer", + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + .config("spark.sql.execution.arrow.compression.codec", "zstd") + .config("spark.sql.inMemoryColumnarStorage.enableVectorizedReader", "true") + .getOrCreate() + + try { + // Read columnar data source + val df = spark.read.parquet("large_dataset.parquet") + + // Cache with Arrow format + df.cache() + + // Queries benefit from zero-copy reads and statistics + val result1 = df.filter("age > 30").select("name", "age").count() + println(s"Filtered count: $result1") + + val result2 = df.groupBy("country").agg(sum("sales")).collect() + println(s"Aggregation result: ${result2.mkString(", ")}") + + // Uncache when done + df.unpersist() + + } finally { + spark.stop() + } + } +} +``` + +## Best Practices + +1. **Use with Columnar Sources**: Maximum benefit with Parquet/ORC +2. **Enable Statistics**: Let Arrow cache collect min/max for filter pushdown +3. **Monitor Memory**: Watch off-heap memory usage in production +4. **Test First**: Benchmark your workload before production deployment +5. **Compression**: Start with `lz4` for balanced performance +6. **Vectorization**: Enable vectorized reader for primitive-heavy workloads + +## Further Reading + +- [Apache Arrow Project](https://arrow.apache.org/) +- [Spark Caching Documentation](https://spark.apache.org/docs/latest/sql-performance-tuning.html#caching-data-in-memory) +- [Arrow IPC Format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index ebae89cbe5e0f..c7da8d11bf04a 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -32,6 +32,10 @@ memory usage and GC pressure. You can call `spark.catalog.uncacheTable("tableNam To list relations cached with an explicit name, use `spark.catalog.listCachedTables()`. Entries cached only via `Dataset.cache()` without a name are not included. +Spark supports two cache formats: +- **Default cache format**: The standard in-memory columnar cache (used by default). +- **Arrow cache format**: An Apache Arrow-based cache that can improve read performance for columnar workloads and enables Arrow ecosystem interoperability. See [Arrow Cache Format documentation](sql-arrow-cache-format.html) for details and configuration. + Configuration of in-memory caching can be done via `spark.conf.set` or by running `SET key=value` commands using SQL. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 1c1024fc0152e..3e5ba7eb21400 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -38,6 +38,50 @@ private[sql] object ArrowUtils { // todo: support more types. + /** + * Check if a Spark DataType is supported by Arrow. This recursively checks complex types + * (Array, Struct, Map). + * + * Note: This checks compatibility with toArrowField(), not toArrowType(). Types like + * GeometryType, GeographyType, and VariantType are not supported by toArrowType() (which only + * handles primitive Arrow types), but ARE supported by toArrowField() which converts them to + * Arrow Struct representations with metadata. Since Arrow cache uses toArrowField() via + * toArrowSchema() to create the schema, these types are supported. + */ + def isSupportedByArrow(dt: DataType): Boolean = { + dt match { + // Primitive types + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: StringType | BinaryType | NullType => + true + + // Decimal + case _: DecimalType => true + + // Temporal types + case DateType | TimestampType | TimestampNTZType | _: TimeType => true + + // Interval types + case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => true + + // Complex types - recursively check element types + case ArrayType(elementType, _) => isSupportedByArrow(elementType) + case StructType(fields) => fields.forall(f => isSupportedByArrow(f.dataType)) + case MapType(keyType, valueType, _) => + isSupportedByArrow(keyType) && isSupportedByArrow(valueType) + + // Special types + // Note: These are not in toArrowType(), but are handled by toArrowField() + case udt: UserDefinedType[_] => isSupportedByArrow(udt.sqlType) + case _: GeometryType => true // Converted to Struct with srid + wkb fields + case _: GeographyType => true // Converted to Struct with srid + wkb fields + case _: VariantType => true // Converted to Struct with value + metadata fields + + // Unsupported types + case _ => false + } + } + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = TypeApiOps(dt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a21653a011b34..7fb32da7be253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4681,6 +4681,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ARROW_CACHE_PREFETCH_ENABLED = + buildConf("spark.sql.execution.arrow.cache.prefetch.enabled") + .doc("When true, Arrow cache read path prefetches and decompresses the next batch " + + "in a background thread while the current batch is being consumed. This can " + + "significantly improve read performance for compressed Arrow caches (e.g., ZSTD) " + + "by overlapping decompression with consumption. Increases memory usage by up to " + + "one additional batch worth of Arrow vectors.") + .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch") .doc("When using TransformWithState in PySpark (both Python Row and Pandas), limit " + @@ -8374,6 +8386,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowPySparkUDFColumnarInputEnabled: Boolean = getConf(ARROW_PYSPARK_UDF_COLUMNAR_INPUT_ENABLED) + def arrowCachePrefetchEnabled: Boolean = getConf(ARROW_CACHE_PREFETCH_ENABLED) + def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int = getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 25114a93f04cc..66f211390125f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -156,7 +156,10 @@ object StaticSQLConf { "org.apache.spark.sql.columnar.CachedBatchSerializer. It will be used to " + "translate SQL data into a format that can more efficiently be cached. The underlying " + "API is subject to change so use with caution. Multiple classes cannot be specified. " + - "The class must have a no-arg constructor.") + "The class must have a no-arg constructor. Available implementations include: " + + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer (default) and " + + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer (Arrow format with " + + "zero-copy columnar reads and better Arrow ecosystem interoperability).") .version("3.1.0") .stringConf .createWithDefault("org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..db9250e8a6e91 --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-jdk21-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1811 1914 147 2.8 362.1 1.0X +Default cache - write + read (uncompressed) 1192 1193 1 4.2 238.3 1.5X +Arrow cache - write + read 1231 1255 33 4.1 246.3 1.5X +Arrow cache - write + read (zstd level -1) 1730 1746 23 2.9 346.0 1.0X +Arrow cache - write + read (zstd level 1) 1758 1763 8 2.8 351.5 1.0X +Arrow cache - write + read (zstd level 3) 1739 1755 23 2.9 347.7 1.0X + + +================================================================================================ +Cache with filter pushdown +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows + filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1556 1593 53 3.2 311.1 1.0X +Default cache - filter (uncompressed) 1270 1316 65 3.9 254.0 1.2X +Arrow cache - filter (with stats) 1341 1358 24 3.7 268.3 1.2X +Arrow cache - filter (zstd level -1) 1737 1757 29 2.9 347.3 0.9X +Arrow cache - filter (zstd level 1) 1839 1844 7 2.7 367.8 0.8X +Arrow cache - filter (zstd level 3) 1827 1829 4 2.7 365.3 0.9X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1498 1551 74 1.3 749.1 1.0X +Default cache - columnar input (uncompressed) 1263 1292 41 1.6 631.4 1.2X +Arrow cache - columnar input 1345 1345 1 1.5 672.4 1.1X +Arrow cache - columnar input (zstd level -1) 1655 1669 21 1.2 827.3 0.9X +Arrow cache - columnar input (zstd level 1) 1627 1674 67 1.2 813.3 0.9X +Arrow cache - columnar input (zstd level 3) 1657 1672 21 1.2 828.5 0.9X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 391 437 38 5.1 195.6 1.0X +Default cache - cache a cached DF (uncompressed) 201 223 31 10.0 100.3 2.0X +Arrow cache - cache a cached DF (zero-copy) 163 177 11 12.3 81.3 2.4X +Arrow cache - cache a cached DF (zstd level -1) 361 369 9 5.5 180.6 1.1X +Arrow cache - cache a cached DF (zstd level 1) 359 366 5 5.6 179.5 1.1X +Arrow cache - cache a cached DF (zstd level 3) 359 362 2 5.6 179.7 1.1X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 10720 10769 70 0.5 2143.9 1.0X +Default cache - select 1 of 20 (uncompressed) 4044 4120 107 1.2 808.8 2.7X +Arrow cache - select 1 of 20 4072 4222 213 1.2 814.4 2.6X +Arrow cache - select 1 of 20 (zstd level -1) 8822 8891 99 0.6 1764.3 1.2X +Arrow cache - select 1 of 20 (zstd level 1) 8937 8952 21 0.6 1787.4 1.2X +Arrow cache - select 1 of 20 (zstd level 3) 8888 8915 38 0.6 1777.7 1.2X + + + diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..f5fd776722f0d --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-jdk25-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1723 1784 86 2.9 344.5 1.0X +Default cache - write + read (uncompressed) 1082 1103 30 4.6 216.3 1.6X +Arrow cache - write + read 1109 1151 60 4.5 221.8 1.6X +Arrow cache - write + read (zstd level -1) 1709 1716 10 2.9 341.9 1.0X +Arrow cache - write + read (zstd level 1) 1684 1701 24 3.0 336.7 1.0X +Arrow cache - write + read (zstd level 3) 1744 1765 29 2.9 348.8 1.0X + + +================================================================================================ +Cache with filter pushdown +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows + filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1560 1590 42 3.2 312.1 1.0X +Default cache - filter (uncompressed) 1215 1264 69 4.1 242.9 1.3X +Arrow cache - filter (with stats) 1286 1300 19 3.9 257.3 1.2X +Arrow cache - filter (zstd level -1) 1693 1700 10 3.0 338.5 0.9X +Arrow cache - filter (zstd level 1) 1730 1742 16 2.9 346.1 0.9X +Arrow cache - filter (zstd level 3) 1750 1752 3 2.9 350.0 0.9X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1499 1552 76 1.3 749.3 1.0X +Default cache - columnar input (uncompressed) 1252 1261 12 1.6 626.0 1.2X +Arrow cache - columnar input 1215 1222 10 1.6 607.5 1.2X +Arrow cache - columnar input (zstd level -1) 1517 1521 7 1.3 758.3 1.0X +Arrow cache - columnar input (zstd level 1) 1499 1516 24 1.3 749.7 1.0X +Arrow cache - columnar input (zstd level 3) 1553 1561 13 1.3 776.3 1.0X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 444 451 4 4.5 222.1 1.0X +Default cache - cache a cached DF (uncompressed) 202 226 32 9.9 101.0 2.2X +Arrow cache - cache a cached DF (zero-copy) 160 176 21 12.5 79.8 2.8X +Arrow cache - cache a cached DF (zstd level -1) 359 378 20 5.6 179.6 1.2X +Arrow cache - cache a cached DF (zstd level 1) 356 365 16 5.6 178.1 1.2X +Arrow cache - cache a cached DF (zstd level 3) 355 360 6 5.6 177.6 1.3X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 10828 10856 40 0.5 2165.6 1.0X +Default cache - select 1 of 20 (uncompressed) 4124 4162 54 1.2 824.8 2.6X +Arrow cache - select 1 of 20 4110 4133 33 1.2 821.9 2.6X +Arrow cache - select 1 of 20 (zstd level -1) 8741 8756 20 0.6 1748.2 1.2X +Arrow cache - select 1 of 20 (zstd level 1) 8777 8816 55 0.6 1755.5 1.2X +Arrow cache - select 1 of 20 (zstd level 3) 8621 8625 4 0.6 1724.3 1.3X + + + diff --git a/sql/core/benchmarks/ArrowCacheBenchmark-results.txt b/sql/core/benchmarks/ArrowCacheBenchmark-results.txt new file mode 100644 index 0000000000000..4d34a8126f723 --- /dev/null +++ b/sql/core/benchmarks/ArrowCacheBenchmark-results.txt @@ -0,0 +1,85 @@ +================================================================================================ +Arrow Cache vs Default Cache +================================================================================================ + +================================================================================================ +Cache primitive types +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows with primitives: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +Default cache - write + read 1711 1828 166 2.9 342.2 1.0X +Default cache - write + read (uncompressed) 1182 1190 12 4.2 236.4 1.4X +Arrow cache - write + read 1158 1183 36 4.3 231.6 1.5X +Arrow cache - write + read (zstd level -1) 1714 1715 3 2.9 342.7 1.0X +Arrow cache - write + read (zstd level 1) 1721 1728 11 2.9 344.2 1.0X +Arrow cache - write + read (zstd level 3) 1729 1741 17 2.9 345.7 1.0X + + +================================================================================================ +Cache with filter pushdown +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows + filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Default cache - filter 1460 1478 26 3.4 291.9 1.0X +Default cache - filter (uncompressed) 1280 1343 90 3.9 256.0 1.1X +Arrow cache - filter (with stats) 1256 1259 4 4.0 251.1 1.2X +Arrow cache - filter (zstd level -1) 1714 1725 15 2.9 342.8 0.9X +Arrow cache - filter (zstd level 1) 1728 1743 21 2.9 345.6 0.8X +Arrow cache - filter (zstd level 3) 1773 1805 45 2.8 354.6 0.8X + + +================================================================================================ +Cache columnar input (Parquet) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 2M rows from Parquet: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - columnar input 1441 1488 66 1.4 720.6 1.0X +Default cache - columnar input (uncompressed) 1249 1255 8 1.6 624.6 1.2X +Arrow cache - columnar input 1252 1254 3 1.6 625.9 1.2X +Arrow cache - columnar input (zstd level -1) 1522 1526 5 1.3 761.1 0.9X +Arrow cache - columnar input (zstd level 1) 1534 1563 41 1.3 767.1 0.9X +Arrow cache - columnar input (zstd level 3) 1544 1569 35 1.3 772.2 0.9X + + +================================================================================================ +Re-cache Arrow cached data (zero-copy test) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Re-cache 2M rows (zero-copy): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Default cache - cache a cached DF 399 411 8 5.0 199.6 1.0X +Default cache - cache a cached DF (uncompressed) 203 220 20 9.9 101.3 2.0X +Arrow cache - cache a cached DF (zero-copy) 158 169 9 12.7 78.9 2.5X +Arrow cache - cache a cached DF (zstd level -1) 362 373 8 5.5 181.1 1.1X +Arrow cache - cache a cached DF (zstd level 1) 362 374 19 5.5 181.0 1.1X +Arrow cache - cache a cached DF (zstd level 3) 358 362 5 5.6 179.1 1.1X + + +================================================================================================ +Cache with column pruning (select 1 of 20 columns) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1015-azure +AMD EPYC 9V74 80-Core Processor +Cache 5M rows, select 1 column: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Default cache - select 1 of 20 columns 9691 9757 93 0.5 1938.3 1.0X +Default cache - select 1 of 20 (uncompressed) 4114 4225 157 1.2 822.8 2.4X +Arrow cache - select 1 of 20 4219 4228 13 1.2 843.9 2.3X +Arrow cache - select 1 of 20 (zstd level -1) 8774 8911 194 0.6 1754.7 1.1X +Arrow cache - select 1 of 20 (zstd level 1) 8957 8959 4 0.6 1791.3 1.1X +Arrow cache - select 1 of 20 (zstd level 3) 8810 8854 63 0.6 1761.9 1.1X + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala new file mode 100644 index 0000000000000..01391d0b74f42 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatch.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch + +/** + * A [[SimpleMetricsCachedBatch]] implementation that stores Arrow RecordBatch data + * in Apache Arrow IPC streaming format. + * + * The batch contains: + * - `numRows`: Number of rows in this batch + * - `arrowData`: Serialized Arrow RecordBatch in IPC streaming format (with optional compression) + * - `stats`: Per-column statistics for partition pruning (upperBound, lowerBound, nullCount, etc.) + * + * This format enables: + * - Zero-copy columnar reads when output is ColumnarBatch with ArrowColumnVector + * - Efficient interoperability with Arrow ecosystem + * - Off-heap memory management via Arrow allocators + * - Built-in compression support (zstd, lz4) at Arrow level + * + * @param numRows Number of rows in this cached batch + * @param arrowData Serialized Arrow RecordBatch in IPC streaming format + * @param stats Per-column statistics as InternalRow (5 fields per column: + * upperBound, lowerBound, nullCount, rowCount, sizeInBytes) + */ +case class ArrowCachedBatch( + numRows: Int, + arrowData: Array[Byte], + stats: InternalRow) extends SimpleMetricsCachedBatch { + + override def sizeInBytes: Long = arrowData.length +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala new file mode 100644 index 0000000000000..2c74cd6bff4fb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializer.scala @@ -0,0 +1,1358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * A [[CachedBatchSerializer]] that uses Apache Arrow as the cache format. + * + * This serializer: + * - Supports both row-based (InternalRow) and columnar (ColumnarBatch) input + * - Stores data in Arrow IPC streaming format with optional compression (zstd/lz4) + * - Enables zero-copy columnar reads when output is ColumnarBatch + * - Uses off-heap memory via Arrow allocators + * - Collects per-column statistics for partition pruning + * - Provides efficient interoperability with Arrow ecosystem + * + * Configuration options: + * - spark.sql.cache.serializer: Set to this class name to enable + * - spark.sql.execution.arrow.maxRecordsPerBatch: Max rows per cached batch + * - spark.sql.execution.arrow.compression.codec: Compression (none/zstd/lz4) + * - spark.sql.inMemoryColumnarStorage.enableVectorizedReader: Enable columnar output + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + // Check if all data types in the schema are supported by Arrow + schema.forall(attr => ArrowUtils.isSupportedByArrow(attr.dataType)) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { rowIterator => + new InternalRowToArrowCachedBatchIterator( + rowIterator, + schema, + sparkSchema, + maxRecordsPerBatch, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + // Capture config values on driver before RDD transformation + val sparkSchema = DataTypeUtils.fromAttributes(schema) + val timeZoneId = conf.sessionLocalTimeZone + val compressionCodecName = conf.arrowCompressionCodec + val compressionLevel = conf.arrowZstdCompressionLevel + + input.mapPartitionsInternal { batchIterator => + new ColumnarBatchToArrowCachedBatchIterator( + batchIterator, + schema, + sparkSchema, + timeZoneId, + compressionCodecName, + compressionLevel) + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = { + // Always support columnar output with Arrow + true + } + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = { + Option(Seq.fill(attributes.length)(classOf[ArrowColumnVector].getName)) + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + // Capture config on driver + val timeZoneId = conf.sessionLocalTimeZone + val prefetchEnabled = conf.arrowCachePrefetchEnabled + + input.mapPartitionsInternal { batchIterator => + val baseIter = new ArrowCachedBatchToColumnarBatchIterator( + batchIterator, + cacheSchema, + selectedSchema, + columnIndices, + timeZoneId) + if (prefetchEnabled) { + new ArrowPrefetchColumnarBatchIterator(baseIter) + } else { + baseIter + } + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val cacheSchema = DataTypeUtils.fromAttributes(cacheAttributes) + val selectedSchema = DataTypeUtils.fromAttributes(selectedAttributes) + val timeZoneId = conf.sessionLocalTimeZone + + // Calculate column indices for projection + val selectedIndices = selectedAttributes.map { attr => + cacheAttributes.indexWhere(_.exprId == attr.exprId) + }.toArray + + // Check if all selected types can use the fast path. + // Types not handled by ArrowColumnReader must use the fallback path. + val needsFallback = selectedSchema.fields.exists { f => + f.dataType match { + case _: ArrayType | _: StructType | _: MapType => true + case CalendarIntervalType | VariantType | NullType => true + case _: UserDefinedType[_] => true + case _ => false + } + } + + if (needsFallback) { + // Fall back to columnar-to-row conversion via ColumnarBatch for complex types. + // Use UnsafeProjection to convert ColumnarBatchRow to UnsafeRow. + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + .mapPartitionsInternal { batchIter => + val toUnsafe = org.apache.spark.sql.catalyst.expressions.UnsafeProjection.create( + selectedSchema) + batchIter.flatMap { batch => + val numRows = batch.numRows() + new Iterator[InternalRow] { + private var rowIdx = 0 + override def hasNext: Boolean = rowIdx < numRows + override def next(): InternalRow = { + val row = batch.getRow(rowIdx) + rowIdx += 1 + toUnsafe(row) + } + } + } + } + } else { + val prefetchEnabled = conf.arrowCachePrefetchEnabled + input.mapPartitionsInternal { batchIterator => + new ArrowCachedBatchToInternalRowIterator( + batchIterator, + cacheSchema, + selectedSchema, + selectedIndices, + timeZoneId, + prefetchEnabled) + } + } + } +} + +/** + * Companion object with shared utility methods for Arrow cache serialization. + */ +private object ArrowCachedBatchSerializer { + + // scalastyle:off caselocale + def createCompressionCodec( + codecName: String, + compressionLevel: Int): CompressionCodec = { + codecName.toLowerCase match { + case "none" => NoCompressionCodec.INSTANCE + case "zstd" => + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType() + factory.createCodec(codecType) + case "lz4" => + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new Lz4CompressionCodec().getCodecType() + factory.createCodec(codecType) + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + } + // scalastyle:on caselocale + + def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + + def createColumnStats(dataType: DataType): ColumnStats = { + dataType match { + case BooleanType => new BooleanColumnStats + case ByteType => new ByteColumnStats + case ShortType => new ShortColumnStats + case IntegerType => new IntColumnStats + case DateType => new IntColumnStats // Date is stored as Int + case LongType => new LongColumnStats + case TimestampType => new LongColumnStats // Timestamp is stored as Long + case TimestampNTZType => new LongColumnStats // TimestampNTZ is stored as Long + case FloatType => new FloatColumnStats + case DoubleType => new DoubleColumnStats + case st: StringType => new StringColumnStats(st) + case BinaryType => new BinaryColumnStats + case dt: DecimalType => new DecimalColumnStats(dt) + case CalendarIntervalType => new IntervalColumnStats + case _: YearMonthIntervalType => new IntColumnStats // stored as Int + case _: DayTimeIntervalType => new LongColumnStats // stored as Long + case _: TimeType => new LongColumnStats // Time is stored as Long (nanoseconds) + case VariantType => new VariantColumnStats + case _ => new ObjectColumnStats(dataType) + } + } + + def buildStatisticsFromCollectors( + collectors: Array[ColumnStats], + schema: Seq[Attribute]): InternalRow = { + val stats = collectors.flatMap { collector => + val collected = collector.collectedStatistics + // ColumnStats returns: [lowerBound, upperBound, nullCount, count, sizeInBytes] + Seq(collected(0), collected(1), collected(2), collected(3), collected(4)) + } + InternalRow.fromSeq(stats.toSeq) + } + + def collectStatistics( + root: VectorSchemaRoot, + schema: Seq[Attribute]): InternalRow = { + val rowCount = root.getRowCount + val vectors = root.getFieldVectors.asScala.toSeq + + // Collect stats for each column: lowerBound, upperBound, nullCount, rowCount, sizeInBytes + val stats = schema.zip(vectors).flatMap { case (attr, vector) => + val nullCount = (0 until rowCount).count(i => vector.isNull(i)) + val sizeInBytes = vector.getBufferSize.toLong + + val (lower, upper) = attr.dataType match { + case BooleanType => calculateMinMaxBoolean(vector, rowCount) + case ByteType => calculateMinMaxByte(vector, rowCount) + case ShortType => calculateMinMaxShort(vector, rowCount) + case IntegerType => calculateMinMaxInt(vector, rowCount) + case DateType => calculateMinMaxDate(vector, rowCount) + case LongType => calculateMinMaxLong(vector, rowCount) + case TimestampType => calculateMinMaxTimestamp(vector, rowCount) + case TimestampNTZType => calculateMinMaxTimestampNTZ(vector, rowCount) + case FloatType => calculateMinMaxFloat(vector, rowCount) + case DoubleType => calculateMinMaxDouble(vector, rowCount) + case st: StringType => calculateMinMaxString(vector, rowCount, st.collationId) + case _: DecimalType => calculateMinMaxDecimal(vector, rowCount, attr.dataType) + case _: YearMonthIntervalType => calculateMinMaxYearMonthInterval(vector, rowCount) + case _: DayTimeIntervalType => calculateMinMaxDayTimeInterval(vector, rowCount) + case _: TimeType => calculateMinMaxTime(vector, rowCount) + case _ => (null, null) // Skip for binary, complex, and other unsupported types + } + + Seq(lower, upper, nullCount, rowCount, sizeInBytes) + } + + new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(stats.toArray) + } + + def calculateMinMaxBoolean( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = true + var max = false + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BitVector].get(i) != 0 + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxByte( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Byte.MaxValue + var max = Byte.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TinyIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxShort( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Short.MaxValue + var max = Short.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.SmallIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxInt( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDate( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.DateDayVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxLong( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.BigIntVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestamp( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroTZVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTimestampNTZ( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = + vector.asInstanceOf[org.apache.arrow.vector.TimeStampMicroVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxFloat( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Float.MaxValue + var max = Float.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float4Vector].get(i) + // Skip NaN: IEEE 754 comparisons with NaN are always false, so NaN never + // updates min/max in the row-based path (FloatColumnStats.gatherValueStats). + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDouble( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Double.MaxValue + var max = Double.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.Float8Vector].get(i) + // Skip NaN to match DoubleColumnStats.gatherValueStats. + if (!value.isNaN) { + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxString( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + collationId: Int = StringType.collationId): (Any, Any) = { + var min: org.apache.spark.unsafe.types.UTF8String = null + var max: org.apache.spark.unsafe.types.UTF8String = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bytes = vector.asInstanceOf[org.apache.arrow.vector.VarCharVector].get(i) + val value = org.apache.spark.unsafe.types.UTF8String.fromBytes(bytes) + if (!hasValue) { + min = value.clone() + max = value.clone() + hasValue = true + } else { + if (value.semanticCompare(min, collationId) < 0) min = value.clone() + if (value.semanticCompare(max, collationId) > 0) max = value.clone() + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDecimal( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int, + dataType: org.apache.spark.sql.types.DataType): (Any, Any) = { + val decimalType = dataType.asInstanceOf[DecimalType] + var min: org.apache.spark.sql.types.Decimal = null + var max: org.apache.spark.sql.types.Decimal = null + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val bigDecimal = vector.asInstanceOf[ + org.apache.arrow.vector.DecimalVector].getObject(i) + val value = org.apache.spark.sql.types.Decimal( + bigDecimal, decimalType.precision, decimalType.scale) + + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value.compareTo(min) < 0) min = value + if (value.compareTo(max) > 0) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxYearMonthInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Int.MaxValue + var max = Int.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.IntervalYearVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxDayTimeInterval( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = org.apache.arrow.vector.DurationVector.get( + vector.asInstanceOf[org.apache.arrow.vector.DurationVector].getDataBuffer, i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } + + def calculateMinMaxTime( + vector: org.apache.arrow.vector.FieldVector, + rowCount: Int): (Any, Any) = { + var min = Long.MaxValue + var max = Long.MinValue + var hasValue = false + + (0 until rowCount).foreach { i => + if (!vector.isNull(i)) { + val value = vector.asInstanceOf[org.apache.arrow.vector.TimeNanoVector].get(i) + if (!hasValue) { + min = value + max = value + hasValue = true + } else { + if (value < min) min = value + if (value > max) max = value + } + } + } + + if (hasValue) (min, max) else (null, null) + } +} + +/** + * Iterator that converts InternalRow to ArrowCachedBatch. + */ +private class InternalRowToArrowCachedBatchIterator( + rowIter: Iterator[InternalRow], + schema: Seq[Attribute], + sparkSchema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"InternalRowToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + private val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Create statistics collectors for each column + private val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + close() + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close() + false + } + + override def next(): ArrowCachedBatch = { + var rowCount = 0 + + // Reset statistics collectors for new batch + var idx = 0 + while (idx < statsCollectors.length) { + statsCollectors(idx) = ArrowCachedBatchSerializer.createColumnStats(schema(idx).dataType) + idx += 1 + } + + Utils.tryWithSafeFinally { + // Write rows to Arrow vectors and collect statistics incrementally + while (rowIter.hasNext && rowCount < maxRecordsPerBatch) { + val row = rowIter.next() + arrowWriter.write(row) + + // Collect statistics for this row + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + + rowCount += 1 + } + arrowWriter.finish() + + // Get the Arrow RecordBatch with compression + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + // Serialize to Arrow IPC format + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + + // Build statistics InternalRow from collected stats + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + } + } + + private def close(): Unit = { + root.close() + allocator.close() + } +} + +/** + * Iterator that converts ColumnarBatch to ArrowCachedBatch. + */ +private class ColumnarBatchToArrowCachedBatchIterator( + batchIter: Iterator[ColumnarBatch], + schema: Seq[Attribute], + sparkSchema: StructType, + timeZoneId: String, + compressionCodecName: String, + compressionLevel: Int) extends Iterator[ArrowCachedBatch] { + + private val compressionCodec = ArrowCachedBatchSerializer.createCompressionCodec( + compressionCodecName, + compressionLevel) + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ColumnarBatchToArrowCachedBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, false, false) + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ArrowCachedBatch = { + val batch = batchIter.next() + val rowCount = batch.numRows() + + // Check if batch is already Arrow-based for zero-copy path + val vectors = (0 until batch.numCols()).map(batch.column) + if (vectors.forall(_.isInstanceOf[ArrowColumnVector])) { + // Fast path: zero-copy extraction of Arrow RecordBatch + convertArrowBatchZeroCopy(batch, rowCount, schema, vectors) + } else { + // Slow path: convert to Arrow via rows + convertToArrowBatch(batch, rowCount, schema) + } + } + + private def convertArrowBatchZeroCopy( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute], + vectors: Seq[ColumnVector]): ArrowCachedBatch = { + // Zero-copy path: extract Arrow vectors directly from ArrowColumnVector + val arrowVectors = vectors.map( + _.asInstanceOf[ArrowColumnVector].getValueVector.asInstanceOf[ + org.apache.arrow.vector.FieldVector]) + + // Create a VectorSchemaRoot from the existing vectors + val root = new VectorSchemaRoot(arrowSchema, arrowVectors.asJava, rowCount) + + Utils.tryWithSafeFinally { + // Use VectorUnloader to create compressed RecordBatch + val unloader = new VectorUnloader(root, true, compressionCodec, true) + val recordBatch = unloader.getRecordBatch() + + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + // Note: We don't close the root here because we don't own the vectors + // They are owned by the input ColumnarBatch + } + } + + private def convertToArrowBatch( + batch: ColumnarBatch, + rowCount: Int, + schema: Seq[Attribute]): ArrowCachedBatch = { + // Convert columnar batch to rows, then to Arrow + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + val unloader = new VectorUnloader(root, true, compressionCodec, true) + + // Collect statistics inline during row iteration, same as InternalRowToArrow path + val statsCollectors: Array[ColumnStats] = schema.map { attr => + ArrowCachedBatchSerializer.createColumnStats(attr.dataType) + }.toArray + + Utils.tryWithSafeFinally { + val rowIterator = batch.rowIterator().asScala + while (rowIterator.hasNext) { + val row = rowIterator.next() + arrowWriter.write(row) + + // Collect statistics for this row inline + var i = 0 + while (i < statsCollectors.length) { + statsCollectors(i).gatherStats(row, i) + i += 1 + } + } + arrowWriter.finish() + + val recordBatch = unloader.getRecordBatch() + Utils.tryWithSafeFinally { + val arrowData = ArrowCachedBatchSerializer.serializeBatch(recordBatch) + val stats = ArrowCachedBatchSerializer.buildStatisticsFromCollectors( + statsCollectors, schema) + ArrowCachedBatch(rowCount, arrowData, stats) + } { + recordBatch.close() + } + } { + arrowWriter.reset() + root.close() + } + } +} + +/** + * Iterator that converts ArrowCachedBatch to ColumnarBatch. + */ +private class ArrowCachedBatchToColumnarBatchIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String) extends Iterator[ColumnarBatch] { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToColumnarBatchIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + // Track only the previous root to close it when next batch is produced + private var previousRoot: VectorSchemaRoot = null + + // Register cleanup - close remaining root and allocator when task completes + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + allocator.close() + } + } + + override def hasNext: Boolean = batchIter.hasNext + + override def next(): ColumnarBatch = { + // Close the previous root since it's been consumed + if (previousRoot != null) { + previousRoot.close() + previousRoot = null + } + + val cachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + + // Deserialize Arrow IPC data + val arrowData = cachedBatch.arrowData + val in = new ByteArrayInputStream(arrowData) + val readChannel = new ReadChannel(Channels.newChannel(in)) + + // Deserialize the RecordBatch + val recordBatch = MessageSerializer.deserializeRecordBatch(readChannel, allocator) + + Utils.tryWithSafeFinally { + // Create root and load batch + val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + // Track this root as the current/previous root + previousRoot = root + + val loader = new VectorLoader(root) + loader.load(recordBatch) + + // Wrap vectors in ArrowColumnVector and project to selected columns + val allColumns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + + val selectedColumns = columnIndices.map(allColumns(_)) + + new ColumnarBatch(selectedColumns, cachedBatch.numRows) + } { + recordBatch.close() + } + } +} + +/** + * A typed column reader that reads from an Arrow FieldVector and writes directly + * to an UnsafeRowWriter, avoiding per-row pattern matching overhead. + */ +private abstract class ArrowColumnReader { + def vector: org.apache.arrow.vector.FieldVector + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit + def setVector(v: org.apache.arrow.vector.FieldVector): Unit +} + +private object ArrowColumnReader { + import org.apache.arrow.vector._ + + def create(dataType: DataType): ArrowColumnReader = dataType match { + case BooleanType => new ArrowColumnReader { + private var _vector: BitVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[BitVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex) != 0) + } + case ByteType => new ArrowColumnReader { + private var _vector: TinyIntVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[TinyIntVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case ShortType => new ArrowColumnReader { + private var _vector: SmallIntVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[SmallIntVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case IntegerType | DateType | _: YearMonthIntervalType => new ArrowColumnReader { + private var _vector: FieldVector = _ + // Pre-bind accessor at setVector time to avoid per-row pattern match + private var _accessor: Int => Int = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v + _accessor = v match { + case iv: IntVector => iv.get + case dv: DateDayVector => dv.get + case iv: org.apache.arrow.vector.IntervalYearVector => iv.get + } + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _accessor(rowIndex)) + } + case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => + new ArrowColumnReader { + private var _vector: FieldVector = _ + private var _accessor: Int => Long = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v + _accessor = v match { + case bv: BigIntVector => bv.get(_) + case tv: TimeStampMicroTZVector => tv.get(_) + case tv: TimeStampMicroVector => tv.get(_) + case dv: org.apache.arrow.vector.DurationVector => + i => org.apache.arrow.vector.DurationVector.get(dv.getDataBuffer, i) + case tv: org.apache.arrow.vector.TimeNanoVector => tv.get(_) + } + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _accessor(rowIndex)) + } + case FloatType => new ArrowColumnReader { + private var _vector: Float4Vector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[Float4Vector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case DoubleType => new ArrowColumnReader { + private var _vector: Float8Vector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[Float8Vector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case _: StringType => new ArrowColumnReader { + private var _vector: VarCharVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[VarCharVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val bytes = _vector.get(rowIndex) + writer.write(ordinal, UTF8String.fromBytes(bytes)) + } + } + case BinaryType => new ArrowColumnReader { + private var _vector: VarBinaryVector = _ + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[VarBinaryVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = + writer.write(ordinal, _vector.get(rowIndex)) + } + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + // Fast path for compact decimals (precision <= 18): + // Read the unscaled long directly from the Arrow buffer, zero allocation. + // Arrow stores Decimal as 128-bit little-endian integer in 16 bytes. + // For compact decimals, the value fits in the lower 8 bytes. + new ArrowColumnReader { + private var _vector: DecimalVector = _ + private var _dataBuffer: org.apache.arrow.memory.ArrowBuf = _ + private val typeWidth = DecimalVector.TYPE_WIDTH // 16 bytes + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = { + _vector = v.asInstanceOf[DecimalVector] + _dataBuffer = _vector.getDataBuffer + } + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val startIndex = rowIndex.toLong * typeWidth + val unscaledLong = _dataBuffer.getLong(startIndex) + writer.write(ordinal, unscaledLong) + } + } + case dt: DecimalType => new ArrowColumnReader { + // Slow path for wide decimals (precision > 18): must go through BigDecimal + private var _vector: DecimalVector = _ + private val precision = dt.precision + private val scale = dt.scale + def vector: FieldVector = _vector + def setVector(v: FieldVector): Unit = _vector = v.asInstanceOf[DecimalVector] + def read(rowIndex: Int, ordinal: Int, writer: UnsafeRowWriter): Unit = { + val decimal = Decimal(_vector.getObject(rowIndex), precision, scale) + writer.write(ordinal, decimal, precision, scale) + } + } + case _ => + throw new UnsupportedOperationException( + s"Complex type $dataType is handled by the fallback path") + } +} + +/** + * Fast-path iterator that converts ArrowCachedBatch to InternalRow. + * Uses pre-built typed column readers to avoid per-row pattern matching, + * and writes directly to UnsafeRowWriter to avoid intermediate SpecificInternalRow. + * Only used for schemas without complex types (Array/Struct/Map). + */ +private class ArrowCachedBatchToInternalRowIterator( + batchIter: Iterator[CachedBatch], + cacheSchema: StructType, + selectedSchema: StructType, + columnIndices: Array[Int], + timeZoneId: String, + prefetchEnabled: Boolean = false) extends Iterator[InternalRow] { + + import java.util.concurrent.{Callable, ExecutionException, Future, Executors, + ExecutorService} + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"ArrowCachedBatchToInternalRowIterator-${TaskContext.get().taskAttemptId()}", + 0, + Long.MaxValue) + + private var currentRoot: VectorSchemaRoot = null + private var currentRowIndex: Int = 0 + private var currentRowCount: Int = 0 + + private val numFields = selectedSchema.length + private val arrowSchema = ArrowUtils.toArrowSchema(cacheSchema, timeZoneId, false, false) + + // Pre-build typed readers per column at init time -- no per-row pattern match + private val columnReaders: Array[ArrowColumnReader] = + selectedSchema.fields.map(f => ArrowColumnReader.create(f.dataType)) + + // Write directly to UnsafeRow -- no intermediate SpecificInternalRow + UnsafeProjection + private val rowWriter = new UnsafeRowWriter(numFields) + + // Prefetch support: deserialize the next batch in background while current batch is consumed + private val prefetchExecutor: ExecutorService = if (prefetchEnabled) { + Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-row-prefetch") + t.setDaemon(true) + t + }) + } else { + null + } + private var prefetchFuture: Future[VectorSchemaRoot] = _ + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + if (prefetchFuture != null) { + prefetchFuture.cancel(true) + prefetchFuture = null + } + if (prefetchExecutor != null) { + prefetchExecutor.shutdownNow() + } + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + allocator.close() + } + } + + override def hasNext: Boolean = { + if (currentRowIndex < currentRowCount) { + true + } else if (prefetchFuture != null || batchIter.hasNext) { + loadNextBatch() + currentRowIndex < currentRowCount + } else { + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + false + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("No more rows") + } + + rowWriter.reset() + rowWriter.zeroOutNullBytes() + + val rowIdx = currentRowIndex + var i = 0 + while (i < numFields) { + val reader = columnReaders(i) + if (reader.vector.isNull(rowIdx)) { + rowWriter.setNullAt(i) + } else { + reader.read(rowIdx, i, rowWriter) + } + i += 1 + } + + currentRowIndex += 1 + rowWriter.getRow() + } + + /** Deserialize a cached batch into a VectorSchemaRoot. */ + private def deserializeBatch(cachedBatch: ArrowCachedBatch): VectorSchemaRoot = { + val in = new ByteArrayInputStream(cachedBatch.arrowData) + val readChannel = new ReadChannel(Channels.newChannel(in)) + val recordBatch = MessageSerializer.deserializeRecordBatch(readChannel, allocator) + try { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val loader = new VectorLoader(root) + loader.load(recordBatch) + root + } finally { + recordBatch.close() + } + } + + /** Submit prefetch for the next batch if available. */ + private def submitPrefetch(): Unit = { + if (prefetchEnabled && batchIter.hasNext) { + val nextCachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + prefetchFuture = prefetchExecutor.submit(new Callable[VectorSchemaRoot] { + override def call(): VectorSchemaRoot = deserializeBatch(nextCachedBatch) + }) + } + } + + private def loadNextBatch(): Unit = { + if (currentRoot != null) { + currentRoot.close() + currentRoot = null + } + + val root = if (prefetchFuture != null) { + // Use the prefetched result + val r = try { + prefetchFuture.get() + } catch { + case e: ExecutionException => throw e.getCause + } + prefetchFuture = null + r + } else { + // No prefetch available, deserialize synchronously + val cachedBatch = batchIter.next().asInstanceOf[ArrowCachedBatch] + deserializeBatch(cachedBatch) + } + + currentRoot = root + + // Update pre-built readers with new vectors + var i = 0 + while (i < numFields) { + columnReaders(i).setVector(root.getVector(columnIndices(i))) + i += 1 + } + + currentRowIndex = 0 + currentRowCount = root.getRowCount + + // Start prefetching the next batch while this one is being consumed + submitPrefetch() + } +} + +/** + * Wraps an ArrowCachedBatchToColumnarBatchIterator with background prefetching. + * While the current ColumnarBatch is being consumed, the next batch is deserialized + * and decompressed in a background thread. This overlaps decompression with consumption + * and is most beneficial for compressed Arrow caches (e.g. ZSTD). + * + * Uses a single-thread executor to avoid per-batch thread creation overhead. + * + * Enabled via spark.sql.execution.arrow.cache.prefetch.enabled=true. + */ +private class ArrowPrefetchColumnarBatchIterator( + underlying: ArrowCachedBatchToColumnarBatchIterator) extends Iterator[ColumnarBatch] { + + import java.util.concurrent.{Callable, ExecutionException, Future, Executors} + + private val executor = Executors.newSingleThreadExecutor(r => { + val t = new Thread(r, "arrow-cache-prefetch") + t.setDaemon(true) + t + }) + + // The prefetched result (null means no more batches) + private var prefetchFuture: Future[ColumnarBatch] = _ + + // Register cleanup + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => + executor.shutdownNow() + } + } + + // Kick off prefetch of the first batch immediately + submitPrefetch() + + override def hasNext: Boolean = prefetchFuture != null + + override def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException("No more batches") + } + + // Wait for the prefetched batch + val batch = try { + prefetchFuture.get() + } catch { + case e: ExecutionException => throw e.getCause + } + + // Start prefetching the next batch + submitPrefetch() + + batch + } + + private def submitPrefetch(): Unit = { + if (underlying.hasNext) { + prefetchFuture = executor.submit(new Callable[ColumnarBatch] { + override def call(): ColumnarBatch = underlying.next() + }) + } else { + prefetchFuture = null + executor.shutdown() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 4e4b3667fa24f..eb17b022f890c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap, ColumnarRow} import org.apache.spark.unsafe.types.UTF8String class ColumnStatisticsSchema(a: Attribute) extends Serializable { @@ -358,8 +359,28 @@ private[columnar] final class ObjectColumnStats(dataType: DataType) extends Colu override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size + // Check if this is a columnar complex type that doesn't support getSizeInBytes + val isColumnarComplexType = columnType match { + case _: ARRAY => + row.getArray(ordinal).isInstanceOf[ColumnarArray] + case _: MAP => + row.getMap(ordinal).isInstanceOf[ColumnarMap] + case struct: STRUCT => + row.getStruct(ordinal, struct.dataType.fields.length).isInstanceOf[ColumnarRow] + case _ => + false + } + + if (!isColumnarComplexType) { + // Normal path: calculate size for unsafe types + // (UnsafeArrayData/UnsafeMapData/UnsafeRow) + val size = columnType.actualSize(row, ordinal) + sizeInBytes += size + } + // else: Skip size calculation for columnar complex types + // (ColumnarArray/ColumnarMap/ColumnarRow). These are views into ColumnVectors + // and don't expose getSizeInBytes() + count += 1 } else { gatherNullStats() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 9c012dbd58e12..9bc5e803c0dd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -365,7 +365,7 @@ object InMemoryRelation { } /* Visible for testing */ - private[columnar] def clearSerializer(): Unit = synchronized { ser = None } + private[sql] def clearSerializer(): Unit = synchronized { ser = None } def apply( storageLevel: StorageLevel, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala new file mode 100644 index 0000000000000..55054ebcf5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ArrowCacheBenchmark.scala @@ -0,0 +1,805 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.internal.config.UI.UI_ENABLED +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +/** + * Benchmark to measure cache performance with Arrow format vs Default format. + * + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to "benchmarks/ArrowCacheBenchmark-results.txt". + * }}} + */ +object ArrowCacheBenchmark extends SqlBasedBenchmark { + + // Do NOT access the inherited `spark` session - it uses default serializer + // Instead, create fresh sessions for each benchmark + + // Create separate sessions for each cache format since SPARK_CACHE_SERIALIZER is static + // CRITICAL: Can only have one active SparkContext at a time + private def createFreshSession(serializer: String): SparkSession = { + // Stop any existing session and clear the registry + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + + // CRITICAL: Clear the cached serializer instance in InMemoryRelation + // This singleton is stored statically and persists across sessions + org.apache.spark.sql.execution.columnar.InMemoryRelation.clearSerializer() + + SparkSession.builder() + .master("local[1]") + .appName(s"ArrowCacheBenchmark-$serializer") + .config(SQLConf.SHUFFLE_PARTITIONS.key, 1) + .config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1) + .config(UI_ENABLED.key, false) + .config(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, serializer) + .getOrCreate() + } + + private def cachePrimitiveTypes(): Unit = { + val numRows = 5000000 // 5M rows for faster benchmarking + runBenchmark("Cache primitive types") { + val benchmark = new Benchmark("Cache 5M rows with primitives", numRows, output = output) + + // Run Default cache benchmark (with compression - default) + benchmark.addCase("Default cache - write + read") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Default cache without compression + benchmark.addCase("Default cache - write + read (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache benchmark + benchmark.addCase("Arrow cache - write + read") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // NOTE: LZ4 compression benchmarks are commented out because Arrow's LZ4 implementation + // requires the optional lz4-java native library dependency. Without it, Arrow falls back + // to Apache Commons Compress pure-Java LZ4 implementation which is extremely slow + // (~50x slower than zstd). To enable fast LZ4 benchmarks, add this dependency to pom.xml: + // + // org.lz4 + // lz4-java + // 1.8.0 + // + + // // Run Arrow cache with lz4 compression benchmark + // benchmark.addCase("Arrow cache - write + read (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set("spark.sql.execution.arrow.compression.codec", "lz4") + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "id * 2L as long_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + // Run Arrow cache with zstd level -1 (fastest) compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "-1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd level 1 compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd level 3 (default) compression benchmark + benchmark.addCase("Arrow cache - write + read (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def cacheWithFilters(): Unit = { + val numRows = 5000000 // 5M rows + runBenchmark("Cache with filter pushdown") { + val benchmark = new Benchmark("Cache 5M rows + filter", numRows, output = output) + + // Default cache filter benchmark (with compression - default) + benchmark.addCase("Default cache - filter") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Default cache filter without compression + benchmark.addCase("Default cache - filter (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter benchmark + benchmark.addCase("Arrow cache - filter (with stats)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // // Arrow cache filter with lz4 compression + // benchmark.addCase("Arrow cache - filter (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set("spark.sql.execution.arrow.compression.codec", "lz4") + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() // Materialize + // df.filter("int_col > 2500000").count() + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + // Arrow cache filter with zstd level -1 (fastest) + benchmark.addCase("Arrow cache - filter (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "-1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter with zstd level 1 + benchmark.addCase("Arrow cache - filter (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Arrow cache filter with zstd level 3 + benchmark.addCase("Arrow cache - filter (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + val df = spark.range(numRows).selectExpr( + "id as int_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + df.filter("int_col > 2500000").count() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def cacheColumnarInput(): Unit = { + val numRows = 2000000 // 2M rows + withTempPath { dir => + val path = dir.getAbsolutePath + + // Write parquet file using a temporary session + val tempSpark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + tempSpark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ).write.parquet(path) + } finally { + tempSpark.stop() + } + + runBenchmark("Cache columnar input (Parquet)") { + val benchmark = new Benchmark("Cache 2M rows from Parquet", numRows, output = output) + + benchmark.addCase("Default cache - columnar input") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Default cache - columnar input (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // benchmark.addCase("Arrow cache - columnar input (lz4)") { _ => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set("spark.sql.execution.arrow.compression.codec", "lz4") + // val parquet = spark.read.parquet(path) + // parquet.cache() + // parquet.write.format("noop").mode("overwrite").save() // Force read all data + // parquet.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + benchmark.addCase("Arrow cache - columnar input (zstd level -1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "-1") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input (zstd level 1)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addCase("Arrow cache - columnar input (zstd level 3)") { _ => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + val parquet = spark.read.parquet(path) + parquet.cache() + parquet.write.format("noop").mode("overwrite").save() // Force read all data + parquet.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + } + + private def recacheArrowData(): Unit = { + val numRows = 2000000 // 2M rows + runBenchmark("Re-cache Arrow cached data (zero-copy test)") { + val benchmark = new Benchmark("Re-cache 2M rows (zero-copy)", numRows, output = output) + + benchmark.addTimerCase("Default cache - cache a cached DF") { timer => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Default cache - cache a cached DF (uncompressed)") { timer => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zero-copy)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + // Drop a column to create a different logical plan + // This preserves ArrowColumnVector for remaining columns, enabling zero-copy + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // benchmark.addTimerCase("Arrow cache - cache a cached DF (lz4)") { timer => + // val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + // try { + // spark.conf.set("spark.sql.execution.arrow.compression.codec", "lz4") + // // Create and cache initial data (NOT timed) + // val df = spark.range(numRows).selectExpr( + // "id as int_col", + // "id * 2L as long_col", + // "cast(id as double) as double_col" + // ) + // df.cache() + // df.write.format("noop").mode("overwrite").save() // Materialize + + // // START TIMING: Cache the cached DataFrame again + // val df2 = df.drop("double_col") + // timer.startTiming() + // df2.cache() + // df2.write.format("noop").mode("overwrite").save() // Force read all data + // timer.stopTiming() + + // df2.unpersist(blocking = true) + // df.unpersist(blocking = true) + // } finally { + // spark.stop() + // } + // } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level -1)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "-1") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level 1)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.addTimerCase("Arrow cache - cache a cached DF (zstd level 3)") { timer => + val spark = createFreshSession(classOf[ArrowCachedBatchSerializer].getName) + try { + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + // Create and cache initial data (NOT timed) + val df = spark.range(numRows).selectExpr( + "id as int_col", + "id * 2L as long_col", + "cast(id as double) as double_col" + ) + df.cache() + df.write.format("noop").mode("overwrite").save() // Materialize cache by reading all rows + + // START TIMING: Cache the cached DataFrame again + val df2 = df.drop("double_col") + timer.startTiming() + df2.cache() + df2.write.format("noop").mode("overwrite").save() // Force read all data + timer.stopTiming() + + df2.unpersist(blocking = true) + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + private def columnPruning(): Unit = { + val numRows = 5000000 // 5M rows + runBenchmark("Cache with column pruning (select 1 of 20 columns)") { + val benchmark = new Benchmark( + "Cache 5M rows, select 1 column", numRows, output = output) + + // Run Default cache benchmark (with compression - default) + benchmark.addCase("Default cache - select 1 of 20 columns") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + // Create DataFrame with 20 columns + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() // Materialize cache + + // Select only first column and count + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Default cache without compression + benchmark.addCase("Default cache - select 1 of 20 (uncompressed)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer") + try { + spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", "false") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache (no compression) + benchmark.addCase("Arrow cache - select 1 of 20") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level -1 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level -1)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "-1") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level 1 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level 1)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "1") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + // Run Arrow cache with zstd compression level 3 + benchmark.addCase("Arrow cache - select 1 of 20 (zstd level 3)") { _ => + val spark = createFreshSession( + "org.apache.spark.sql.execution.columnar.ArrowCachedBatchSerializer") + try { + spark.conf.set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set("spark.sql.execution.arrow.compression.codec", "zstd") + spark.conf.set("spark.sql.execution.arrow.compression.level", "3") + val df = spark.range(numRows).selectExpr( + (0 until 20).map(i => s"id + $i as col$i"): _* + ) + df.cache() + df.count() + df.select("col0").write.format("noop").mode("overwrite").save() + df.unpersist(blocking = true) + } finally { + spark.stop() + } + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Arrow Cache vs Default Cache") { + cachePrimitiveTypes() + cacheWithFilters() + cacheColumnarInput() + recacheArrowData() + columnPruning() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala new file mode 100644 index 0000000000000..e89ab83c764bc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ArrowCachedBatchSerializerSuite.scala @@ -0,0 +1,1835 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.sql.{Date, Timestamp} +import java.time.{Duration, LocalDateTime, LocalTime, Period} + +import org.apache.arrow.vector.{ + BigIntVector, BitVector, DateDayVector, DecimalVector, + Float4Vector, Float8Vector, IntVector, SmallIntVector, + TimeNanoVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, + VarBinaryVector, VarCharVector, VectorSchemaRoot} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.CalendarInterval + +/** UDT whose sqlType is Arrow-supported (ArrayType(DoubleType)). */ +private class SupportedUDT extends UserDefinedType[Array[Double]] { + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + override def serialize(obj: Array[Double]): Any = obj + override def deserialize(datum: Any): Array[Double] = datum.asInstanceOf[Array[Double]] + override def userClass: Class[Array[Double]] = classOf[Array[Double]] +} + +/** UDT whose sqlType is ObjectType - not supported by Arrow. */ +private class UnsupportedUDT extends UserDefinedType[AnyRef] { + override def sqlType: DataType = ObjectType(classOf[AnyRef]) + override def serialize(obj: AnyRef): Any = obj + override def deserialize(datum: Any): AnyRef = datum.asInstanceOf[AnyRef] + override def userClass: Class[AnyRef] = classOf[AnyRef] +} + +class ArrowCachedBatchSerializerSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + override protected def sparkConf = { + super.sparkConf + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, + classOf[ArrowCachedBatchSerializer].getName) + .set(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key, "false") + } + + // InMemoryRelation caches the serializer instance in a process-wide field that is initialized + // from spark.sql.cache.serializer only on first use. When another suite runs first in the same + // JVM, that field is already bound to DefaultCachedBatchSerializer, so reset it here to pick up + // the Arrow serializer configured above, and reset it again afterwards so we do not leak the + // Arrow serializer to later suites. + override def beforeAll(): Unit = { + super.beforeAll() + InMemoryRelation.clearSerializer() + } + + override def afterAll(): Unit = { + InMemoryRelation.clearSerializer() + super.afterAll() + } + + test("basic caching with primitive types") { + val df = Seq( + (1, 2L, 3.0f, 4.0, "hello"), + (5, 6L, 7.0f, 8.0, "world"), + (9, 10L, 11.0f, 12.0, "test") + ).toDF("a", "b", "c", "d", "e") + + df.cache() + checkAnswer(df, Seq( + Row(1, 2L, 3.0f, 4.0, "hello"), + Row(5, 6L, 7.0f, 8.0, "world"), + Row(9, 10L, 11.0f, 12.0, "test") + )) + + // Verify it was actually cached + assert(df.storageLevel.useMemory) + } + + test("caching with all primitive types") { + val df = Seq( + (true, 1.toByte, 2.toShort, 3, 4L, 5.0f, 6.0), + (false, 7.toByte, 8.toShort, 9, 10L, 11.0f, 12.0), + (true, 13.toByte, 14.toShort, 15, 16L, 17.0f, 18.0) + ).toDF("bool", "byte", "short", "int", "long", "float", "double") + + df.cache() + checkAnswer(df, Seq( + Row(true, 1.toByte, 2.toShort, 3, 4L, 5.0f, 6.0), + Row(false, 7.toByte, 8.toShort, 9, 10L, 11.0f, 12.0), + Row(true, 13.toByte, 14.toShort, 15, 16L, 17.0f, 18.0) + )) + } + + test("caching with null values") { + val df = Seq( + (Some(1), Some("a")), + (None, Some("b")), + (Some(3), None), + (None, None) + ).toDF("num", "str") + + df.cache() + checkAnswer(df, Seq( + Row(1, "a"), + Row(null, "b"), + Row(3, null), + Row(null, null) + )) + } + + test("caching with date and timestamp types") { + val date1 = Date.valueOf("2020-01-01") + val date2 = Date.valueOf("2021-06-15") + val ts1 = Timestamp.valueOf("2020-01-01 12:00:00") + val ts2 = Timestamp.valueOf("2021-06-15 15:30:45") + + val df = Seq( + (date1, ts1), + (date2, ts2) + ).toDF("date", "timestamp") + + df.cache() + checkAnswer(df, Seq( + Row(date1, ts1), + Row(date2, ts2) + )) + } + + test("caching with decimal types") { + val df = Seq( + BigDecimal("123.45"), + BigDecimal("678.90"), + BigDecimal("999.99") + ).toDF("decimal") + + df.cache() + checkAnswer(df, Seq( + Row(BigDecimal("123.45")), + Row(BigDecimal("678.90")), + Row(BigDecimal("999.99")) + )) + } + + test("caching with binary type") { + val df = Seq( + "hello".getBytes("UTF-8"), + "world".getBytes("UTF-8"), + "test".getBytes("UTF-8") + ).toDF("binary") + + df.cache() + val result = df.collect() + assert(result.length == 3) + assert(new String(result(0).getAs[Array[Byte]](0), "UTF-8") == "hello") + assert(new String(result(1).getAs[Array[Byte]](0), "UTF-8") == "world") + assert(new String(result(2).getAs[Array[Byte]](0), "UTF-8") == "test") + } + + test("caching with array type") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5, 6), + Seq(7, 8, 9) + ).toDF("array") + + df.cache() + checkAnswer(df, Seq( + Row(Seq(1, 2, 3)), + Row(Seq(4, 5, 6)), + Row(Seq(7, 8, 9)) + )) + } + + test("caching with struct type") { + val df = Seq( + (1, ("a", 10)), + (2, ("b", 20)), + (3, ("c", 30)) + ).toDF("id", "struct") + + df.cache() + checkAnswer(df, Seq( + Row(1, Row("a", 10)), + Row(2, Row("b", 20)), + Row(3, Row("c", 30)) + )) + } + + test("caching with map type") { + val df = Seq( + Map("a" -> 1, "b" -> 2), + Map("c" -> 3, "d" -> 4), + Map("e" -> 5, "f" -> 6) + ).toDF("map") + + df.cache() + checkAnswer(df, Seq( + Row(Map("a" -> 1, "b" -> 2)), + Row(Map("c" -> 3, "d" -> 4)), + Row(Map("e" -> 5, "f" -> 6)) + )) + } + + test("caching with nested complex types") { + val df = Seq( + (1, Seq(("a", Seq(1, 2)), ("b", Seq(3, 4)))), + (2, Seq(("c", Seq(5, 6)), ("d", Seq(7, 8)))) + ).toDF("id", "nested") + + df.cache() + checkAnswer(df, Seq( + Row(1, Seq(Row("a", Seq(1, 2)), Row("b", Seq(3, 4)))), + Row(2, Seq(Row("c", Seq(5, 6)), Row("d", Seq(7, 8)))) + )) + } + + test("caching with filter pushdown") { + val df = (1 to 100).map(i => (i, i * 2, s"str$i")).toDF("a", "b", "c") + df.cache() + + // This should use cached data with filter + val filtered = df.filter($"a" > 50) + checkAnswer(filtered, (51 to 100).map(i => Row(i, i * 2, s"str$i"))) + + // Verify cache was used + assert(filtered.queryExecution.executedPlan.toString.contains("InMemoryTableScan")) + } + + test("caching with column projection") { + val df = (1 to 100).map(i => (i, i * 2, i * 3, s"str$i")).toDF("a", "b", "c", "d") + df.cache() + + // Select subset of columns + val projected = df.select("a", "c") + checkAnswer(projected, (1 to 100).map(i => Row(i, i * 3))) + + // Verify cache was used + assert(projected.queryExecution.executedPlan.toString.contains("InMemoryTableScan")) + } + + test("caching with multiple batches") { + withSQLConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> "10") { + val df = (1 to 50).map(i => (i, s"str$i")).toDF("a", "b") + df.cache() + + checkAnswer(df, (1 to 50).map(i => Row(i, s"str$i"))) + + // Verify multiple batches were created + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + } + } + + test("uncache and recache") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + + // Cache + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + assert(df.storageLevel.useMemory) + + // Uncache + df.unpersist() + assert(!df.storageLevel.useMemory) + + // Recache + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + assert(df.storageLevel.useMemory) + } + + test("cache with aggregation") { + val df = Seq( + ("a", 1), + ("b", 2), + ("a", 3), + ("b", 4), + ("a", 5) + ).toDF("key", "value") + + df.cache() + + val agg = df.groupBy("key").sum("value") + checkAnswer(agg, Seq(Row("a", 9), Row("b", 6))) + } + + test("cache with join") { + val df1 = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value1") + val df2 = Seq((1, "x"), (2, "y"), (3, "z")).toDF("id", "value2") + + df1.cache() + df2.cache() + + val joined = df1.join(df2, "id") + checkAnswer(joined, Seq( + Row(1, "a", "x"), + Row(2, "b", "y"), + Row(3, "c", "z") + )) + } + + test("vectorized reader enabled") { + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + + // Verify vectorized reader is used + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + assert(inMemoryScan.get.supportsColumnar) + } + } + + test("compression codec - none") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "none") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("compression codec - zstd") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "zstd") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("compression codec - lz4") { + withSQLConf(SQLConf.ARROW_EXECUTION_COMPRESSION_CODEC.key -> "lz4") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"), Row(3, "c"))) + } + } + + test("large dataset") { + val df = (1 to 10000).map(i => (i, i * 2, s"string$i")).toDF("a", "b", "c") + df.cache() + + checkAnswer( + df.filter($"a" > 9000), + (9001 to 10000).map(i => Row(i, i * 2, s"string$i")) + ) + } + + test("empty dataset") { + val df = Seq.empty[(Int, String)].toDF("id", "value") + df.cache() + checkAnswer(df, Seq.empty[Row]) + } + + test("single row") { + val df = Seq((1, "single")).toDF("id", "value") + df.cache() + checkAnswer(df, Seq(Row(1, "single"))) + } + + test("cache table command") { + withTempView("test_table") { + Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + .createOrReplaceTempView("test_table") + + sql("CACHE TABLE test_table") + + checkAnswer( + sql("SELECT * FROM test_table"), + Seq(Row(1, "a"), Row(2, "b"), Row(3, "c")) + ) + + sql("UNCACHE TABLE test_table") + } + } + + test("columnar batch from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = (1 to 100).map(i => (i, i * 2, s"str$i")).toDF("a", "b", "c") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, (1 to 100).map(i => Row(i, i * 2, s"str$i"))) + } + } + + test("supportsColumnarInput with supported types") { + val serializer = new ArrowCachedBatchSerializer() + + // All primitive types should be supported + val primitiveSchema = Seq( + AttributeReference("bool", BooleanType)(), + AttributeReference("byte", ByteType)(), + AttributeReference("short", ShortType)(), + AttributeReference("int", IntegerType)(), + AttributeReference("long", LongType)(), + AttributeReference("float", FloatType)(), + AttributeReference("double", DoubleType)(), + AttributeReference("string", StringType)(), + AttributeReference("binary", BinaryType)() + ) + assert(serializer.supportsColumnarInput(primitiveSchema)) + + // Temporal types should be supported + val temporalSchema = Seq( + AttributeReference("date", DateType)(), + AttributeReference("timestamp", TimestampType)(), + AttributeReference("timestampNtz", TimestampNTZType)() + ) + assert(serializer.supportsColumnarInput(temporalSchema)) + + // Decimal should be supported + val decimalSchema = Seq( + AttributeReference("decimal", DecimalType(10, 2))() + ) + assert(serializer.supportsColumnarInput(decimalSchema)) + + // Complex types should be supported + val complexSchema = Seq( + AttributeReference("array", ArrayType(IntegerType))(), + AttributeReference("struct", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + )))(), + AttributeReference("map", MapType(StringType, IntegerType))() + ) + assert(serializer.supportsColumnarInput(complexSchema)) + + // Nested complex types should be supported + val nestedSchema = Seq( + AttributeReference("nested", ArrayType(StructType(Seq( + StructField("x", IntegerType), + StructField("y", ArrayType(StringType)) + ))))() + ) + assert(serializer.supportsColumnarInput(nestedSchema)) + } + + test("supportsColumnarInput correctly validates all types") { + // Verify that isSupportedByArrow handles all standard Spark SQL types + assert(ArrowUtils.isSupportedByArrow(BooleanType)) + assert(ArrowUtils.isSupportedByArrow(ByteType)) + assert(ArrowUtils.isSupportedByArrow(ShortType)) + assert(ArrowUtils.isSupportedByArrow(IntegerType)) + assert(ArrowUtils.isSupportedByArrow(LongType)) + assert(ArrowUtils.isSupportedByArrow(FloatType)) + assert(ArrowUtils.isSupportedByArrow(DoubleType)) + assert(ArrowUtils.isSupportedByArrow(StringType)) + assert(ArrowUtils.isSupportedByArrow(BinaryType)) + assert(ArrowUtils.isSupportedByArrow(DateType)) + assert(ArrowUtils.isSupportedByArrow(TimestampType)) + assert(ArrowUtils.isSupportedByArrow(TimestampNTZType)) + assert(ArrowUtils.isSupportedByArrow(DecimalType(10, 2))) + assert(ArrowUtils.isSupportedByArrow(NullType)) + assert(ArrowUtils.isSupportedByArrow(CalendarIntervalType)) + + // Complex types + assert(ArrowUtils.isSupportedByArrow(ArrayType(IntegerType))) + assert(ArrowUtils.isSupportedByArrow(StructType(Seq(StructField("x", IntegerType))))) + assert(ArrowUtils.isSupportedByArrow(MapType(StringType, IntegerType))) + + // Nested complex types + assert(ArrowUtils.isSupportedByArrow( + ArrayType(StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StringType)) + ))) + )) + + // UDT: delegates to sqlType - supported when sqlType is Arrow-compatible + // ExamplePointUDT.sqlType = ArrayType(DoubleType) -> supported + assert(ArrowUtils.isSupportedByArrow(new ExamplePointUDT()), + "UDT with Arrow-supported sqlType should be supported") + assert(ArrowUtils.isSupportedByArrow(new SupportedUDT()), + "UDT with ArrayType(DoubleType) sqlType should be supported") + // UDT with ObjectType sqlType -> not supported (ObjectType is internal, not an Arrow type) + assert(!ArrowUtils.isSupportedByArrow(new UnsupportedUDT()), + "UDT with ObjectType sqlType should not be supported") + } + + test("verify Arrow cache serializer is actually used") { + val df = Seq((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + df.cache() + df.count() // Materialize the cache + + // Verify the query plan uses InMemoryTableScan + val plan = df.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined, "InMemoryTableScan should be present in cached query plan") + + // Verify the serializer is ArrowCachedBatchSerializer + val serializer = inMemoryScan.get.relation.cacheBuilder.serializer + assert(serializer.isInstanceOf[ArrowCachedBatchSerializer], + s"Expected ArrowCachedBatchSerializer but got ${serializer.getClass.getName}") + } + + test("columnar input with array type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3)), + (2, Seq(4, 5, 6)), + (3, Seq(7, 8, 9)) + ).toDF("id", "array_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3)), + Row(2, Seq(4, 5, 6)), + Row(3, Seq(7, 8, 9)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with struct type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, ("a", 10)), + (2, ("b", 20)), + (3, ("c", 30)) + ).toDF("id", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Row("a", 10)), + Row(2, Row("b", 20)), + Row(3, Row("c", 30)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with map type from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Map("a" -> 1, "b" -> 2)), + (2, Map("c" -> 3, "d" -> 4)), + (3, Map("e" -> 5, "f" -> 6)) + ).toDF("id", "map_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Map("a" -> 1, "b" -> 2)), + Row(2, Map("c" -> 3, "d" -> 4)), + Row(3, Map("e" -> 5, "f" -> 6)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with nested complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(("a", Seq(1, 2)), ("b", Seq(3, 4)))), + (2, Seq(("c", Seq(5, 6)), ("d", Seq(7, 8)))) + ).toDF("id", "nested_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(Row("a", Seq(1, 2)), Row("b", Seq(3, 4)))), + Row(2, Seq(Row("c", Seq(5, 6)), Row("d", Seq(7, 8)))) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with array of structs from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(("apple", 1.5), ("banana", 2.0))), + (2, Seq(("orange", 1.8), ("grape", 3.5))), + (3, Seq(("mango", 2.5))) + ).toDF("id", "items") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(Row("apple", 1.5), Row("banana", 2.0))), + Row(2, Seq(Row("orange", 1.8), Row("grape", 3.5))), + Row(3, Seq(Row("mango", 2.5))) + )) + + // Verify cache was used and operations work + val filtered = cached.filter($"id" > 1) + checkAnswer(filtered, Seq( + Row(2, Seq(Row("orange", 1.8), Row("grape", 3.5))), + Row(3, Seq(Row("mango", 2.5))) + )) + } + } + + test("columnar input with struct containing arrays from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, ("user1", Seq("tag1", "tag2", "tag3"))), + (2, ("user2", Seq("tag4", "tag5"))), + (3, ("user3", Seq("tag6"))) + ).toDF("id", "user_info") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Row("user1", Seq("tag1", "tag2", "tag3"))), + Row(2, Row("user2", Seq("tag4", "tag5"))), + Row(3, Row("user3", Seq("tag6"))) + )) + + // Verify we can access nested fields + val extracted = cached.select($"id", $"user_info._1".as("name")) + checkAnswer(extracted, Seq( + Row(1, "user1"), + Row(2, "user2"), + Row(3, "user3") + )) + } + } + + test("columnar input with map of arrays from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Map("a" -> Seq(1, 2, 3), "b" -> Seq(4, 5))), + (2, Map("c" -> Seq(6, 7), "d" -> Seq(8, 9, 10))) + ).toDF("id", "map_of_arrays") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Map("a" -> Seq(1, 2, 3), "b" -> Seq(4, 5))), + Row(2, Map("c" -> Seq(6, 7), "d" -> Seq(8, 9, 10))) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with null values in complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Some(Seq(1, 2, 3)), Some(("a", 10))), + (2, None, Some(("b", 20))), + (3, Some(Seq(4, 5)), None), + (4, None, None) + ).toDF("id", "array_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(2, null, Row("b", 20)), + Row(3, Seq(4, 5), null), + Row(4, null, null) + )) + + // Verify filtering works with nulls + val filtered = cached.filter($"array_col".isNotNull) + checkAnswer(filtered, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(3, Seq(4, 5), null) + )) + } + } + + test("columnar input with empty arrays and maps from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3), Map("a" -> 1)), + (2, Seq.empty[Int], Map.empty[String, Int]), + (3, Seq(4), Map("b" -> 2, "c" -> 3)) + ).toDF("id", "array_col", "map_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Map("a" -> 1)), + Row(2, Seq.empty[Int], Map.empty[String, Int]), + Row(3, Seq(4), Map("b" -> 2, "c" -> 3)) + )) + + // Verify cache was used + assert(cached.storageLevel.useMemory) + } + } + + test("columnar input with deeply nested structures from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + // Create a deeply nested structure: Array[Struct[Map[String, Array[Int]]]] + val df = Seq( + (1, Seq( + (Map("x" -> Seq(1, 2)), "data1"), + (Map("y" -> Seq(3, 4, 5)), "data2") + )), + (2, Seq( + (Map("z" -> Seq(6)), "data3") + )) + ).toDF("id", "deep_nested") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq( + Row(Map("x" -> Seq(1, 2)), "data1"), + Row(Map("y" -> Seq(3, 4, 5)), "data2") + )), + Row(2, Seq( + Row(Map("z" -> Seq(6)), "data3") + )) + )) + + // Verify operations work on deeply nested data + val result = cached.filter($"id" === 1) + assert(result.count() === 1) + } + } + + test("columnar input with mixed primitive and complex types from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, "name1", 100L, Seq(1, 2, 3), Map("k1" -> "v1"), ("nested", 99)), + (2, "name2", 200L, Seq(4, 5), Map("k2" -> "v2"), ("nested2", 88)), + (3, "name3", 300L, Seq(6), Map("k3" -> "v3"), ("nested3", 77)) + ).toDF("id", "name", "value", "array_col", "map_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, "name1", 100L, Seq(1, 2, 3), Map("k1" -> "v1"), Row("nested", 99)), + Row(2, "name2", 200L, Seq(4, 5), Map("k2" -> "v2"), Row("nested2", 88)), + Row(3, "name3", 300L, Seq(6), Map("k3" -> "v3"), Row("nested3", 77)) + )) + + // Verify column projection works + val projected = cached.select("id", "array_col", "struct_col") + checkAnswer(projected, Seq( + Row(1, Seq(1, 2, 3), Row("nested", 99)), + Row(2, Seq(4, 5), Row("nested2", 88)), + Row(3, Seq(6), Row("nested3", 77)) + )) + } + } + + test("columnar input with large complex types dataset from parquet") { + withTempPath { dir => + val path = dir.getAbsolutePath + // Create a larger dataset with complex types + val df = (1 to 1000).map { i => + (i, Seq(i, i * 2, i * 3), Map(s"key$i" -> i * 10), (s"struct$i", i * 100)) + }.toDF("id", "array_col", "map_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache - should use columnar input path + val cached = spark.read.parquet(path).cache() + + // Verify a filtered subset + val filtered = cached.filter($"id" > 990) + assert(filtered.count() === 10) + + // Verify content of filtered data + val result = filtered.collect().sortBy(_.getInt(0)) + assert(result.length === 10) + assert(result(0).getInt(0) === 991) + assert(result(0).getAs[Seq[Int]](1) === Seq(991, 1982, 2973)) + } + } + + test("columnar input with vectorized reader and complex types") { + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "true") { + withTempPath { dir => + val path = dir.getAbsolutePath + val df = Seq( + (1, Seq(1, 2, 3), ("a", 10)), + (2, Seq(4, 5, 6), ("b", 20)), + (3, Seq(7, 8, 9), ("c", 30)) + ).toDF("id", "array_col", "struct_col") + + // Write as parquet (columnar format) + df.write.parquet(path) + + // Read and cache with vectorized reader enabled + val cached = spark.read.parquet(path).cache() + checkAnswer(cached, Seq( + Row(1, Seq(1, 2, 3), Row("a", 10)), + Row(2, Seq(4, 5, 6), Row("b", 20)), + Row(3, Seq(7, 8, 9), Row("c", 30)) + )) + + // Verify vectorized reader is used + val plan = cached.queryExecution.executedPlan + val inMemoryScan = plan.collectFirst { + case scan: InMemoryTableScanExec => scan + } + assert(inMemoryScan.isDefined) + assert(inMemoryScan.get.supportsColumnar) + } + } + } + + test("InternalRow path (readValueFromVector) handles all supported data types") { + // Exercises every explicit type arm in readValueFromVector via the InternalRow fallback path + // (CACHE_VECTORIZED_READER_ENABLED=false, set at suite level). Each case verifies a full + // cache write -> read roundtrip including both non-null and null values. + + // --- Primitive types --- + + // BooleanType: BitVector.get(i) != 0 + val boolDf = Seq(Some(true), None, Some(false)).toDF("v") + boolDf.cache() + checkAnswer(boolDf, Seq(Row(true), Row(null), Row(false))) + boolDf.unpersist() + + // ByteType: TinyIntVector.get(i) + val byteDf = Seq(Some(1.toByte), None, Some(10.toByte)).toDF("v") + byteDf.cache() + checkAnswer(byteDf, Seq(Row(1.toByte), Row(null), Row(10.toByte))) + byteDf.unpersist() + + // ShortType: SmallIntVector.get(i) + val shortDf = Seq(Some(1.toShort), None, Some(100.toShort)).toDF("v") + shortDf.cache() + checkAnswer(shortDf, Seq(Row(1.toShort), Row(null), Row(100.toShort))) + shortDf.unpersist() + + // IntegerType: IntVector.get(i) + val intDf = Seq(Some(42), None, Some(-7)).toDF("v") + intDf.cache() + checkAnswer(intDf, Seq(Row(42), Row(null), Row(-7))) + intDf.unpersist() + + // LongType: BigIntVector.get(i) + val longDf = Seq(Some(100L), None, Some(-50L)).toDF("v") + longDf.cache() + checkAnswer(longDf, Seq(Row(100L), Row(null), Row(-50L))) + longDf.unpersist() + + // FloatType: Float4Vector.get(i) + val floatDf = Seq(Some(3.14f), None, Some(-1.0f)).toDF("v") + floatDf.cache() + checkAnswer(floatDf, Seq(Row(3.14f), Row(null), Row(-1.0f))) + floatDf.unpersist() + + // DoubleType: Float8Vector.get(i) + val doubleDf = Seq(Some(2.718), None, Some(-1.0)).toDF("v") + doubleDf.cache() + checkAnswer(doubleDf, Seq(Row(2.718), Row(null), Row(-1.0))) + doubleDf.unpersist() + + // --- String and Binary types --- + + // StringType: VarCharVector.get(i) -> UTF8String.fromBytes + val stringDf = Seq(Some("hello"), None, Some("world")).toDF("v") + stringDf.cache() + checkAnswer(stringDf, Seq(Row("hello"), Row(null), Row("world"))) + stringDf.unpersist() + + // BinaryType: VarBinaryVector.get(i) + val bytes1 = "hello".getBytes("UTF-8") + val bytes2 = "world".getBytes("UTF-8") + val binaryDf = Seq(bytes1, bytes2).toDF("v") + binaryDf.cache() + val binaryResult = binaryDf.collect() + assert(binaryResult(0).getAs[Array[Byte]](0) sameElements bytes1) + assert(binaryResult(1).getAs[Array[Byte]](0) sameElements bytes2) + binaryDf.unpersist() + + // DecimalType (compact, precision <= 18): fast path reads unscaled long from Arrow buffer + val decDf = Seq(Some(BigDecimal("123.45")), None, Some(BigDecimal("678.90"))).toDF("v") + decDf.cache() + checkAnswer(decDf, Seq(Row(BigDecimal("123.45")), Row(null), Row(BigDecimal("678.90")))) + decDf.unpersist() + + // DecimalType (compact, negative values): verifies sign-bit correctness when reading + // lower 8 bytes of Arrow's 128-bit little-endian two's-complement buffer as signed Long + val negDecData = Seq( + new java.math.BigDecimal("-123.45"), + new java.math.BigDecimal("0.00"), + new java.math.BigDecimal("-999999.99")) + val negDecDf = spark.createDataFrame( + spark.sparkContext.parallelize(negDecData.map(Row(_))), + StructType(Seq(StructField("v", DecimalType(10, 2))))) + negDecDf.cache() + checkAnswer(negDecDf, negDecData.map(d => Row(d))) + negDecDf.unpersist() + + // DecimalType (wide, precision > 18): slow path via DecimalVector.getObject -> BigDecimal + val wideDecData = Seq( + new java.math.BigDecimal("12345678901234567890.1234567890"), + new java.math.BigDecimal("-99999999999999999999.9999999999"), + new java.math.BigDecimal("0.0000000001")) + val wideDecDf = spark.createDataFrame( + spark.sparkContext.parallelize(wideDecData.map(Row(_))), + StructType(Seq(StructField("v", DecimalType(30, 10))))) + wideDecDf.cache() + checkAnswer(wideDecDf, wideDecData.map(d => Row(d))) + wideDecDf.unpersist() + + // --- Date and time types --- + + // DateType: DateDayVector.get(i) (days since epoch) + val dateDf = Seq(Some(Date.valueOf("2020-01-01")), None, Some(Date.valueOf("2025-12-31"))) + .toDF("v") + dateDf.cache() + checkAnswer(dateDf, + Seq(Row(Date.valueOf("2020-01-01")), Row(null), Row(Date.valueOf("2025-12-31")))) + dateDf.unpersist() + + // TimestampType: TimeStampMicroTZVector.get(i) (microseconds since epoch) + val ts1 = Timestamp.valueOf("2020-01-01 12:00:00") + val ts2 = Timestamp.valueOf("2025-06-15 00:00:00") + val tsDf = Seq(Some(ts1), None, Some(ts2)).toDF("v") + tsDf.cache() + checkAnswer(tsDf, Seq(Row(ts1), Row(null), Row(ts2))) + tsDf.unpersist() + + // TimestampNTZType: TimeStampMicroVector.get(i) (microseconds, no timezone) + val ldt1 = LocalDateTime.of(2020, 1, 1, 12, 0) + val ldt2 = LocalDateTime.of(2025, 6, 15, 0, 0) + val tsNtzDf = Seq(Some(ldt1), None, Some(ldt2)).toDF("v") + tsNtzDf.cache() + checkAnswer(tsNtzDf, Seq(Row(ldt1), Row(null), Row(ldt2))) + tsNtzDf.unpersist() + + // --- Interval types --- + + // YearMonthIntervalType: IntervalYearVector.get(i) (months) + val ymiSql = "SELECT INTERVAL '1-1' YEAR TO MONTH AS ymi" + val ymiDf = spark.sql(ymiSql) + ymiDf.cache() + checkAnswer(ymiDf, spark.sql(ymiSql)) + ymiDf.unpersist() + + // DayTimeIntervalType: DurationVector.get(int) returns ArrowBuf; must use static form + val dtiSql = "SELECT INTERVAL '1' DAY AS dti" + val dtiDf = spark.sql(dtiSql) + dtiDf.cache() + checkAnswer(dtiDf, spark.sql(dtiSql)) + dtiDf.unpersist() + + // TimeType: TimeNanoVector.get(i) (nanoseconds since midnight) + val timeDf = Seq(LocalTime.of(12, 30, 45), LocalTime.of(0, 0, 0)).toDF("t") + timeDf.cache() + checkAnswer(timeDf, Seq(Row(LocalTime.of(12, 30, 45)), Row(LocalTime.of(0, 0, 0)))) + timeDf.unpersist() + + // CalendarIntervalType: ArrowColumnVector.getInterval(i) (IntervalMonthDayNanoVector) + val interval = new CalendarInterval(1, 2, 3000000L) // 1 month, 2 days, 3 ms + val ciSchema = StructType(Seq(StructField("ci", CalendarIntervalType, nullable = true))) + val ciDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(interval), Row(null))), ciSchema) + ciDf.cache() + checkAnswer(ciDf, Seq(Row(interval), Row(null))) + ciDf.unpersist() + + // --- Null type --- + + // NullType: row.setNullAt without dispatching into readValueFromVector + val nullSchema = StructType(Seq(StructField("n", NullType, nullable = true))) + val nullDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(null), Row(null))), nullSchema) + nullDf.cache() + checkAnswer(nullDf, Seq(Row(null), Row(null))) + nullDf.unpersist() + + // --- Complex types --- + + // ArrayType: ArrowColumnVector.getArray(i) (ListVector) + val arrayDf = Seq(Seq(1, 2, 3), Seq(4, 5, 6)).toDF("v") + arrayDf.cache() + checkAnswer(arrayDf, Seq(Row(Seq(1, 2, 3)), Row(Seq(4, 5, 6)))) + arrayDf.unpersist() + + // StructType: ArrowColumnVector.getStruct(i) (StructVector) + val structSql = + "SELECT named_struct('a', 1, 'b', 'x') AS v " + + "UNION ALL SELECT named_struct('a', 2, 'b', 'y') AS v" + val structDf = spark.sql(structSql) + structDf.cache() + checkAnswer(structDf, spark.sql(structSql)) + structDf.unpersist() + + // MapType: ArrowColumnVector.getMap(i) (MapVector) + val mapDf = Seq(Map(1 -> "a"), Map(2 -> "b")).toDF("v") + mapDf.cache() + checkAnswer(mapDf, Seq(Row(Map(1 -> "a")), Row(Map(2 -> "b")))) + mapDf.unpersist() + + // UserDefinedType: dispatches to readValueFromVector with udt.sqlType (ArrayType(DoubleType)) + val point = new ExamplePoint(1.0, 2.0) + val udtSchema = StructType(Seq(StructField("p", new ExamplePointUDT(), nullable = true))) + val udtDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(point), Row(null))), udtSchema) + udtDf.cache() + checkAnswer(udtDf, Seq(Row(point), Row(null))) + udtDf.unpersist() + + // VariantType: ArrowColumnVector.getVariant(i) (StructVector) + val variantDf = spark.sql("SELECT parse_json('{\"a\":1}') AS v") + variantDf.cache() + checkAnswer(variantDf.selectExpr("to_json(v)"), Seq(Row("{\"a\":1}"))) + variantDf.unpersist() + } + + // Helper: cache a single-column DataFrame (row path) and return its ArrowCachedBatch stats. + // Stats layout per column: [lowerBound(0), upperBound(1), nullCount(2), rowCount(3), size(4)]. + private def cachedStats(df: org.apache.spark.sql.DataFrame) + : org.apache.spark.sql.catalyst.InternalRow = { + df.count() // trigger cache population + val relation = df.queryExecution.executedPlan.collectFirst { + case scan: InMemoryTableScanExec => scan.relation + }.get + relation.cacheBuilder.cachedColumnBuffers.first().asInstanceOf[ArrowCachedBatch].stats + } + + // Helper: creates a single-column, single-partition DataFrame backed by an RDD. + // LocalRelation can split across multiple partitions, causing cachedStats to see only the first + // partition's stats. sc.parallelize(data, numSlices=1) forces exactly one partition. + private def singlePartDf(values: Seq[Any], dt: DataType): org.apache.spark.sql.DataFrame = + spark.createDataFrame( + spark.sparkContext.parallelize(values.map(v => Row(v)), 1), + StructType(Seq(StructField("v", dt, nullable = true)))) + + test("createColumnStats returns the correct ColumnStats subclass for each supported type") { + // Direct unit test: verify the stats class dispatched for each Spark type, which determines + // whether partition pruning via min/max bounds is enabled. + + // Orderable types: createColumnStats returns a stats class that tracks min/max bounds. + assert(ArrowCachedBatchSerializer.createColumnStats(BooleanType) + .isInstanceOf[BooleanColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(ByteType).isInstanceOf[ByteColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(ShortType).isInstanceOf[ShortColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(IntegerType).isInstanceOf[IntColumnStats]) + // DateType is stored as Int (days since epoch) -> IntColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(DateType).isInstanceOf[IntColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(LongType).isInstanceOf[LongColumnStats]) + // TimestampType/NTZ stored as Long (microseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(TimestampType) + .isInstanceOf[LongColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(TimestampNTZType) + .isInstanceOf[LongColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(FloatType).isInstanceOf[FloatColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(DoubleType).isInstanceOf[DoubleColumnStats]) + // StringType (all collations) -> StringColumnStats with collation-aware semantic comparison + assert(ArrowCachedBatchSerializer.createColumnStats(StringType).isInstanceOf[StringColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + new StringType(1)).isInstanceOf[StringColumnStats]) // collationId 1 = UTF8_LCASE + assert(ArrowCachedBatchSerializer.createColumnStats( + StringType("UNICODE")).isInstanceOf[StringColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(DecimalType(10, 2)) + .isInstanceOf[DecimalColumnStats]) + // YearMonthIntervalType stored as Int (months) -> IntColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats( + YearMonthIntervalType()).isInstanceOf[IntColumnStats]) + // DayTimeIntervalType stored as Long (microseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats( + DayTimeIntervalType()).isInstanceOf[LongColumnStats]) + // TimeType stored as Long (nanoseconds) -> LongColumnStats + assert(ArrowCachedBatchSerializer.createColumnStats(TimeType(6)).isInstanceOf[LongColumnStats]) + + // Non-orderable types: createColumnStats returns a stats class with null bounds. + assert(ArrowCachedBatchSerializer.createColumnStats(BinaryType).isInstanceOf[BinaryColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + CalendarIntervalType).isInstanceOf[IntervalColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(VariantType) + .isInstanceOf[VariantColumnStats]) + + // Complex types and UDT: no natural ordering -> ObjectColumnStats (null bounds). + assert(ArrowCachedBatchSerializer.createColumnStats( + ArrayType(IntegerType)).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + StructType(Seq(StructField("a", IntegerType)))).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + MapType(StringType, IntegerType)).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats( + new ExamplePointUDT()).isInstanceOf[ObjectColumnStats]) + assert(ArrowCachedBatchSerializer.createColumnStats(NullType).isInstanceOf[ObjectColumnStats]) + } + + test("row path stats: orderable types produce correct min/max bounds") { + // DataFrames use the row path (InternalRowToArrowCachedBatchIterator), exercising + // createColumnStats + buildStatisticsFromCollectors. singlePartDf ensures all values land + // in one cached batch so cachedStats.first() sees the global min and max. + + // BooleanType: lower=false, upper=true + val boolDf = singlePartDf(Seq(false, true), BooleanType).cache() + val boolStats = cachedStats(boolDf) + assert(!boolStats.isNullAt(0) && !boolStats.isNullAt(1)) + assert(!boolStats.getBoolean(0) && boolStats.getBoolean(1)) + boolDf.unpersist() + + // ByteType: lower=1, upper=10 + val byteDf = singlePartDf(Seq(1.toByte, 10.toByte), ByteType).cache() + val byteStats = cachedStats(byteDf) + assert(!byteStats.isNullAt(0) && !byteStats.isNullAt(1)) + assert(byteStats.getByte(0) == 1.toByte && byteStats.getByte(1) == 10.toByte) + byteDf.unpersist() + + // ShortType: lower=1, upper=10 + val shortDf = singlePartDf(Seq(1.toShort, 10.toShort), ShortType).cache() + val shortStats = cachedStats(shortDf) + assert(!shortStats.isNullAt(0) && !shortStats.isNullAt(1)) + assert(shortStats.getShort(0) == 1.toShort && shortStats.getShort(1) == 10.toShort) + shortDf.unpersist() + + // IntegerType: lower=1, upper=10 + val intDf = singlePartDf(Seq(1, 10), IntegerType).cache() + val intStats = cachedStats(intDf) + assert(!intStats.isNullAt(0) && !intStats.isNullAt(1)) + assert(intStats.getInt(0) == 1 && intStats.getInt(1) == 10) + intDf.unpersist() + + // DateType: stored as Int (days since epoch); 2020-01-01 < 2025-01-01 + val dateDf = singlePartDf( + Seq(Date.valueOf("2020-01-01"), Date.valueOf("2025-01-01")), DateType).cache() + val dateStats = cachedStats(dateDf) + assert(!dateStats.isNullAt(0) && !dateStats.isNullAt(1)) + assert(dateStats.getInt(0) < dateStats.getInt(1)) + dateDf.unpersist() + + // LongType: lower=1L, upper=10L + val longDf = singlePartDf(Seq(1L, 10L), LongType).cache() + val longStats = cachedStats(longDf) + assert(!longStats.isNullAt(0) && !longStats.isNullAt(1)) + assert(longStats.getLong(0) == 1L && longStats.getLong(1) == 10L) + longDf.unpersist() + + // TimestampType: stored as Long (microseconds since epoch); 2020 < 2025 + val tsDf = singlePartDf( + Seq(Timestamp.valueOf("2020-01-01 00:00:00"), Timestamp.valueOf("2025-01-01 00:00:00")), + TimestampType).cache() + val tsStats = cachedStats(tsDf) + assert(!tsStats.isNullAt(0) && !tsStats.isNullAt(1)) + assert(tsStats.getLong(0) < tsStats.getLong(1)) + tsDf.unpersist() + + // TimestampNTZType: stored as Long (microseconds since epoch); 2020 < 2025 + val tsNtzDf = singlePartDf( + Seq(LocalDateTime.of(2020, 1, 1, 0, 0), LocalDateTime.of(2025, 1, 1, 0, 0)), + TimestampNTZType).cache() + val tsNtzStats = cachedStats(tsNtzDf) + assert(!tsNtzStats.isNullAt(0) && !tsNtzStats.isNullAt(1)) + assert(tsNtzStats.getLong(0) < tsNtzStats.getLong(1)) + tsNtzDf.unpersist() + + // FloatType: NaN is included but IEEE 754 comparisons with NaN are always false, + // so NaN never updates min/max; lower=1.0f, upper=10.0f + val floatDf = singlePartDf(Seq(1.0f, Float.NaN, 10.0f), FloatType).cache() + val floatStats = cachedStats(floatDf) + assert(!floatStats.isNullAt(0) && !floatStats.isNullAt(1)) + assert(floatStats.getFloat(0) == 1.0f && floatStats.getFloat(1) == 10.0f) + floatDf.unpersist() + + // DoubleType: same NaN-exclusion behavior via IEEE 754; lower=1.0, upper=10.0 + val doubleDf = singlePartDf(Seq(1.0, Double.NaN, 10.0), DoubleType).cache() + val doubleStats = cachedStats(doubleDf) + assert(!doubleStats.isNullAt(0) && !doubleStats.isNullAt(1)) + assert(doubleStats.getDouble(0) == 1.0 && doubleStats.getDouble(1) == 10.0) + doubleDf.unpersist() + + // StringType (UTF8_BINARY): "apple" < "zebra" in binary order + val stringDf = singlePartDf(Seq("apple", "zebra"), StringType).cache() + val stringStats = cachedStats(stringDf) + assert(!stringStats.isNullAt(0) && !stringStats.isNullAt(1)) + assert(stringStats.getUTF8String(0).toString == "apple") + assert(stringStats.getUTF8String(1).toString == "zebra") + stringDf.unpersist() + + // Collated StringType (UTF8_LCASE): semantic min/max uses case-insensitive comparison. + // "Apple" and "zebra": case-insensitively "apple" < "zebra", so lower="Apple", upper="zebra". + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val collatedSchema = StructType(Seq(StructField("v", collatedStringType, nullable = true))) + val collatedDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("Apple"), Row("zebra")), 1), + collatedSchema).cache() + val collatedStats = cachedStats(collatedDf) + assert(!collatedStats.isNullAt(0), "lower bound should not be null for collated StringType") + assert(!collatedStats.isNullAt(1), "upper bound should not be null for collated StringType") + assert(collatedStats.getUTF8String(0).toString == "Apple") // semantic min + assert(collatedStats.getUTF8String(1).toString == "zebra") // semantic max + collatedDf.unpersist() + + // DecimalType(10,2): lower=1.23, upper=9.87 + val decimalDf = singlePartDf( + Seq(new java.math.BigDecimal("1.23"), new java.math.BigDecimal("9.87")), + DecimalType(10, 2)).cache() + val decimalStats = cachedStats(decimalDf) + assert(!decimalStats.isNullAt(0) && !decimalStats.isNullAt(1)) + assert(decimalStats.getDecimal(0, 10, 2).compareTo(decimalStats.getDecimal(1, 10, 2)) < 0) + decimalDf.unpersist() + + // YearMonthIntervalType: stored as Int (months); Period.of(1,0,0)=12mo < Period.of(2,0,0)=24mo + val ymiDf = singlePartDf( + Seq(Period.of(1, 0, 0), Period.of(2, 0, 0)), YearMonthIntervalType()).cache() + val ymiStats = cachedStats(ymiDf) + assert(!ymiStats.isNullAt(0) && !ymiStats.isNullAt(1)) + assert(ymiStats.getInt(0) < ymiStats.getInt(1)) + ymiDf.unpersist() + + // DayTimeIntervalType: stored as Long (microseconds); 1 day < 2 days + val dtiDf = singlePartDf( + Seq(Duration.ofDays(1), Duration.ofDays(2)), DayTimeIntervalType()).cache() + val dtiStats = cachedStats(dtiDf) + assert(!dtiStats.isNullAt(0) && !dtiStats.isNullAt(1)) + assert(dtiStats.getLong(0) < dtiStats.getLong(1)) + dtiDf.unpersist() + + // TimeType: stored as Long (nanoseconds); 08:00 < 20:00 + val timeDf = singlePartDf( + Seq(LocalTime.of(8, 0, 0), LocalTime.of(20, 0, 0)), TimeType(6)).cache() + val timeStats = cachedStats(timeDf) + assert(!timeStats.isNullAt(0) && !timeStats.isNullAt(1)) + assert(timeStats.getLong(0) < timeStats.getLong(1)) + timeDf.unpersist() + } + + test("row path stats: non-orderable types produce null lower and upper bounds") { + // Verifies that types without natural ordering return null bounds so that partition pruning + // is safely disabled for them, preventing incorrect data exclusion. + def assertNullBounds(df: org.apache.spark.sql.DataFrame): Unit = { + val stats = cachedStats(df) + assert(stats.isNullAt(0), "lower bound should be null for non-orderable type") + assert(stats.isNullAt(1), "upper bound should be null for non-orderable type") + df.unpersist() + } + + // BinaryType: no natural total ordering + assertNullBounds(Seq(Array[Byte](1, 2), Array[Byte](3, 4)).toDF("v").cache()) + + // CalendarIntervalType: unordered composite (months + days + nanoseconds) + val ciSchema = StructType(Seq(StructField("v", CalendarIntervalType, nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(new CalendarInterval(1, 2, 3000000L)), + Row(new CalendarInterval(2, 0, 0L)))), + ciSchema).cache()) + + // ArrayType: no natural ordering + assertNullBounds(spark.sql( + "SELECT array(1, 2) AS v UNION ALL SELECT array(3, 4) AS v" + ).cache()) + + // StructType: no natural ordering + assertNullBounds(spark.sql( + "SELECT named_struct('i', 1, 's', 'a') AS v " + + "UNION ALL SELECT named_struct('i', 2, 's', 'b') AS v" + ).cache()) + + // MapType: no natural ordering + assertNullBounds(spark.sql( + "SELECT map(1, 'a') AS v UNION ALL SELECT map(2, 'b') AS v" + ).cache()) + + // UserDefinedType: no natural ordering + val udtSchema = StructType(Seq(StructField("v", new ExamplePointUDT(), nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)), Row(null))), + udtSchema).cache()) + + // NullType: all values are null by definition + val nullSchema = StructType(Seq(StructField("v", NullType, nullable = true))) + assertNullBounds(spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(null), Row(null))), + nullSchema).cache()) + + // VariantType: no natural ordering + assertNullBounds(spark.sql("SELECT parse_json('{\"k\":1}') AS v").cache()) + } + + test("row path stats: all-NaN Float/Double column produces inverted sentinel bounds") { + // FloatColumnStats and DoubleColumnStats initialize upper=MinValue, lower=MaxValue as + // sentinels. IEEE 754 comparisons with NaN are always false, so NaN never beats either + // sentinel. When every value is NaN, the sentinels are returned unchanged: lower=MaxValue, + // upper=MinValue (lower > upper). This differs from the Arrow path, which returns null bounds + // for all-NaN input (because calculateMinMaxFloat/Double explicitly skips NaN with !_.isNaN + // and returns (null, null) when hasValue stays false). + val floatDf = singlePartDf(Seq(Float.NaN), FloatType).cache() + val floatStats = cachedStats(floatDf) + assert(!floatStats.isNullAt(0), + "FloatType lower should not be null for all-NaN (sentinel used)") + assert(!floatStats.isNullAt(1), + "FloatType upper should not be null for all-NaN (sentinel used)") + assert(floatStats.getFloat(0) == Float.MaxValue, + s"FloatType lower expected Float.MaxValue (sentinel), got ${floatStats.getFloat(0)}") + assert(floatStats.getFloat(1) == Float.MinValue, + s"FloatType upper expected Float.MinValue (sentinel), got ${floatStats.getFloat(1)}") + floatDf.unpersist() + + val doubleDf = singlePartDf(Seq(Double.NaN), DoubleType).cache() + val doubleStats = cachedStats(doubleDf) + assert(!doubleStats.isNullAt(0), + "DoubleType lower should not be null for all-NaN (sentinel used)") + assert(!doubleStats.isNullAt(1), + "DoubleType upper should not be null for all-NaN (sentinel used)") + assert(doubleStats.getDouble(0) == Double.MaxValue, + s"DoubleType lower expected Double.MaxValue (sentinel), got ${doubleStats.getDouble(0)}") + assert(doubleStats.getDouble(1) == Double.MinValue, + s"DoubleType upper expected Double.MinValue (sentinel), got ${doubleStats.getDouble(1)}") + doubleDf.unpersist() + } + + test("collectStatistics produces correct min/max bounds for all orderable types") { + // Direct unit test of ArrowCachedBatchSerializer.collectStatistics, which is invoked whenever + // the input ColumnarBatch contains ArrowColumnVector columns (zero-copy path in + // ColumnarBatchToArrowCachedBatchIterator). Three rows [low, mid, high] ensure min/max are + // correctly identified for each type. + val serializer = new ArrowCachedBatchSerializer() + + val schema = Seq( + AttributeReference("bool_col", BooleanType)(), // BitVector + AttributeReference("byte_col", ByteType)(), // TinyIntVector + AttributeReference("short_col", ShortType)(), // SmallIntVector + AttributeReference("float_col", FloatType)(), // Float4Vector + AttributeReference("double_col", DoubleType)(), // Float8Vector + AttributeReference("date_col", DateType)(), // DateDayVector (days since epoch) + AttributeReference("ts_col", TimestampType)(), // TimeStampMicroTZVector (microseconds) + AttributeReference("ts_ntz_col", TimestampNTZType)(),// TimeStampMicroVector (microseconds) + AttributeReference("int_col", IntegerType)(), // IntVector (standalone) + AttributeReference("long_col", LongType)(), // BigIntVector (standalone) + AttributeReference("decimal_col", DecimalType(10, 2))(), // DecimalVector + AttributeReference("ymi_col", YearMonthIntervalType())(), // IntervalYearVector (months) + AttributeReference("dti_col", DayTimeIntervalType())(), // DurationVector (microseconds) + AttributeReference("time_col", TimeType(6))() // TimeNanoVector (nanoseconds) + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val boolVector = root.getVector("bool_col").asInstanceOf[BitVector] + val byteVector = root.getVector("byte_col").asInstanceOf[TinyIntVector] + val shortVector = root.getVector("short_col").asInstanceOf[SmallIntVector] + val floatVector = root.getVector("float_col").asInstanceOf[Float4Vector] + val doubleVector = root.getVector("double_col").asInstanceOf[Float8Vector] + val dateVector = root.getVector("date_col").asInstanceOf[DateDayVector] + val tsVector = root.getVector("ts_col").asInstanceOf[TimeStampMicroTZVector] + val tsNtzVector = root.getVector("ts_ntz_col").asInstanceOf[TimeStampMicroVector] + val intVector = root.getVector("int_col").asInstanceOf[IntVector] + val longVector = root.getVector("long_col").asInstanceOf[BigIntVector] + val decimalVector = root.getVector("decimal_col").asInstanceOf[DecimalVector] + val ymiVector = root.getVector("ymi_col") + .asInstanceOf[org.apache.arrow.vector.IntervalYearVector] + val dtiVector = root.getVector("dti_col") + .asInstanceOf[org.apache.arrow.vector.DurationVector] + val timeVector = root.getVector("time_col").asInstanceOf[TimeNanoVector] + + // Row 0: low values + boolVector.setSafe(0, 0) // false + byteVector.setSafe(0, 1.toByte) + shortVector.setSafe(0, 100.toShort) + floatVector.setSafe(0, 1.0f) + doubleVector.setSafe(0, 1.0) + dateVector.setSafe(0, 18262) // 2020-01-01 + tsVector.setSafe(0, 1577836800000000L) // 2020-01-01 00:00:00 UTC in microseconds + tsNtzVector.setSafe(0, 1577836800000000L) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + decimalVector.setSafe(0, new java.math.BigDecimal("1.23")) + ymiVector.setSafe(0, 12) // 1 year = 12 months + dtiVector.setSafe(0, 86400000000L) // 1 day in microseconds + timeVector.setSafe(0, 28800000000000L) // 08:00:00 in nanoseconds + + // Row 1: mid values -- Float/Double use NaN to verify NaN is excluded from min/max + boolVector.setSafe(1, 1) // true -- becomes the max + byteVector.setSafe(1, 5.toByte) + shortVector.setSafe(1, 500.toShort) + floatVector.setSafe(1, Float.NaN) // NaN: must not affect lower=1.0f or upper=10.0f + doubleVector.setSafe(1, Double.NaN) // NaN: must not affect lower=1.0 or upper=10.0 + dateVector.setSafe(1, 19000) + tsVector.setSafe(1, 1700000000000000L) + tsNtzVector.setSafe(1, 1700000000000000L) + intVector.setSafe(1, 5) + longVector.setSafe(1, 5L) + decimalVector.setSafe(1, new java.math.BigDecimal("5.55")) + ymiVector.setSafe(1, 18) // 1.5 years = 18 months + dtiVector.setSafe(1, 172800000000L) // 2 days in microseconds + timeVector.setSafe(1, 43200000000000L) // 12:00:00 in nanoseconds + + // Row 2: high values + boolVector.setSafe(2, 0) // false again (3 rows; bool max stays true from row 1) + byteVector.setSafe(2, 10.toByte) + shortVector.setSafe(2, 1000.toShort) + floatVector.setSafe(2, 10.0f) + doubleVector.setSafe(2, 10.0) + dateVector.setSafe(2, 20000) + tsVector.setSafe(2, 1800000000000000L) + tsNtzVector.setSafe(2, 1800000000000000L) + intVector.setSafe(2, 10) + longVector.setSafe(2, 10L) + decimalVector.setSafe(2, new java.math.BigDecimal("9.87")) + ymiVector.setSafe(2, 24) // 2 years = 24 months + dtiVector.setSafe(2, 259200000000L) // 3 days in microseconds + timeVector.setSafe(2, 72000000000000L) // 20:00:00 in nanoseconds + + root.setRowCount(3) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // Stats layout: [lower(0), upper(1), nullCount(2), rowCount(3), sizeInBytes(4)] per column. + // col0 BooleanType (offset 0): lower=false, upper=true + assert(!stats.getBoolean(0), s"BooleanType lower expected false, got ${stats.getBoolean(0)}") + assert(stats.getBoolean(1), s"BooleanType upper expected true, got ${stats.getBoolean(1)}") + + // col1 ByteType (offset 5): lower=1, upper=10 + assert(stats.getByte(5) == 1.toByte, s"ByteType lower=${stats.getByte(5)}") + assert(stats.getByte(6) == 10.toByte, s"ByteType upper=${stats.getByte(6)}") + + // col2 ShortType (offset 10): lower=100, upper=1000 + assert(stats.getShort(10) == 100.toShort, s"ShortType lower=${stats.getShort(10)}") + assert(stats.getShort(11) == 1000.toShort, s"ShortType upper=${stats.getShort(11)}") + + // col3 FloatType (offset 15): lower=1.0f, upper=10.0f + assert(stats.getFloat(15) == 1.0f, s"FloatType lower=${stats.getFloat(15)}") + assert(stats.getFloat(16) == 10.0f, s"FloatType upper=${stats.getFloat(16)}") + + // col4 DoubleType (offset 20): lower=1.0, upper=10.0 + assert(stats.getDouble(20) == 1.0, s"DoubleType lower=${stats.getDouble(20)}") + assert(stats.getDouble(21) == 10.0, s"DoubleType upper=${stats.getDouble(21)}") + + // col5 DateType (offset 25): lower=18262 (2020-01-01), upper=20000 + assert(stats.getInt(25) == 18262, s"DateType lower=${stats.getInt(25)}") + assert(stats.getInt(26) == 20000, s"DateType upper=${stats.getInt(26)}") + + // col6 TimestampType (offset 30): lower < upper (microseconds since epoch) + assert(stats.getLong(30) == 1577836800000000L, + s"TimestampType lower=${stats.getLong(30)}") + assert(stats.getLong(31) == 1800000000000000L, + s"TimestampType upper=${stats.getLong(31)}") + + // col7 TimestampNTZType (offset 35): lower < upper (microseconds, no timezone) + assert(stats.getLong(35) == 1577836800000000L, + s"TimestampNTZType lower=${stats.getLong(35)}") + assert(stats.getLong(36) == 1800000000000000L, + s"TimestampNTZType upper=${stats.getLong(36)}") + + // col8 IntegerType (offset 40): lower=1, upper=10 + assert(stats.getInt(40) == 1, s"IntegerType lower=${stats.getInt(40)}") + assert(stats.getInt(41) == 10, s"IntegerType upper=${stats.getInt(41)}") + + // col9 LongType (offset 45): lower=1L, upper=10L + assert(stats.getLong(45) == 1L, s"LongType lower=${stats.getLong(45)}") + assert(stats.getLong(46) == 10L, s"LongType upper=${stats.getLong(46)}") + + // col10 DecimalType(10,2) (offset 50): lower=1.23, upper=9.87 + assert(stats.getDecimal(50, 10, 2).toJavaBigDecimal.compareTo( + new java.math.BigDecimal("1.23")) == 0, + s"DecimalType lower=${stats.getDecimal(50, 10, 2)}") + assert(stats.getDecimal(51, 10, 2).toJavaBigDecimal.compareTo( + new java.math.BigDecimal("9.87")) == 0, + s"DecimalType upper=${stats.getDecimal(51, 10, 2)}") + + // col11 YearMonthIntervalType (offset 55): lower=12 months (1yr), upper=24 months (2yr) + assert(stats.getInt(55) == 12, s"YearMonthIntervalType lower=${stats.getInt(55)}") + assert(stats.getInt(56) == 24, s"YearMonthIntervalType upper=${stats.getInt(56)}") + + // col12 DayTimeIntervalType (offset 60): lower=1 day, upper=3 days (in microseconds) + assert(stats.getLong(60) == 86400000000L, + s"DayTimeIntervalType lower=${stats.getLong(60)}") + assert(stats.getLong(61) == 259200000000L, + s"DayTimeIntervalType upper=${stats.getLong(61)}") + + // col13 TimeType (offset 65): lower=08:00:00 (28800000000000ns), + // upper=20:00:00 (72000000000000ns) + assert(stats.getLong(65) == 28800000000000L, + s"TimeType lower=${stats.getLong(65)}") + assert(stats.getLong(66) == 72000000000000L, + s"TimeType upper=${stats.getLong(66)}") + + // All null counts should be 0 + (0 until 14).foreach { col => + assert(stats.getInt(col * 5 + 2) == 0, s"nullCount for col$col should be 0") + } + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + test("collectStatistics produces correct min/max bounds for StringType") { + // StringType in Arrow is stored as VarCharVector (raw UTF-8 bytes). This test covers the + // two distinct code paths in calculateMinMaxString: binary (UTF8_BINARY) and collation-aware + // semantic (collated). The collated case directly exercises the Bug 2 fix: before the fix, + // `case StringType =>` (singleton) did not match collated types so they returned null bounds. + + // UTF8_BINARY: binary-order comparison. + // {"apple", "cherry", "banana"} -> lower=apple, upper=cherry + { + val schema = Seq(AttributeReference("str_col", StringType)()) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + try { + root.allocateNew() + val strVector = root.getVector("str_col").asInstanceOf[VarCharVector] + strVector.setSafe(0, "apple".getBytes("UTF-8"), 0, 5) + strVector.setSafe(1, "cherry".getBytes("UTF-8"), 0, 6) + strVector.setSafe(2, "banana".getBytes("UTF-8"), 0, 6) + root.setRowCount(3) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + assert(!stats.isNullAt(0), "UTF8_BINARY lower bound should not be null") + assert(!stats.isNullAt(1), "UTF8_BINARY upper bound should not be null") + assert(stats.getUTF8String(0).toString == "apple", + s"UTF8_BINARY lower expected 'apple', got ${stats.getUTF8String(0)}") + assert(stats.getUTF8String(1).toString == "cherry", + s"UTF8_BINARY upper expected 'cherry', got ${stats.getUTF8String(1)}") + root.close() + } catch { + case e: Exception => root.close(); throw e + } + } + + // UTF8_LCASE (collationId=1): case-insensitive semantic comparison. + // Data: {"Apple", "banana", "Cherry"} + // Binary order: "Apple"(A=65) < "Cherry"(C=67) < "banana"(b=98) -> binary max = "banana" + // Semantic order: apple < banana < cherry -> semantic max = "Cherry" + // Asserting upper == "Cherry" (not "banana") verifies collation-aware semanticCompare is used. + { + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val schema = Seq(AttributeReference("str_col", collatedStringType)()) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + try { + root.allocateNew() + val strVector = root.getVector("str_col").asInstanceOf[VarCharVector] + strVector.setSafe(0, "Apple".getBytes("UTF-8"), 0, 5) + strVector.setSafe(1, "banana".getBytes("UTF-8"), 0, 6) + strVector.setSafe(2, "Cherry".getBytes("UTF-8"), 0, 6) + root.setRowCount(3) + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + assert(!stats.isNullAt(0), "UTF8_LCASE lower bound should not be null") + assert(!stats.isNullAt(1), "UTF8_LCASE upper bound should not be null") + assert(stats.getUTF8String(0).toString == "Apple", + s"UTF8_LCASE lower expected 'Apple' (semantic min), got ${stats.getUTF8String(0)}") + // "Cherry" is the semantic max (case-insensitively: cherry > banana > apple). + // "banana" would be the binary max -- asserting "Cherry" proves semanticCompare is used. + assert(stats.getUTF8String(1).toString == "Cherry", + s"UTF8_LCASE upper expected 'Cherry' (semantic max), got ${stats.getUTF8String(1)}") + root.close() + } catch { + case e: Exception => root.close(); throw e + } + } + } + + test("collectStatistics returns null bounds when all Float/Double values are NaN") { + // When every non-null value in a Float or Double column is NaN, calculateMinMaxFloat/Double + // finds no valid (non-NaN) values. hasValue stays false -> returns (null, null) -> null bounds. + // Null bounds disable partition pruning, ensuring NaN-only batches are never incorrectly + // pruned. + val schema = Seq( + AttributeReference("float_col", FloatType)(), + AttributeReference("double_col", DoubleType)() + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val floatVector = root.getVector("float_col").asInstanceOf[Float4Vector] + val doubleVector = root.getVector("double_col").asInstanceOf[Float8Vector] + + floatVector.setSafe(0, Float.NaN) + floatVector.setSafe(1, Float.NaN) + doubleVector.setSafe(0, Double.NaN) + doubleVector.setSafe(1, Double.NaN) + root.setRowCount(2) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // FloatType (col0, offset 0): no valid values -> null bounds + assert(stats.isNullAt(0), "FloatType lower bound should be null when all values are NaN") + assert(stats.isNullAt(1), "FloatType upper bound should be null when all values are NaN") + + // DoubleType (col1, offset 5): no valid values -> null bounds + assert(stats.isNullAt(5), "DoubleType lower bound should be null when all values are NaN") + assert(stats.isNullAt(6), "DoubleType upper bound should be null when all values are NaN") + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + test("collectStatistics returns null bounds for non-orderable types") { + // BinaryType has no natural ordering, so its lower and upper bounds must be null. + // Null bounds disable partition pruning for those columns, preventing incorrect data exclusion. + // A control IntegerType column confirms bounds are per-type, not per-batch. + val schema = Seq( + AttributeReference("bin_col", BinaryType)(), // VarBinaryVector -- unordered + AttributeReference("int_col", IntegerType)() // IntVector -- orderable (control column) + ) + val sparkSchema = StructType(schema.map(a => StructField(a.name, a.dataType))) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + + try { + root.allocateNew() + val binVector = root.getVector("bin_col").asInstanceOf[VarBinaryVector] + val intVector = root.getVector("int_col").asInstanceOf[IntVector] + + binVector.setSafe(0, "hello".getBytes("UTF-8")) + binVector.setSafe(1, "world".getBytes("UTF-8")) + intVector.setSafe(0, 1) + intVector.setSafe(1, 10) + root.setRowCount(2) + + val stats = ArrowCachedBatchSerializer.collectStatistics(root, schema) + + // BinaryType (col0, offset 0): both bounds must be null -- no ordering defined + assert(stats.isNullAt(0), "BinaryType lower bound should be null") + assert(stats.isNullAt(1), "BinaryType upper bound should be null") + assert(stats.getInt(2) == 0, "BinaryType null count should be 0") + assert(stats.getInt(3) == 2, "BinaryType row count should be 2") + + // IntegerType (col1, offset 5): bounds should be non-null and correct + assert(!stats.isNullAt(5), "IntegerType lower bound should not be null") + assert(!stats.isNullAt(6), "IntegerType upper bound should not be null") + assert(stats.getInt(5) == 1, s"IntegerType lower=${stats.getInt(5)}") + assert(stats.getInt(6) == 10, s"IntegerType upper=${stats.getInt(6)}") + + root.close() + } catch { + case e: Exception => + root.close() + throw e + } + } + + // ------------------------------------------------------------------------- + // Collated string bug fixes + // ------------------------------------------------------------------------- + + test("caching collated string columns does not throw UnsupportedOperationException") { + // Bug: readValueFromVector used `case StringType =>` (singleton match) which only matches + // UTF8_BINARY. Collated StringType instances (e.g. UTF8_LCASE, UNICODE) are separate class + // instances and fell through to `case other => throw UnsupportedOperationException(...)`. + // Fix: use `case _: StringType =>` to match all string type instances. + Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach { collation => + withTable("tbl") { + sql(s"CACHE TABLE tbl AS SELECT col FROM VALUES " + + s"('hello' COLLATE $collation), ('world' COLLATE $collation) AS t(col)") + checkAnswer( + sql("SELECT col FROM tbl"), + Seq(Row("hello"), Row("world"))) + } + } + } + + test("caching collated string columns with null values reads correctly") { + // Verify that null collated string values are also handled correctly in readValueFromVector. + withTable("tbl") { + sql("CACHE TABLE tbl AS SELECT col FROM VALUES " + + "('a' COLLATE UTF8_LCASE), (null), ('B' COLLATE UTF8_LCASE) AS t(col)") + checkAnswer( + sql("SELECT col FROM tbl"), + Seq(Row("a"), Row(null), Row("B"))) + } + } + + test("filter on cached collated column uses correct semantic stats for partition pruning") { + // Bug: collectStatistics used `case StringType =>` (singleton), so collated string columns + // got null min/max stats. When InMemoryTableScanExec evaluated the partition filter + // (e.g. col = 'a') against null bounds, SQL null was coerced to false and the batch was + // incorrectly pruned, causing queries to return empty results even when matching rows exist. + // Fix: use `case st: StringType =>` and pass st.collationId to calculateMinMaxString so + // stats are computed with collation-aware semanticCompare, matching + // DefaultCachedBatchSerializer. + withTable("tbl") { + // Cache the table so InMemoryTableScanExec is used with partition-filter pushdown. + sql("CACHE TABLE tbl AS SELECT col FROM VALUES " + + "('a' COLLATE UTF8_LCASE), ('B' COLLATE UTF8_LCASE), ('c' COLLATE UTF8_LCASE) AS t(col)") + + // 'a' is in the table; with null stats (before fix) the batch would be incorrectly pruned. + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Seq(Row("a"))) + // 'B' is in the table; UTF8_LCASE: 'b' == 'B', so this matches 'B'. + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'B'"), Seq(Row("B"))) + // 'z' is not in the table; result should be empty (not incorrectly pruned to empty). + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'z'"), Seq.empty) + } + } + + test("row path stats for collated strings use collation-aware semantic comparison") { + // Bug: createColumnStats used `case StringType =>` (singleton), so collated string columns + // got StringColumnStats(StringType) -- i.e., the wrong collation ID (UTF8_BINARY=0) -- instead + // of StringColumnStats(collatedType). Since StringColumnStats uses semanticCompare(collationId) + // for ordering, passing the wrong collation ID produced binary-order stats for collated + // columns, + // which could incorrectly prune batches for case-insensitive or locale-sensitive collations. + // Fix: use `case st: StringType => new StringColumnStats(st)`. + // + // Test: cache {"Apple", "banana", "Cherry"} with UTF8_LCASE. + // Binary order: "Apple" < "Cherry" < "banana" (uppercase < lowercase in ASCII). + // Semantic (case-insensitive) order: "Apple" < "banana" < "Cherry". + // So semantic lower="Apple", upper="Cherry"; binary lower="Apple", upper="banana". + // A filter WHERE col = 'cherry' should match "Cherry" semantically but not return empty. + val collatedStringType = new StringType(1) // collationId 1 = UTF8_LCASE + val schema = StructType(Seq(StructField("v", collatedStringType, nullable = true))) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("Apple"), Row("banana"), Row("Cherry")), 1), + schema).cache() + val stats = cachedStats(df) + // With correct semantic stats: lower="Apple", upper="Cherry" (case-insensitive order) + // "apple" <= "Apple" <= "banana" <= "Cherry" <= "cherry" semantically. + assert(!stats.isNullAt(0), "lower bound should not be null for collated StringType") + assert(!stats.isNullAt(1), "upper bound should not be null for collated StringType") + assert(stats.getUTF8String(0).toString == "Apple") // semantic min (case-insensitive) + assert(stats.getUTF8String(1).toString == "Cherry") // semantic max (case-insensitive) + df.unpersist() + } +} + +/** + * Tests that ArrowCachedBatch and ArrowCachedBatchSerializer are registered in KryoSerializer. + * Without the registration, persisting with DISK_ONLY storage level would fail when + * spark.kryo.registrationRequired=true because Kryo rejects unregistered classes. + */ +class ArrowCachedBatchKryoRegistrationSuite extends QueryTest with SharedSparkSession { + + override def sparkConf: SparkConf = super.sparkConf + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, classOf[ArrowCachedBatchSerializer].getName) + .set("spark.kryo.registrationRequired", "true") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + override def beforeAll(): Unit = { + super.beforeAll() + InMemoryRelation.clearSerializer() + } + + override def afterAll(): Unit = { + InMemoryRelation.clearSerializer() + super.afterAll() + } + + test("ArrowCachedBatch and ArrowCachedBatchSerializer are registered in KryoSerializer") { + withTable("t1") { + sql("CREATE TABLE t1 AS SELECT 1 AS a") + checkAnswer(sql("SELECT * FROM t1").persist(StorageLevel.DISK_ONLY), Seq(Row(1))) + } + } +}