Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ jobs:
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
array_api_tests/test_signatures.py::test_array_method_signature[__setitem__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]

# edge case failures (https://github.com/cubed-dev/cubed/issues/420)
Expand All @@ -144,7 +143,6 @@ jobs:
array_api_tests/test_manipulation_functions.py::TestExpandDims::test_expand_dims_tuples

# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_indexing_functions.py::test_take_along_axis
Expand Down
5 changes: 5 additions & 0 deletions cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def __getitem__(self: T_ChunkedArray, key, /) -> T_ChunkedArray:

return index(self, key)

def __setitem__(self: T_ChunkedArray, key, value, /) -> None:
from cubed.core.indexing import setitem

setitem(self, key, value)

def __repr__(self):
return f"cubed.core.CoreArray<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>"

Expand Down
127 changes: 125 additions & 2 deletions cubed/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
import ndindex
import numpy as np

from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, backend_array_to_numpy_array
from cubed.core.array import CoreArray
from cubed.core.ops import general_blockwise, map_selection, merge_chunks
from cubed.core.ops import (
_create_zarr_indexer,
general_blockwise,
map_blocks,
map_selection,
map_selection_update,
merge_chunks,
)
from cubed.primitive.blockwise import ChunkKey, FunctionArgs
from cubed.utils import array_size, normalize_chunks

Expand Down Expand Up @@ -289,3 +296,119 @@ def shape(self) -> tuple[int, ...]:
The number of blocks per axis.
"""
return self.array.numblocks


def setitem(x, key, value, /) -> None:
from cubed import Array

if isinstance(value, Array) and value.size == 1:
value = as_pyscalar(value)

if isinstance(value, (bool, int, float, complex)):
out = setitem_scalar(x, key, value)
else:
out = setitem_array(x, key, value)

# mutate the array
x._plan = out._plan


def as_pyscalar(x):
# based on https://github.com/data-apis/array-api/issues/815

import cubed

if x.size != 1:
raise ValueError("Can't convert array with size!=1 to a python scalar")

axes = tuple(i for i, a in enumerate(x.shape) if a == 1)
if len(axes) > 0:
x = cubed.squeeze(x, axis=axes)
if cubed.isdtype(x.dtype, "real floating"):
return float(x)
elif cubed.isdtype(x.dtype, "complex floating"):
return complex(x)
elif cubed.isdtype(x.dtype, "integral"):
return int(x)
elif cubed.isdtype(x.dtype, "bool"):
return bool(x)
else:
raise ValueError(f"Can't convert array with dtype {x.dtype} to a python scalar")


def setitem_scalar(source: "Array", key, value):
"""Set scalar value on Zarr array indexing by key."""

from cubed import Array

# check that value is a scalar, so we don't have to worry about chunk selection, broadcasting, etc
if isinstance(value, Array):
raise NotImplementedError("Only scalar values are supported for set")

chunks = source.chunks
idx = ndindex.ndindex(key)
idx = idx.expand(source.shape)
selection = idx.raw
indexer = _create_zarr_indexer(selection, source.shape, source.chunksize)
output_blocks = map(
lambda chunk_projection: list(chunk_projection[0]), list(indexer)
)
chunk_selections = {cp.chunk_coords: cp.chunk_selection for cp in indexer}

return map_blocks(
_setitem_scalar,
source,
dtype=source.dtype,
chunks=chunks,
output_blocks=output_blocks,
value=value,
chunk_selections=chunk_selections,
)


def _setitem_scalar(a, value=None, chunk_selections=None, block_id=None):
if IS_IMMUTABLE_ARRAY:
a = a.at[chunk_selections[block_id]].set(value)
else:
a[chunk_selections[block_id]] = value
return a


def setitem_array(source: "Array", key, value):
"""Set value on Zarr array indexing by key."""

idx = ndindex.ndindex(key)
idx = idx.expand(source.shape)
selection = idx.raw
indexer = _create_zarr_indexer(selection, source.shape, source.chunksize)
chunk_selections = {cp.chunk_coords: cp.chunk_selection for cp in indexer}
chunk_out_selections = {cp.chunk_coords: cp.out_selection for cp in indexer}

def selection_function(out_key):
out_coords = out_key.coords
return chunk_out_selections[out_coords]

max_num_input_blocks = 1 # TODO

out = map_selection_update(
_setitem_array,
selection_function,
value,
source,
source.shape,
source.dtype,
source.chunks,
max_num_input_blocks=max_num_input_blocks,
chunk_selections=chunk_selections,
)

return out


def _setitem_array(a, out, chunk_selections=None, block_id=None):
if block_id in chunk_selections:
if IS_IMMUTABLE_ARRAY:
out = out.at[chunk_selections[block_id]].set(a)
else:
out[chunk_selections[block_id]] = a
return out
126 changes: 126 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,54 @@ def _assemble_index_chunk(
return out


def _assemble_index_chunk_update(
arrays,
y,
dtype=None,
func=None,
selection_function=None,
in_shape=None,
in_chunksize=None,
block_id=None,
**kwargs,
):
assert not isinstance(arrays, list), (
"index expects an iterator of array blocks, not a list"
)

try:
# compute the selection on x required to get the relevant chunk for out_coords
out_coords = block_id
in_sel = selection_function(ChunkKey("out", out_coords))

# use a Zarr indexer to convert this to input coordinates
indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize)

shape = indexer.shape
out = nxp.empty(shape, dtype=dtype)
except KeyError:
# TODO: better way than this
shape = (0, 0)
out = nxp.empty(shape, dtype=dtype)

if array_size(shape) > 0:
_, lchunk_selection, lout_selection, *_ = zip(*indexer)
for ai, chunk_select, out_select in zip(
arrays, lchunk_selection, lout_selection
):
if IS_IMMUTABLE_ARRAY:
out = out.at[out_select].set(ai[chunk_select])
else:
out[out_select] = ai[chunk_select]

if func is not None:
if has_keyword(func, "block_id"):
out = func(out, y, block_id=block_id, **kwargs)
else:
out = func(out, y, **kwargs)
return out


def map_selection(
func,
selection_function,
Expand Down Expand Up @@ -774,6 +822,84 @@ def back_key_function(out_key: ChunkKey) -> FunctionArgs[Iterator[ChunkKey]]:
return out


def map_selection_update(
func,
selection_function,
x,
y,
shape,
dtype,
chunks,
max_num_input_blocks,
**kwargs,
) -> "Array":
"""
Apply a function to selected subsets of an input array using standard NumPy indexing notation.

Parameters
----------
func : callable
Function to apply to every block to produce the output array.
Must accept ``block_id`` as a keyword argument (with same meaning as for ``map_blocks``).
selection_function : callable
A function that maps an output chunk key to one or more selections on the input array.
x: Array
The input array.
y: Array
The input array.
shape : tuple
Shape of the output array.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
Chunk shape of blocks in the output array.
max_num_input_blocks : int
The maximum number of input blocks read from the input array.
"""

def back_key_function(out_key: ChunkKey) -> FunctionArgs[Iterator[ChunkKey]]:
# compute the selection on x required to get the relevant chunk for out_key
try:
in_sel = selection_function(out_key)

# use a Zarr indexer to convert selection to input coordinates
indexer = _create_zarr_indexer(in_sel, x.shape, x.chunksize)

return FunctionArgs(
iter(tuple(ChunkKey(x.name, cp.chunk_coords) for cp in indexer)),
ChunkKey(y.name, out_key.coords), # type: ignore
output_name=out_key.name,
)
except KeyError:
return FunctionArgs(
iter([]),
ChunkKey(y.name, out_key.coords), # type: ignore
output_name=out_key.name,
)

num_input_blocks = (max_num_input_blocks, 1)

out = general_blockwise(
_assemble_index_chunk_update,
back_key_function,
x,
y,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
extra_func_kwargs=dict(func=func, dtype=x.dtype),
num_input_blocks=num_input_blocks,
selection_function=selection_function,
in_shape=x.shape,
in_chunksize=x.chunksize,
**kwargs,
)
from cubed import Array

assert isinstance(out, Array) # single output
return out


def map_blocks(
func,
*args: "Array",
Expand Down
38 changes: 38 additions & 0 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,41 @@ def test_blocks():
x.blocks[[0, 1], [0, 1]]
with pytest.raises(IndexError, match="out of bounds"):
x.blocks[100, 100]


def test_setitem_scalar(spec):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(3, 3),
spec=spec,
)
a[:, 1] = -1
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
x[:, 1] = -1
assert_array_equal(a.compute(), x)


def test_setitem_single_element_array(spec):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(3, 3),
spec=spec,
)
b = xp.asarray([-1], spec=spec)
a[:, 1] = b
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
x[:, 1] = -1
assert_array_equal(a.compute(), x)


def test_setitem_array(spec):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(3, 3),
spec=spec,
)
b = xp.asarray([-1, -1, -1, -1], chunks=(3,), spec=spec)
a[:, 1] = b
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
x[:, 1] = -1
assert_array_equal(a.compute(), x)
Loading