[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#19423
[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#19423oraluben wants to merge 1 commit intoapache:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Metal cooperative_tensor intrinsics, enabling operations like fill, load, store, and multiply-accumulate for MetalPerformancePrimitives. The changes include C++ declarations, Python IR builder wrappers, and runtime storage scope definitions. Feedback suggests enhancing the Python operator definitions in python/tvm/tirx/op.py by adding comprehensive docstrings and a span parameter to improve maintainability and error reporting in TVM Script.
| def cooperative_tensor_fill( | ||
| d: Var, | ||
| index: PrimExpr, | ||
| value: PrimExpr, | ||
| rows: int, | ||
| cols: int, | ||
| ): | ||
| return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols) | ||
|
|
||
|
|
||
| def cooperative_tensor_load( | ||
| d: Var, | ||
| index: PrimExpr, | ||
| ptr: PrimExpr, | ||
| stride: PrimExpr, | ||
| rows: int, | ||
| cols: int, | ||
| transpose_matrix: bool = False, | ||
| ): | ||
| return call_intrin( | ||
| "handle", | ||
| "tirx.cooperative_tensor_load", | ||
| d, | ||
| index, | ||
| ptr, | ||
| stride, | ||
| rows, | ||
| cols, | ||
| transpose_matrix, | ||
| ) | ||
|
|
||
|
|
||
| def cooperative_tensor_store( | ||
| d: PrimExpr, | ||
| index: PrimExpr, | ||
| ptr: PrimExpr, | ||
| stride: PrimExpr, | ||
| rows: int, | ||
| cols: int, | ||
| transpose_matrix: bool = False, | ||
| ): | ||
| return call_intrin( | ||
| "handle", | ||
| "tirx.cooperative_tensor_store", | ||
| d, | ||
| index, | ||
| ptr, | ||
| stride, | ||
| rows, | ||
| cols, | ||
| transpose_matrix, | ||
| ) | ||
|
|
||
|
|
||
| def cooperative_tensor_multiply_accumulate( | ||
| d: Var, | ||
| index_d: PrimExpr, | ||
| a: Var, | ||
| index_a: PrimExpr, | ||
| b: Var, | ||
| index_b: PrimExpr, | ||
| c: Var, | ||
| index_c: PrimExpr, | ||
| M: int, | ||
| N: int, | ||
| K: int, | ||
| transpose_a: bool = False, | ||
| transpose_b: bool = False, | ||
| ): | ||
| return call_intrin( | ||
| "handle", | ||
| "tirx.cooperative_tensor_multiply_accumulate", | ||
| d, | ||
| index_d, | ||
| a, | ||
| index_a, | ||
| b, | ||
| index_b, | ||
| c, | ||
| index_c, | ||
| M, | ||
| N, | ||
| K, | ||
| transpose_a, | ||
| transpose_b, | ||
| ) |
There was a problem hiding this comment.
The new cooperative_tensor functions are missing docstrings and the span parameter. Adding these will improve maintainability and ensure better error reporting and source mapping in TVM Script, consistent with other TIR operators in this module.
def cooperative_tensor_fill(
d: Var,
index: PrimExpr,
value: PrimExpr,
rows: int,
cols: int,
span: Span | None = None,
):
"""Fill a cooperative_tensor with a given value.
Parameters
----------
d : Var
The cooperative_tensor variable.
index : PrimExpr
The index of the tensor.
value : PrimExpr
The value to fill.
rows : int
The number of rows.
cols : int
The number of columns.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols, span=span)
def cooperative_tensor_load(
d: Var,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
span: Span | None = None,
):
"""Load data from device or threadgroup memory into a cooperative_tensor.
Parameters
----------
d : Var
The cooperative_tensor variable.
index : PrimExpr
The index of the tensor.
ptr : PrimExpr
The pointer to the source data.
stride : PrimExpr
The stride of the source data.
rows : int
The number of rows.
cols : int
The number of columns.
transpose_matrix : bool
Whether to transpose the matrix during load.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tirx.cooperative_tensor_load",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
span=span,
)
def cooperative_tensor_store(
d: PrimExpr,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
span: Span | None = None,
):
"""Store data from a cooperative_tensor to device or threadgroup memory.
Parameters
----------
d : PrimExpr
The cooperative_tensor expression.
index : PrimExpr
The index of the tensor.
ptr : PrimExpr
The pointer to the destination data.
stride : PrimExpr
The stride of the destination data.
rows : int
The number of rows.
cols : int
The number of columns.
transpose_matrix : bool
Whether to transpose the matrix during store.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tirx.cooperative_tensor_store",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
span=span,
)
def cooperative_tensor_multiply_accumulate(
d: Var,
index_d: PrimExpr,
a: Var,
index_a: PrimExpr,
b: Var,
index_b: PrimExpr,
c: Var,
index_c: PrimExpr,
M: int,
N: int,
K: int,
transpose_a: bool = False,
transpose_b: bool = False,
span: Span | None = None,
):
"""Multiply and accumulate matrices using cooperative_tensor.
Parameters
----------
d : Var
The destination cooperative_tensor variable.
index_d : PrimExpr
The index of the destination tensor.
a : Var
The first input cooperative_tensor variable.
index_a : PrimExpr
The index of the first input tensor.
b : Var
The second input cooperative_tensor variable.
index_b : PrimExpr
The index of the second input tensor.
c : Var
The third input cooperative_tensor variable.
index_c : PrimExpr
The index of the third input tensor.
M : int
The M dimension of the multiplication.
N : int
The N dimension of the multiplication.
K : int
The K dimension of the multiplication.
transpose_a : bool
Whether to transpose the first input matrix.
transpose_b : bool
Whether to transpose the second input matrix.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tirx.cooperative_tensor_multiply_accumulate",
d,
index_d,
a,
index_a,
b,
index_b,
c,
index_c,
M,
N,
K,
transpose_a,
transpose_b,
span=span,
)…orage scope Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives): - cooperative_tensor_fill: fill a cooperative_tensor with a value - cooperative_tensor_load: load from device/threadgroup memory - cooperative_tensor_store: store to device/threadgroup memory - cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor) for buffers backed by MPP cooperative_tensor registers, analogous to the existing metal.simdgroup scope but targeting the Metal 4 tensor operations API. These primitives enable code generation for MetalPerformancePrimitives matmul2d, which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix instructions on M1-M4.
6f80a69 to
ae0618c
Compare
Summary
Add TIR builtins and storage scope for Metal cooperative_tensor operations (MetalPerformancePrimitives / Metal 4).
Motivation
Apple Metal 4 introduces MetalPerformancePrimitives (MPP) with
matmul2dusingcooperative_tensoroperands. On M5, this routes to NAX tensor cores; on M1-M4, it falls back to simdgroup matrix instructions. These TIR primitives enable backend codegen to emit MPP calls.Changes
New TIR builtins
cooperative_tensor_fill(d, index, value, rows, cols)cooperative_tensor_load(d, index, ptr, stride, rows, cols, transpose)cooperative_tensor_store(d, index, ptr, stride, rows, cols, transpose)cooperative_tensor_multiply_accumulate(d, di, a, ai, b, bi, c, ci, M, N, K, trans_a, trans_b)New storage scope
metal.cooperative_tensor(StorageRank::kMetalCooperativeTensor)Files changed
include/tvm/tirx/builtin.h— Op declarationssrc/tirx/op/builtin.cc— Op registrationspython/tvm/tirx/op.py— Python wrapperspython/tvm/script/ir_builder/tirx/ir.py— Script parser exportssrc/runtime/thread_storage_scope.h— StorageRank enum + scope parsingThese builtins mirror the existing
simdgroup_*builtins for the older Metal simdgroup matrix API, extended with M/N/K dimension parameters for the matmul2d descriptor.