From 356534091be81c44841be8f1723cb7e0cb360995 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 10 May 2026 11:29:33 -0700 Subject: [PATCH 1/3] pyopencl fake np: Cast then/else to a common dtype. pyopencl.array does not allow array branches with unequal dtypes. --- arraycontext/impl/pyopencl/fake_numpy.py | 32 +++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 20a2d8cb..8de9904b 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -460,9 +460,39 @@ def absolute(self, a): # {{{ sorting, searching, and counting def where(self, criterion, then, else_): - def where_inner(inner_crit, inner_then, inner_else): + + def where_inner( + inner_crit: ArrayOrScalar, + inner_then: ArrayOrScalar, + inner_else: ArrayOrScalar, + ) -> ArrayOrScalar: if isinstance(inner_crit, bool | np.bool_): return inner_then if inner_crit else inner_else + + # pyopencl's if_positive does not support then, else branches with + # unequal dtypes -> cast them to a common dtype. + inner_then_dtype = ( + inner_then.dtype + if isinstance(inner_then, cl_array.Array) + else np.dtype(type(inner_then)) + ) + inner_else_dtype = ( + inner_else.dtype + if isinstance(inner_else, cl_array.Array) + else np.dtype(type(inner_else)) + ) + dtype = np.promote_types(inner_then_dtype, inner_else_dtype) + inner_then = ( + inner_then.astype(dtype) + if isinstance(inner_then, cl_array.Array) + else dtype.type(inner_then) + ) + inner_else = ( + inner_else.astype(dtype) + if isinstance(inner_else, cl_array.Array) + else dtype.type(inner_else) + ) + return cl_array.if_positive(inner_crit != 0, inner_then, inner_else, queue=self._array_context.queue) From 766d2dc17ca4aca0b5a50c61212b5eb29cc2f24a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 10 May 2026 11:33:49 -0700 Subject: [PATCH 2/3] Account non-contiguous arys in PyOpenCLActx.to_numpy. --- arraycontext/impl/pyopencl/__init__.py | 48 ++++++++++++++++++++++++-- test/test_arraycontext.py | 19 ++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 653bb6a4..06753c52 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -37,6 +37,8 @@ import numpy as np from typing_extensions import Self, override +from pytools import memoize_method + from arraycontext.container.traversal import ( rec_map_array_container, rec_map_container, @@ -62,7 +64,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping - from numpy.typing import NDArray + from numpy.typing import DTypeLike, NDArray import loopy as lp import pyopencl as cl @@ -263,12 +265,54 @@ def to_numpy(self, array: Array) -> np.ndarray: def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: ... + @memoize_method + def _get_to_numpy_noncontiguous_copy_kernel( + self, dtype: DTypeLike, ndim: int + ) -> lp.TranslationUnit: + """ + Returns a translation unit containing a loopy kernel that: + + - Accepts a PyOpenCL array ``inp`` with per-axis strides exposed as + ``s0, s1, ..., s{ndim-1}``. + - Produces a contiguous, row-major (C-order) output array ``output`` of + the same shape, with elements copied from the corresponding + coordinates in ``input``. + """ + + import loopy as lp + + from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS + + t_unit = lp.make_copy_kernel( + ["c"] * ndim, [f"stride:s{i}" for i in range(ndim)] + ) + t_unit = lp.add_dtypes(t_unit, {"input": dtype}) + new_args = [ + *t_unit.default_entrypoint.args, + *[lp.ValueArg(f"s{i}", dtype=np.uint64) for i in range(ndim)], + ] + t_unit = t_unit.with_kernel(t_unit.default_entrypoint.copy(args=new_args)) + t_unit = lp.set_options(t_unit, _DEFAULT_LOOPY_OPTIONS) + return t_unit + @override def to_numpy(self, array: ArrayOrContainerOrScalar ) -> NumpyOrContainerOrScalar: def _to_numpy(ary): - return ary.get(queue=self.queue) + if ary.flags.forc: + # pyopencl supports host transfers only for contiguous arrays. + return ary.get(queue=self.queue) + + result = self.call_loopy( + self._get_to_numpy_noncontiguous_copy_kernel(ary.dtype, ary.ndim), + input=ary, + **{ + f"s{i}": stride // ary.dtype.itemsize + for i, stride in enumerate(ary.strides) + }, + )["output"] + return result.get(queue=self.queue) return with_array_context( self._rec_map_container(_to_numpy, array), diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 9948d71f..de143985 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1659,6 +1659,25 @@ def test_linspace(actx_factory: ArrayContextFactory, args, kwargs): assert np.allclose(actx_linspace, np_linspace) +# {{{ test_to_numpy_transpose + +def test_to_numpy_transpose(actx_factory: ArrayContextFactory): + # fails prior to for + # pyopencl actx -- cl_array.Array.transpose generates non-contiguous + # arrays requiring non-trivial logic for to host copies. + actx = actx_factory() + rng = np.random.default_rng() + np_ary = rng.random((256, 256, 256)) + ary = actx.from_numpy(np_ary) + axis_perm = (0, 2, 1) + + np.testing.assert_allclose( + actx.to_numpy(actx.np.transpose(ary, axis_perm)), + np.transpose(np_ary, axis_perm)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: From 2a87adf78e2a62ed26fda4eda440629b996f45b1 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 10 May 2026 12:39:56 -0700 Subject: [PATCH 3/3] Update baseline to overcome some typing issues in pyopencl. --- .basedpyright/baseline.json | 198 +++++++++++++++++++++--------------- 1 file changed, 115 insertions(+), 83 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 7ce3ca80..8e075edb 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -3215,6 +3215,22 @@ "lineCount": 3 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 17, + "endColumn": 36, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 17, + "endColumn": 31, + "lineCount": 1 + } + }, { "code": "reportUnknownParameterType", "range": { @@ -3242,16 +3258,104 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 19, - "endColumn": 26, + "startColumn": 15, + "endColumn": 24, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 15, + "endColumn": 29, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 23, + "endColumn": 30, "lineCount": 1 } }, { "code": "reportUnknownVariableType", "range": { - "startColumn": 19, - "endColumn": 44, + "startColumn": 23, + "endColumn": 48, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 61, + "endColumn": 70, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 61, + "endColumn": 70, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 72, + "endColumn": 80, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 72, + "endColumn": 80, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 39, + "endColumn": 48, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 39, + "endColumn": 57, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 27, + "endColumn": 33, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 47, + "endColumn": 58, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 47, + "endColumn": 58, "lineCount": 1 } }, @@ -3264,7 +3368,7 @@ } }, { - "code": "reportUnknownArgumentType", + "code": "reportArgumentType", "range": { "startColumn": 36, "endColumn": 45, @@ -4714,79 +4818,15 @@ } }, { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 12, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 24, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 36, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 36, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 48, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 48, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 23, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", + "code": "reportReturnType", "range": { - "startColumn": 40, - "endColumn": 55, - "lineCount": 1 + "startColumn": 19, + "endColumn": 52, + "lineCount": 2 } }, { - "code": "reportUnknownArgumentType", + "code": "reportArgumentType", "range": { "startColumn": 57, "endColumn": 67, @@ -4794,7 +4834,7 @@ } }, { - "code": "reportUnknownArgumentType", + "code": "reportArgumentType", "range": { "startColumn": 69, "endColumn": 79, @@ -4808,14 +4848,6 @@ "endColumn": 80, "lineCount": 1 } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } } ], "./arraycontext/impl/pyopencl/taggable_cl_array.py": [