From 3a1a1d7250937ad59c97c1d09c376e9711991014 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 17 Dec 2025 15:00:29 +0000 Subject: [PATCH 1/4] Implement __setitem__ for Cubed arrays --- .github/workflows/array-api-tests.yml | 2 - cubed/core/array.py | 5 + cubed/core/indexing.py | 119 +++++++++++++++++++++++- cubed/core/ops.py | 126 ++++++++++++++++++++++++++ cubed/tests/test_indexing.py | 38 ++++++++ 5 files changed, 287 insertions(+), 3 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 51cc6c41d..a163a3f1f 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -123,7 +123,6 @@ jobs: array_api_tests/test_signatures.py::test_func_signature[take_along_axis] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - array_api_tests/test_signatures.py::test_array_method_signature[__setitem__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] # edge case failures (https://github.com/cubed-dev/cubed/issues/420) @@ -144,7 +143,6 @@ jobs: array_api_tests/test_manipulation_functions.py::TestExpandDims::test_expand_dims_tuples # not implemented - array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_indexing_functions.py::test_take_along_axis diff --git a/cubed/core/array.py b/cubed/core/array.py index 88e05af52..b853d6ab0 100644 --- a/cubed/core/array.py +++ b/cubed/core/array.py @@ -244,6 +244,11 @@ def __getitem__(self: T_ChunkedArray, key, /) -> T_ChunkedArray: return index(self, key) + def __setitem__(self: T_ChunkedArray, key, value, /) -> None: + from cubed.core.indexing import setitem + + setitem(self, key, value) + def __repr__(self): return f"cubed.core.CoreArray<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>" diff --git a/cubed/core/indexing.py b/cubed/core/indexing.py index b0721745b..d1d7af121 100644 --- a/cubed/core/indexing.py +++ b/cubed/core/indexing.py @@ -8,7 +8,14 @@ from cubed.backend_array_api import backend_array_to_numpy_array from cubed.core.array import CoreArray -from cubed.core.ops import general_blockwise, map_selection, merge_chunks +from cubed.core.ops import ( + _create_zarr_indexer, + general_blockwise, + map_blocks, + map_selection, + map_selection_update, + merge_chunks, +) from cubed.primitive.blockwise import ChunkKey, FunctionArgs from cubed.utils import array_size, normalize_chunks @@ -289,3 +296,113 @@ def shape(self) -> tuple[int, ...]: The number of blocks per axis. """ return self.array.numblocks + + +def setitem(x, key, value, /) -> None: + from cubed import Array + + if isinstance(value, Array) and value.size == 1: + value = as_pyscalar(value) + + if isinstance(value, (bool, int, float, complex)): + out = setitem_scalar(x, key, value) + else: + out = setitem_array(x, key, value) + + # mutate the array + x._plan = out._plan + + +def as_pyscalar(x): + # based on https://github.com/data-apis/array-api/issues/815 + + import cubed + + if x.size != 1: + raise ValueError("Can't convert array with size!=1 to a python scalar") + + axes = tuple(i for i, a in enumerate(x.shape) if a == 1) + if len(axes) > 0: + x = cubed.squeeze(x, axis=axes) + if cubed.isdtype(x.dtype, "real floating"): + return float(x) + elif cubed.isdtype(x.dtype, "complex floating"): + return complex(x) + elif cubed.isdtype(x.dtype, "integral"): + return int(x) + elif cubed.isdtype(x.dtype, "bool"): + return bool(x) + else: + raise ValueError(f"Can't convert array with dtype {x.dtype} to a python scalar") + + +def setitem_scalar(source: "Array", key, value): + """Set scalar value on Zarr array indexing by key.""" + + from cubed import Array + + # check that value is a scalar, so we don't have to worry about chunk selection, broadcasting, etc + if isinstance(value, Array): + raise NotImplementedError("Only scalar values are supported for set") + + chunks = source.chunks + idx = ndindex.ndindex(key) + idx = idx.expand(source.shape) + selection = idx.raw + indexer = _create_zarr_indexer(selection, source.shape, source.chunksize) + output_blocks = map( + lambda chunk_projection: list(chunk_projection[0]), list(indexer) + ) + chunk_selections = {cp.chunk_coords: cp.chunk_selection for cp in indexer} + + return map_blocks( + _setitem_scalar, + source, + dtype=source.dtype, + chunks=chunks, + output_blocks=output_blocks, + value=value, + chunk_selections=chunk_selections, + ) + + +def _setitem_scalar(a, value=None, chunk_selections=None, block_id=None): + a[chunk_selections[block_id]] = value + return a + + +def setitem_array(source: "Array", key, value): + """Set value on Zarr array indexing by key.""" + + idx = ndindex.ndindex(key) + idx = idx.expand(source.shape) + selection = idx.raw + indexer = _create_zarr_indexer(selection, source.shape, source.chunksize) + chunk_selections = {cp.chunk_coords: cp.chunk_selection for cp in indexer} + chunk_out_selections = {cp.chunk_coords: cp.out_selection for cp in indexer} + + def selection_function(out_key): + out_coords = out_key.coords + return chunk_out_selections[out_coords] + + max_num_input_blocks = 1 # TODO + + out = map_selection_update( + _setitem_array, + selection_function, + value, + source, + source.shape, + source.dtype, + source.chunks, + max_num_input_blocks=max_num_input_blocks, + chunk_selections=chunk_selections, + ) + + return out + + +def _setitem_array(a, out, chunk_selections=None, block_id=None): + if block_id in chunk_selections: + out[chunk_selections[block_id]] = a + return out diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 80aa3f6c1..b02754aad 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -708,6 +708,54 @@ def _assemble_index_chunk( return out +def _assemble_index_chunk_update( + arrays, + y, + dtype=None, + func=None, + selection_function=None, + in_shape=None, + in_chunksize=None, + block_id=None, + **kwargs, +): + assert not isinstance(arrays, list), ( + "index expects an iterator of array blocks, not a list" + ) + + try: + # compute the selection on x required to get the relevant chunk for out_coords + out_coords = block_id + in_sel = selection_function(ChunkKey("out", out_coords)) + + # use a Zarr indexer to convert this to input coordinates + indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize) + + shape = indexer.shape + out = nxp.empty(shape, dtype=dtype) + except KeyError: + # TODO: better way than this + shape = (0, 0) + out = nxp.empty(shape, dtype=dtype) + + if array_size(shape) > 0: + _, lchunk_selection, lout_selection, *_ = zip(*indexer) + for ai, chunk_select, out_select in zip( + arrays, lchunk_selection, lout_selection + ): + if IS_IMMUTABLE_ARRAY: + out = out.at[out_select].set(ai[chunk_select]) + else: + out[out_select] = ai[chunk_select] + + if func is not None: + if has_keyword(func, "block_id"): + out = func(out, y, block_id=block_id, **kwargs) + else: + out = func(out, y, **kwargs) + return out + + def map_selection( func, selection_function, @@ -774,6 +822,84 @@ def back_key_function(out_key: ChunkKey) -> FunctionArgs[Iterator[ChunkKey]]: return out +def map_selection_update( + func, + selection_function, + x, + y, + shape, + dtype, + chunks, + max_num_input_blocks, + **kwargs, +) -> "Array": + """ + Apply a function to selected subsets of an input array using standard NumPy indexing notation. + + Parameters + ---------- + func : callable + Function to apply to every block to produce the output array. + Must accept ``block_id`` as a keyword argument (with same meaning as for ``map_blocks``). + selection_function : callable + A function that maps an output chunk key to one or more selections on the input array. + x: Array + The input array. + y: Array + The input array. + shape : tuple + Shape of the output array. + dtype : np.dtype + The ``dtype`` of the output array. + chunks : tuple + Chunk shape of blocks in the output array. + max_num_input_blocks : int + The maximum number of input blocks read from the input array. + """ + + def back_key_function(out_key: ChunkKey) -> FunctionArgs[Iterator[ChunkKey]]: + # compute the selection on x required to get the relevant chunk for out_key + try: + in_sel = selection_function(out_key) + + # use a Zarr indexer to convert selection to input coordinates + indexer = _create_zarr_indexer(in_sel, x.shape, x.chunksize) + + return FunctionArgs( + iter(tuple(ChunkKey(x.name, cp.chunk_coords) for cp in indexer)), + ChunkKey(y.name, out_key.coords), + output_name=out_key.name, + ) + except KeyError: + return FunctionArgs( + iter([]), + ChunkKey(y.name, out_key.coords), + output_name=out_key.name, + ) + + num_input_blocks = (max_num_input_blocks, 1) + + out = general_blockwise( + _assemble_index_chunk_update, + back_key_function, + x, + y, + shapes=[shape], + dtypes=[dtype], + chunkss=[chunks], + extra_func_kwargs=dict(func=func, dtype=x.dtype), + num_input_blocks=num_input_blocks, + selection_function=selection_function, + in_shape=x.shape, + in_chunksize=x.chunksize, + **kwargs, + ) + from cubed import Array + + assert isinstance(out, Array) # single output + return out + + def map_blocks( func, *args: "Array", diff --git a/cubed/tests/test_indexing.py b/cubed/tests/test_indexing.py index 9acfdc50f..7bc890906 100644 --- a/cubed/tests/test_indexing.py +++ b/cubed/tests/test_indexing.py @@ -109,3 +109,41 @@ def test_blocks(): x.blocks[[0, 1], [0, 1]] with pytest.raises(IndexError, match="out of bounds"): x.blocks[100, 100] + + +def test_setitem_scalar(spec): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(3, 3), + spec=spec, + ) + a[:, 1] = -1 + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + x[:, 1] = -1 + assert_array_equal(a.compute(), x) + + +def test_setitem_single_element_array(spec): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(3, 3), + spec=spec, + ) + b = xp.asarray([-1], spec=spec) + a[:, 1] = b + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + x[:, 1] = -1 + assert_array_equal(a.compute(), x) + + +def test_setitem_array(spec): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(3, 3), + spec=spec, + ) + b = xp.asarray([-1, -1, -1, -1], chunks=(3,), spec=spec) + a[:, 1] = b + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + x[:, 1] = -1 + assert_array_equal(a.compute(), x) From df7779b06ccb40f115a3597867fbc44759a3c4d4 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 10 Apr 2026 10:28:43 +0100 Subject: [PATCH 2/4] Handle setitem for JAX using `at` --- cubed/core/indexing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cubed/core/indexing.py b/cubed/core/indexing.py index d1d7af121..fd1f023f0 100644 --- a/cubed/core/indexing.py +++ b/cubed/core/indexing.py @@ -6,7 +6,7 @@ import ndindex import numpy as np -from cubed.backend_array_api import backend_array_to_numpy_array +from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, backend_array_to_numpy_array from cubed.core.array import CoreArray from cubed.core.ops import ( _create_zarr_indexer, @@ -404,5 +404,8 @@ def selection_function(out_key): def _setitem_array(a, out, chunk_selections=None, block_id=None): if block_id in chunk_selections: - out[chunk_selections[block_id]] = a + if IS_IMMUTABLE_ARRAY: + out = out.at[chunk_selections[block_id]].set(a) + else: + out[chunk_selections[block_id]] = a return out From bee67b846200ff7b1926043aba38a0b8f5f9b945 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 10 Apr 2026 10:32:06 +0100 Subject: [PATCH 3/4] Suppress mypy warning --- cubed/core/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index b02754aad..0daeaef6f 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -867,13 +867,13 @@ def back_key_function(out_key: ChunkKey) -> FunctionArgs[Iterator[ChunkKey]]: return FunctionArgs( iter(tuple(ChunkKey(x.name, cp.chunk_coords) for cp in indexer)), - ChunkKey(y.name, out_key.coords), + ChunkKey(y.name, out_key.coords), # type: ignore output_name=out_key.name, ) except KeyError: return FunctionArgs( iter([]), - ChunkKey(y.name, out_key.coords), + ChunkKey(y.name, out_key.coords), # type: ignore output_name=out_key.name, ) From 0bd232c0c1d1623f3d2c0a470438c2d230dc025c Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 10 Apr 2026 11:13:32 +0100 Subject: [PATCH 4/4] Handle setitem for JAX using `at` --- cubed/core/indexing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cubed/core/indexing.py b/cubed/core/indexing.py index fd1f023f0..24b219865 100644 --- a/cubed/core/indexing.py +++ b/cubed/core/indexing.py @@ -367,7 +367,10 @@ def setitem_scalar(source: "Array", key, value): def _setitem_scalar(a, value=None, chunk_selections=None, block_id=None): - a[chunk_selections[block_id]] = value + if IS_IMMUTABLE_ARRAY: + a = a.at[chunk_selections[block_id]].set(value) + else: + a[chunk_selections[block_id]] = value return a