diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 45fb5e9f..a7bc914e 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -310,6 +310,22 @@ def memory_plan(self): "MVOUT1": 3, } + +class Step: + """One load->compute->store unit of the kernel body (see codegen_loops). + + Bundles the DMA, mask, index and compute buffers so the body can be an + ordered list of steps; the formerly ad-hoc mask/index buffers are just + fields here. + """ + __slots__ = ("applys", "dma_loads", + "loads", "compute", "stores", "dma_stores") + + def __init__(self, **buffers): + for name, buf in buffers.items(): + setattr(self, name, buf) + + class MLIRKernel(mlir_common.BaseMLIRKernel): overrides = ExtensionOverrides newvar_prefix = "%" @@ -321,11 +337,15 @@ def __init__(self, kernel_group, reason=None): self.spad_buffer = IndentedBuffer() self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() - self.applys = IndentedBuffer() - self.masks = IndentedBuffer() - self.dma_loads = IndentedBuffer() - self.dma_stores = IndentedBuffer() - self.indexed_buffer = IndentedBuffer() + # Kernel body = ordered load->compute->store steps; step 0 keeps the base + # loads/compute/stores (the CSE target default captured self.compute at init). + step0 = Step( + applys=IndentedBuffer(), + dma_loads=IndentedBuffer(), dma_stores=IndentedBuffer(), + loads=self.loads, compute=self.compute, stores=self.stores, + ) + self.steps = [step0] + self._bind_step(step0) self.global_vars = IndentedBuffer() self.header = IndentedBuffer() self.gem5_header = IndentedBuffer() @@ -546,19 +566,9 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE return index def load(self, name: str, index: sympy.Expr): - index, comptute_depedency = self.convert_indirect_indexing(index) + index, _ = self.convert_indirect_indexing(index) padding = self.get_padding_type() - # In case of special form of indirect access, we need to put load in dma_store buffer - if comptute_depedency: - apply_buffer = self.dma_stores - dma_buffer = self.dma_stores - load_buffer = self.dma_stores - else: - apply_buffer = None - dma_buffer = self.dma_loads - load_buffer = self.loads - # Extract dram info dram_var = self.kernel_group.args.input(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) @@ -566,7 +576,7 @@ def load(self, name: str, index: sympy.Expr): mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Extract sram info - local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, buffer=apply_buffer) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vmap.vlane_split_axis vlane_stride = local_tile_desc.vmap.vlane_stride tile_numel_per_lane = local_tile_desc.get_numel_per_lane() @@ -582,16 +592,10 @@ def load(self, name: str, index: sympy.Expr): code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, dram_shape, tile_shape, dram_stride, tile_stride, int(padding)) - self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching + self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching - if not comptute_depedency: - # Generate vector load instruction - with self.override_buffer_cse(buffer=load_buffer): - out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) - else: - # FIXME. Any good idea? - out = sram_var - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + with self.override_buffer_cse(buffer=self.loads): + out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape) self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] return out @@ -933,6 +937,30 @@ def index_expr(self, index, dtype): def codegen_global_init(self): return self.global_vars + def _bind_step(self, step): + # Make `step` the current emit sink: route the body buffers to its buffers + self.current_step = step + self.applys = step.applys + self.dma_loads = step.dma_loads + self.dma_stores = step.dma_stores + self.loads = step.loads + self.compute = step.compute + self.stores = step.stores + + def push_step(self): + # New load->compute->store step; later emits land here, steps bridge via spad + step = Step( + applys=IndentedBuffer(), + dma_loads=IndentedBuffer(), dma_stores=IndentedBuffer(), + loads=IndentedBuffer(), compute=IndentedBuffer(), stores=IndentedBuffer(), + ) + self.steps.append(step) + self._bind_step(step) + self.cse = self.cse.clone() # share name counter, fresh dedup cache (region-safe) + self.target_buffer_override.set(self.compute) + self.target_cse_override.set(self.cse) + return step + def codegen_loops(self): code = mlir_common.ParallelLoopBuffer() # Loop body part @@ -965,18 +993,18 @@ def codegen_loops(self): epilogue = reduction_loop.epilogue_line() code.writelines(reduction_lines) stack.enter_context(code.indent(attribute="{accumulation_loop=true}", suffix=epilogue)) - code.splice(self.applys) - code.splice(self.indexed_buffer) - code.splice(self.dma_loads) - # Compute body - code.writelines(self.compute_body_loop.lines()) - with contextlib.ExitStack() as stack: - stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) - code.splice(self.masks) - code.splice(self.loads) - code.splice(self.compute) - code.splice(self.stores) - code.splice(self.dma_stores) + for step in self.steps: + code.splice(step.applys) + code.splice(step.dma_loads) + # Compute body -- only steps that have one get the loop + epilogue + if any(b.getvalue() for b in (step.loads, step.compute, step.stores)): + code.writelines(self.compute_body_loop.lines()) + with contextlib.ExitStack() as stack: + stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + code.splice(step.loads) + code.splice(step.compute) + code.splice(step.stores) + code.splice(step.dma_stores) code.splice(self.reductions_suffix) # Non-outerloop end code.writeline(f"return") @@ -1450,7 +1478,7 @@ def get_mask(self): upper_bound = ops.constant(self.compute_body_loop.size, "index") step_vec = ops.step(self.compute_body_loop.step, "index") - with self.override_buffer_cse(buffer=self.masks, cse=self.mask_cse): + with self.override_buffer_cse(buffer=self.compute, cse=self.mask_cse): gap = ops.sub(upper_bound, self.compute_idx) gap_vec = ops.broadcast(gap, self.compute_body_loop.step) mask_var = ops.lt(step_vec, gap_vec) @@ -1473,37 +1501,38 @@ def convert_indirect_indexing(self, index :sympy.Expr): first_dim = indirect_dims[0] spad_vars = dict() compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims]) - target_dma_buffers = self.dma_stores if compute_dependecy else self.dma_loads - # Load indirect operands + # Store each newly-produced indirect index into spad, in its producing step for target_dim in indirect_dims: if target_dim in self.spad_buffer_dict: - sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] - else: - # FIXME. - var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0] - dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]] - - local_tile_desc = self.kernel_group.tile_desc - tile_numel_per_lane = local_tile_desc.get_numel_per_lane() - tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) - tile_vec = local_tile_desc.get_compute_vec_size() - vshape = f"vector<{var_info[0]}x{var_info[1]}>" - sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) - self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] - - # Store the indirect index variable - target_var = self.cse.varname_map[target_dim] - compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) - with self.override_buffer_cse(buffer=self.stores): - ops._store(target_var, sram_var, compute_index_var, tile_shape) + continue + var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0] + dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]] + local_tile_desc = self.kernel_group.tile_desc + tile_numel_per_lane = local_tile_desc.get_numel_per_lane() + tile_shape = local_tile_desc.get_mlir_shape(var_info[1]) + tile_vec = local_tile_desc.get_compute_vec_size() + vshape = f"vector<{var_info[0]}x{var_info[1]}>" + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim) + self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape] + target_var = self.cse.varname_map[target_dim] + compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) + with self.override_buffer_cse(buffer=self.stores): + ops._store(target_var, sram_var, compute_index_var, tile_shape) + + # Offset build runs after the index is in spad -> own step when just produced + if compute_dependecy: + self.push_step() + + # Build the offset (outer ops) in the current step, reading indices back from spad + for target_dim in indirect_dims: + sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] mlir_dtype = vshape.split("x")[1][:-1] - with self.override_buffer_cse(buffer=target_dma_buffers): + with self.override_buffer_cse(buffer=self.dma_loads): out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape) spad_vars[target_dim] = out - with self.override_buffer_cse(buffer=target_dma_buffers): - # Apply stride + with self.override_buffer_cse(buffer=self.dma_loads): for arg in index.args: if "tmp" not in str(arg): continue @@ -1512,22 +1541,18 @@ def convert_indirect_indexing(self, index :sympy.Expr): coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype) spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff) index = index.replace(arg, 0) - - # Sum for dim, var in spad_vars.items(): if dim == first_dim: continue spad_vars[first_dim] = ops.add(spad_vars[first_dim], var) - # Store index var sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim] mlir_dtype = vshape.split("x")[1][:-1] - with self.override_buffer_cse(buffer=target_dma_buffers): - ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute... + with self.override_buffer_cse(buffer=self.dma_loads): + ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) - # Conversion mlir_dtype = self.var_info[spad_vars[first_dim]][1] - with self.override_buffer_cse(buffer=target_dma_buffers): + with self.override_buffer_cse(buffer=self.dma_loads): out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape) if mlir_dtype != "index": out = ops.index_cast(out, "index")