diff --git a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java index c197a4fd6f..320250f25d 100644 --- a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java +++ b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java @@ -20,7 +20,6 @@ import org.apache.parquet.bytes.ByteBufferAllocator; import org.apache.parquet.bytes.BytesInput; -import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.bytes.CapacityByteArrayOutputStream; import org.apache.parquet.column.Encoding; import org.apache.parquet.column.values.ValuesWriter; @@ -29,10 +28,24 @@ public abstract class ByteStreamSplitValuesWriter extends ValuesWriter { + /** + * Batch size for buffered scatter writes. Values are accumulated in a batch buffer + * and flushed as bulk {@code write(byte[], off, len)} calls to each stream, replacing + * N individual single-byte writes with one bulk write per stream per flush. + */ + private static final int BATCH_SIZE = 128; + protected final int numStreams; protected final int elementSizeInBytes; private final CapacityByteArrayOutputStream[] byteStreams; + // Batch buffers for int (4-byte) and long (8-byte) scatter writes. + // Only one of these is ever non-null per instance. + private int[] intBatch; + private long[] longBatch; + private byte[] scatterBuf; + private int batchCount; + public ByteStreamSplitValuesWriter( int elementSizeInBytes, int initialCapacity, int pageSize, ByteBufferAllocator allocator) { if (elementSizeInBytes <= 0) { @@ -53,7 +66,8 @@ public ByteStreamSplitValuesWriter( @Override public long getBufferedSize() { - long totalSize = 0; + // Include unflushed batch values without triggering a flush + long totalSize = (long) batchCount * elementSizeInBytes; for (CapacityByteArrayOutputStream stream : this.byteStreams) { totalSize += stream.size(); } @@ -62,6 +76,7 @@ public long getBufferedSize() { @Override public BytesInput getBytes() { + flushBatch(); BytesInput[] allInputs = new BytesInput[this.numStreams]; for (int i = 0; i < this.numStreams; ++i) { allInputs[i] = BytesInput.from(this.byteStreams[i]); @@ -76,6 +91,7 @@ public Encoding getEncoding() { @Override public void reset() { + batchCount = 0; for (CapacityByteArrayOutputStream stream : this.byteStreams) { stream.reset(); } @@ -83,6 +99,7 @@ public void reset() { @Override public void close() { + batchCount = 0; for (CapacityByteArrayOutputStream stream : byteStreams) { stream.close(); } @@ -99,6 +116,71 @@ protected void scatterBytes(byte[] bytes) { } } + /** + * Buffer a 4-byte integer value for batched scatter to the byte streams. + * Values are accumulated until the batch is full, then flushed as bulk + * {@code write(byte[], off, len)} calls — one per stream. + */ + protected void bufferInt(int v) { + if (intBatch == null) { + intBatch = new int[BATCH_SIZE]; + scatterBuf = new byte[BATCH_SIZE]; + } + intBatch[batchCount++] = v; + if (batchCount == BATCH_SIZE) { + flushIntBatch(); + } + } + + /** + * Buffer an 8-byte long value for batched scatter to the byte streams. + */ + protected void bufferLong(long v) { + if (longBatch == null) { + longBatch = new long[BATCH_SIZE]; + scatterBuf = new byte[BATCH_SIZE]; + } + longBatch[batchCount++] = v; + if (batchCount == BATCH_SIZE) { + flushLongBatch(); + } + } + + private void flushBatch() { + if (batchCount == 0) return; + if (intBatch != null) { + flushIntBatch(); + } else if (longBatch != null) { + flushLongBatch(); + } + } + + private void flushIntBatch() { + if (batchCount == 0) return; + final int count = batchCount; + for (int stream = 0; stream < 4; stream++) { + final int shift = stream << 3; // stream * 8 + for (int i = 0; i < count; i++) { + scatterBuf[i] = (byte) (intBatch[i] >>> shift); + } + byteStreams[stream].write(scatterBuf, 0, count); + } + batchCount = 0; + } + + private void flushLongBatch() { + if (batchCount == 0) return; + final int count = batchCount; + for (int stream = 0; stream < 8; stream++) { + final int shift = stream << 3; // stream * 8 + for (int i = 0; i < count; i++) { + scatterBuf[i] = (byte) (longBatch[i] >>> shift); + } + byteStreams[stream].write(scatterBuf, 0, count); + } + batchCount = 0; + } + @Override public long getAllocatedSize() { long totalCapacity = 0; @@ -116,7 +198,7 @@ public FloatByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteB @Override public void writeFloat(float v) { - super.scatterBytes(BytesUtils.intToBytes(Float.floatToIntBits(v))); + bufferInt(Float.floatToIntBits(v)); } @Override @@ -133,7 +215,7 @@ public DoubleByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byte @Override public void writeDouble(double v) { - super.scatterBytes(BytesUtils.longToBytes(Double.doubleToLongBits(v))); + bufferLong(Double.doubleToLongBits(v)); } @Override @@ -149,7 +231,7 @@ public IntegerByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byt @Override public void writeInteger(int v) { - super.scatterBytes(BytesUtils.intToBytes(v)); + bufferInt(v); } @Override @@ -165,7 +247,7 @@ public LongByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteBu @Override public void writeLong(long v) { - super.scatterBytes(BytesUtils.longToBytes(v)); + bufferLong(v); } @Override