[TIRX] Bind parallel loops to GPU threads before VerifyMemory#19363
[TIRX] Bind parallel loops to GPU threads before VerifyMemory#19363zhils wants to merge 5 commits intoapache:mainfrom
Conversation
`VerifyMemory` on GPU targets treats direct accesses outside thread environments as illegal. In the ScatterValue CUDA lowering path, `topi.scatter_elements` emits `ForKind::kParallel` loops without explicit thread bindings, which triggers false host-memory access failures (e.g. "Did you forget to bind?") during TIR verification. This change adds a new `tirx` pass (`BindParallelLoopsToThreads`) and inserts it before `VerifyMemory` in the `s_tir` pipelines (including adreno). The pass rewrites parallel loops into `blockIdx.x/threadIdx.x` thread-extent regions, substitutes loop vars with global thread indices, and adds bounds checks for non-divisible extents. This preserves correctness while ensuring GPU kernels pass memory verification for this path.
There was a problem hiding this comment.
Code Review
This pull request introduces the BindParallelLoopsToThreads pass, which converts ForKind::kParallel loops into GPU block and thread bindings, and integrates this pass into the S-TIR pipelines. Additionally, it provides a configuration option to allow unsupported host compilers for NVCC on Windows and adds a functional test for scatter operations on CUDA. Review feedback identifies a critical issue regarding the handling of nested parallel loops which could lead to invalid GPU register bindings, an inconsistency in GPU device type definitions between files, and a minor code redundancy in the loop variable substitution logic.
Fix three correctness/configuration issues in the GPU parallel-loop binding path used before VerifyMemory. First, preserve non-zero loop mins by mapping parallel indices as min + global_idx instead of global_idx. Second, avoid rewriting parallel loops when already inside a thread environment to prevent invalid nested bindings. Third, register cuda.nvcc_allow_unsupported_compiler as a valid PassContext key so the NVCC workaround can be enabled via config without raising Invalid config option. Made-with: Cursor
- Add kDLWebGPU to IsGPUDevice in verify_memory.cc - Remove redundant Var wrapper in loop_partition.cc - Fix nested parallel loop handling in bind_parallel_loops_to_threads.cc
tlopex
left a comment
There was a problem hiding this comment.
A few things to address here:
-
Please remove
.tmp_scatter_cuda_check.py. This looks like a local debugging script and should not be committed. It should be replaced with a proper pytest test instead. -
This PR currently mixes several unrelated changes: the NVCC
-allow-unsupported-compilerworkaround, theloop_partition.ccVar{}fix, and thekDLWebGPUchange inverify_memory.ccare all independent and should be split into separate PRs. -
Please clarify the semantics for nested parallel loops. When
in_parallel_loop_is already true, inner parallel loops are left askParallel, but once they are inside a GPU kernel they effectively become serial. If that is intentional, it would be good to document it explicitly; otherwise it may be better to reject this case. -
The
elsebranch inIfThenElselooks unnecessary.Evaluate(IntImm(DataType::Int(32), 0))is a no-op, and TIR already supportsIfThenElsewithout anelse, so this can just beIfThenElse(global_idx < extent, mapped_body). -
For target-less
PrimFuncs, please avoid falling back toTarget::Current(). If the function has no target attribute, it is better to leave it unchanged rather than guessing from ambient context. That would also be consistent with how other recent passes handle target-less functions. -
This PR also needs proper pytest coverage for the new pass.
Thanks for the review. I removed .tmp_scatter_cuda_check.py and added tests/python/tirx-transform/test_tir_transform_bind_parallel_loops_to_threads.py for the pass. The NVCC workaround, loop_partition Var{} fix, and verify_memory WebGPU change are reverted here and will be submitted as separate PRs. Nested T.parallel is now rejected with a clear error; the pass doc comment explains semantics (including leaving kParallel unchanged under an existing thread env). IfThenElse no longer has a no-op else. PrimFuncs without kTarget are left unchanged (no Target::Current fallback). |
- Document pass semantics; reject nested T.parallel; no Target::Current fallback - Use IfThenElse without else; include tvm/ffi/function.h for TVM_FFI_THROW - Add test_tir_transform_bind_parallel_loops_to_threads.py - Remove local debug script .tmp_scatter_cuda_check.py - Revert unrelated nvcc/loop_partition/verify_memory/build_cuda changes (split to other PRs) Made-with: Cursor
VerifyMemoryon GPU targets treats direct accesses outside thread environments as illegal. In the ScatterValue CUDA lowering path,topi.scatter_elementsemitsForKind::kParallelloops without explicit thread bindings, which triggers false host-memory access failures (e.g. "Did you forget to bind?") during TIR verification.This change adds a new
tirxpass (BindParallelLoopsToThreads) and inserts it beforeVerifyMemoryin thes_tirpipelines (including adreno). The pass rewrites parallel loops intoblockIdx.x/threadIdx.xthread-extent regions, substitutes loop vars with global thread indices, and adds bounds checks for non-divisible extents. This preserves correctness while ensuring GPU kernels pass memory verification for this path.