Skip to content

[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#19423

Open
oraluben wants to merge 1 commit intoapache:mainfrom
oraluben:metal-cooperative-tensor-upstream
Open

[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#19423
oraluben wants to merge 1 commit intoapache:mainfrom
oraluben:metal-cooperative-tensor-upstream

Conversation

@oraluben
Copy link
Copy Markdown
Contributor

Summary

Add TIR builtins and storage scope for Metal cooperative_tensor operations (MetalPerformancePrimitives / Metal 4).

Motivation

Apple Metal 4 introduces MetalPerformancePrimitives (MPP) with matmul2d using cooperative_tensor operands. On M5, this routes to NAX tensor cores; on M1-M4, it falls back to simdgroup matrix instructions. These TIR primitives enable backend codegen to emit MPP calls.

Changes

New TIR builtins

  • cooperative_tensor_fill(d, index, value, rows, cols)
  • cooperative_tensor_load(d, index, ptr, stride, rows, cols, transpose)
  • cooperative_tensor_store(d, index, ptr, stride, rows, cols, transpose)
  • cooperative_tensor_multiply_accumulate(d, di, a, ai, b, bi, c, ci, M, N, K, trans_a, trans_b)

New storage scope

  • metal.cooperative_tensor (StorageRank::kMetalCooperativeTensor)

Files changed

  • include/tvm/tirx/builtin.h — Op declarations
  • src/tirx/op/builtin.cc — Op registrations
  • python/tvm/tirx/op.py — Python wrappers
  • python/tvm/script/ir_builder/tirx/ir.py — Script parser exports
  • src/runtime/thread_storage_scope.h — StorageRank enum + scope parsing

These builtins mirror the existing simdgroup_* builtins for the older Metal simdgroup matrix API, extended with M/N/K dimension parameters for the matmul2d descriptor.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Metal cooperative_tensor intrinsics, enabling operations like fill, load, store, and multiply-accumulate for MetalPerformancePrimitives. The changes include C++ declarations, Python IR builder wrappers, and runtime storage scope definitions. Feedback suggests enhancing the Python operator definitions in python/tvm/tirx/op.py by adding comprehensive docstrings and a span parameter to improve maintainability and error reporting in TVM Script.

Comment thread python/tvm/tirx/op.py
Comment on lines +1795 to +1880
def cooperative_tensor_fill(
d: Var,
index: PrimExpr,
value: PrimExpr,
rows: int,
cols: int,
):
return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols)


def cooperative_tensor_load(
d: Var,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_load",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
)


def cooperative_tensor_store(
d: PrimExpr,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_store",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
)


def cooperative_tensor_multiply_accumulate(
d: Var,
index_d: PrimExpr,
a: Var,
index_a: PrimExpr,
b: Var,
index_b: PrimExpr,
c: Var,
index_c: PrimExpr,
M: int,
N: int,
K: int,
transpose_a: bool = False,
transpose_b: bool = False,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_multiply_accumulate",
d,
index_d,
a,
index_a,
b,
index_b,
c,
index_c,
M,
N,
K,
transpose_a,
transpose_b,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new cooperative_tensor functions are missing docstrings and the span parameter. Adding these will improve maintainability and ensure better error reporting and source mapping in TVM Script, consistent with other TIR operators in this module.

def cooperative_tensor_fill(
    d: Var,
    index: PrimExpr,
    value: PrimExpr,
    rows: int,
    cols: int,
    span: Span | None = None,
):
    """Fill a cooperative_tensor with a given value.

    Parameters
    ----------
    d : Var
        The cooperative_tensor variable.
    index : PrimExpr
        The index of the tensor.
    value : PrimExpr
        The value to fill.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols, span=span)


def cooperative_tensor_load(
    d: Var,
    index: PrimExpr,
    ptr: PrimExpr,
    stride: PrimExpr,
    rows: int,
    cols: int,
    transpose_matrix: bool = False,
    span: Span | None = None,
):
    """Load data from device or threadgroup memory into a cooperative_tensor.

    Parameters
    ----------
    d : Var
        The cooperative_tensor variable.
    index : PrimExpr
        The index of the tensor.
    ptr : PrimExpr
        The pointer to the source data.
    stride : PrimExpr
        The stride of the source data.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    transpose_matrix : bool
        Whether to transpose the matrix during load.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_load",
        d,
        index,
        ptr,
        stride,
        rows,
        cols,
        transpose_matrix,
        span=span,
    )


def cooperative_tensor_store(
    d: PrimExpr,
    index: PrimExpr,
    ptr: PrimExpr,
    stride: PrimExpr,
    rows: int,
    cols: int,
    transpose_matrix: bool = False,
    span: Span | None = None,
):
    """Store data from a cooperative_tensor to device or threadgroup memory.

    Parameters
    ----------
    d : PrimExpr
        The cooperative_tensor expression.
    index : PrimExpr
        The index of the tensor.
    ptr : PrimExpr
        The pointer to the destination data.
    stride : PrimExpr
        The stride of the destination data.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    transpose_matrix : bool
        Whether to transpose the matrix during store.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_store",
        d,
        index,
        ptr,
        stride,
        rows,
        cols,
        transpose_matrix,
        span=span,
    )


def cooperative_tensor_multiply_accumulate(
    d: Var,
    index_d: PrimExpr,
    a: Var,
    index_a: PrimExpr,
    b: Var,
    index_b: PrimExpr,
    c: Var,
    index_c: PrimExpr,
    M: int,
    N: int,
    K: int,
    transpose_a: bool = False,
    transpose_b: bool = False,
    span: Span | None = None,
):
    """Multiply and accumulate matrices using cooperative_tensor.

    Parameters
    ----------
    d : Var
        The destination cooperative_tensor variable.
    index_d : PrimExpr
        The index of the destination tensor.
    a : Var
        The first input cooperative_tensor variable.
    index_a : PrimExpr
        The index of the first input tensor.
    b : Var
        The second input cooperative_tensor variable.
    index_b : PrimExpr
        The index of the second input tensor.
    c : Var
        The third input cooperative_tensor variable.
    index_c : PrimExpr
        The index of the third input tensor.
    M : int
        The M dimension of the multiplication.
    N : int
        The N dimension of the multiplication.
    K : int
        The K dimension of the multiplication.
    transpose_a : bool
        Whether to transpose the first input matrix.
    transpose_b : bool
        Whether to transpose the second input matrix.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_multiply_accumulate",
        d,
        index_d,
        a,
        index_a,
        b,
        index_b,
        c,
        index_c,
        M,
        N,
        K,
        transpose_a,
        transpose_b,
        span=span,
    )

…orage scope

Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives):
- cooperative_tensor_fill: fill a cooperative_tensor with a value
- cooperative_tensor_load: load from device/threadgroup memory
- cooperative_tensor_store: store to device/threadgroup memory
- cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d

Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor)
for buffers backed by MPP cooperative_tensor registers, analogous to the existing
metal.simdgroup scope but targeting the Metal 4 tensor operations API.

These primitives enable code generation for MetalPerformancePrimitives matmul2d,
which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix
instructions on M1-M4.
@oraluben oraluben force-pushed the metal-cooperative-tensor-upstream branch from 6f80a69 to ae0618c Compare April 20, 2026 07:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant