From 3ec75a82d5a777397a11e3bcafa21294d85204c2 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 26 Jun 2026 22:12:22 +0900 Subject: [PATCH 1/2] [Frontend] Carry gather/scatter offset as an explicit transfer descriptor Replace the affine.apply{indirect_access} symbol smuggle with an explicit offset descriptor. convert_indirect_indexing returns the offset spad instead of folding sympy.Symbol(out) into the index; emit_transfer carries it as a togsim.transfer operand; decompose_transfer lifts that operand to a memref.dma_start {indirect_offset = @spad_symbol} attribute (memref.dma_start is a registered op and rejects an extra operand, but accepts the attribute); lower_dma_to_gemmini reads the attribute and resolves the global for CONFIG4 (drops _find_indirect); build_skeleton adds the offset spad to the gather DMA read_bufs so the offset-build -> gather dependency edge forms in the trace. The index stays clean (base only). Validated on both paths: Spike functional (computed-index gather + scatter allclose, pointwise/reduce regression 0) and the C++ trace timing path end-to-end (allclose; gather togsim_dma read_bufs now includes the offset spad). Indirect addressing (scattered-DMA timing) in the new trace path is a separate gap tracked in issue #284. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01EEfyUDpkMLRYZ2NAMbb3jN --- .../mlir/mlir_codegen_backend.py | 37 +++++++++---------- .../mlir/passes/build_skeleton.py | 5 +++ .../mlir/passes/decompose_transfer.py | 7 +++- .../mlir/passes/lower_dma_to_gemmini.py | 35 +++++++----------- 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index a7bc914e..0b8330ce 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -566,7 +566,7 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE return index def load(self, name: str, index: sympy.Expr): - index, _ = self.convert_indirect_indexing(index) + index, offset_desc = self.convert_indirect_indexing(index) padding = self.get_padding_type() # Extract dram info @@ -591,7 +591,7 @@ def load(self, name: str, index: sympy.Expr): compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) + dram_shape, tile_shape, dram_stride, tile_stride, int(padding), offset=offset_desc) self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching with self.override_buffer_cse(buffer=self.loads): @@ -602,6 +602,7 @@ def load(self, name: str, index: sympy.Expr): def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs): dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + offset_desc = None # Handle scatter store if "tmp" in str(index): @@ -613,7 +614,7 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) if mode == "atomic_add": loaded_value = ops.load(name, index) value = ops.add(loaded_value, value) - index, _ = self.convert_indirect_indexing(index) + index, offset_desc = self.convert_indirect_indexing(index) dram_var = self.kernel_group.args.output(name) # Prepare dma instruction @@ -655,7 +656,7 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) # Generate DMA instruction code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, dram_stride, tile_stride, 0) + dram_shape, tile_shape, dram_stride, tile_stride, 0, offset=offset_desc) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -1358,7 +1359,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, dram_shape, tile_shape, dram_stride, tile_stride, padding, - subtile_size=None, async_type=None): + subtile_size=None, async_type=None, offset=None): """Emit a generic togsim.transfer op for a DMA whose access exceeds the 4D Gemmini descriptor limit. Carries the full N-D access (dram/tile strides + shapes) plus the SSA operands a memref.dma_start needs @@ -1399,12 +1400,16 @@ def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtyp if subtile_size: av = int(async_type) if async_type is not None else 1 attrs += f', subtile_size = {list(subtile_size)}, async = {av} : i64' - # operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride - return ( - f'"togsim.transfer"(%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, ' - f'%{tag}, %{dma_type}, %{vst}) {{{attrs}}} : ' - f'({dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index) -> ()' - ) + # operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride [, offset spad] + operands = (f'%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, ' + f'%{tag}, %{dma_type}, %{vst}') + optypes = f'{dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index' + if offset is not None: # indirect: per-position offset spad (decompose lifts it to a symbol attr) + offset_buf, offset_type = offset + operands += f', %{offset_buf}' + optypes += f', {offset_type}' + attrs += ', indirect = true' + return f'"togsim.transfer"({operands}) {{{attrs}}} : ({optypes}) -> ()' def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] @@ -1547,13 +1552,7 @@ def convert_indirect_indexing(self, index :sympy.Expr): spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] - mlir_dtype = vshape.split("x")[1][:-1] with self.override_buffer_cse(buffer=self.dma_loads): ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) - - mlir_dtype = self.var_info[spad_vars[first_dim]][1] - with self.override_buffer_cse(buffer=self.dma_loads): - out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape) - if mlir_dtype != "index": - out = ops.index_cast(out, "index") - return index + sympy.Symbol(str(out)), compute_dependecy + # Clean base index + the offset spad as an explicit transfer descriptor + return index, (sram_var, tile_shape) diff --git a/PyTorchSimFrontend/mlir/passes/build_skeleton.py b/PyTorchSimFrontend/mlir/passes/build_skeleton.py index 4c3d89cb..4174231d 100644 --- a/PyTorchSimFrontend/mlir/passes/build_skeleton.py +++ b/PyTorchSimFrontend/mlir/passes/build_skeleton.py @@ -441,6 +441,11 @@ def _emit_one_dma(ctx, op, node, builder, bufs, tags): spad_id = bufs.of(dep._global_of(f["src"] if node.is_write else f["dst"])) read_bufs = [spad_id] if node.is_write else [] write_bufs = [] if node.is_write else [spad_id] + if "indirect_offset" in op.attributes: # gather/scatter reads the offset spad -> dep on its build + from mlir.ir import FlatSymbolRefAttr + off_id = bufs.of(FlatSymbolRefAttr(op.attributes["indirect_offset"]).value) + if off_id not in read_bufs: + read_bufs = read_bufs + [off_id] tag_id = tags.bind(_value_key(f["tag"]), spad_id) _emit_dma(ctx, node, tag_id, dram_index, tag_index, read_bufs, write_bufs) diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py index 10b2edfb..dfb67bb5 100644 --- a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -91,7 +91,10 @@ def run(module, vectorlane=128, **_): targets.append(op.operation) for op in targets: - dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op.operands + op_operands = list(op.operands) + dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op_operands[:7] + # indirect: offset spad operand -> lift to a symbol attr (memref.dma_start can't take the operand) + offset_sym = op_operands[7].owner.attributes["name"] if len(op_operands) > 7 else None kind = op.attributes["dma_kind"].value # StringAttr -> "MVIN"/"MVOUT" vlane_axis = IntegerAttr(op.attributes["vlane_split_axis"]).value dram_stride = _int_array(op.attributes["dram_stride"]) @@ -127,6 +130,8 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at if st_attr is not None: attrs["subtile_size"] = st_attr attrs["async"] = async_attr + if offset_sym is not None: + attrs["indirect_offset"] = offset_sym Operation.create( "memref.dma_start", results=[], operands=operands, attributes=attrs) diff --git a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py index 998a6db5..51204b50 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py +++ b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py @@ -58,12 +58,18 @@ def run(module, timing=False): memref.dma_wait is erased in both modes (matches C++ DmaWaitOpLowering). """ from mlir.ir import (InsertionPoint, Operation, IntegerType, IndexType, - IntegerAttr, MemRefType) + IntegerAttr, MemRefType, FlatSymbolRefAttr, TypeAttr) from mlir.dialects import llvm, arith, memref i64 = IntegerType.get_signless(64) idx = IndexType.get() + # memref.global symbol -> type, to resolve the indirect_offset spad + sym2type = {} + for g in module.operation.regions[0].blocks[0].operations: + if g.operation.name == "memref.global": + sym2type[g.attributes["sym_name"].value] = MemRefType(TypeAttr(g.attributes["type"]).value) + def const_int(val): return IntegerAttr(val.owner.attributes["value"]).value @@ -119,9 +125,8 @@ def elem_addr_i64(memref_val, indices, mtype, elem_bytes): is_mvin = dma_type in (MVIN, MVIN2, MVIN3) elem_bytes = _elem_bytes(src_ty.element_type) - # Indirect (gather): the gather-side indices are src for mvin, dst for mvout. - gather_idx = src_idx if is_mvin else dst_idx - indirect, indirect_memref = _find_indirect(gather_idx) + # Indirect (gather): offset spad referenced by the indirect_offset symbol attr + indirect = "indirect_offset" in op.attributes tile_shape = _subtile(op) if tile_shape is None: @@ -155,9 +160,12 @@ def elem_addr_i64(memref_val, indices, mtype, elem_bytes): i64_const((spad4[2] << 32) | (spad4[3] & 0xFFFFFFFF))) if indirect: # CONFIG4: rs1 = indirect index-spad base address, rs2 = (elem_size<<16)|stride(1) + offset_sym = FlatSymbolRefAttr(op.attributes["indirect_offset"]).value + off_ty = sym2type[offset_sym] + indirect_memref = memref.GetGlobalOp(off_ty, offset_sym).result ind_base = memref.ExtractAlignedPointerAsIndexOp(indirect_memref).result ind_addr = arith.IndexCastOp(i64, ind_base).result - ind_esize = _elem_bytes(MemRefType(indirect_memref.type).element_type) + ind_esize = _elem_bytes(off_ty.element_type) asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (1 & 0xFFFF))) asm(dma_type, dram_addr, spad_addr) op.erase() @@ -189,23 +197,6 @@ def _elem_bytes(elem_type): return max(bits, 8) // 8 -def _find_indirect(indices): - """If a gather index is an affine.apply{indirect_access} whose operands include - index_cast(affine.load(%spad)), return (True, %spad memref); else (False, None).""" - for idx in indices: - ap = idx.owner - if getattr(ap, "name", None) != "affine.apply" or "indirect_access" not in ap.attributes: - continue - for operand in ap.operands: - ic = operand.owner - if getattr(ic, "name", None) != "arith.index_cast": - continue - ld = ic.operands[0].owner - if getattr(ld, "name", None) == "affine.load": - return True, ld.operands[0] # affine.load operand 0 == the index spad memref - return False, None - - def lower_text(text): if OP_NAME not in text: return text From c7553e5c340e9c0231ca843394cc194ebf24eeaa Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 26 Jun 2026 23:12:56 +0900 Subject: [PATCH 2/2] [Frontend] Clean up indirect access: symbol-set detection, CONFIG4 stride, compute-loop multi-dim offset Three refinements on top of the explicit offset descriptor: 1. Detect indirect access by an explicit symbol set instead of an "tmp" substring match. indirect_indexing now records str(index_var) in self.indirect_symbols; a _has_indirect(expr) helper tests the index free_symbols against it; the former "tmp"-string sites (store, get_dma_info, convert entry/indirect_dims/stride) use it. Removes the now-dead _find_indirect from lower_dma_to_gemmini. 2. Single indirect dim: pass the raw index spad and let the MVIN apply the gather stride per position (CONFIG4 offset_stride) instead of a vector_load/muli/vector_store round-trip that pre-multiplied the stride. emit_transfer carries offset_stride; decompose copies the attribute; lower programs CONFIG4 with it (was hardcoded to 1). 3. Multi indirect dim (e.g. x[ix, iy]): the offset is a sum of strided indices, which a single CONFIG4 channel cannot do, so the sum stays in the kernel -- but build it in the compute loop (chunked by compute_vec_size, not a single tile-wide vector) and store it to a DEDICATED offset spad so an index that is live elsewhere (x[ix, iy] + ix) is not clobbered. push_step separates the offset-build loop from the gather that reads it. Adds multi-dim gather and index-reuse cases to test_indirect_access. Validated: indirect/scatter/embedding + the two new multi-dim cases pass; add/matmul/softmax regression 0. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01EEfyUDpkMLRYZ2NAMbb3jN --- .../mlir/mlir_codegen_backend.py | 66 ++++++++++++------- .../mlir/passes/decompose_transfer.py | 1 + .../mlir/passes/lower_dma_to_gemmini.py | 3 +- tests/ops/misc/test_indirect_access.py | 27 ++++++++ 4 files changed, 74 insertions(+), 23 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 0b8330ce..db9875c9 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -384,6 +384,7 @@ def __init__(self, kernel_group, reason=None): self.welford_reduce_out = None self.reduce_iterator = {} self.spad_buffer_dict = dict() + self.indirect_symbols = set() # CSE-var names bound as indirect indices self.base_vector_initialized = False self.loop_size = None @@ -605,7 +606,7 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) offset_desc = None # Handle scatter store - if "tmp" in str(index): + if self._has_indirect(index): # Convert the output buffer type to the inplace buffer arg_name = V.graph.scheduler.mutation_real_name.get(name, name) if arg_name not in self.kernel_group.args.inplace_buffers: @@ -788,8 +789,12 @@ def store_reduction(self, name, index, value): self.reductions_suffix.writeline(common.DeferredLine(name, code)) def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): + self.indirect_symbols.add(str(index_var)) # record the bound indirect symbol return str(index_var) + def _has_indirect(self, expr): + return any(s.name in self.indirect_symbols for s in expr.free_symbols) + def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): # In case of index expr, dimension size should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): @@ -1227,7 +1232,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe """ # Use loads as default if buffer is None: - buffer = self.applys if "tmp" not in str(index) else self.dma_loads + buffer = self.applys if not self._has_indirect(index) else self.dma_loads # TODO. kg_tile_desc = self.kernel_group.tile_desc @@ -1237,7 +1242,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe total_dims = [int(str(i)[5:]) for i in self.itervars] local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop - indirect_syms = [s for s in index.free_symbols if "tmp" in s.name] + indirect_syms = [s for s in index.free_symbols if s.name in self.indirect_symbols] index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True) indirect_dims = [f"{i}" for i in indirect_syms] @@ -1405,10 +1410,10 @@ def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtyp f'%{tag}, %{dma_type}, %{vst}') optypes = f'{dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index' if offset is not None: # indirect: per-position offset spad (decompose lifts it to a symbol attr) - offset_buf, offset_type = offset + offset_buf, offset_type, offset_stride = offset operands += f', %{offset_buf}' optypes += f', {offset_type}' - attrs += ', indirect = true' + attrs += f', indirect = true, offset_stride = {int(offset_stride)} : i64' return f'"togsim.transfer"({operands}) {{{attrs}}} : ({optypes}) -> ()' def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): @@ -1490,7 +1495,7 @@ def get_mask(self): return mask_shape, mask_var def convert_indirect_indexing(self, index :sympy.Expr): - if "tmp" not in str(index): + if not self._has_indirect(index): return index, None # Note: In case of indirect indexing, dimensions should be divisible by tile size @@ -1501,12 +1506,11 @@ def convert_indirect_indexing(self, index :sympy.Expr): raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") # Process start - indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)] + indirect_dims = [str(dim) for dim in index.free_symbols if str(dim) in self.indirect_symbols] indirect_dims.sort() first_dim = indirect_dims[0] spad_vars = dict() compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims]) - # Store each newly-produced indirect index into spad, in its producing step for target_dim in indirect_dims: if target_dim in self.spad_buffer_dict: @@ -1529,17 +1533,30 @@ def convert_indirect_indexing(self, index :sympy.Expr): if compute_dependecy: self.push_step() - # Build the offset (outer ops) in the current step, reading indices back from spad + # Single indirect dim: the raw index IS the offset; the MVIN applies offset_stride (CONFIG4) + if len(indirect_dims) == 1: + offset_stride = 1 + for arg in list(index.args): + if not self._has_indirect(arg): + continue + if arg.is_Mul and arg.args[0].is_number: + offset_stride = int(arg.args[0]) + index = index.replace(arg, 0) + sram_var, _, _, _, tile_shape, _ = self.spad_buffer_dict[first_dim] + return index, (sram_var, tile_shape, offset_stride) + + # Multi indirect dim: sum the strided indices in the compute loop (chunked by compute_vec_size) + local_tile_desc = self.kernel_group.tile_desc + compute_vec_size = local_tile_desc.get_compute_vec_size() for target_dim in indirect_dims: - sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] + sram_var, _, _, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] mlir_dtype = vshape.split("x")[1][:-1] - with self.override_buffer_cse(buffer=self.dma_loads): - out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape) - spad_vars[target_dim] = out - - with self.override_buffer_cse(buffer=self.dma_loads): + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + with self.override_buffer_cse(buffer=self.loads): + spad_vars[target_dim] = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) + with self.override_buffer_cse(buffer=self.compute): for arg in index.args: - if "tmp" not in str(arg): + if not self._has_indirect(arg): continue if arg.is_Mul and arg.args[0].is_number: coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1] @@ -1550,9 +1567,14 @@ def convert_indirect_indexing(self, index :sympy.Expr): if dim == first_dim: continue spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) - - sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] - with self.override_buffer_cse(buffer=self.dma_loads): - ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) - # Clean base index + the offset spad as an explicit transfer descriptor - return index, (sram_var, tile_shape) + # Summed offset goes to a DEDICATED spad (not an index buffer) to avoid clobbering a live index + var_info = [v for k, v in self.var_info.items() if str(k) == first_dim][0] + dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]] + off_shape = local_tile_desc.get_mlir_shape(var_info[1]) + off_sram, off_index = self.get_scratchpad_buffer( + dtype, "indirect_offset_" + first_dim, local_tile_desc, "indirect_offset_" + first_dim) + off_compute_index = ",".join(off_index.split(",")[:-1] + [f"%{self.compute_idx}"]) + with self.override_buffer_cse(buffer=self.stores): + ops._store(spad_vars[first_dim], off_sram, off_compute_index, off_shape) + self.push_step() # offset-build compute loop must finish before the gather reads it + return index, (off_sram, off_shape, 1) diff --git a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py index dfb67bb5..0bf04e30 100644 --- a/PyTorchSimFrontend/mlir/passes/decompose_transfer.py +++ b/PyTorchSimFrontend/mlir/passes/decompose_transfer.py @@ -132,6 +132,7 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at attrs["async"] = async_attr if offset_sym is not None: attrs["indirect_offset"] = offset_sym + attrs["offset_stride"] = op.attributes["offset_stride"] Operation.create( "memref.dma_start", results=[], operands=operands, attributes=attrs) diff --git a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py index 51204b50..5ca842c1 100644 --- a/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py +++ b/PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py @@ -166,7 +166,8 @@ def elem_addr_i64(memref_val, indices, mtype, elem_bytes): ind_base = memref.ExtractAlignedPointerAsIndexOp(indirect_memref).result ind_addr = arith.IndexCastOp(i64, ind_base).result ind_esize = _elem_bytes(off_ty.element_type) - asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (1 & 0xFFFF))) + off_stride = IntegerAttr(op.attributes["offset_stride"]).value + asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (off_stride & 0xFFFF))) asm(dma_type, dram_addr, spad_addr) op.erase() diff --git a/tests/ops/misc/test_indirect_access.py b/tests/ops/misc/test_indirect_access.py index f64fe50d..ae3a4ed4 100644 --- a/tests/ops/misc/test_indirect_access.py +++ b/tests/ops/misc/test_indirect_access.py @@ -52,6 +52,31 @@ def scatter_only(out, token_indices, weighted_output): res = opt_fn(out, token_indices, weighted_output) test_result("ScatterAdd(index_add_)", res, cpu_out) +def test_multidim_indirect(device, size=(64, 64), n=256): + torch.manual_seed(0) + def gather2d(x, ix, iy): + return x[ix, iy] + 1.0 + x = torch.randn(size, dtype=torch.float32).to(device=device) + ix = torch.randint(0, size[0], [n]).to(device=device) + iy = torch.randint(0, size[1], [n]).to(device=device) + opt_fn = torch.compile(dynamic=False)(gather2d) + res = opt_fn(x, ix, iy) + out = gather2d(x.cpu(), ix.cpu(), iy.cpu()) + test_result("Multi-dim Indirect (x[ix,iy])", res, out) + +def test_multidim_indirect_index_reuse(device, size=(64, 64), n=256): + torch.manual_seed(0) + def gather_reuse(x, ix, iy): + # ix is reused after the gather -> the offset spad must not clobber the index spad + return x[ix, iy] + ix.float() + x = torch.randn(size, dtype=torch.float32).to(device=device) + ix = torch.randint(0, size[0], [n]).to(device=device) + iy = torch.randint(0, size[1], [n]).to(device=device) + opt_fn = torch.compile(dynamic=False)(gather_reuse) + res = opt_fn(x, ix, iy) + out = gather_reuse(x.cpu(), ix.cpu(), iy.cpu()) + test_result("Multi-dim Indirect index reuse (x[ix,iy]+ix)", res, out) + def test_scatter_full(device, size=(128, 128)): def vectoradd(a, idx, b): a[idx, :] = b @@ -71,4 +96,6 @@ def vectoradd(a, idx, b): test_scatter_full(device, size=(2048, 2048)) test_scatter_add(device) test_indirect_vectoradd(device) + test_multidim_indirect(device) + test_multidim_indirect_index_reuse(device) #test_embedding(device, 1024, 2048) \ No newline at end of file