diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h index 4d1267e97bb9..c13d90165128 100644 --- a/include/tvm/tirx/transform.h +++ b/include/tvm/tirx/transform.h @@ -330,6 +330,12 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); */ TVM_DLL Pass BindTarget(Target target); +/*! + * \brief Convert ForKind::kParallel loops to blockIdx.x/threadIdx.x bindings on GPU targets. + * \return The pass. + */ +TVM_DLL Pass BindParallelLoopsToThreads(); + /*! * \brief Set a PrimFunc as the entry point if it is only function in IRModule. * \return The pass. @@ -354,6 +360,7 @@ TVM_DLL Pass Filter(ffi::TypedFunction fcond); * * \return The pass. */ + } // namespace transform } // namespace tirx } // namespace tvm diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index a63fb4346d34..434b5c4bfca1 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -89,6 +89,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), + tirx.transform.BindParallelLoopsToThreads(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), ] diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index f775c0dd1eac..e13ffaf8ee74 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -87,6 +87,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), + tirx.transform.BindParallelLoopsToThreads(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), ] diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 6e18558b0ecd..34861981e4b0 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -428,6 +428,17 @@ def VerifyMemory(): return _ffi_api.VerifyMemory() # type: ignore +def BindParallelLoopsToThreads(): + """Convert T.parallel loops to block/thread bindings for GPU PrimFuncs. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BindParallelLoopsToThreads() # type: ignore + + @_ffi.register_object("s_tir.transform.HoistIfThenElseConfig") class HoistIfThenElseConfig(_ir.Attrs): """Config for hoist if then else pass""" diff --git a/src/tirx/transform/bind_parallel_loops_to_threads.cc b/src/tirx/transform/bind_parallel_loops_to_threads.cc new file mode 100644 index 000000000000..15493d052449 --- /dev/null +++ b/src/tirx/transform/bind_parallel_loops_to_threads.cc @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file bind_parallel_loops_to_threads.cc + * \brief Convert ForKind::kParallel loops to GPU thread bindings. + * + * Semantics: + * - Only runs when the PrimFunc carries a `tvm::attr::kTarget` that refers to a GPU device. + * Functions without a target attribute are left unchanged (no ambient `Target::Current` guess). + * - The outermost `kParallel` loop in the function is rewritten to `blockIdx.x` / `threadIdx.x` + * `thread_extent` scopes, with a guard `if (global_idx < extent)` and no else-branch. + * - Nested `kParallel` loops (parallel inside parallel) are rejected: binding only the outer + * parallel nest would leave inner `kParallel` serial within the mapped kernel, which is + * almost never what users intend. + * - A `kParallel` that appears inside an existing thread environment (`thread_extent` / + * `virtual_thread`) is left unchanged so it does not introduce conflicting thread bindings. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tirx { +namespace { + +static bool IsGpuDeviceType(int dev_type) { + return dev_type == kDLCUDA || dev_type == kDLROCM || dev_type == kDLOpenCL || + dev_type == kDLVulkan || dev_type == kDLMetal || dev_type == kDLWebGPU; +} + +class ParallelLoopToThreadBindingMutator : public StmtExprMutator { + public: + explicit ParallelLoopToThreadBindingMutator(int64_t max_threads_per_block) + : max_threads_per_block_(max_threads_per_block) {} + + private: + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + bool prev = in_thread_env_; + in_thread_env_ = true; + Stmt ret = StmtExprMutator::VisitStmt_(op); + in_thread_env_ = prev; + return ret; + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt TransformParallelFor(const ForNode* for_node) { + if (in_thread_env_) { + return ffi::GetRef(for_node); + } + + DataType dtype = for_node->loop_var.dtype(); + PrimExpr min = cast(dtype, for_node->min); + PrimExpr extent = cast(dtype, for_node->extent); + PrimExpr max_threads = IntImm(dtype, max_threads_per_block_); + PrimExpr num_blocks = ceildiv(extent, max_threads); + + Var tx_var("threadIdx.x", dtype); + Var bx_var("blockIdx.x", dtype); + IterVar tx_iter(Range::FromMinExtent(IntImm(dtype, 0), max_threads), tx_var, + IterVarType::kThreadIndex, "threadIdx.x"); + IterVar bx_iter(Range::FromMinExtent(IntImm(dtype, 0), num_blocks), bx_var, + IterVarType::kThreadIndex, "blockIdx.x"); + + PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var); + PrimExpr mapped_idx = cast(dtype, min + global_idx); + Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}}); + mapped_body = IfThenElse(global_idx < extent, mapped_body); + + Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body); + Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx); + return body_with_bx; + } + + Stmt VisitStmt_(const ForNode* op) final { + if (op->kind == ForKind::kThreadBinding) { + bool prev = in_thread_env_; + in_thread_env_ = true; + Stmt ret = StmtExprMutator::VisitStmt_(op); + in_thread_env_ = prev; + return ret; + } + if (op->kind != ForKind::kParallel) { + return StmtExprMutator::VisitStmt_(op); + } + if (in_parallel_loop_) { + TVM_FFI_THROW(InternalError) + << "BindParallelLoopsToThreads does not support nested parallel loops. " + << "Inner parallel loops become serial once bound into a GPU kernel. " + << "Please rewrite the TIR to avoid nested T.parallel."; + } + bool prev_in_parallel = in_parallel_loop_; + in_parallel_loop_ = true; + For updated = Downcast(StmtExprMutator::VisitStmt_(op)); + in_parallel_loop_ = prev_in_parallel; + return TransformParallelFor(updated.get()); + } + + int64_t max_threads_per_block_; + bool in_thread_env_{false}; + bool in_parallel_loop_{false}; +}; + +} // namespace + +namespace transform { + +Pass BindParallelLoopsToThreads() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto opt_target = f->GetAttr(tvm::attr::kTarget); + if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) { + return f; + } + Target target = opt_target.value(); + + int64_t max_threads_per_block = 1024; + if (auto opt_max_threads = target->GetAttr("max_num_threads")) { + max_threads_per_block = opt_max_threads.value()->value; + } + + PrimFuncNode* n = f.CopyOnWrite(); + n->body = ParallelLoopToThreadBindingMutator(max_threads_per_block)(n->body); + return f; + }; + + return CreatePrimFuncPass(pass_func, 0, "tirx.BindParallelLoopsToThreads", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tirx.transform.BindParallelLoopsToThreads", BindParallelLoopsToThreads); +} + +} // namespace transform +} // namespace tirx +} // namespace tvm + diff --git a/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py b/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py new file mode 100644 index 000000000000..e3050b16e409 --- /dev/null +++ b/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for tirx.transform.BindParallelLoopsToThreads.""" + +import pytest + +import tvm +import tvm.testing +from tvm.script import ir as I +from tvm.script import tirx as T + + +def test_bind_parallel_skips_without_target(): + """PrimFuncs without tvm::attr::kTarget must be left unchanged (no Target::Current guess).""" + + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + tvm.ir.assert_structural_equal(after, Mod) + + +def test_bind_parallel_skips_non_gpu_target(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("llvm")}) + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + tvm.ir.assert_structural_equal(after, Mod) + + +def test_bind_parallel_cuda_wraps_parallel_in_thread_extents(): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("cuda")}) + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) + body = after["main"].body + assert isinstance(body, tvm.tirx.AttrStmt) + assert body.node.thread_tag == "blockIdx.x" + inner = body.body + assert isinstance(inner, tvm.tirx.AttrStmt) + assert inner.node.thread_tag == "threadIdx.x" + assert isinstance(inner.body, tvm.tirx.IfThenElse) + assert inner.body.else_case is None + + +def test_bind_parallel_nested_parallel_raises(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4, 4), "float32")): + T.func_attr({"target": T.target("cuda")}) + for i in T.parallel(4): + for j in T.parallel(4): + A[i, j] = T.float32(1) + + with pytest.raises(tvm.error.InternalError, match="nested parallel"): + tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + + +def test_bind_parallel_respects_max_num_threads(): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((256,), "float32")): + T.func_attr({"target": T.target({"kind": "cuda", "max_num_threads": 128})}) + for i in T.parallel(256): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) + inner = after["main"].body.body + assert isinstance(inner, tvm.tirx.AttrStmt) + assert inner.node.thread_tag == "threadIdx.x" + assert isinstance(inner.value, tvm.tirx.IntImm) + assert inner.value.value == 128 + + +if __name__ == "__main__": + tvm.testing.main()