From 6f413c793820b61b5a3a02fac749a636bc844d88 Mon Sep 17 00:00:00 2001 From: Hardy <18930861549@163.com> Date: Mon, 6 Apr 2026 15:52:59 +0800 Subject: [PATCH 1/5] [TIRX] Bind parallel loops to GPU threads before VerifyMemory `VerifyMemory` on GPU targets treats direct accesses outside thread environments as illegal. In the ScatterValue CUDA lowering path, `topi.scatter_elements` emits `ForKind::kParallel` loops without explicit thread bindings, which triggers false host-memory access failures (e.g. "Did you forget to bind?") during TIR verification. This change adds a new `tirx` pass (`BindParallelLoopsToThreads`) and inserts it before `VerifyMemory` in the `s_tir` pipelines (including adreno). The pass rewrites parallel loops into `blockIdx.x/threadIdx.x` thread-extent regions, substitutes loop vars with global thread indices, and adds bounds checks for non-divisible extents. This preserves correctness while ensuring GPU kernels pass memory verification for this path. --- .tmp_scatter_cuda_check.py | 34 ++++++ include/tvm/tirx/transform.h | 7 ++ python/tvm/contrib/nvcc.py | 21 ++++ python/tvm/s_tir/backend/adreno/pipeline.py | 1 + python/tvm/s_tir/pipeline.py | 1 + python/tvm/tirx/transform/transform.py | 11 ++ src/tirx/analysis/verify_memory.cc | 2 +- .../bind_parallel_loops_to_threads.cc | 111 ++++++++++++++++++ 8 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 .tmp_scatter_cuda_check.py create mode 100644 src/tirx/transform/bind_parallel_loops_to_threads.cc diff --git a/.tmp_scatter_cuda_check.py b/.tmp_scatter_cuda_check.py new file mode 100644 index 000000000000..34e84b335298 --- /dev/null +++ b/.tmp_scatter_cuda_check.py @@ -0,0 +1,34 @@ +import numpy as np +import torch +from torch import nn +from torch.export import export +import tvm +from tvm import relax +from tvm.relax.frontend.torch import from_exported_program +from tvm.relax.backend.cuda import get_default_pipeline + +class ScatterValue(nn.Module): + def forward(self, x, index): + return x.scatter(1, index, 0.5) + +torch.manual_seed(0) +x = torch.randn(4, 8, dtype=torch.float32) +idx = torch.randint(0, 8, (4, 2), dtype=torch.int64) + +mod = from_exported_program(export(ScatterValue(), args=(x, idx))) +tgt = tvm.target.Target('cuda') +with tgt: + mod = get_default_pipeline(tgt)(mod) + +ex = relax.build(mod, tgt, relax_pipeline=None) +vm = relax.VirtualMachine(ex, tvm.cuda(0)) +out = vm['main']( + tvm.runtime.tensor(x.numpy(), device=tvm.cuda(0)), + tvm.runtime.tensor(idx.numpy(), device=tvm.cuda(0)), +) +out_np = out.numpy() if hasattr(out, 'numpy') else out[0].numpy() +ref_np = ScatterValue()(x, idx).numpy() + +print('shape_match', out_np.shape == ref_np.shape) +print('allclose', np.allclose(out_np, ref_np, rtol=1e-5, atol=1e-6)) +print('max_abs_diff', float(np.max(np.abs(out_np - ref_np)))) diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h index 4d1267e97bb9..c13d90165128 100644 --- a/include/tvm/tirx/transform.h +++ b/include/tvm/tirx/transform.h @@ -330,6 +330,12 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); */ TVM_DLL Pass BindTarget(Target target); +/*! + * \brief Convert ForKind::kParallel loops to blockIdx.x/threadIdx.x bindings on GPU targets. + * \return The pass. + */ +TVM_DLL Pass BindParallelLoopsToThreads(); + /*! * \brief Set a PrimFunc as the entry point if it is only function in IRModule. * \return The pass. @@ -354,6 +360,7 @@ TVM_DLL Pass Filter(ffi::TypedFunction fcond); * * \return The pass. */ + } // namespace transform } // namespace tirx } // namespace tvm diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 1f933a63fc7b..d0287a0e4cd4 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -177,6 +177,27 @@ def _compile_cuda_nvcc( else: raise ValueError("options must be str or list of str") + # Optional workaround for NVCC host compiler version checks on Windows. + # Priority: + # 1) PassContext config: cuda.nvcc_allow_unsupported_compiler (bool) + # 2) Environment variable: TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER in {"1","true","on","yes"} + # 3) Default: False + allow_unsupported_compiler = False + if "cuda.nvcc_allow_unsupported_compiler" in pass_context.config: + allow_unsupported_compiler = bool( + pass_context.config["cuda.nvcc_allow_unsupported_compiler"] + ) + else: + env_val = os.environ.get("TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER", "").strip().lower() + allow_unsupported_compiler = env_val in {"1", "true", "on", "yes"} + + if ( + platform.system() == "Windows" + and allow_unsupported_compiler + and "-allow-unsupported-compiler" not in cmd + ): + cmd += ["-allow-unsupported-compiler"] + cmd += ["-o", file_target] if not use_nvshmem: cmd += [temp_code] diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index a63fb4346d34..434b5c4bfca1 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -89,6 +89,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), + tirx.transform.BindParallelLoopsToThreads(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), ] diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index f775c0dd1eac..e13ffaf8ee74 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -87,6 +87,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), + tirx.transform.BindParallelLoopsToThreads(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), ] diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 6e18558b0ecd..34861981e4b0 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -428,6 +428,17 @@ def VerifyMemory(): return _ffi_api.VerifyMemory() # type: ignore +def BindParallelLoopsToThreads(): + """Convert T.parallel loops to block/thread bindings for GPU PrimFuncs. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.BindParallelLoopsToThreads() # type: ignore + + @_ffi.register_object("s_tir.transform.HoistIfThenElseConfig") class HoistIfThenElseConfig(_ir.Attrs): """Config for hoist if then else pass""" diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index a6c3c0ef3552..f5b667e4e9fb 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -60,7 +60,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { void Run() { if (!IsGPUDevice(dev_type_)) return; StmtExprVisitor::VisitStmt(func_->body); - } + } /// Verification result std::vector Errors() const { return errs_; } diff --git a/src/tirx/transform/bind_parallel_loops_to_threads.cc b/src/tirx/transform/bind_parallel_loops_to_threads.cc new file mode 100644 index 000000000000..03577f79e616 --- /dev/null +++ b/src/tirx/transform/bind_parallel_loops_to_threads.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file bind_parallel_loops_to_threads.cc + * \brief Convert ForKind::kParallel loops to GPU thread bindings. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tirx { +namespace { + +static bool IsGpuDeviceType(int dev_type) { + return dev_type == kDLCUDA || dev_type == kDLROCM || dev_type == kDLOpenCL || + dev_type == kDLVulkan || dev_type == kDLMetal || dev_type == kDLWebGPU; +} + +class ParallelLoopToThreadBindingMutator : public StmtExprMutator { + public: + explicit ParallelLoopToThreadBindingMutator(int64_t max_threads_per_block) + : max_threads_per_block_(max_threads_per_block) {} + + private: + Stmt VisitStmt_(const ForNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + const auto* for_node = stmt.as(); + TVM_FFI_ICHECK(for_node != nullptr); + if (for_node->kind != ForKind::kParallel) { + return stmt; + } + + DataType dtype = for_node->loop_var.dtype(); + PrimExpr extent = cast(dtype, for_node->extent); + PrimExpr max_threads = IntImm(dtype, max_threads_per_block_); + PrimExpr num_blocks = ceildiv(extent, max_threads); + + Var tx_var("threadIdx.x", dtype); + Var bx_var("blockIdx.x", dtype); + IterVar tx_iter(Range::FromMinExtent(IntImm(dtype, 0), max_threads), tx_var, + IterVarType::kThreadIndex, "threadIdx.x"); + IterVar bx_iter(Range::FromMinExtent(IntImm(dtype, 0), num_blocks), bx_var, + IterVarType::kThreadIndex, "blockIdx.x"); + + PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var); + Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), global_idx}}); + mapped_body = IfThenElse(global_idx < extent, mapped_body, Evaluate(IntImm(DataType::Int(32), 0))); + + Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body); + Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx); + return body_with_bx; + } + + int64_t max_threads_per_block_; +}; + +} // namespace + +namespace transform { + +Pass BindParallelLoopsToThreads() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto opt_target = f->GetAttr(tvm::attr::kTarget); + if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) { + return f; + } + + int64_t max_threads_per_block = 1024; + if (auto opt_max_threads = opt_target.value()->GetAttr("max_num_threads")) { + max_threads_per_block = opt_max_threads.value()->value; + } + + PrimFuncNode* n = f.CopyOnWrite(); + n->body = ParallelLoopToThreadBindingMutator(max_threads_per_block)(n->body); + return f; + }; + + return CreatePrimFuncPass(pass_func, 0, "tirx.BindParallelLoopsToThreads", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tirx.transform.BindParallelLoopsToThreads", BindParallelLoopsToThreads); +} + +} // namespace transform +} // namespace tirx +} // namespace tvm + From a405367cec3ad411dd2eb08d4bd1c57a852843e7 Mon Sep 17 00:00:00 2001 From: Hardy <18930861549@163.com> Date: Mon, 6 Apr 2026 16:32:40 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=EF=BB=BF[TIRX][CUDA]=20Fix=20parallel=20lo?= =?UTF-8?q?op=20binding=20and=20register=20nvcc=20pass=20config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix three correctness/configuration issues in the GPU parallel-loop binding path used before VerifyMemory. First, preserve non-zero loop mins by mapping parallel indices as min + global_idx instead of global_idx. Second, avoid rewriting parallel loops when already inside a thread environment to prevent invalid nested bindings. Third, register cuda.nvcc_allow_unsupported_compiler as a valid PassContext key so the NVCC workaround can be enabled via config without raising Invalid config option. Made-with: Cursor --- src/target/opt/build_cuda_on.cc | 1 + .../bind_parallel_loops_to_threads.cc | 47 +++++++++++++++---- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 4e312e93a462..388efad4080a 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -108,5 +108,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("target.build.cuda", BuildCUDA); } TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", ffi::String); +TVM_REGISTER_PASS_CONFIG_OPTION("cuda.nvcc_allow_unsupported_compiler", Bool); } // namespace codegen } // namespace tvm diff --git a/src/tirx/transform/bind_parallel_loops_to_threads.cc b/src/tirx/transform/bind_parallel_loops_to_threads.cc index 03577f79e616..2db4cb57c9ec 100644 --- a/src/tirx/transform/bind_parallel_loops_to_threads.cc +++ b/src/tirx/transform/bind_parallel_loops_to_threads.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -44,15 +45,24 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { : max_threads_per_block_(max_threads_per_block) {} private: - Stmt VisitStmt_(const ForNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - const auto* for_node = stmt.as(); - TVM_FFI_ICHECK(for_node != nullptr); - if (for_node->kind != ForKind::kParallel) { - return stmt; + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + bool prev = in_thread_env_; + in_thread_env_ = true; + Stmt ret = StmtExprMutator::VisitStmt_(op); + in_thread_env_ = prev; + return ret; + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt TransformParallelFor(const ForNode* for_node) { + if (in_thread_env_) { + return ffi::GetRef(for_node); } DataType dtype = for_node->loop_var.dtype(); + PrimExpr min = cast(dtype, for_node->min); PrimExpr extent = cast(dtype, for_node->extent); PrimExpr max_threads = IntImm(dtype, max_threads_per_block_); PrimExpr num_blocks = ceildiv(extent, max_threads); @@ -65,7 +75,8 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { IterVarType::kThreadIndex, "blockIdx.x"); PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var); - Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), global_idx}}); + PrimExpr mapped_idx = cast(dtype, min + global_idx); + Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}}); mapped_body = IfThenElse(global_idx < extent, mapped_body, Evaluate(IntImm(DataType::Int(32), 0))); Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body); @@ -73,7 +84,24 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { return body_with_bx; } + Stmt VisitStmt_(const ForNode* op) final { + if (op->kind == ForKind::kThreadBinding) { + bool prev = in_thread_env_; + in_thread_env_ = true; + Stmt ret = StmtExprMutator::VisitStmt_(op); + in_thread_env_ = prev; + return ret; + } + if (op->kind != ForKind::kParallel) { + return StmtExprMutator::VisitStmt_(op); + } + // First mutate inside this loop, then rewrite the current parallel loop. + For updated = Downcast(StmtExprMutator::VisitStmt_(op)); + return TransformParallelFor(updated.get()); + } + int64_t max_threads_per_block_; + bool in_thread_env_{false}; }; } // namespace @@ -83,12 +111,13 @@ namespace transform { Pass BindParallelLoopsToThreads() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); - if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) { + Target target = opt_target.value_or(Target::Current(/*allow_none=*/true)); + if (!target.defined() || !IsGpuDeviceType(target->GetTargetDeviceType())) { return f; } int64_t max_threads_per_block = 1024; - if (auto opt_max_threads = opt_target.value()->GetAttr("max_num_threads")) { + if (auto opt_max_threads = target->GetAttr("max_num_threads")) { max_threads_per_block = opt_max_threads.value()->value; } From a65d7999c72c83e94d21f9cea9b2675abc2c37ed Mon Sep 17 00:00:00 2001 From: zhils <18930861549@163.com> Date: Tue, 7 Apr 2026 18:32:13 +0800 Subject: [PATCH 3/5] Fix code review issues - Add kDLWebGPU to IsGPUDevice in verify_memory.cc - Remove redundant Var wrapper in loop_partition.cc - Fix nested parallel loop handling in bind_parallel_loops_to_threads.cc --- src/s_tir/transform/loop_partition.cc | 2 +- src/tirx/analysis/verify_memory.cc | 2 +- src/tirx/transform/bind_parallel_loops_to_threads.cc | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index 718ced207a04..a040a79ede07 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -766,7 +766,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore - return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); + return Substitute(body, {{for_node->loop_var, make_const(DataType::Int(32), 0)}}); } else { TVM_FFI_ICHECK(for_node->kind != ForKind::kThreadBinding); auto new_loop = ffi::make_object(*for_node); diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index f5b667e4e9fb..91f454f1f683 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -150,7 +150,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || - kDLMetal == dev_type || kDLROCM == dev_type; + kDLMetal == dev_type || kDLROCM == dev_type || kDLWebGPU == dev_type; } private: diff --git a/src/tirx/transform/bind_parallel_loops_to_threads.cc b/src/tirx/transform/bind_parallel_loops_to_threads.cc index 2db4cb57c9ec..f1ae774de424 100644 --- a/src/tirx/transform/bind_parallel_loops_to_threads.cc +++ b/src/tirx/transform/bind_parallel_loops_to_threads.cc @@ -95,13 +95,19 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { if (op->kind != ForKind::kParallel) { return StmtExprMutator::VisitStmt_(op); } - // First mutate inside this loop, then rewrite the current parallel loop. + if (in_parallel_loop_) { + return StmtExprMutator::VisitStmt_(op); + } + bool prev_in_parallel = in_parallel_loop_; + in_parallel_loop_ = true; For updated = Downcast(StmtExprMutator::VisitStmt_(op)); + in_parallel_loop_ = prev_in_parallel; return TransformParallelFor(updated.get()); } int64_t max_threads_per_block_; bool in_thread_env_{false}; + bool in_parallel_loop_{false}; }; } // namespace From 8f7d42a6f731ca4312206cc06d875b216e227603 Mon Sep 17 00:00:00 2001 From: zhils <18930861549@163.com> Date: Wed, 8 Apr 2026 11:38:46 +0800 Subject: [PATCH 4/5] Trigger CI From 12a78e29792d8511c51d2c95c58fa7d3c60abdc4 Mon Sep 17 00:00:00 2001 From: zhils <18930861549@163.com> Date: Mon, 20 Apr 2026 12:36:23 +0800 Subject: [PATCH 5/5] [TIRX] Harden BindParallelLoopsToThreads and add pytest - Document pass semantics; reject nested T.parallel; no Target::Current fallback - Use IfThenElse without else; include tvm/ffi/function.h for TVM_FFI_THROW - Add test_tir_transform_bind_parallel_loops_to_threads.py - Remove local debug script .tmp_scatter_cuda_check.py - Revert unrelated nvcc/loop_partition/verify_memory/build_cuda changes (split to other PRs) Made-with: Cursor --- .tmp_scatter_cuda_check.py | 34 ------ python/tvm/contrib/nvcc.py | 21 ---- src/s_tir/transform/loop_partition.cc | 2 +- src/target/opt/build_cuda_on.cc | 1 - src/tirx/analysis/verify_memory.cc | 4 +- .../bind_parallel_loops_to_threads.cc | 23 +++- ...ransform_bind_parallel_loops_to_threads.py | 106 ++++++++++++++++++ 7 files changed, 128 insertions(+), 63 deletions(-) delete mode 100644 .tmp_scatter_cuda_check.py create mode 100644 tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py diff --git a/.tmp_scatter_cuda_check.py b/.tmp_scatter_cuda_check.py deleted file mode 100644 index 34e84b335298..000000000000 --- a/.tmp_scatter_cuda_check.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np -import torch -from torch import nn -from torch.export import export -import tvm -from tvm import relax -from tvm.relax.frontend.torch import from_exported_program -from tvm.relax.backend.cuda import get_default_pipeline - -class ScatterValue(nn.Module): - def forward(self, x, index): - return x.scatter(1, index, 0.5) - -torch.manual_seed(0) -x = torch.randn(4, 8, dtype=torch.float32) -idx = torch.randint(0, 8, (4, 2), dtype=torch.int64) - -mod = from_exported_program(export(ScatterValue(), args=(x, idx))) -tgt = tvm.target.Target('cuda') -with tgt: - mod = get_default_pipeline(tgt)(mod) - -ex = relax.build(mod, tgt, relax_pipeline=None) -vm = relax.VirtualMachine(ex, tvm.cuda(0)) -out = vm['main']( - tvm.runtime.tensor(x.numpy(), device=tvm.cuda(0)), - tvm.runtime.tensor(idx.numpy(), device=tvm.cuda(0)), -) -out_np = out.numpy() if hasattr(out, 'numpy') else out[0].numpy() -ref_np = ScatterValue()(x, idx).numpy() - -print('shape_match', out_np.shape == ref_np.shape) -print('allclose', np.allclose(out_np, ref_np, rtol=1e-5, atol=1e-6)) -print('max_abs_diff', float(np.max(np.abs(out_np - ref_np)))) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index d0287a0e4cd4..1f933a63fc7b 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -177,27 +177,6 @@ def _compile_cuda_nvcc( else: raise ValueError("options must be str or list of str") - # Optional workaround for NVCC host compiler version checks on Windows. - # Priority: - # 1) PassContext config: cuda.nvcc_allow_unsupported_compiler (bool) - # 2) Environment variable: TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER in {"1","true","on","yes"} - # 3) Default: False - allow_unsupported_compiler = False - if "cuda.nvcc_allow_unsupported_compiler" in pass_context.config: - allow_unsupported_compiler = bool( - pass_context.config["cuda.nvcc_allow_unsupported_compiler"] - ) - else: - env_val = os.environ.get("TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER", "").strip().lower() - allow_unsupported_compiler = env_val in {"1", "true", "on", "yes"} - - if ( - platform.system() == "Windows" - and allow_unsupported_compiler - and "-allow-unsupported-compiler" not in cmd - ): - cmd += ["-allow-unsupported-compiler"] - cmd += ["-o", file_target] if not use_nvshmem: cmd += [temp_code] diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index a040a79ede07..718ced207a04 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -766,7 +766,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore - return Substitute(body, {{for_node->loop_var, make_const(DataType::Int(32), 0)}}); + return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { TVM_FFI_ICHECK(for_node->kind != ForKind::kThreadBinding); auto new_loop = ffi::make_object(*for_node); diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 388efad4080a..4e312e93a462 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -108,6 +108,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("target.build.cuda", BuildCUDA); } TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", ffi::String); -TVM_REGISTER_PASS_CONFIG_OPTION("cuda.nvcc_allow_unsupported_compiler", Bool); } // namespace codegen } // namespace tvm diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index 91f454f1f683..a6c3c0ef3552 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -60,7 +60,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { void Run() { if (!IsGPUDevice(dev_type_)) return; StmtExprVisitor::VisitStmt(func_->body); - } + } /// Verification result std::vector Errors() const { return errs_; } @@ -150,7 +150,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || - kDLMetal == dev_type || kDLROCM == dev_type || kDLWebGPU == dev_type; + kDLMetal == dev_type || kDLROCM == dev_type; } private: diff --git a/src/tirx/transform/bind_parallel_loops_to_threads.cc b/src/tirx/transform/bind_parallel_loops_to_threads.cc index f1ae774de424..15493d052449 100644 --- a/src/tirx/transform/bind_parallel_loops_to_threads.cc +++ b/src/tirx/transform/bind_parallel_loops_to_threads.cc @@ -20,8 +20,20 @@ /*! * \file bind_parallel_loops_to_threads.cc * \brief Convert ForKind::kParallel loops to GPU thread bindings. + * + * Semantics: + * - Only runs when the PrimFunc carries a `tvm::attr::kTarget` that refers to a GPU device. + * Functions without a target attribute are left unchanged (no ambient `Target::Current` guess). + * - The outermost `kParallel` loop in the function is rewritten to `blockIdx.x` / `threadIdx.x` + * `thread_extent` scopes, with a guard `if (global_idx < extent)` and no else-branch. + * - Nested `kParallel` loops (parallel inside parallel) are rejected: binding only the outer + * parallel nest would leave inner `kParallel` serial within the mapped kernel, which is + * almost never what users intend. + * - A `kParallel` that appears inside an existing thread environment (`thread_extent` / + * `virtual_thread`) is left unchanged so it does not introduce conflicting thread bindings. */ +#include #include #include #include @@ -77,7 +89,7 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var); PrimExpr mapped_idx = cast(dtype, min + global_idx); Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}}); - mapped_body = IfThenElse(global_idx < extent, mapped_body, Evaluate(IntImm(DataType::Int(32), 0))); + mapped_body = IfThenElse(global_idx < extent, mapped_body); Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body); Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx); @@ -96,7 +108,10 @@ class ParallelLoopToThreadBindingMutator : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } if (in_parallel_loop_) { - return StmtExprMutator::VisitStmt_(op); + TVM_FFI_THROW(InternalError) + << "BindParallelLoopsToThreads does not support nested parallel loops. " + << "Inner parallel loops become serial once bound into a GPU kernel. " + << "Please rewrite the TIR to avoid nested T.parallel."; } bool prev_in_parallel = in_parallel_loop_; in_parallel_loop_ = true; @@ -117,10 +132,10 @@ namespace transform { Pass BindParallelLoopsToThreads() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); - Target target = opt_target.value_or(Target::Current(/*allow_none=*/true)); - if (!target.defined() || !IsGpuDeviceType(target->GetTargetDeviceType())) { + if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) { return f; } + Target target = opt_target.value(); int64_t max_threads_per_block = 1024; if (auto opt_max_threads = target->GetAttr("max_num_threads")) { diff --git a/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py b/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py new file mode 100644 index 000000000000..e3050b16e409 --- /dev/null +++ b/tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for tirx.transform.BindParallelLoopsToThreads.""" + +import pytest + +import tvm +import tvm.testing +from tvm.script import ir as I +from tvm.script import tirx as T + + +def test_bind_parallel_skips_without_target(): + """PrimFuncs without tvm::attr::kTarget must be left unchanged (no Target::Current guess).""" + + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + tvm.ir.assert_structural_equal(after, Mod) + + +def test_bind_parallel_skips_non_gpu_target(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("llvm")}) + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + tvm.ir.assert_structural_equal(after, Mod) + + +def test_bind_parallel_cuda_wraps_parallel_in_thread_extents(): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("cuda")}) + for i in T.parallel(4): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) + body = after["main"].body + assert isinstance(body, tvm.tirx.AttrStmt) + assert body.node.thread_tag == "blockIdx.x" + inner = body.body + assert isinstance(inner, tvm.tirx.AttrStmt) + assert inner.node.thread_tag == "threadIdx.x" + assert isinstance(inner.body, tvm.tirx.IfThenElse) + assert inner.body.else_case is None + + +def test_bind_parallel_nested_parallel_raises(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((4, 4), "float32")): + T.func_attr({"target": T.target("cuda")}) + for i in T.parallel(4): + for j in T.parallel(4): + A[i, j] = T.float32(1) + + with pytest.raises(tvm.error.InternalError, match="nested parallel"): + tvm.tirx.transform.BindParallelLoopsToThreads()(Mod) + + +def test_bind_parallel_respects_max_num_threads(): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((256,), "float32")): + T.func_attr({"target": T.target({"kind": "cuda", "max_num_threads": 128})}) + for i in T.parallel(256): + A[i] = T.float32(1) + + after = tvm.tirx.transform.BindParallelLoopsToThreads()(Before) + inner = after["main"].body.body + assert isinstance(inner, tvm.tirx.AttrStmt) + assert inner.node.thread_tag == "threadIdx.x" + assert isinstance(inner.value, tvm.tirx.IntImm) + assert inner.value.value == 128 + + +if __name__ == "__main__": + tvm.testing.main()