From 7418314d1d02ae7133cff2d76f26822c9fff66c6 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 20 Apr 2026 13:24:30 -0500 Subject: [PATCH] Remove collect_common_factors_on_increment --- loopy/__init__.py | 2 - loopy/transform/arithmetic.py | 305 +--------------------------------- 2 files changed, 7 insertions(+), 300 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index 911f239e5..671ff436c 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -124,7 +124,6 @@ from loopy.tools import Optional, clear_in_mem_caches, memoize_on_disk, t_unit_to_python from loopy.transform.add_barrier import add_barrier from loopy.transform.arithmetic import ( - collect_common_factors_on_increment, fold_constants, ) from loopy.transform.batch import to_batched @@ -315,7 +314,6 @@ "change_arg_to_image", "chunk_iname", "clear_in_mem_caches", - "collect_common_factors_on_increment", "concatenate_arrays", "duplicate_inames", "expand_subst", diff --git a/loopy/transform/arithmetic.py b/loopy/transform/arithmetic.py index 8ad7d658b..f63f60c0e 100644 --- a/loopy/transform/arithmetic.py +++ b/loopy/transform/arithmetic.py @@ -24,15 +24,19 @@ """ -from loopy.diagnostic import LoopyError -from loopy.kernel import LoopKernel +from typing import TYPE_CHECKING + from loopy.translation_unit import for_each_kernel +if TYPE_CHECKING: + from loopy.kernel import LoopKernel + + # {{{ fold constants @for_each_kernel -def fold_constants(kernel): +def fold_constants(kernel: LoopKernel): from loopy.symbolic import ConstantFoldingMapper cfm = ConstantFoldingMapper() @@ -52,299 +56,4 @@ def fold_constants(kernel): # }}} -# {{{ collect_common_factors_on_increment - -# thus far undocumented -@for_each_kernel -def collect_common_factors_on_increment(kernel, var_name, vary_by_axes=()): - assert isinstance(kernel, LoopKernel) - # FIXME: Does not understand subst rules for now - if kernel.substitutions: - from loopy.transform.subst import expand_subst - kernel = expand_subst(kernel) - - if var_name in kernel.temporary_variables: - var_descr = kernel.temporary_variables[var_name] - elif var_name in kernel.arg_dict: - var_descr = kernel.arg_dict[var_name] - else: - raise NameError("array '%s' was not found" % var_name) - - # {{{ check/normalize vary_by_axes - - if isinstance(vary_by_axes, str): - vary_by_axes = vary_by_axes.split(",") - - from loopy.kernel.array import ArrayBase - if isinstance(var_descr, ArrayBase): - if var_descr.dim_names is not None: - name_to_index = { - name: idx - for idx, name in enumerate(var_descr.dim_names)} - else: - name_to_index = {} - - def map_ax_name_to_index(ax): - if isinstance(ax, str): - try: - return name_to_index[ax] - except KeyError: - raise LoopyError("axis name '%s' not understood " % ax) from None - else: - return ax - - vary_by_axes = [map_ax_name_to_index(ax) for ax in vary_by_axes] - - if ( - vary_by_axes - and - (min(vary_by_axes) < 0 - or - max(vary_by_axes) > var_descr.num_user_axes())): - raise LoopyError("vary_by_axes refers to out-of-bounds axis index") - - # }}} - - from pymbolic.mapper.substitutor import make_subst_func - from pymbolic.primitives import ( - Product, - Subscript, - Sum, - Variable, - flattened_product, - flattened_sum, - is_zero, - ) - - from loopy.symbolic import ( - SubstitutionMapper, - UnidirectionalUnifier, - get_dependencies, - ) - - # {{{ common factor key list maintenance - - # list of (index_key, common factors found) - common_factors = [] - - def find_unifiable_cf_index(index_key): - for i, (key, _val) in enumerate(common_factors): - unif = UnidirectionalUnifier( - lhs_mapping_candidates=get_dependencies(key)) - - unif_result = unif(key, index_key) - - if unif_result: - assert len(unif_result) == 1 - return i, unif_result[0] - - return None, None - - def extract_index_key(access_expr): - if isinstance(access_expr, Variable): - return () - - elif isinstance(access_expr, Subscript): - index = access_expr.index_tuple - return tuple(index[ax] for ax in vary_by_axes) - else: - raise ValueError("unexpected type of access_expr") - - def is_assignee(insn): - return var_name in insn.assignee_var_names() - - def iterate_as(cls, expr): - if isinstance(expr, cls): - yield from expr.children - else: - yield expr - - # }}} - - # {{{ find common factors - - from loopy.kernel.data import Assignment - - for insn in kernel.instructions: - if not is_assignee(insn): - continue - - if not isinstance(insn, Assignment): - raise LoopyError("'%s' modified by non-single-assignment" - % var_name) - - lhs = insn.assignee - rhs = insn.expression - - if is_zero(rhs): - continue - - index_key = extract_index_key(lhs) - cf_index, unif_result = find_unifiable_cf_index(index_key) - - if cf_index is None: - # {{{ doesn't exist yet - - assert unif_result is None - - my_common_factors = None - - for term in iterate_as(Sum, rhs): - if term == lhs: - continue - - for part in iterate_as(Product, term): - if var_name in get_dependencies(part): - raise LoopyError("unexpected dependency on '%s' " - "in RHS of instruction '%s'" - % (var_name, insn.id)) - - product_parts = set(iterate_as(Product, term)) - - if my_common_factors is None: - my_common_factors = product_parts - else: - my_common_factors = my_common_factors & product_parts - - if my_common_factors is not None: - common_factors.append((index_key, my_common_factors)) - - # }}} - else: - # {{{ match, filter existing common factors - - _, my_common_factors = common_factors[cf_index] - - unif_subst_map = SubstitutionMapper( - make_subst_func(unif_result.lmap)) - - for term in iterate_as(Sum, rhs): - if term == lhs: - continue - - for part in iterate_as(Product, term): - if var_name in get_dependencies(part): - raise LoopyError("unexpected dependency on '%s' " - "in RHS of instruction '%s'" - % (var_name, insn.id)) - - product_parts = set(iterate_as(Product, term)) - - my_common_factors = { - cf for cf in my_common_factors - if unif_subst_map(cf) in product_parts} - - common_factors[cf_index] = (index_key, my_common_factors) - - # }}} - - # }}} - - common_factors = [ - (ik, cf) for ik, cf in common_factors - if cf] - - if not common_factors: - raise LoopyError("no common factors found") - - # {{{ remove common factors - - new_insns = [] - - for insn in kernel.instructions: - if not isinstance(insn, Assignment) or not is_assignee(insn): - new_insns.append(insn) - continue - - index_key = extract_index_key(insn.assignee) - - lhs = insn.assignee - rhs = insn.expression - - if is_zero(rhs): - new_insns.append(insn) - continue - - index_key = extract_index_key(lhs) - cf_index, unif_result = find_unifiable_cf_index(index_key) - - if cf_index is None: - new_insns.append(insn) - continue - - _, my_common_factors = common_factors[cf_index] - - unif_subst_map = SubstitutionMapper( - make_subst_func(unif_result.lmap)) - - mapped_my_common_factors = { - unif_subst_map(cf) - for cf in my_common_factors} - - new_sum_terms = [] - - for term in iterate_as(Sum, rhs): - if term == lhs: - new_sum_terms.append(term) - continue - - new_sum_terms.append( - flattened_product([ - part - for part in iterate_as(Product, term) - if part not in mapped_my_common_factors - ])) - - new_insns.append( - insn.copy(expression=flattened_sum(new_sum_terms))) - - # }}} - - # {{{ substitute common factors into usage sites - - def find_substitution(expr): - if isinstance(expr, Subscript): - v = expr.aggregate.name - elif isinstance(expr, Variable): - v = expr.name - else: - return expr - - if v != var_name: - return expr - - index_key = extract_index_key(expr) - cf_index, unif_result = find_unifiable_cf_index(index_key) - - unif_subst_map = SubstitutionMapper( - make_subst_func(unif_result.lmap)) - - _, my_common_factors = common_factors[cf_index] - - if my_common_factors is not None: - return flattened_product( - [unif_subst_map(cf) for cf in my_common_factors] - + [expr]) - else: - return expr - - insns = new_insns - new_insns = [] - - subm = SubstitutionMapper(find_substitution) - - for insn in insns: - if not isinstance(insn, Assignment) or is_assignee(insn): - new_insns.append(insn) - continue - - new_insns.append(insn.with_transformed_expressions(subm)) - - # }}} - - return kernel.copy(instructions=new_insns) - -# }}} - - # vim: foldmethod=marker