Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(self, kernel_group, reason=None):
self.welford_reduce_out = None
self.reduce_iterator = {}
self.spad_buffer_dict = dict()
self.indirect_symbols = set() # CSE-var names bound as indirect indices
self.base_vector_initialized = False
self.loop_size = None

Expand Down Expand Up @@ -566,7 +567,7 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE
return index

def load(self, name: str, index: sympy.Expr):
index, _ = self.convert_indirect_indexing(index)
index, offset_desc = self.convert_indirect_indexing(index)
padding = self.get_padding_type()

# Extract dram info
Expand All @@ -591,7 +592,7 @@ def load(self, name: str, index: sympy.Expr):
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])

code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
dram_shape, tile_shape, dram_stride, tile_stride, int(padding))
dram_shape, tile_shape, dram_stride, tile_stride, int(padding), offset=offset_desc)
self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching

with self.override_buffer_cse(buffer=self.loads):
Expand All @@ -602,9 +603,10 @@ def load(self, name: str, index: sympy.Expr):
def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs):
dtype = V.graph.get_dtype(name)
mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]
offset_desc = None

# Handle scatter store
if "tmp" in str(index):
if self._has_indirect(index):
# Convert the output buffer type to the inplace buffer
arg_name = V.graph.scheduler.mutation_real_name.get(name, name)
if arg_name not in self.kernel_group.args.inplace_buffers:
Expand All @@ -613,7 +615,7 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs)
if mode == "atomic_add":
loaded_value = ops.load(name, index)
value = ops.add(loaded_value, value)
index, _ = self.convert_indirect_indexing(index)
index, offset_desc = self.convert_indirect_indexing(index)
dram_var = self.kernel_group.args.output(name)

# Prepare dma instruction
Expand Down Expand Up @@ -655,7 +657,7 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs)

# Generate DMA instruction
code = self.emit_transfer("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
dram_shape, tile_shape, dram_stride, tile_stride, 0)
dram_shape, tile_shape, dram_stride, tile_stride, 0, offset=offset_desc)
self.dma_stores.writeline(common.DeferredLine(name, code))

def reduction(self, dtype, src_dtype, reduction_type, value):
Expand Down Expand Up @@ -787,8 +789,12 @@ def store_reduction(self, name, index, value):
self.reductions_suffix.writeline(common.DeferredLine(name, code))

def indirect_indexing(self, index_var, size, check=True, wrap_neg=True):
self.indirect_symbols.add(str(index_var)) # record the bound indirect symbol
return str(index_var)

def _has_indirect(self, expr):
return any(s.name in self.indirect_symbols for s in expr.free_symbols)

def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index):
# In case of index expr, dimension size should be divisible by tile size
if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges):
Expand Down Expand Up @@ -1226,7 +1232,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
"""
# Use loads as default
if buffer is None:
buffer = self.applys if "tmp" not in str(index) else self.dma_loads
buffer = self.applys if not self._has_indirect(index) else self.dma_loads

# TODO.
kg_tile_desc = self.kernel_group.tile_desc
Expand All @@ -1236,7 +1242,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
total_dims = [int(str(i)[5:]) for i in self.itervars]
local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane)
local_dims.sort() # Assume that smaller index is placed in the outer loop
indirect_syms = [s for s in index.free_symbols if "tmp" in s.name]
indirect_syms = [s for s in index.free_symbols if s.name in self.indirect_symbols]
index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True)
indirect_dims = [f"{i}" for i in indirect_syms]

Expand Down Expand Up @@ -1358,7 +1364,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype,
dram_var, dram_index_var, sram_var, sram_index_var,
dram_shape, tile_shape, dram_stride, tile_stride, padding,
subtile_size=None, async_type=None):
subtile_size=None, async_type=None, offset=None):
"""Emit a generic togsim.transfer op for a DMA whose access exceeds the
4D Gemmini descriptor limit. Carries the full N-D access (dram/tile
strides + shapes) plus the SSA operands a memref.dma_start needs
Expand Down Expand Up @@ -1399,12 +1405,16 @@ def emit_transfer(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtyp
if subtile_size:
av = int(async_type) if async_type is not None else 1
attrs += f', subtile_size = {list(subtile_size)}, async = {av} : i64'
# operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride
return (
f'"togsim.transfer"(%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, '
f'%{tag}, %{dma_type}, %{vst}) {{{attrs}}} : '
f'({dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index) -> ()'
)
# operands: dram, dram_idx, sram, sram_idx, tag, dma_type, vlane_stride [, offset spad]
operands = (f'%{dram_var}, %{dram_index_var}, %{sram_var}, %{zero_cse}, '
f'%{tag}, %{dma_type}, %{vst}')
optypes = f'{dram_shape}, index, {tile_shape}, index, memref<1xi32>, index, index'
if offset is not None: # indirect: per-position offset spad (decompose lifts it to a symbol attr)
offset_buf, offset_type, offset_stride = offset
operands += f', %{offset_buf}'
optypes += f', {offset_type}'
attrs += f', indirect = true, offset_stride = {int(offset_stride)} : i64'
return f'"togsim.transfer"({operands}) {{{attrs}}} : ({optypes}) -> ()'

def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None):
c_type = mlir_common.DTYPE_TO_C[dtype]
Expand Down Expand Up @@ -1485,7 +1495,7 @@ def get_mask(self):
return mask_shape, mask_var

def convert_indirect_indexing(self, index :sympy.Expr):
if "tmp" not in str(index):
if not self._has_indirect(index):
return index, None

# Note: In case of indirect indexing, dimensions should be divisible by tile size
Expand All @@ -1496,12 +1506,11 @@ def convert_indirect_indexing(self, index :sympy.Expr):
raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")

# Process start
indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)]
indirect_dims = [str(dim) for dim in index.free_symbols if str(dim) in self.indirect_symbols]
indirect_dims.sort()
first_dim = indirect_dims[0]
spad_vars = dict()
compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims])

# Store each newly-produced indirect index into spad, in its producing step
for target_dim in indirect_dims:
if target_dim in self.spad_buffer_dict:
Expand All @@ -1524,17 +1533,30 @@ def convert_indirect_indexing(self, index :sympy.Expr):
if compute_dependecy:
self.push_step()

# Build the offset (outer ops) in the current step, reading indices back from spad
# Single indirect dim: the raw index IS the offset; the MVIN applies offset_stride (CONFIG4)
if len(indirect_dims) == 1:
offset_stride = 1
for arg in list(index.args):
if not self._has_indirect(arg):
continue
if arg.is_Mul and arg.args[0].is_number:
offset_stride = int(arg.args[0])
index = index.replace(arg, 0)
sram_var, _, _, _, tile_shape, _ = self.spad_buffer_dict[first_dim]
return index, (sram_var, tile_shape, offset_stride)

# Multi indirect dim: sum the strided indices in the compute loop (chunked by compute_vec_size)
local_tile_desc = self.kernel_group.tile_desc
compute_vec_size = local_tile_desc.get_compute_vec_size()
for target_dim in indirect_dims:
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
sram_var, _, _, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
mlir_dtype = vshape.split("x")[1][:-1]
with self.override_buffer_cse(buffer=self.dma_loads):
out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape)
spad_vars[target_dim] = out

with self.override_buffer_cse(buffer=self.dma_loads):
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
with self.override_buffer_cse(buffer=self.loads):
spad_vars[target_dim] = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape)
with self.override_buffer_cse(buffer=self.compute):
for arg in index.args:
if "tmp" not in str(arg):
if not self._has_indirect(arg):
continue
if arg.is_Mul and arg.args[0].is_number:
coeff_dtype = self.var_info[spad_vars[str(arg.args[1])]][1]
Expand All @@ -1545,15 +1567,14 @@ def convert_indirect_indexing(self, index :sympy.Expr):
if dim == first_dim:
continue
spad_vars[first_dim] = ops.add(spad_vars[first_dim], var)

sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim]
mlir_dtype = vshape.split("x")[1][:-1]
with self.override_buffer_cse(buffer=self.dma_loads):
ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape)

mlir_dtype = self.var_info[spad_vars[first_dim]][1]
with self.override_buffer_cse(buffer=self.dma_loads):
out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape)
if mlir_dtype != "index":
out = ops.index_cast(out, "index")
return index + sympy.Symbol(str(out)), compute_dependecy
# Summed offset goes to a DEDICATED spad (not an index buffer) to avoid clobbering a live index
var_info = [v for k, v in self.var_info.items() if str(k) == first_dim][0]
dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]
off_shape = local_tile_desc.get_mlir_shape(var_info[1])
off_sram, off_index = self.get_scratchpad_buffer(
dtype, "indirect_offset_" + first_dim, local_tile_desc, "indirect_offset_" + first_dim)
off_compute_index = ",".join(off_index.split(",")[:-1] + [f"%{self.compute_idx}"])
with self.override_buffer_cse(buffer=self.stores):
ops._store(spad_vars[first_dim], off_sram, off_compute_index, off_shape)
self.push_step() # offset-build compute loop must finish before the gather reads it
return index, (off_sram, off_shape, 1)
5 changes: 5 additions & 0 deletions PyTorchSimFrontend/mlir/passes/build_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ def _emit_one_dma(ctx, op, node, builder, bufs, tags):
spad_id = bufs.of(dep._global_of(f["src"] if node.is_write else f["dst"]))
read_bufs = [spad_id] if node.is_write else []
write_bufs = [] if node.is_write else [spad_id]
if "indirect_offset" in op.attributes: # gather/scatter reads the offset spad -> dep on its build
from mlir.ir import FlatSymbolRefAttr
off_id = bufs.of(FlatSymbolRefAttr(op.attributes["indirect_offset"]).value)
if off_id not in read_bufs:
read_bufs = read_bufs + [off_id]
tag_id = tags.bind(_value_key(f["tag"]), spad_id)
_emit_dma(ctx, node, tag_id, dram_index, tag_index, read_bufs, write_bufs)

Expand Down
8 changes: 7 additions & 1 deletion PyTorchSimFrontend/mlir/passes/decompose_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def run(module, vectorlane=128, **_):
targets.append(op.operation)

for op in targets:
dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op.operands
op_operands = list(op.operands)
dram, dram_idx, sram, sram_idx, tag, dma_type, vst = op_operands[:7]
# indirect: offset spad operand -> lift to a symbol attr (memref.dma_start can't take the operand)
offset_sym = op_operands[7].owner.attributes["name"] if len(op_operands) > 7 else None
kind = op.attributes["dma_kind"].value # StringAttr -> "MVIN"/"MVOUT"
vlane_axis = IntegerAttr(op.attributes["vlane_split_axis"]).value
dram_stride = _int_array(op.attributes["dram_stride"])
Expand Down Expand Up @@ -127,6 +130,9 @@ def _emit(sram_mem, sram_indices, dram_idx_val, vsa_val, dr_attr, tl_attr, st_at
if st_attr is not None:
attrs["subtile_size"] = st_attr
attrs["async"] = async_attr
if offset_sym is not None:
attrs["indirect_offset"] = offset_sym
attrs["offset_stride"] = op.attributes["offset_stride"]
Operation.create(
"memref.dma_start", results=[], operands=operands, attributes=attrs)

Expand Down
38 changes: 15 additions & 23 deletions PyTorchSimFrontend/mlir/passes/lower_dma_to_gemmini.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ def run(module, timing=False):
memref.dma_wait is erased in both modes (matches C++ DmaWaitOpLowering).
"""
from mlir.ir import (InsertionPoint, Operation, IntegerType, IndexType,
IntegerAttr, MemRefType)
IntegerAttr, MemRefType, FlatSymbolRefAttr, TypeAttr)
from mlir.dialects import llvm, arith, memref

i64 = IntegerType.get_signless(64)
idx = IndexType.get()

# memref.global symbol -> type, to resolve the indirect_offset spad
sym2type = {}
for g in module.operation.regions[0].blocks[0].operations:
if g.operation.name == "memref.global":
sym2type[g.attributes["sym_name"].value] = MemRefType(TypeAttr(g.attributes["type"]).value)

def const_int(val):
return IntegerAttr(val.owner.attributes["value"]).value

Expand Down Expand Up @@ -119,9 +125,8 @@ def elem_addr_i64(memref_val, indices, mtype, elem_bytes):
is_mvin = dma_type in (MVIN, MVIN2, MVIN3)

elem_bytes = _elem_bytes(src_ty.element_type)
# Indirect (gather): the gather-side indices are src for mvin, dst for mvout.
gather_idx = src_idx if is_mvin else dst_idx
indirect, indirect_memref = _find_indirect(gather_idx)
# Indirect (gather): offset spad referenced by the indirect_offset symbol attr
indirect = "indirect_offset" in op.attributes

tile_shape = _subtile(op)
if tile_shape is None:
Expand Down Expand Up @@ -155,10 +160,14 @@ def elem_addr_i64(memref_val, indices, mtype, elem_bytes):
i64_const((spad4[2] << 32) | (spad4[3] & 0xFFFFFFFF)))
if indirect:
# CONFIG4: rs1 = indirect index-spad base address, rs2 = (elem_size<<16)|stride(1)
offset_sym = FlatSymbolRefAttr(op.attributes["indirect_offset"]).value
off_ty = sym2type[offset_sym]
indirect_memref = memref.GetGlobalOp(off_ty, offset_sym).result
ind_base = memref.ExtractAlignedPointerAsIndexOp(indirect_memref).result
ind_addr = arith.IndexCastOp(i64, ind_base).result
ind_esize = _elem_bytes(MemRefType(indirect_memref.type).element_type)
asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (1 & 0xFFFF)))
ind_esize = _elem_bytes(off_ty.element_type)
off_stride = IntegerAttr(op.attributes["offset_stride"]).value
asm(CONFIG4, ind_addr, i64_const(((ind_esize & 0xFF) << 16) | (off_stride & 0xFFFF)))
asm(dma_type, dram_addr, spad_addr)
op.erase()

Expand Down Expand Up @@ -189,23 +198,6 @@ def _elem_bytes(elem_type):
return max(bits, 8) // 8


def _find_indirect(indices):
"""If a gather index is an affine.apply{indirect_access} whose operands include
index_cast(affine.load(%spad)), return (True, %spad memref); else (False, None)."""
for idx in indices:
ap = idx.owner
if getattr(ap, "name", None) != "affine.apply" or "indirect_access" not in ap.attributes:
continue
for operand in ap.operands:
ic = operand.owner
if getattr(ic, "name", None) != "arith.index_cast":
continue
ld = ic.operands[0].owner
if getattr(ld, "name", None) == "affine.load":
return True, ld.operands[0] # affine.load operand 0 == the index spad memref
return False, None


def lower_text(text):
if OP_NAME not in text:
return text
Expand Down
27 changes: 27 additions & 0 deletions tests/ops/misc/test_indirect_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ def scatter_only(out, token_indices, weighted_output):
res = opt_fn(out, token_indices, weighted_output)
test_result("ScatterAdd(index_add_)", res, cpu_out)

def test_multidim_indirect(device, size=(64, 64), n=256):
torch.manual_seed(0)
def gather2d(x, ix, iy):
return x[ix, iy] + 1.0
x = torch.randn(size, dtype=torch.float32).to(device=device)
ix = torch.randint(0, size[0], [n]).to(device=device)
iy = torch.randint(0, size[1], [n]).to(device=device)
opt_fn = torch.compile(dynamic=False)(gather2d)
res = opt_fn(x, ix, iy)
out = gather2d(x.cpu(), ix.cpu(), iy.cpu())
test_result("Multi-dim Indirect (x[ix,iy])", res, out)

def test_multidim_indirect_index_reuse(device, size=(64, 64), n=256):
torch.manual_seed(0)
def gather_reuse(x, ix, iy):
# ix is reused after the gather -> the offset spad must not clobber the index spad
return x[ix, iy] + ix.float()
x = torch.randn(size, dtype=torch.float32).to(device=device)
ix = torch.randint(0, size[0], [n]).to(device=device)
iy = torch.randint(0, size[1], [n]).to(device=device)
opt_fn = torch.compile(dynamic=False)(gather_reuse)
res = opt_fn(x, ix, iy)
out = gather_reuse(x.cpu(), ix.cpu(), iy.cpu())
test_result("Multi-dim Indirect index reuse (x[ix,iy]+ix)", res, out)

def test_scatter_full(device, size=(128, 128)):
def vectoradd(a, idx, b):
a[idx, :] = b
Expand All @@ -71,4 +96,6 @@ def vectoradd(a, idx, b):
test_scatter_full(device, size=(2048, 2048))
test_scatter_add(device)
test_indirect_vectoradd(device)
test_multidim_indirect(device)
test_multidim_indirect_index_reuse(device)
#test_embedding(device, 1024, 2048)