From 43589bda2c7f49c094809c4ffde47389cb4e9e8d Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Fri, 24 Oct 2025 12:32:19 +0200 Subject: [PATCH 1/6] Utilize memory allocator in ReaderProperties.GetStream --- internal/utils/buf_reader.go | 51 +++++++++++++++++++++++++++++++----- parquet/file/page_reader.go | 6 ++++- parquet/reader_properties.go | 11 +++++--- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/internal/utils/buf_reader.go b/internal/utils/buf_reader.go index c222c8bd..7762f7db 100644 --- a/internal/utils/buf_reader.go +++ b/internal/utils/buf_reader.go @@ -22,6 +22,8 @@ import ( "errors" "fmt" "io" + + "github.com/apache/arrow-go/v18/arrow/memory" ) type Reader interface { @@ -38,6 +40,7 @@ type byteReader struct { // NewByteReader creates a new ByteReader instance from the given byte slice. // It wraps the bytes.NewReader function to implement BufferedReader interface. +// It is considered not to own the underlying byte slice, so the Free method is a no-op. func NewByteReader(buf []byte) *byteReader { r := bytes.NewReader(buf) return &byteReader{ @@ -108,10 +111,43 @@ func (r *byteReader) Reset(Reader) {} func (r *byteReader) BufferSize() int { return len(r.buf) } +func (r *byteReader) Free() {} + +// bytesBufferReader is a byte slice with a bytes reader wrapped around it. +// It uses an allocator to allocate and free the underlying byte slice. +type bytesBufferReader struct { + alloc memory.Allocator + byteReader +} + +// NewBytesBufferReader creates a new bytesBufferReader with the given size and allocator. +func NewBytesBufferReader(size int, alloc memory.Allocator) *bytesBufferReader { + buf := alloc.Allocate(size) + return &bytesBufferReader{ + alloc: alloc, + byteReader: byteReader{ + bytes.NewReader(buf), + buf, + 0, + }, + } +} + +// Outer returns the underlying byte slice. +func (r *bytesBufferReader) Buffer() []byte { + return r.buf +} + +// Free releases the underlying byte slice back to the allocator. +func (r *bytesBufferReader) Free() { + r.alloc.Free(r.buf) +} + // bufferedReader is similar to bufio.Reader except // it will expand the buffer if necessary when asked to Peek // more bytes than are in the buffer type bufferedReader struct { + alloc memory.Allocator // allocator used to allocate the buffer bufferSz int buf []byte r, w int @@ -122,9 +158,10 @@ type bufferedReader struct { // NewBufferedReader returns a buffered reader with similar semantics to bufio.Reader // except Peek will expand the internal buffer if needed rather than return // an error. -func NewBufferedReader(rd Reader, sz int) *bufferedReader { +func NewBufferedReader(rd Reader, sz int, alloc memory.Allocator) *bufferedReader { r := &bufferedReader{ - rd: rd, + alloc: alloc, + rd: rd, } r.resizeBuffer(sz) return r @@ -140,11 +177,9 @@ func (b *bufferedReader) Reset(rd Reader) { func (b *bufferedReader) resetBuffer() { if b.buf == nil { - b.buf = make([]byte, b.bufferSz) + b.buf = b.alloc.Allocate(b.bufferSz) } else if b.bufferSz > cap(b.buf) { - buf := b.buf - b.buf = make([]byte, b.bufferSz) - copy(b.buf, buf) + b.buf = b.alloc.Reallocate(b.bufferSz, b.buf) } else { b.buf = b.buf[:b.bufferSz] } @@ -298,3 +333,7 @@ func (b *bufferedReader) Read(p []byte) (n int, err error) { b.r += n return n, nil } + +func (b *bufferedReader) Free() { + b.alloc.Free(b.buf) +} diff --git a/parquet/file/page_reader.go b/parquet/file/page_reader.go index 1ba7ecbe..4bfd9fe3 100644 --- a/parquet/file/page_reader.go +++ b/parquet/file/page_reader.go @@ -383,6 +383,9 @@ func (p *serializedPageReader) Close() error { p.dictPageBuffer.Release() p.dataPageBuffer.Release() } + if p.r != nil { + p.r.Free() + } return nil } @@ -550,7 +553,8 @@ func (p *serializedPageReader) GetDictionaryPage() (*DictionaryPage, error) { readBufSize := min(int(p.dataOffset-p.baseOffset), p.r.BufferSize()) rd := utils.NewBufferedReader( io.NewSectionReader(p.r.Outer(), p.dictOffset-p.baseOffset, p.dataOffset-p.baseOffset), - readBufSize) + readBufSize, + p.mem) if err := p.readPageHeader(rd, hdr); err != nil { return nil, err } diff --git a/parquet/reader_properties.go b/parquet/reader_properties.go index 5b119dcf..11ead445 100644 --- a/parquet/reader_properties.go +++ b/parquet/reader_properties.go @@ -52,6 +52,7 @@ type BufferedReader interface { Outer() utils.Reader BufferSize() int Reset(utils.Reader) + Free() io.Reader } @@ -74,17 +75,19 @@ func (r *ReaderProperties) Allocator() memory.Allocator { return r.alloc } // into a buffer in memory and return a bytes.NewReader for that buffer. func (r *ReaderProperties) GetStream(source io.ReaderAt, start, nbytes int64) (BufferedReader, error) { if r.BufferedStreamEnabled { - return utils.NewBufferedReader(io.NewSectionReader(source, start, nbytes), int(r.BufferSize)), nil + return utils.NewBufferedReader(io.NewSectionReader(source, start, nbytes), int(r.BufferSize), r.alloc), nil } - data := make([]byte, nbytes) - n, err := source.ReadAt(data, start) + buf := utils.NewBytesBufferReader(int(nbytes), r.alloc) + n, err := source.ReadAt(buf.Buffer(), start) if err != nil { + buf.Free() return nil, fmt.Errorf("parquet: tried reading from file, but got error: %w", err) } if n != int(nbytes) { + buf.Free() return nil, fmt.Errorf("parquet: tried reading %d bytes starting at position %d from file but only got %d", nbytes, start, n) } - return utils.NewByteReader(data), nil + return buf, nil } From c9269043977e6d9c8df49220d66ad82f461c32e5 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Fri, 24 Oct 2025 13:13:21 +0200 Subject: [PATCH 2/6] Ensure that decompressBuffer doesn't get reallocated by io.CopyN --- parquet/file/page_reader.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/parquet/file/page_reader.go b/parquet/file/page_reader.go index 4bfd9fe3..c73dfe73 100644 --- a/parquet/file/page_reader.go +++ b/parquet/file/page_reader.go @@ -504,7 +504,16 @@ func (p *serializedPageReader) Page() Page { } func (p *serializedPageReader) decompress(rd io.Reader, lenCompressed int, buf []byte) ([]byte, error) { - p.decompressBuffer.ResizeNoShrink(lenCompressed) + // As of go1.25.3: There is an issue when bytes.Buffer and io.CopyN are used together. io.CopyN + // uses io.LimitReader, which does an additional read on the underlying reader to determine EOF. + // However, bytes.Buffer always attempts to read at least bytes.MinRead (which is 512 bytes) from the + // underlying reader, even if there is less data available than that. So even if there are no more bytes, + // the buffer must have at least bytes.MinRead capacity remaining to avoid a relocation. + allocSize := lenCompressed + if p.decompressBuffer.Cap() < lenCompressed+bytes.MinRead { + allocSize = lenCompressed + bytes.MinRead + } + p.decompressBuffer.ResizeNoShrink(allocSize) b := bytes.NewBuffer(p.decompressBuffer.Bytes()[:0]) if _, err := io.CopyN(b, rd, int64(lenCompressed)); err != nil { return nil, err From eec120a3c923a70ddb935e3000bc141bf8ef1fc7 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Fri, 24 Oct 2025 19:33:31 +0200 Subject: [PATCH 3/6] fixup! Ensure that decompressBuffer doesn't get reallocated by io.CopyN --- parquet/file/page_reader.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/parquet/file/page_reader.go b/parquet/file/page_reader.go index c73dfe73..5e35baae 100644 --- a/parquet/file/page_reader.go +++ b/parquet/file/page_reader.go @@ -509,11 +509,10 @@ func (p *serializedPageReader) decompress(rd io.Reader, lenCompressed int, buf [ // However, bytes.Buffer always attempts to read at least bytes.MinRead (which is 512 bytes) from the // underlying reader, even if there is less data available than that. So even if there are no more bytes, // the buffer must have at least bytes.MinRead capacity remaining to avoid a relocation. - allocSize := lenCompressed if p.decompressBuffer.Cap() < lenCompressed+bytes.MinRead { - allocSize = lenCompressed + bytes.MinRead + p.decompressBuffer.Reserve(lenCompressed + bytes.MinRead) } - p.decompressBuffer.ResizeNoShrink(allocSize) + p.decompressBuffer.ResizeNoShrink(lenCompressed) b := bytes.NewBuffer(p.decompressBuffer.Bytes()[:0]) if _, err := io.CopyN(b, rd, int64(lenCompressed)); err != nil { return nil, err From 8edb901956a3676f366c1a2fdaf030af4519b3bb Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Tue, 28 Oct 2025 09:13:13 +0100 Subject: [PATCH 4/6] fixup! Ensure that decompressBuffer doesn't get reallocated by io.CopyN --- internal/utils/buf_reader.go | 6 +++ parquet/file/page_reader.go | 20 +++----- parquet/file/row_group_reader.go | 2 + parquet/reader_writer_properties_test.go | 65 +++++++++++++++++++++++- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/internal/utils/buf_reader.go b/internal/utils/buf_reader.go index 7762f7db..46a16da2 100644 --- a/internal/utils/buf_reader.go +++ b/internal/utils/buf_reader.go @@ -122,6 +122,9 @@ type bytesBufferReader struct { // NewBytesBufferReader creates a new bytesBufferReader with the given size and allocator. func NewBytesBufferReader(size int, alloc memory.Allocator) *bytesBufferReader { + if alloc == nil { + alloc = memory.DefaultAllocator + } buf := alloc.Allocate(size) return &bytesBufferReader{ alloc: alloc, @@ -159,6 +162,9 @@ type bufferedReader struct { // except Peek will expand the internal buffer if needed rather than return // an error. func NewBufferedReader(rd Reader, sz int, alloc memory.Allocator) *bufferedReader { + if alloc == nil { + alloc = memory.DefaultAllocator + } r := &bufferedReader{ alloc: alloc, rd: rd, diff --git a/parquet/file/page_reader.go b/parquet/file/page_reader.go index 5e35baae..89e34532 100644 --- a/parquet/file/page_reader.go +++ b/parquet/file/page_reader.go @@ -17,7 +17,6 @@ package file import ( - "bytes" "errors" "fmt" "io" @@ -504,21 +503,16 @@ func (p *serializedPageReader) Page() Page { } func (p *serializedPageReader) decompress(rd io.Reader, lenCompressed int, buf []byte) ([]byte, error) { - // As of go1.25.3: There is an issue when bytes.Buffer and io.CopyN are used together. io.CopyN - // uses io.LimitReader, which does an additional read on the underlying reader to determine EOF. - // However, bytes.Buffer always attempts to read at least bytes.MinRead (which is 512 bytes) from the - // underlying reader, even if there is less data available than that. So even if there are no more bytes, - // the buffer must have at least bytes.MinRead capacity remaining to avoid a relocation. - if p.decompressBuffer.Cap() < lenCompressed+bytes.MinRead { - p.decompressBuffer.Reserve(lenCompressed + bytes.MinRead) - } p.decompressBuffer.ResizeNoShrink(lenCompressed) - b := bytes.NewBuffer(p.decompressBuffer.Bytes()[:0]) - if _, err := io.CopyN(b, rd, int64(lenCompressed)); err != nil { + data := p.decompressBuffer.Bytes() + n, err := io.ReadFull(rd, data) + if err != nil { return nil, err } + if n != lenCompressed { + return nil, fmt.Errorf("parquet: expected to read %d bytes but only read %d", lenCompressed, n) + } - data := p.decompressBuffer.Bytes() if p.cryptoCtx.DataDecryptor != nil { data = p.cryptoCtx.DataDecryptor.Decrypt(p.decompressBuffer.Bytes()) } @@ -563,6 +557,7 @@ func (p *serializedPageReader) GetDictionaryPage() (*DictionaryPage, error) { io.NewSectionReader(p.r.Outer(), p.dictOffset-p.baseOffset, p.dataOffset-p.baseOffset), readBufSize, p.mem) + defer rd.Free() if err := p.readPageHeader(rd, hdr); err != nil { return nil, err } @@ -774,6 +769,7 @@ func (p *serializedPageReader) Next() bool { firstRowIdx := p.rowsSeen p.rowsSeen += int64(dataHeader.GetNumValues()) + data, err := p.decompress(p.r, lenCompressed, buf.Bytes()) if err != nil { p.err = err diff --git a/parquet/file/row_group_reader.go b/parquet/file/row_group_reader.go index ea5f7098..bf75db46 100644 --- a/parquet/file/row_group_reader.go +++ b/parquet/file/row_group_reader.go @@ -134,11 +134,13 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) { } if r.fileDecryptor == nil { + stream.Free() return nil, xerrors.New("column in rowgroup is encrypted, but no file decryptor") } const encryptedRowGroupsLimit = 32767 if i > encryptedRowGroupsLimit { + stream.Free() return nil, xerrors.New("encrypted files cannot contain more than 32767 column chunks") } diff --git a/parquet/reader_writer_properties_test.go b/parquet/reader_writer_properties_test.go index 00b26a83..a8e9b752 100644 --- a/parquet/reader_writer_properties_test.go +++ b/parquet/reader_writer_properties_test.go @@ -18,6 +18,7 @@ package parquet_test import ( "bytes" + "errors" "testing" "github.com/apache/arrow-go/v18/arrow/memory" @@ -67,7 +68,67 @@ func TestReaderPropsGetStreamInsufficient(t *testing.T) { buf := memory.NewBufferBytes([]byte(data)) rdr := bytes.NewReader(buf.Bytes()) - props := parquet.NewReaderProperties(nil) - _, err := props.GetStream(rdr, 12, 15) + props1 := parquet.NewReaderProperties(nil) + _, err := props1.GetStream(rdr, 12, 15) + assert.Error(t, err) +} + +type mockReaderAt struct{} + +func (m *mockReaderAt) ReadAt(p []byte, off int64) (int, error) { + return 0, errors.New("mock error") +} + +func TestReaderPropsGetStreamWithAllocator(t *testing.T) { + pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer pool.AssertSize(t, 0) + + data := "data to read" + buf := memory.NewBufferBytes([]byte(data)) + rdr := bytes.NewReader(buf.Bytes()) + + // no leak on success + props := parquet.NewReaderProperties(pool) + bufRdr, err := props.GetStream(rdr, 0, int64(len(data))) + assert.NoError(t, err) + bufRdr.Free() + + // no leak on reader error + _, err = props.GetStream(&mockReaderAt{}, 0, 10) + assert.Error(t, err) + + // no leak on insufficient read + _, err = props.GetStream(rdr, 0, int64(len(data)+10)) assert.Error(t, err) } + +func TestReaderPropsGetStreamBufferedWithAllocator(t *testing.T) { + pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer pool.AssertSize(t, 0) + + data := "data to read" + rdr := bytes.NewReader(memory.NewBufferBytes([]byte(data)).Bytes()) + + props := parquet.NewReaderProperties(pool) + props.BufferedStreamEnabled = true + + buf := make([]byte, len(data)) + bufRdr, err := props.GetStream(rdr, 0, int64(len(data))) + assert.NoError(t, err) + _, err = bufRdr.Read(buf) + assert.NoError(t, err) + bufRdr.Free() + + bufRdr, err = props.GetStream(&mockReaderAt{}, 0, 10) + assert.NoError(t, err) + _, err = bufRdr.Read(buf) + assert.Error(t, err) + bufRdr.Free() + + bufRdr, err = props.GetStream(rdr, 0, int64(len(data)+10)) + assert.NoError(t, err) + n, err := bufRdr.Read(buf) + assert.NoError(t, err) + assert.NotEqual(t, len(data)+10, n) + bufRdr.Free() +} From c7bee5efdde478f7b8b021da299397d218fe68b5 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Tue, 28 Oct 2025 13:01:59 +0100 Subject: [PATCH 5/6] Read uncompressed data directly into the page buffer --- internal/utils/buf_reader.go | 2 + parquet/file/page_reader.go | 73 ++++++++++++++++++++++++++++-------- parquet/reader_properties.go | 2 + 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/internal/utils/buf_reader.go b/internal/utils/buf_reader.go index 46a16da2..7ba904d1 100644 --- a/internal/utils/buf_reader.go +++ b/internal/utils/buf_reader.go @@ -111,6 +111,8 @@ func (r *byteReader) Reset(Reader) {} func (r *byteReader) BufferSize() int { return len(r.buf) } +func (r *byteReader) Buffered() int { return len(r.buf) - r.pos } + func (r *byteReader) Free() {} // bytesBufferReader is a byte slice with a bytes reader wrapped around it. diff --git a/parquet/file/page_reader.go b/parquet/file/page_reader.go index 89e34532..9d528cb1 100644 --- a/parquet/file/page_reader.go +++ b/parquet/file/page_reader.go @@ -374,6 +374,8 @@ type serializedPageReader struct { dataPageBuffer *memory.Buffer dictPageBuffer *memory.Buffer err error + + isCompressed bool } func (p *serializedPageReader) Close() error { @@ -402,6 +404,7 @@ func (p *serializedPageReader) init(compressType compress.Compression, ctx *Cryp return err } p.codec = codec + p.isCompressed = compressType != compress.Codecs.Uncompressed if ctx != nil { p.cryptoCtx = *ctx @@ -444,6 +447,7 @@ func NewPageReader(r parquet.BufferedReader, nrows int64, compressType compress. dictPageBuffer: memory.NewResizableBuffer(mem), } rdr.decompressBuffer.ResizeNoShrink(defaultPageHeaderSize) + rdr.isCompressed = compressType != compress.Codecs.Uncompressed if ctx != nil { rdr.cryptoCtx = *ctx rdr.initDecryption() @@ -460,6 +464,8 @@ func (p *serializedPageReader) Reset(r parquet.BufferedReader, nrows int64, comp if p.err != nil { return } + p.isCompressed = compressType != compress.Codecs.Uncompressed + if ctx != nil { p.cryptoCtx = *ctx p.initDecryption() @@ -502,6 +508,36 @@ func (p *serializedPageReader) Page() Page { return p.curPage } +func (p *serializedPageReader) stealFromBuffer(br parquet.BufferedReader, lenUncompressed int) ([]byte, error) { + data, err := br.Peek(lenUncompressed) + if err != nil { + return nil, err + } + if p.cryptoCtx.DataDecryptor != nil { + data = p.cryptoCtx.DataDecryptor.Decrypt(data) + } + // advance the reader + _, err = br.Discard(lenUncompressed) + if err != nil && err != io.EOF { + return nil, err + } + return data, nil +} + +func (p *serializedPageReader) readUncompressed(br parquet.BufferedReader, lenUncompressed int, buf []byte) ([]byte, error) { + n, err := io.ReadFull(br, buf[:lenUncompressed]) + if err != nil { + return nil, err + } + if n != lenUncompressed { + return nil, fmt.Errorf("parquet: expected to read %d bytes but only read %d", lenUncompressed, n) + } + if p.cryptoCtx.DataDecryptor != nil { + buf = p.cryptoCtx.DataDecryptor.Decrypt(buf) + } + return buf, nil +} + func (p *serializedPageReader) decompress(rd io.Reader, lenCompressed int, buf []byte) ([]byte, error) { p.decompressBuffer.ResizeNoShrink(lenCompressed) data := p.decompressBuffer.Bytes() @@ -583,12 +619,9 @@ func (p *serializedPageReader) GetDictionaryPage() (*DictionaryPage, error) { return nil, errors.New("parquet: invalid page header (negative number of values)") } - p.dictPageBuffer.ResizeNoShrink(lenUncompressed) - buf := memory.NewBufferBytes(p.dictPageBuffer.Bytes()) - - data, err := p.decompress(rd, lenCompressed, buf.Bytes()) + data, err := p.getPageBytes(rd, p.isCompressed, lenCompressed, lenUncompressed, p.dictPageBuffer) if err != nil { - return nil, err + return nil, fmt.Errorf("parquet: could not read dictionary page data: %w", err) } if len(data) != lenUncompressed { return nil, fmt.Errorf("parquet: metadata said %d bytes uncompressed dictionary page, got %d bytes", lenUncompressed, len(data)) @@ -596,7 +629,7 @@ func (p *serializedPageReader) GetDictionaryPage() (*DictionaryPage, error) { return &DictionaryPage{ page: page{ - buf: buf, + buf: memory.NewBufferBytes(data), typ: hdr.Type, nvals: dictHeader.GetNumValues(), encoding: dictHeader.GetEncoding(), @@ -693,6 +726,20 @@ func (p *serializedPageReader) SeekToPageWithRow(rowIdx int64) error { return p.err } +func (p *serializedPageReader) getPageBytes( + r parquet.BufferedReader, isCompressed bool, lenCompressed, lenUncompressed int, buffer *memory.Buffer, +) ([]byte, error) { + if isCompressed { + buffer.ResizeNoShrink(lenUncompressed) + return p.decompress(r, lenCompressed, buffer.Bytes()) + } + if r.Buffered() >= lenCompressed { + return p.stealFromBuffer(r, lenCompressed) + } + buffer.ResizeNoShrink(lenUncompressed) + return p.readUncompressed(r, lenCompressed, buffer.Bytes()) +} + func (p *serializedPageReader) Next() bool { // Loop here because there may be unhandled page types that we skip until // finding a page that we do know what to do with @@ -732,10 +779,7 @@ func (p *serializedPageReader) Next() bool { return false } - p.dictPageBuffer.ResizeNoShrink(lenUncompressed) - buf := memory.NewBufferBytes(p.dictPageBuffer.Bytes()) - - data, err := p.decompress(p.r, lenCompressed, buf.Bytes()) + data, err := p.getPageBytes(p.r, p.isCompressed, lenCompressed, lenUncompressed, p.dictPageBuffer) if err != nil { p.err = err return false @@ -748,7 +792,7 @@ func (p *serializedPageReader) Next() bool { // make dictionary page p.curPage = &DictionaryPage{ page: page{ - buf: buf, + buf: memory.NewBufferBytes(data), typ: p.curPageHdr.Type, nvals: dictHeader.GetNumValues(), encoding: dictHeader.GetEncoding(), @@ -764,13 +808,10 @@ func (p *serializedPageReader) Next() bool { return false } - p.dataPageBuffer.ResizeNoShrink(lenUncompressed) - buf := memory.NewBufferBytes(p.dataPageBuffer.Bytes()) - firstRowIdx := p.rowsSeen p.rowsSeen += int64(dataHeader.GetNumValues()) - data, err := p.decompress(p.r, lenCompressed, buf.Bytes()) + data, err := p.getPageBytes(p.r, p.isCompressed, lenCompressed, lenUncompressed, p.dataPageBuffer) if err != nil { p.err = err return false @@ -783,7 +824,7 @@ func (p *serializedPageReader) Next() bool { // make datapagev1 p.curPage = &DataPageV1{ page: page{ - buf: buf, + buf: memory.NewBufferBytes(data), typ: p.curPageHdr.Type, nvals: dataHeader.GetNumValues(), encoding: dataHeader.GetEncoding(), diff --git a/parquet/reader_properties.go b/parquet/reader_properties.go index 11ead445..254cf217 100644 --- a/parquet/reader_properties.go +++ b/parquet/reader_properties.go @@ -50,6 +50,8 @@ type BufferedReader interface { Peek(int) ([]byte, error) Discard(int) (int, error) Outer() utils.Reader + // Buffered returns the number of bytes already read and stored in the buffer + Buffered() int BufferSize() int Reset(utils.Reader) Free() From 1dde56d243f73085acdb9527c2e16d1cb537639d Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Thu, 30 Oct 2025 19:05:14 +0100 Subject: [PATCH 6/6] Fix encryption for DataPageV2 --- parquet/file/column_reader_test.go | 152 +++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/parquet/file/column_reader_test.go b/parquet/file/column_reader_test.go index 25c26bc8..51e68cf7 100644 --- a/parquet/file/column_reader_test.go +++ b/parquet/file/column_reader_test.go @@ -33,7 +33,9 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/internal/utils" "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/compress" "github.com/apache/arrow-go/v18/parquet/file" + "github.com/apache/arrow-go/v18/parquet/internal/encryption" "github.com/apache/arrow-go/v18/parquet/internal/testutils" "github.com/apache/arrow-go/v18/parquet/pqarrow" "github.com/apache/arrow-go/v18/parquet/schema" @@ -42,6 +44,17 @@ import ( "github.com/stretchr/testify/suite" ) +const ( + FooterEncryptionKey = "0123456789012345" + ColumnEncryptionKey1 = "1234567890123450" + ColumnEncryptionKey2 = "1234567890123451" + ColumnEncryptionKey3 = "1234567890123452" + FooterEncryptionKeyID = "kf" + ColumnEncryptionKey1ID = "kc1" + ColumnEncryptionKey2ID = "kc2" + ColumnEncryptionKey3ID = "kc3" +) + func initValues(values reflect.Value) { if values.Kind() != reflect.Slice { panic("must init values with slice") @@ -813,6 +826,145 @@ func TestFullSeekRow(t *testing.T) { } } +func checkDecryptedValues(t *testing.T, writerProps *parquet.WriterProperties, readProps *parquet.ReaderProperties) { + sc := arrow.NewSchema([]arrow.Field{ + {Name: "c0", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "c1", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "c2", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64), Nullable: true}, + }, nil) + + tbl, err := array.TableFromJSON(mem, sc, []string{`[ + {"c0": 1, "c1": "a", "c2": [1]}, + {"c0": 2, "c1": "b", "c2": [1, 2]}, + {"c0": 3, "c1": "c", "c2": [null]}, + {"c0": null, "c1": "d", "c2": []}, + {"c0": 5, "c1": null, "c2": [3, 3, 3]}, + {"c0": 6, "c1": "f", "c2": null} + ]`}) + require.NoError(t, err) + defer tbl.Release() + + schema := tbl.Schema() + arrWriterProps := pqarrow.NewArrowWriterProperties() + + var buf bytes.Buffer + wr, err := pqarrow.NewFileWriter(schema, &buf, writerProps, arrWriterProps) + require.NoError(t, err) + + require.NoError(t, wr.WriteTable(tbl, tbl.NumRows())) + require.NoError(t, wr.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), file.WithReadProps(readProps)) + require.NoError(t, err) + defer rdr.Close() + + rgr := rdr.RowGroup(0) + col0, err := rgr.Column(0) + require.NoError(t, err) + + icr := col0.(*file.Int64ColumnChunkReader) + // require.NoError(t, icr.SeekToRow(3)) // TODO: this causes a panic currently + + vals := make([]int64, 6) + defLvls := make([]int16, 6) + repLvls := make([]int16, 6) + + totalLvls, read, err := icr.ReadBatch(6, vals, defLvls, repLvls) + require.NoError(t, err) + assert.EqualValues(t, 6, totalLvls) + assert.EqualValues(t, 5, read) + assert.Equal(t, []int64{1, 2, 3, 5, 6}, vals[:read]) + assert.Equal(t, []int16{1, 1, 1, 0, 1, 1}, defLvls[:totalLvls]) + assert.Equal(t, []int16{0, 0, 0, 0, 0, 0}, repLvls[:totalLvls]) + + col1, err := rgr.Column(1) + require.NoError(t, err) + + scr := col1.(*file.ByteArrayColumnChunkReader) + + bavals := make([]parquet.ByteArray, 6) + badefLvls := make([]int16, 6) + barepLvls := make([]int16, 6) + + totalLvls, read, err = scr.ReadBatch(6, bavals, badefLvls, barepLvls) + require.NoError(t, err) + assert.EqualValues(t, 6, totalLvls) + assert.EqualValues(t, 5, read) + expectedBAs := []parquet.ByteArray{ + []byte("a"), + []byte("b"), + []byte("c"), + []byte("d"), + []byte("f"), + } + assert.Equal(t, expectedBAs, bavals[:read]) + assert.Equal(t, []int16{1, 1, 1, 1, 0, 1}, badefLvls[:totalLvls]) + assert.Equal(t, []int16{0, 0, 0, 0, 0, 0}, barepLvls[:totalLvls]) + + col2, err := rgr.Column(2) + require.NoError(t, err) + + lcr := col2.(*file.Int64ColumnChunkReader) + vals = make([]int64, 10) + defLvls = make([]int16, 10) + repLvls = make([]int16, 10) + totalLvls, read, err = lcr.ReadBatch(6, vals, defLvls, repLvls) + require.NoError(t, err) + + assert.EqualValues(t, 6, totalLvls) + assert.EqualValues(t, 4, read) + + assert.Equal(t, []int64{1, 1, 2, 3}, vals[:read]) + assert.Equal(t, []int16{3, 3, 3, 2, 1, 3}, defLvls[:totalLvls]) + assert.Equal(t, []int16{0, 0, 1, 0, 0, 0}, repLvls[:totalLvls]) +} + +func TestDecryptColumns(t *testing.T) { + encryptCols := make(parquet.ColumnPathToEncryptionPropsMap) + encryptCols["c0"] = parquet.NewColumnEncryptionProperties("c0", parquet.WithKey(ColumnEncryptionKey1), parquet.WithKeyID(ColumnEncryptionKey1ID)) + encryptCols["c1"] = parquet.NewColumnEncryptionProperties("c1", parquet.WithKey(ColumnEncryptionKey2), parquet.WithKeyID(ColumnEncryptionKey2ID)) + encryptCols["c2.list.element"] = parquet.NewColumnEncryptionProperties("c2.list.element", parquet.WithKey(ColumnEncryptionKey3), parquet.WithKeyID(ColumnEncryptionKey3ID)) + encryptProps := parquet.NewFileEncryptionProperties(FooterEncryptionKey, parquet.WithFooterKeyMetadata(FooterEncryptionKeyID), + parquet.WithEncryptedColumns(encryptCols), parquet.WithAlg(parquet.AesCtr)) + + stringKr1 := make(encryption.StringKeyIDRetriever) + stringKr1.PutKey(FooterEncryptionKeyID, FooterEncryptionKey) + stringKr1.PutKey(ColumnEncryptionKey1ID, ColumnEncryptionKey1) + stringKr1.PutKey(ColumnEncryptionKey2ID, ColumnEncryptionKey2) + stringKr1.PutKey(ColumnEncryptionKey3ID, ColumnEncryptionKey3) + decryptProps := parquet.NewFileDecryptionProperties(parquet.WithKeyRetriever(stringKr1)) + + tests := []struct { + name string + dataPageVersion parquet.DataPageVersion + bufferedStream bool + compression compress.Compression + }{ + {"DataPageV2_BufferedRead", parquet.DataPageV2, true, compress.Codecs.Uncompressed}, + {"DataPageV2_DirectRead", parquet.DataPageV2, false, compress.Codecs.Uncompressed}, + {"DataPageV2_BufferedRead_Compressed", parquet.DataPageV2, true, compress.Codecs.Snappy}, + {"DataPageV2_DirectRead_Compressed", parquet.DataPageV2, false, compress.Codecs.Snappy}, + // {"DataPageV1_BufferedRead", parquet.DataPageV1, true, compress.Codecs.Uncompressed}, + // {"DataPageV1_DirectRead", parquet.DataPageV1, false, compress.Codecs.Uncompressed}, + // {"DataPageV1_BufferedRead_Compressed", parquet.DataPageV1, true, compress.Codecs.Snappy}, + // {"DataPageV1_DirectRead_Compressed", parquet.DataPageV1, false, compress.Codecs.Snappy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writerProps := parquet.NewWriterProperties( + parquet.WithDataPageVersion(tt.dataPageVersion), + parquet.WithEncryptionProperties(encryptProps.Clone("")), + parquet.WithCompression(tt.compression), + ) + readProps := parquet.NewReaderProperties(nil) + readProps.FileDecryptProps = decryptProps.Clone("") + readProps.BufferedStreamEnabled = tt.bufferedStream + checkDecryptedValues(t, writerProps, readProps) + }) + } +} + func BenchmarkReadInt32Column(b *testing.B) { // generate parquet with RLE-dictionary encoded int32 column tempdir := b.TempDir()