Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import org.apache.parquet.bytes.ByteBufferAllocator;
import org.apache.parquet.bytes.BytesInput;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.bytes.CapacityByteArrayOutputStream;
import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.values.ValuesWriter;
Expand All @@ -29,10 +28,24 @@

public abstract class ByteStreamSplitValuesWriter extends ValuesWriter {

/**
* Batch size for buffered scatter writes. Values are accumulated in a batch buffer
* and flushed as bulk {@code write(byte[], off, len)} calls to each stream, replacing
* N individual single-byte writes with one bulk write per stream per flush.
*/
private static final int BATCH_SIZE = 128;

protected final int numStreams;
protected final int elementSizeInBytes;
private final CapacityByteArrayOutputStream[] byteStreams;

// Batch buffers for int (4-byte) and long (8-byte) scatter writes.
// Only one of these is ever non-null per instance.
private int[] intBatch;
private long[] longBatch;
private byte[] scatterBuf;
private int batchCount;

public ByteStreamSplitValuesWriter(
int elementSizeInBytes, int initialCapacity, int pageSize, ByteBufferAllocator allocator) {
if (elementSizeInBytes <= 0) {
Expand All @@ -53,7 +66,8 @@ public ByteStreamSplitValuesWriter(

@Override
public long getBufferedSize() {
long totalSize = 0;
// Include unflushed batch values without triggering a flush
long totalSize = (long) batchCount * elementSizeInBytes;
for (CapacityByteArrayOutputStream stream : this.byteStreams) {
totalSize += stream.size();
}
Expand All @@ -62,6 +76,7 @@ public long getBufferedSize() {

@Override
public BytesInput getBytes() {
flushBatch();
BytesInput[] allInputs = new BytesInput[this.numStreams];
for (int i = 0; i < this.numStreams; ++i) {
allInputs[i] = BytesInput.from(this.byteStreams[i]);
Expand All @@ -76,13 +91,15 @@ public Encoding getEncoding() {

@Override
public void reset() {
batchCount = 0;
for (CapacityByteArrayOutputStream stream : this.byteStreams) {
stream.reset();
}
}

@Override
public void close() {
batchCount = 0;
for (CapacityByteArrayOutputStream stream : byteStreams) {
stream.close();
}
Expand All @@ -99,6 +116,71 @@ protected void scatterBytes(byte[] bytes) {
}
}

/**
* Buffer a 4-byte integer value for batched scatter to the byte streams.
* Values are accumulated until the batch is full, then flushed as bulk
* {@code write(byte[], off, len)} calls — one per stream.
*/
protected void bufferInt(int v) {
if (intBatch == null) {
intBatch = new int[BATCH_SIZE];
scatterBuf = new byte[BATCH_SIZE];
}
intBatch[batchCount++] = v;
if (batchCount == BATCH_SIZE) {
flushIntBatch();
}
}

/**
* Buffer an 8-byte long value for batched scatter to the byte streams.
*/
protected void bufferLong(long v) {
if (longBatch == null) {
longBatch = new long[BATCH_SIZE];
scatterBuf = new byte[BATCH_SIZE];
}
longBatch[batchCount++] = v;
if (batchCount == BATCH_SIZE) {
flushLongBatch();
}
}

private void flushBatch() {
if (batchCount == 0) return;
if (intBatch != null) {
flushIntBatch();
} else if (longBatch != null) {
flushLongBatch();
}
}

private void flushIntBatch() {
if (batchCount == 0) return;
final int count = batchCount;
for (int stream = 0; stream < 4; stream++) {
final int shift = stream << 3; // stream * 8
for (int i = 0; i < count; i++) {
scatterBuf[i] = (byte) (intBatch[i] >>> shift);
}
byteStreams[stream].write(scatterBuf, 0, count);
}
batchCount = 0;
}

private void flushLongBatch() {
if (batchCount == 0) return;
final int count = batchCount;
for (int stream = 0; stream < 8; stream++) {
final int shift = stream << 3; // stream * 8
for (int i = 0; i < count; i++) {
scatterBuf[i] = (byte) (longBatch[i] >>> shift);
}
byteStreams[stream].write(scatterBuf, 0, count);
}
batchCount = 0;
}

@Override
public long getAllocatedSize() {
long totalCapacity = 0;
Expand All @@ -116,7 +198,7 @@ public FloatByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteB

@Override
public void writeFloat(float v) {
super.scatterBytes(BytesUtils.intToBytes(Float.floatToIntBits(v)));
bufferInt(Float.floatToIntBits(v));
}

@Override
Expand All @@ -133,7 +215,7 @@ public DoubleByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byte

@Override
public void writeDouble(double v) {
super.scatterBytes(BytesUtils.longToBytes(Double.doubleToLongBits(v)));
bufferLong(Double.doubleToLongBits(v));
}

@Override
Expand All @@ -149,7 +231,7 @@ public IntegerByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byt

@Override
public void writeInteger(int v) {
super.scatterBytes(BytesUtils.intToBytes(v));
bufferInt(v);
}

@Override
Expand All @@ -165,7 +247,7 @@ public LongByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteBu

@Override
public void writeLong(long v) {
super.scatterBytes(BytesUtils.longToBytes(v));
bufferLong(v);
}

@Override
Expand Down