From ae0618c20ad856f63fa563d724258fe4cc8ff323 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 20 Apr 2026 13:17:18 +0800 Subject: [PATCH] [TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives): - cooperative_tensor_fill: fill a cooperative_tensor with a value - cooperative_tensor_load: load from device/threadgroup memory - cooperative_tensor_store: store to device/threadgroup memory - cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor) for buffers backed by MPP cooperative_tensor registers, analogous to the existing metal.simdgroup scope but targeting the Metal 4 tensor operations API. These primitives enable code generation for MetalPerformancePrimitives matmul2d, which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix instructions on M1-M4. --- include/tvm/tirx/builtin.h | 45 ++++++++++ python/tvm/script/ir_builder/tirx/ir.py | 8 ++ python/tvm/tirx/op.py | 104 ++++++++++++++++++++++++ src/runtime/thread_storage_scope.h | 7 ++ src/tirx/op/builtin.cc | 12 +++ 5 files changed, 176 insertions(+) diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h index d0d5b3d57e27..70cd0ead720e 100644 --- a/include/tvm/tirx/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -788,6 +788,51 @@ TVM_DLL const Op& simdgroup_store(); */ TVM_DLL const Op& simdgroup_multiply_accumulate(); +// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4) + +/*! + * \brief Fill a cooperative_tensor with a given value. + * + * void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value, + * int rows, int cols); + */ +TVM_DLL const Op& cooperative_tensor_fill(); + +/*! + * \brief Load data from device or threadgroup memory into a cooperative_tensor. + * + * void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr, + * PrimExpr stride, int rows, int cols, + * bool transpose_matrix, + * int mma_M, int mma_N, int mma_K, + * int operand_role); + * operand_role: 0=left(A), 1=right(B), 2=destination(C) + */ +TVM_DLL const Op& cooperative_tensor_load(); + +/*! + * \brief Store data from a cooperative_tensor to device or threadgroup memory. + * + * void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr, + * PrimExpr stride, int rows, int cols, + * bool transpose_matrix, + * int mma_M, int mma_N, int mma_K, + * int operand_role); + * operand_role: 0=left(A), 1=right(B), 2=destination(C) + */ +TVM_DLL const Op& cooperative_tensor_store(); + +/*! + * \brief Multiply and accumulate two matrices using cooperative_tensor + * (MetalPerformancePrimitives matmul2d). + * + * void cooperative_tensor_multiply_accumulate( + * Var d, PrimExpr index_d, Var a, PrimExpr index_a, + * Var b, PrimExpr index_b, Var c, PrimExpr index_c, + * int M, int N, int K, bool transpose_a, bool transpose_b); + */ +TVM_DLL const Op& cooperative_tensor_multiply_accumulate(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/python/tvm/script/ir_builder/tirx/ir.py b/python/tvm/script/ir_builder/tirx/ir.py index ce88c563156f..5bae6caab436 100644 --- a/python/tvm/script/ir_builder/tirx/ir.py +++ b/python/tvm/script/ir_builder/tirx/ir.py @@ -1962,6 +1962,10 @@ def wrapped(*args, **kwargs) -> T: simdgroup_load = _op_wrapper(_tir_op.simdgroup_load) simdgroup_store = _op_wrapper(_tir_op.simdgroup_store) simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate) +cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill) +cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load) +cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store) +cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate) create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) @@ -2252,6 +2256,10 @@ def wrapped(*args, **kwargs): "simdgroup_load", "simdgroup_store", "simdgroup_multiply_accumulate", + "cooperative_tensor_fill", + "cooperative_tensor_load", + "cooperative_tensor_store", + "cooperative_tensor_multiply_accumulate", "create_barriers", "mma_store", "mma_fill", diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 6b4a636f3061..f24c2bbea145 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -1792,6 +1792,110 @@ def simdgroup_multiply_accumulate( ) +def cooperative_tensor_fill( + d: Var, + index: PrimExpr, + value: PrimExpr, + rows: int, + cols: int, +): + return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols) + + +def cooperative_tensor_load( + d: Var, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + rows: int, + cols: int, + transpose_matrix: bool = False, + mma_M: int = 0, + mma_N: int = 0, + mma_K: int = 0, + operand_role: int = 0, +): + return call_intrin( + "handle", + "tirx.cooperative_tensor_load", + d, + index, + ptr, + stride, + rows, + cols, + transpose_matrix, + mma_M, + mma_N, + mma_K, + operand_role, + ) + + +def cooperative_tensor_store( + d: PrimExpr, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + rows: int, + cols: int, + transpose_matrix: bool = False, + mma_M: int = 0, + mma_N: int = 0, + mma_K: int = 0, + operand_role: int = 0, +): + return call_intrin( + "handle", + "tirx.cooperative_tensor_store", + d, + index, + ptr, + stride, + rows, + cols, + transpose_matrix, + mma_M, + mma_N, + mma_K, + operand_role, + ) + + +def cooperative_tensor_multiply_accumulate( + d: Var, + index_d: PrimExpr, + a: Var, + index_a: PrimExpr, + b: Var, + index_b: PrimExpr, + c: Var, + index_c: PrimExpr, + M: int, + N: int, + K: int, + transpose_a: bool = False, + transpose_b: bool = False, +): + return call_intrin( + "handle", + "tirx.cooperative_tensor_multiply_accumulate", + d, + index_d, + a, + index_a, + b, + index_b, + c, + index_c, + M, + N, + K, + transpose_a, + transpose_b, + ) + + def vectorlow(dtype, vec): """Get the low level half of the vector diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 313e4cfe484c..0155aa1ffd67 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -71,6 +71,8 @@ enum class StorageRank { kMMAMatrixC = 11, /*! \brief Metal SIMD group memory */ kMetalSimdGroup = 12, + /*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */ + kMetalCooperativeTensor = 13, }; /*! @@ -129,6 +131,8 @@ struct StorageScope { return "m16n8k8.matrixC" + tag; case StorageRank::kMetalSimdGroup: return "metal.simdgroup" + tag; + case StorageRank::kMetalCooperativeTensor: + return "metal.cooperative_tensor" + tag; default: TVM_FFI_THROW(InternalError) << "unknown storage scope"; return ""; @@ -182,6 +186,9 @@ struct StorageScope { } else if (s.compare(0, 15, "metal.simdgroup") == 0) { r.rank = StorageRank::kMetalSimdGroup; r.tag = s.substr(15, std::string::npos); + } else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) { + r.rank = StorageRank::kMetalCooperativeTensor; + r.tag = s.substr(24, std::string::npos); } else { TVM_FFI_THROW(InternalError) << "unknown storage scope " << s; } diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index 68f9ce219bb7..452dd2f10c3a 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -348,6 +348,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store) TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation",