Skip to content

Commit 3ec0c0a

Browse files
puneetdixit200Deepak kudi
authored andcommitted
Fix v2 structured field selection
1 parent 6ce787d commit 3ec0c0a

3 files changed

Lines changed: 27 additions & 2 deletions

File tree

changes/3983.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed field selection for Zarr v2 arrays with structured dtypes.

src/zarr/core/array.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5408,6 +5408,7 @@ async def _get_selection(
54085408

54095409
# check fields are sensible
54105410
out_dtype = check_fields(fields, dtype)
5411+
has_fields = bool(fields)
54115412

54125413
# setup output buffer
54135414
if out is not None:
@@ -5425,6 +5426,11 @@ async def _get_selection(
54255426
dtype=out_dtype,
54265427
order=order,
54275428
)
5429+
read_buffer = (
5430+
prototype.nd_buffer.empty(shape=indexer.shape, dtype=dtype, order=order)
5431+
if has_fields
5432+
else out_buffer
5433+
)
54285434
if product(indexer.shape) > 0:
54295435
# need to use the order from the metadata for v2
54305436
_config = config
@@ -5457,7 +5463,7 @@ async def _get_selection(
54575463
)
54585464
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexed_chunks
54595465
],
5460-
out_buffer,
5466+
read_buffer,
54615467
drop_axes=indexer.drop_axes,
54625468
)
54635469
if _config.read_missing_chunks is False:
@@ -5475,6 +5481,8 @@ async def _get_selection(
54755481
f"missing chunks with the fill value.\n"
54765482
f"Missing chunks:\n{chunks_str}"
54775483
)
5484+
if has_fields:
5485+
out_buffer[...] = read_buffer[fields]
54785486
if isinstance(indexer, BasicIndexer) and indexer.shape == ():
54795487
return out_buffer.as_scalar()
54805488
return out_buffer.as_ndarray_like()

tests/test_v2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Any, Literal
3+
from typing import Any, Literal, cast
44

55
import numpy as np
66
import pytest
@@ -252,6 +252,22 @@ def test_structured_dtype_roundtrip(fill_value: float | bytes, tmp_path: Path) -
252252
assert (a == za[:]).all()
253253

254254

255+
def test_structured_dtype_field_selection(tmp_path: Path) -> None:
256+
a = np.array(
257+
[("Rex", 9, 81.0), ("Fido", 3, 27.0)],
258+
dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")],
259+
)
260+
za = zarr.create_array(store=tmp_path / "data.zarr", data=a, zarr_format=2)
261+
262+
np.testing.assert_array_equal(cast(Any, za)["name"], a["name"])
263+
np.testing.assert_array_equal(
264+
za.get_basic_selection(slice(1, None), fields="age"), a["age"][1:]
265+
)
266+
np.testing.assert_array_equal(
267+
za.get_basic_selection(Ellipsis, fields=["name", "age"]), a[["name", "age"]]
268+
)
269+
270+
255271
@pytest.mark.parametrize(
256272
(
257273
"fill_value",

0 commit comments

Comments
 (0)