-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[TIRX] Bind parallel loops to GPU threads before VerifyMemory #19363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhils
wants to merge
5
commits into
apache:main
Choose a base branch
from
zhils:my-fix-branch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
6f413c7
[TIRX] Bind parallel loops to GPU threads before VerifyMemory
zhils a405367
[TIRX][CUDA] Fix parallel loop binding and register nvcc pass config
zhils a65d799
Fix code review issues
zhils 8f7d42a
Trigger CI
zhils 12a78e2
[TIRX] Harden BindParallelLoopsToThreads and add pytest
zhils File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file bind_parallel_loops_to_threads.cc | ||
| * \brief Convert ForKind::kParallel loops to GPU thread bindings. | ||
| * | ||
| * Semantics: | ||
| * - Only runs when the PrimFunc carries a `tvm::attr::kTarget` that refers to a GPU device. | ||
| * Functions without a target attribute are left unchanged (no ambient `Target::Current` guess). | ||
| * - The outermost `kParallel` loop in the function is rewritten to `blockIdx.x` / `threadIdx.x` | ||
| * `thread_extent` scopes, with a guard `if (global_idx < extent)` and no else-branch. | ||
| * - Nested `kParallel` loops (parallel inside parallel) are rejected: binding only the outer | ||
| * parallel nest would leave inner `kParallel` serial within the mapped kernel, which is | ||
| * almost never what users intend. | ||
| * - A `kParallel` that appears inside an existing thread environment (`thread_extent` / | ||
| * `virtual_thread`) is left unchanged so it does not introduce conflicting thread bindings. | ||
| */ | ||
|
|
||
| #include <tvm/ffi/function.h> | ||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/s_tir/stmt.h> | ||
| #include <tvm/target/target.h> | ||
| #include <tvm/tirx/op.h> | ||
| #include <tvm/tirx/stmt.h> | ||
| #include <tvm/tirx/stmt_functor.h> | ||
| #include <tvm/tirx/transform.h> | ||
|
|
||
| namespace tvm { | ||
| namespace tirx { | ||
| namespace { | ||
|
|
||
| static bool IsGpuDeviceType(int dev_type) { | ||
| return dev_type == kDLCUDA || dev_type == kDLROCM || dev_type == kDLOpenCL || | ||
| dev_type == kDLVulkan || dev_type == kDLMetal || dev_type == kDLWebGPU; | ||
| } | ||
|
|
||
| class ParallelLoopToThreadBindingMutator : public StmtExprMutator { | ||
| public: | ||
| explicit ParallelLoopToThreadBindingMutator(int64_t max_threads_per_block) | ||
| : max_threads_per_block_(max_threads_per_block) {} | ||
|
|
||
| private: | ||
| Stmt VisitStmt_(const AttrStmtNode* op) final { | ||
| if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { | ||
| bool prev = in_thread_env_; | ||
| in_thread_env_ = true; | ||
| Stmt ret = StmtExprMutator::VisitStmt_(op); | ||
| in_thread_env_ = prev; | ||
| return ret; | ||
| } | ||
| return StmtExprMutator::VisitStmt_(op); | ||
| } | ||
|
|
||
| Stmt TransformParallelFor(const ForNode* for_node) { | ||
| if (in_thread_env_) { | ||
| return ffi::GetRef<Stmt>(for_node); | ||
| } | ||
|
|
||
| DataType dtype = for_node->loop_var.dtype(); | ||
| PrimExpr min = cast(dtype, for_node->min); | ||
| PrimExpr extent = cast(dtype, for_node->extent); | ||
| PrimExpr max_threads = IntImm(dtype, max_threads_per_block_); | ||
| PrimExpr num_blocks = ceildiv(extent, max_threads); | ||
|
|
||
| Var tx_var("threadIdx.x", dtype); | ||
| Var bx_var("blockIdx.x", dtype); | ||
| IterVar tx_iter(Range::FromMinExtent(IntImm(dtype, 0), max_threads), tx_var, | ||
| IterVarType::kThreadIndex, "threadIdx.x"); | ||
| IterVar bx_iter(Range::FromMinExtent(IntImm(dtype, 0), num_blocks), bx_var, | ||
| IterVarType::kThreadIndex, "blockIdx.x"); | ||
|
|
||
| PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var); | ||
| PrimExpr mapped_idx = cast(dtype, min + global_idx); | ||
| Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}}); | ||
| mapped_body = IfThenElse(global_idx < extent, mapped_body); | ||
|
|
||
| Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body); | ||
| Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx); | ||
| return body_with_bx; | ||
| } | ||
|
|
||
| Stmt VisitStmt_(const ForNode* op) final { | ||
| if (op->kind == ForKind::kThreadBinding) { | ||
| bool prev = in_thread_env_; | ||
| in_thread_env_ = true; | ||
| Stmt ret = StmtExprMutator::VisitStmt_(op); | ||
| in_thread_env_ = prev; | ||
| return ret; | ||
| } | ||
| if (op->kind != ForKind::kParallel) { | ||
| return StmtExprMutator::VisitStmt_(op); | ||
| } | ||
| if (in_parallel_loop_) { | ||
| TVM_FFI_THROW(InternalError) | ||
| << "BindParallelLoopsToThreads does not support nested parallel loops. " | ||
| << "Inner parallel loops become serial once bound into a GPU kernel. " | ||
| << "Please rewrite the TIR to avoid nested T.parallel."; | ||
| } | ||
| bool prev_in_parallel = in_parallel_loop_; | ||
| in_parallel_loop_ = true; | ||
| For updated = Downcast<For>(StmtExprMutator::VisitStmt_(op)); | ||
| in_parallel_loop_ = prev_in_parallel; | ||
| return TransformParallelFor(updated.get()); | ||
| } | ||
|
|
||
| int64_t max_threads_per_block_; | ||
| bool in_thread_env_{false}; | ||
| bool in_parallel_loop_{false}; | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace transform { | ||
|
|
||
| Pass BindParallelLoopsToThreads() { | ||
| auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | ||
| auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget); | ||
| if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) { | ||
| return f; | ||
| } | ||
| Target target = opt_target.value(); | ||
|
|
||
| int64_t max_threads_per_block = 1024; | ||
| if (auto opt_max_threads = target->GetAttr<Integer>("max_num_threads")) { | ||
| max_threads_per_block = opt_max_threads.value()->value; | ||
| } | ||
|
|
||
| PrimFuncNode* n = f.CopyOnWrite(); | ||
| n->body = ParallelLoopToThreadBindingMutator(max_threads_per_block)(n->body); | ||
| return f; | ||
| }; | ||
|
|
||
| return CreatePrimFuncPass(pass_func, 0, "tirx.BindParallelLoopsToThreads", {}); | ||
| } | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK() { | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("tirx.transform.BindParallelLoopsToThreads", BindParallelLoopsToThreads); | ||
| } | ||
|
|
||
| } // namespace transform | ||
| } // namespace tirx | ||
| } // namespace tvm | ||
|
|
||
106 changes: 106 additions & 0 deletions
106
tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Tests for tirx.transform.BindParallelLoopsToThreads.""" | ||
|
|
||
| import pytest | ||
|
|
||
| import tvm | ||
| import tvm.testing | ||
| from tvm.script import ir as I | ||
| from tvm.script import tirx as T | ||
|
|
||
|
|
||
| def test_bind_parallel_skips_without_target(): | ||
| """PrimFuncs without tvm::attr::kTarget must be left unchanged (no Target::Current guess).""" | ||
|
|
||
| @I.ir_module | ||
| class Mod: | ||
| @T.prim_func | ||
| def main(A: T.Buffer((4,), "float32")): | ||
| for i in T.parallel(4): | ||
| A[i] = T.float32(1) | ||
|
|
||
| after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) | ||
| tvm.ir.assert_structural_equal(after, Mod) | ||
|
|
||
|
|
||
| def test_bind_parallel_skips_non_gpu_target(): | ||
| @I.ir_module | ||
| class Mod: | ||
| @T.prim_func | ||
| def main(A: T.Buffer((4,), "float32")): | ||
| T.func_attr({"target": T.target("llvm")}) | ||
| for i in T.parallel(4): | ||
| A[i] = T.float32(1) | ||
|
|
||
| after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) | ||
| tvm.ir.assert_structural_equal(after, Mod) | ||
|
|
||
|
|
||
| def test_bind_parallel_cuda_wraps_parallel_in_thread_extents(): | ||
| @I.ir_module | ||
| class Before: | ||
| @T.prim_func | ||
| def main(A: T.Buffer((4,), "float32")): | ||
| T.func_attr({"target": T.target("cuda")}) | ||
| for i in T.parallel(4): | ||
| A[i] = T.float32(1) | ||
|
|
||
| after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) | ||
| body = after["main"].body | ||
| assert isinstance(body, tvm.tirx.AttrStmt) | ||
| assert body.node.thread_tag == "blockIdx.x" | ||
| inner = body.body | ||
| assert isinstance(inner, tvm.tirx.AttrStmt) | ||
| assert inner.node.thread_tag == "threadIdx.x" | ||
| assert isinstance(inner.body, tvm.tirx.IfThenElse) | ||
| assert inner.body.else_case is None | ||
|
|
||
|
|
||
| def test_bind_parallel_nested_parallel_raises(): | ||
| @I.ir_module | ||
| class Mod: | ||
| @T.prim_func | ||
| def main(A: T.Buffer((4, 4), "float32")): | ||
| T.func_attr({"target": T.target("cuda")}) | ||
| for i in T.parallel(4): | ||
| for j in T.parallel(4): | ||
| A[i, j] = T.float32(1) | ||
|
|
||
| with pytest.raises(tvm.error.InternalError, match="nested parallel"): | ||
| tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) | ||
|
|
||
|
|
||
| def test_bind_parallel_respects_max_num_threads(): | ||
| @I.ir_module | ||
| class Before: | ||
| @T.prim_func | ||
| def main(A: T.Buffer((256,), "float32")): | ||
| T.func_attr({"target": T.target({"kind": "cuda", "max_num_threads": 128})}) | ||
| for i in T.parallel(256): | ||
| A[i] = T.float32(1) | ||
|
|
||
| after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) | ||
| inner = after["main"].body.body | ||
| assert isinstance(inner, tvm.tirx.AttrStmt) | ||
| assert inner.node.thread_tag == "threadIdx.x" | ||
| assert isinstance(inner.value, tvm.tirx.IntImm) | ||
| assert inner.value.value == 128 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.