diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..21a99f7e8f 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,10 +7,15 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Generic, List, TypeAlias, TypeGuard, TypeVar +from typing import Any, Callable, Generic, List, TypeAlias, TypeGuard, TypeVar from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import FunCall as GTFNIRFunCall +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import ( + Expr as GTFNIRExpr, + SymRef as GTFNIRSymRef, +) _Fun = TypeVar("_Fun", bound=itir.Expr) @@ -44,8 +49,8 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[_FunCallToSymRe assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument if isinstance(fun, str): return ( - isinstance(node, itir.FunCall) - and isinstance(node.fun, itir.SymRef) + isinstance(node, itir.FunCall | GTFNIRFunCall) + and isinstance(node.fun, itir.SymRef | GTFNIRSymRef) and node.fun.id == fun ) else: @@ -135,3 +140,14 @@ def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef] ): return True return False + + +def is_tuple_expr_of( + pred: Callable[[Any], bool], + expr: itir.Expr | GTFNIRExpr, +) -> bool: + if is_call_to(expr, "make_tuple"): + return all(is_tuple_expr_of(pred, arg) for arg in expr.args) + if is_call_to(expr, "tuple_get"): + return is_tuple_expr_of(pred, expr.args[1]) + return pred(expr) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py new file mode 100644 index 0000000000..2918404c51 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -0,0 +1,132 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.common import OffsetProvider +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import trace_shifts + + +@dataclasses.dataclass(frozen=True) +class CheckInOutField(PreserveLocationVisitor, NodeTranslator): + """ + Checks within a SetAt if any fields which are written to are also read with an offset and raises a ValueError in this case. + + Example: + >>> from gt4py.next.iterator.transforms import infer_domain + >>> from gt4py.next.type_system import type_specifications as ts + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) + >>> offset_provider = {"IOff": IDim} + >>> cartesian_domain = im.call("cartesian_domain")( + ... im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5) + ... ) + >>> ir = itir.Program( + ... id="test", + ... function_definitions=[], + ... params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + ... im.ref("inout") + ... ), + ... domain=cartesian_domain, + ... target=im.ref("inout"), + ... ), + ... ], + ... ) + >>> CheckInOutField.apply(ir, offset_provider=offset_provider) + Traceback (most recent call last): + ... + ValueError: The target inout is also read with an offset. + """ + + @classmethod + def apply( + cls, + program: itir.Program, + offset_provider: common.OffsetProvider | common.OffsetProviderType, + ): + return cls().visit(program, offset_provider=offset_provider) + + def visit_SetAt(self, node: itir.SetAt, **kwargs) -> itir.SetAt: + offset_provider = kwargs["offset_provider"] + + def extract_subexprs(expr): + """Return a list of all subexpressions in expr.args, including expr itself.""" + subexprs = [expr] + if isinstance(expr, itir.FunCall): + for arg in expr.args: + subexprs.extend(extract_subexprs(arg)) + return subexprs + + def visit_nested_make_tuple_tuple_get(expr): + """Recursively visit make_tuple and tuple_get expr and check all as_fieldop subexpressions.""" + if cpm.is_applied_as_fieldop(expr): + check_expr(expr.fun, expr.args, offset_provider) + elif cpm.is_call_to(expr, ("make_tuple", "tuple_get")): + for arg in expr.args: + visit_nested_make_tuple_tuple_get(arg) + + def filter_shifted_args( + shifts: list[set[tuple[itir.OffsetLiteral, ...]]], + args: list[itir.Expr], + offset_provider: OffsetProvider, + ) -> list[itir.Expr]: + """ + Filters out trivial shifts (empty or all horizontal/vertical with zero offset) + and returns filtered shifts and corresponding args. + """ + filtered = [ + arg + for shift, arg in zip(shifts, args) + if shift not in (set(), {()}) + and any( + offset_provider[off.value].kind # type: ignore[index] # mypy not smart enough + not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL} + or val.value != 0 + for off, val in ( + (pair for pair in shift if len(pair) == 2) # set case: skip () + if isinstance(shift, set) + else zip(shift[0::2], shift[1::2]) # tuple/list case + ) + ) + ] + return filtered if filtered else [] + + def check_expr( + fun: itir.FunCall, + args: list[itir.Expr], + offset_provider: OffsetProvider, + ) -> None: + shifts = trace_shifts.trace_stencil(fun.args[0], num_args=len(args)) + + shifted_args = filter_shifted_args(shifts, args, offset_provider) + target_subexprs = extract_subexprs(node.target) + for arg in shifted_args: + arg_subexprs = extract_subexprs(arg) + for subexpr in arg_subexprs: + if subexpr in target_subexprs: + raise ValueError(f"The target {node.target} is also read with an offset.") + if not cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.SymRef), arg): + raise ValueError( + f"Unexpected as_fieldop argument {arg}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first." + ) + + if cpm.is_applied_as_fieldop(node.expr): + check_expr(node.expr.fun, node.expr.args, offset_provider) + else: # Account for nested im.make_tuple and im.tuple_get + visit_nested_make_tuple_tuple_get(node.expr) + return node diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4b3a258396..26679d9c80 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -46,14 +46,6 @@ def _merge_arguments( return new_args -def _is_tuple_expr_of_literals(expr: itir.Expr): - if cpm.is_call_to(expr, "make_tuple"): - return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) - if cpm.is_call_to(expr, "tuple_get"): - return _is_tuple_expr_of_literals(expr.args[1]) - return isinstance(expr, itir.Literal) - - def _inline_as_fieldop_arg( arg: itir.Expr, *, uids: eve_utils.UIDGenerator ) -> tuple[itir.Expr, dict[str, itir.Expr]]: @@ -142,7 +134,7 @@ def fuse_as_fieldop( # transform scalar `if` into per-grid-point `if` # TODO(tehrengruber): revisit if we want to inline if_ arg = im.op_as_fieldop("if_")(*arg.args) - elif _is_tuple_expr_of_literals(arg): + elif cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: raise NotImplementedError() @@ -189,7 +181,7 @@ def fuse_as_fieldop( def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool: - if _is_tuple_expr_of_literals(node): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), node): return True if ( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e8ecdedc8e..2763cf061a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,6 +12,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + check_inout_field, concat_where, dead_code_elimination, fuse_as_fieldop, @@ -92,6 +93,8 @@ def apply_common_transforms( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = check_inout_field.CheckInOutField.apply(ir, offset_provider=offset_provider) + ir = remove_broadcast.RemoveBroadcast.apply(ir) ir = concat_where.transform_to_as_fieldop(ir) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 174ca8edb1..802dcf6a94 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -126,6 +126,7 @@ class CompiledProgramsPool: definition_stage: ffront_stages.ProgramDefinition program_type: ts_ffront.ProgramType static_params: Sequence[str] | None = None # not ordered + static_domain_sizes: bool = False _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index f7445461c0..e2ffecc9b1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,12 +8,13 @@ from __future__ import annotations -from typing import Callable, ClassVar, Optional, Union +from typing import ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common from gt4py.next.iterator import builtins +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef @@ -97,25 +98,6 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: - if ( - isinstance(expr, FunCall) - and isinstance(expr.fun, SymRef) - and expr.fun.id == "tuple_get" - and len(expr.args) == 2 - and _is_tuple_expr_of(pred, expr.args[1]) - ): - return True - if ( - isinstance(expr, FunCall) - and isinstance(expr.fun, SymRef) - and expr.fun.id == "make_tuple" - and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) - ): - return True - return pred(expr) - - class SidComposite(Expr): values: list[Expr] @@ -125,7 +107,7 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of( + or cpm.is_tuple_expr_of( lambda expr: isinstance(expr, (SymRef, Literal)) or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), el, @@ -139,9 +121,9 @@ def _values_validator( def _might_be_scalar_expr(expr: Expr) -> bool: if isinstance(expr, BinaryExpr): - return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + return all(cpm.is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) if isinstance(expr, UnaryExpr): - return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + return cpm.is_tuple_expr_of(_might_be_scalar_expr, expr.expr) if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) @@ -150,7 +132,7 @@ def _might_be_scalar_expr(expr: Expr) -> bool: return all(_might_be_scalar_expr(arg) for arg in expr.args) if isinstance(expr, CastExpr): return _might_be_scalar_expr(expr.obj_expr) - if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): return True return False @@ -183,7 +165,7 @@ def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] ) -> None: for inp in inputs: - if not _is_tuple_expr_of( + if not cpm.is_tuple_expr_of( lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) or ( isinstance(expr, FunCall) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index a445390583..c73869e73c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -56,27 +56,6 @@ _horizontal_dimension = "gtfn::unstructured::dim::horizontal" -def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "tuple_get" - and len(expr.args) == 2 - and _is_tuple_of_ref_or_literal(expr.args[1]) - ): - return True - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "make_tuple" - and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) - ): - return True - if isinstance(expr, (itir.SymRef, itir.Literal)): - return True - return False - - def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -587,13 +566,12 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - if _is_tuple_of_ref_or_literal(node.expr): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, (itir.SymRef, itir.Literal)), node.expr): node.expr = im.as_fieldop("deref", node.domain)(node.expr) itir_projector, extracted_expr = ir_utils_misc.extract_projector(node.expr) projector = self.visit(itir_projector, **kwargs) if itir_projector is not None else None node.expr = extracted_expr - assert cpm.is_applied_as_fieldop(node.expr), node.expr stencil = node.expr.fun.args[0] domain = node.domain diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py new file mode 100644 index 0000000000..c2bd0e48e3 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py @@ -0,0 +1,378 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Optional + +import pytest +from next_tests.toy_connectivity import e2v_conn + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.check_inout_field import CheckInOutField +from gt4py.next.type_system import type_specifications as ts + +float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +offset_provider = {"IOff": IDim} +i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) +cartesian_domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 5)}) + + +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=params, + declarations=declarations or [], + body=body, + ) + + +def test_check_inout_no_offset(): + # inout ← (⇑deref)(inout) + ir = program_factory( + params=[im.sym("inout", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + # Should not raise + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_zero_offset(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))(inout) + ir = program_factory( + params=[im.sym("inout", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( + im.ref("inout") + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + # Should not raise + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_e2v_zero_offset(): + # inout ← (⇑(λ(x) → ·⟪E2Vₒ, 0ₒ⟫(x)))(inout) + offset_provider = {"E2V": e2v_conn} # override + ir = program_factory( + params=[im.sym("inout", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 0)("x"))))( + im.ref("inout") + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_offset(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout) + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_shift_different_field(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 1ₒ⟫(y)))(inout, in); + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")) + ) + ) + )(im.ref("inout"), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_as_fieldop_arg(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))((⇑deref)(inout)) + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.as_fieldop(im.ref("deref"))(im.ref("inout")) + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises( + ValueError, + match=r"Unexpected as_fieldop argument \(⇑deref\)\(inout\). Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first.", + ): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_two_fields(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))((⇑deref)(inout), in) + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) + ) + ) + )(im.make_tuple(im.ref("inout"), im.ref("in"))), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_tuple(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(x)))({inout, in}) + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) + ) + ) + )(im.make_tuple(im.ref("inout"), im.ref("in"))), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_make_tuple_as_fieldop_in_arg(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(x)))({(⇑deref)(inout), in}) + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) + ) + ) + )(im.make_tuple(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in"))), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises( + ValueError, + match=r"Unexpected as_fieldop argument \{\(⇑deref\)\(inout\), in\}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first.", + ): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple(): + # {inout, inout2} ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(inout2)} + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("inout2", i_field_type)], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + im.as_fieldop(im.ref("deref"))(im.ref("inout2")), + ), + domain=cartesian_domain, + target=im.make_tuple(im.ref("inout"), im.ref("inout2")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target {inout, inout2} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_as_fieldop(): + # {inout, out} ← (⇑(λ(x, y) → {·⟪IOffₒ, 1ₒ⟫(x), ·y}))(inout, in) + ir = program_factory( + params=[ + im.sym("inout", i_field_type), + im.sym("in", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.make_tuple(im.deref(im.shift("IOff", 1)("x")), im.deref("y")) + ) + )(im.ref("inout"), im.ref("in")), + domain=cartesian_domain, + target=im.make_tuple(im.ref("inout"), im.ref("out")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target {inout, out} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get_make_tuple(): + # inout ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), as_fieldop(...)}[0] + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.tuple_get( + 0, + im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get(): + # inout ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(in)} + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[i_field_type] * 2)), + im.sym("in", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.ref("inout")) + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_tuple_get(): + # {inout[0], inout2} ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(inout2)} + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[i_field_type] * 2)), + im.sym("inout2", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.ref("inout")) + ), + im.as_fieldop(im.ref("deref"))(im.ref("inout2")), + ), + domain=cartesian_domain, + target=im.make_tuple(im.tuple_get(0, im.ref("inout")), im.ref("inout2")), + ), + ], + ) + + with pytest.raises( + ValueError, match="The target {inout\[0\], inout2} is also read with an offset." + ): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get_tuple(): + # inout[0] ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0][0]), (⇑deref)(in)} + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[ts.TupleType(types=[i_field_type] * 2)] * 2)), + im.sym("in", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.tuple_get(0, im.ref("inout"))) + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + domain=cartesian_domain, + target=im.tuple_get(0, im.ref("inout")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout\[0\] is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider)