Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 95 additions & 70 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,22 @@ def memory_plan(self):
"MVOUT1": 3,
}


class Step:
"""One load->compute->store unit of the kernel body (see codegen_loops).

Bundles the DMA, mask, index and compute buffers so the body can be an
ordered list of steps; the formerly ad-hoc mask/index buffers are just
fields here.
"""
__slots__ = ("applys", "dma_loads",
"loads", "compute", "stores", "dma_stores")

def __init__(self, **buffers):
for name, buf in buffers.items():
setattr(self, name, buf)


class MLIRKernel(mlir_common.BaseMLIRKernel):
overrides = ExtensionOverrides
newvar_prefix = "%"
Expand All @@ -321,11 +337,15 @@ def __init__(self, kernel_group, reason=None):
self.spad_buffer = IndentedBuffer()
self.reduction_prefix = IndentedBuffer()
self.reduction_suffix = IndentedBuffer()
self.applys = IndentedBuffer()
self.masks = IndentedBuffer()
self.dma_loads = IndentedBuffer()
self.dma_stores = IndentedBuffer()
self.indexed_buffer = IndentedBuffer()
# Kernel body = ordered load->compute->store steps; step 0 keeps the base
# loads/compute/stores (the CSE target default captured self.compute at init).
step0 = Step(
applys=IndentedBuffer(),
dma_loads=IndentedBuffer(), dma_stores=IndentedBuffer(),
loads=self.loads, compute=self.compute, stores=self.stores,
)
self.steps = [step0]
self._bind_step(step0)
self.global_vars = IndentedBuffer()
self.header = IndentedBuffer()
self.gem5_header = IndentedBuffer()
Expand Down Expand Up @@ -546,27 +566,17 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE
return index

def load(self, name: str, index: sympy.Expr):
index, comptute_depedency = self.convert_indirect_indexing(index)
index, _ = self.convert_indirect_indexing(index)
padding = self.get_padding_type()

# In case of special form of indirect access, we need to put load in dma_store buffer
if comptute_depedency:
apply_buffer = self.dma_stores
dma_buffer = self.dma_stores
load_buffer = self.dma_stores
else:
apply_buffer = None
dma_buffer = self.dma_loads
load_buffer = self.loads

# Extract dram info
dram_var = self.kernel_group.args.input(name)
dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
dtype = V.graph.get_dtype(name)
mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]

# Extract sram info
local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, buffer=apply_buffer)
local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index)
vlane_split_axis = local_tile_desc.vmap.vlane_split_axis
vlane_stride = local_tile_desc.vmap.vlane_stride
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
Expand All @@ -582,16 +592,10 @@ def load(self, name: str, index: sympy.Expr):

code = self.emit_transfer("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
dram_shape, tile_shape, dram_stride, tile_stride, int(padding))
self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching
self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching

if not comptute_depedency:
# Generate vector load instruction
with self.override_buffer_cse(buffer=load_buffer):
out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape)
else:
# FIXME. Any good idea?
out = sram_var
self.register_var_info(out, [compute_vec_size, mlir_dtype])
with self.override_buffer_cse(buffer=self.loads):
out = ops._load(compute_vec_size, mlir_dtype, sram_var, compute_index_var, tile_shape)
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
return out

Expand Down Expand Up @@ -933,6 +937,30 @@ def index_expr(self, index, dtype):
def codegen_global_init(self):
return self.global_vars

def _bind_step(self, step):
# Make `step` the current emit sink: route the body buffers to its buffers
self.current_step = step
self.applys = step.applys
self.dma_loads = step.dma_loads
self.dma_stores = step.dma_stores
self.loads = step.loads
self.compute = step.compute
self.stores = step.stores

def push_step(self):
# New load->compute->store step; later emits land here, steps bridge via spad
step = Step(
applys=IndentedBuffer(),
dma_loads=IndentedBuffer(), dma_stores=IndentedBuffer(),
loads=IndentedBuffer(), compute=IndentedBuffer(), stores=IndentedBuffer(),
)
self.steps.append(step)
self._bind_step(step)
self.cse = self.cse.clone() # share name counter, fresh dedup cache (region-safe)
self.target_buffer_override.set(self.compute)
self.target_cse_override.set(self.cse)
return step

def codegen_loops(self):
code = mlir_common.ParallelLoopBuffer()
# Loop body part
Expand Down Expand Up @@ -965,18 +993,18 @@ def codegen_loops(self):
epilogue = reduction_loop.epilogue_line()
code.writelines(reduction_lines)
stack.enter_context(code.indent(attribute="{accumulation_loop=true}", suffix=epilogue))
code.splice(self.applys)
code.splice(self.indexed_buffer)
code.splice(self.dma_loads)
# Compute body
code.writelines(self.compute_body_loop.lines())
with contextlib.ExitStack() as stack:
stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line()))
code.splice(self.masks)
code.splice(self.loads)
code.splice(self.compute)
code.splice(self.stores)
code.splice(self.dma_stores)
for step in self.steps:
code.splice(step.applys)
code.splice(step.dma_loads)
# Compute body -- only steps that have one get the loop + epilogue
if any(b.getvalue() for b in (step.loads, step.compute, step.stores)):
code.writelines(self.compute_body_loop.lines())
with contextlib.ExitStack() as stack:
stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line()))
code.splice(step.loads)
code.splice(step.compute)
code.splice(step.stores)
code.splice(step.dma_stores)
code.splice(self.reductions_suffix)
# Non-outerloop end
code.writeline(f"return")
Expand Down Expand Up @@ -1450,7 +1478,7 @@ def get_mask(self):
upper_bound = ops.constant(self.compute_body_loop.size, "index")
step_vec = ops.step(self.compute_body_loop.step, "index")

with self.override_buffer_cse(buffer=self.masks, cse=self.mask_cse):
with self.override_buffer_cse(buffer=self.compute, cse=self.mask_cse):
gap = ops.sub(upper_bound, self.compute_idx)
gap_vec = ops.broadcast(gap, self.compute_body_loop.step)
mask_var = ops.lt(step_vec, gap_vec)
Expand All @@ -1473,37 +1501,38 @@ def convert_indirect_indexing(self, index :sympy.Expr):
first_dim = indirect_dims[0]
spad_vars = dict()
compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims])
target_dma_buffers = self.dma_stores if compute_dependecy else self.dma_loads

# Load indirect operands
# Store each newly-produced indirect index into spad, in its producing step
for target_dim in indirect_dims:
if target_dim in self.spad_buffer_dict:
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
else:
# FIXME.
var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0]
dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]

local_tile_desc = self.kernel_group.tile_desc
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
tile_shape = local_tile_desc.get_mlir_shape(var_info[1])
tile_vec = local_tile_desc.get_compute_vec_size()
vshape = f"vector<{var_info[0]}x{var_info[1]}>"
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim)
self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]

# Store the indirect index variable
target_var = self.cse.varname_map[target_dim]
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
with self.override_buffer_cse(buffer=self.stores):
ops._store(target_var, sram_var, compute_index_var, tile_shape)
continue
var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0]
dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]
local_tile_desc = self.kernel_group.tile_desc
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
tile_shape = local_tile_desc.get_mlir_shape(var_info[1])
tile_vec = local_tile_desc.get_compute_vec_size()
vshape = f"vector<{var_info[0]}x{var_info[1]}>"
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim)
self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
target_var = self.cse.varname_map[target_dim]
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
with self.override_buffer_cse(buffer=self.stores):
ops._store(target_var, sram_var, compute_index_var, tile_shape)

# Offset build runs after the index is in spad -> own step when just produced
if compute_dependecy:
self.push_step()

# Build the offset (outer ops) in the current step, reading indices back from spad
for target_dim in indirect_dims:
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
mlir_dtype = vshape.split("x")[1][:-1]
with self.override_buffer_cse(buffer=target_dma_buffers):
with self.override_buffer_cse(buffer=self.dma_loads):
out = ops._load(tile_numel_per_lane, mlir_dtype, sram_var, sram_index_var, tile_shape)
spad_vars[target_dim] = out

with self.override_buffer_cse(buffer=target_dma_buffers):
# Apply stride
with self.override_buffer_cse(buffer=self.dma_loads):
for arg in index.args:
if "tmp" not in str(arg):
continue
Expand All @@ -1512,22 +1541,18 @@ def convert_indirect_indexing(self, index :sympy.Expr):
coeff = self.get_const_cse(int(arg.args[0]), coeff_dtype)
spad_vars[str(arg.args[1])] = ops.mul(spad_vars[str(arg.args[1])], coeff)
index = index.replace(arg, 0)

# Sum
for dim, var in spad_vars.items():
if dim == first_dim:
continue
spad_vars[first_dim] = ops.add(spad_vars[first_dim], var)

# Store index var
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[first_dim]
mlir_dtype = vshape.split("x")[1][:-1]
with self.override_buffer_cse(buffer=target_dma_buffers):
ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape) # FIXME. Maybe require fine grain compute...
with self.override_buffer_cse(buffer=self.dma_loads):
ops._store(spad_vars[first_dim], sram_var, sram_index_var, tile_shape)

# Conversion
mlir_dtype = self.var_info[spad_vars[first_dim]][1]
with self.override_buffer_cse(buffer=target_dma_buffers):
with self.override_buffer_cse(buffer=self.dma_loads):
out = ops._load(1, mlir_dtype, sram_var, sram_index_var, tile_shape)
if mlir_dtype != "index":
out = ops.index_cast(out, "index")
Expand Down