Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions include/tvm/tirx/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,51 @@ TVM_DLL const Op& simdgroup_store();
*/
TVM_DLL const Op& simdgroup_multiply_accumulate();

// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)

/*!
* \brief Fill a cooperative_tensor with a given value.
*
* void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
* int rows, int cols);
*/
TVM_DLL const Op& cooperative_tensor_fill();

/*!
* \brief Load data from device or threadgroup memory into a cooperative_tensor.
*
* void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
* PrimExpr stride, int rows, int cols,
* bool transpose_matrix,
* int mma_M, int mma_N, int mma_K,
* int operand_role);
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
*/
TVM_DLL const Op& cooperative_tensor_load();

/*!
* \brief Store data from a cooperative_tensor to device or threadgroup memory.
*
* void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
* PrimExpr stride, int rows, int cols,
* bool transpose_matrix,
* int mma_M, int mma_N, int mma_K,
* int operand_role);
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
*/
TVM_DLL const Op& cooperative_tensor_store();

/*!
* \brief Multiply and accumulate two matrices using cooperative_tensor
* (MetalPerformancePrimitives matmul2d).
*
* void cooperative_tensor_multiply_accumulate(
* Var d, PrimExpr index_d, Var a, PrimExpr index_a,
* Var b, PrimExpr index_b, Var c, PrimExpr index_c,
* int M, int N, int K, bool transpose_a, bool transpose_b);
*/
TVM_DLL const Op& cooperative_tensor_multiply_accumulate();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tirx/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,10 @@ def wrapped(*args, **kwargs) -> T:
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -2252,6 +2256,10 @@ def wrapped(*args, **kwargs):
"simdgroup_load",
"simdgroup_store",
"simdgroup_multiply_accumulate",
"cooperative_tensor_fill",
"cooperative_tensor_load",
"cooperative_tensor_store",
"cooperative_tensor_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
Expand Down
104 changes: 104 additions & 0 deletions python/tvm/tirx/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,110 @@ def simdgroup_multiply_accumulate(
)


def cooperative_tensor_fill(
d: Var,
index: PrimExpr,
value: PrimExpr,
rows: int,
cols: int,
):
return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols)


def cooperative_tensor_load(
d: Var,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
mma_M: int = 0,
mma_N: int = 0,
mma_K: int = 0,
operand_role: int = 0,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_load",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
mma_M,
mma_N,
mma_K,
operand_role,
)


def cooperative_tensor_store(
d: PrimExpr,
index: PrimExpr,
ptr: PrimExpr,
stride: PrimExpr,
rows: int,
cols: int,
transpose_matrix: bool = False,
mma_M: int = 0,
mma_N: int = 0,
mma_K: int = 0,
operand_role: int = 0,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_store",
d,
index,
ptr,
stride,
rows,
cols,
transpose_matrix,
mma_M,
mma_N,
mma_K,
operand_role,
)


def cooperative_tensor_multiply_accumulate(
d: Var,
index_d: PrimExpr,
a: Var,
index_a: PrimExpr,
b: Var,
index_b: PrimExpr,
c: Var,
index_c: PrimExpr,
M: int,
N: int,
K: int,
transpose_a: bool = False,
transpose_b: bool = False,
):
return call_intrin(
"handle",
"tirx.cooperative_tensor_multiply_accumulate",
d,
index_d,
a,
index_a,
b,
index_b,
c,
index_c,
M,
N,
K,
transpose_a,
transpose_b,
)
Comment on lines +1795 to +1896
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new cooperative_tensor functions are missing docstrings and the span parameter. Adding these will improve maintainability and ensure better error reporting and source mapping in TVM Script, consistent with other TIR operators in this module.

def cooperative_tensor_fill(
    d: Var,
    index: PrimExpr,
    value: PrimExpr,
    rows: int,
    cols: int,
    span: Span | None = None,
):
    """Fill a cooperative_tensor with a given value.

    Parameters
    ----------
    d : Var
        The cooperative_tensor variable.
    index : PrimExpr
        The index of the tensor.
    value : PrimExpr
        The value to fill.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols, span=span)


def cooperative_tensor_load(
    d: Var,
    index: PrimExpr,
    ptr: PrimExpr,
    stride: PrimExpr,
    rows: int,
    cols: int,
    transpose_matrix: bool = False,
    span: Span | None = None,
):
    """Load data from device or threadgroup memory into a cooperative_tensor.

    Parameters
    ----------
    d : Var
        The cooperative_tensor variable.
    index : PrimExpr
        The index of the tensor.
    ptr : PrimExpr
        The pointer to the source data.
    stride : PrimExpr
        The stride of the source data.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    transpose_matrix : bool
        Whether to transpose the matrix during load.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_load",
        d,
        index,
        ptr,
        stride,
        rows,
        cols,
        transpose_matrix,
        span=span,
    )


def cooperative_tensor_store(
    d: PrimExpr,
    index: PrimExpr,
    ptr: PrimExpr,
    stride: PrimExpr,
    rows: int,
    cols: int,
    transpose_matrix: bool = False,
    span: Span | None = None,
):
    """Store data from a cooperative_tensor to device or threadgroup memory.

    Parameters
    ----------
    d : PrimExpr
        The cooperative_tensor expression.
    index : PrimExpr
        The index of the tensor.
    ptr : PrimExpr
        The pointer to the destination data.
    stride : PrimExpr
        The stride of the destination data.
    rows : int
        The number of rows.
    cols : int
        The number of columns.
    transpose_matrix : bool
        Whether to transpose the matrix during store.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_store",
        d,
        index,
        ptr,
        stride,
        rows,
        cols,
        transpose_matrix,
        span=span,
    )


def cooperative_tensor_multiply_accumulate(
    d: Var,
    index_d: PrimExpr,
    a: Var,
    index_a: PrimExpr,
    b: Var,
    index_b: PrimExpr,
    c: Var,
    index_c: PrimExpr,
    M: int,
    N: int,
    K: int,
    transpose_a: bool = False,
    transpose_b: bool = False,
    span: Span | None = None,
):
    """Multiply and accumulate matrices using cooperative_tensor.

    Parameters
    ----------
    d : Var
        The destination cooperative_tensor variable.
    index_d : PrimExpr
        The index of the destination tensor.
    a : Var
        The first input cooperative_tensor variable.
    index_a : PrimExpr
        The index of the first input tensor.
    b : Var
        The second input cooperative_tensor variable.
    index_b : PrimExpr
        The index of the second input tensor.
    c : Var
        The third input cooperative_tensor variable.
    index_c : PrimExpr
        The index of the third input tensor.
    M : int
        The M dimension of the multiplication.
    N : int
        The N dimension of the multiplication.
    K : int
        The K dimension of the multiplication.
    transpose_a : bool
        Whether to transpose the first input matrix.
    transpose_b : bool
        Whether to transpose the second input matrix.
    span : Optional[Span]
        The location of this operator in the source code.

    Returns
    -------
    call : PrimExpr
        The call expression.
    """
    return call_intrin(
        "handle",
        "tirx.cooperative_tensor_multiply_accumulate",
        d,
        index_d,
        a,
        index_a,
        b,
        index_b,
        c,
        index_c,
        M,
        N,
        K,
        transpose_a,
        transpose_b,
        span=span,
    )



def vectorlow(dtype, vec):
"""Get the low level half of the vector

Expand Down
7 changes: 7 additions & 0 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ enum class StorageRank {
kMMAMatrixC = 11,
/*! \brief Metal SIMD group memory */
kMetalSimdGroup = 12,
/*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */
kMetalCooperativeTensor = 13,
};

/*!
Expand Down Expand Up @@ -129,6 +131,8 @@ struct StorageScope {
return "m16n8k8.matrixC" + tag;
case StorageRank::kMetalSimdGroup:
return "metal.simdgroup" + tag;
case StorageRank::kMetalCooperativeTensor:
return "metal.cooperative_tensor" + tag;
default:
TVM_FFI_THROW(InternalError) << "unknown storage scope";
return "";
Expand Down Expand Up @@ -182,6 +186,9 @@ struct StorageScope {
} else if (s.compare(0, 15, "metal.simdgroup") == 0) {
r.rank = StorageRank::kMetalSimdGroup;
r.tag = s.substr(15, std::string::npos);
} else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) {
r.rank = StorageRank::kMetalCooperativeTensor;
r.tag = s.substr(24, std::string::npos);
} else {
TVM_FFI_THROW(InternalError) << "unknown storage scope " << s;
}
Expand Down
12 changes: 12 additions & 0 deletions src/tirx/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Expand Down
Loading