cuda: add per-session mutable state rebinding#20241
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20241
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 26 Pending, 1 Unrelated FailureAs of commit 9f52f6a with merge base d7ca5db ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds CUDA-private infrastructure to support per-session rebinding of mutable model state (e.g., KV/conv/recurrent buffers) while keeping model weights shared across logical conversations, enabling multi-session local serving without duplicating large CUDA-resident weights.
Changes:
- Introduces
cuda_mutable_statecontext/session manager with descriptor capture and per-execute rebinding hooks. - Integrates mutable-state hooks into
CudaBackend::init(handle noting) andCudaBackend::execute(rebind-before-run). - Adds a fail-closed unit test and wires it into both Buck and CMake builds.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cuda/runtime/cuda_mutable_state.h | Declares CUDA-private context/session API for mutable-state rebinding. |
| backends/cuda/runtime/cuda_mutable_state.cpp | Implements context/session tracking, template capture, buffer allocation, and rebind logic. |
| backends/cuda/runtime/cuda_backend.cpp | Calls into mutable-state hooks during init and before execute. |
| backends/cuda/runtime/test/test_cuda_mutable_state.cpp | Adds unit test covering fail-closed defaults. |
| backends/cuda/runtime/TARGETS | Adds new source/header to CUDA backend and registers Buck unit test target. |
| backends/cuda/CMakeLists.txt | Adds new source to CMake backend build and registers CMake unit test target. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
73fb7d1 to
4632963
Compare
|
@claude review the PR |
|
Claude finished @Gasoonjia's task in 1m 53s —— View job Code Review: cuda — per-session mutable state rebindingReviewed all 6 changed files against The three earlier Copilot comments are already resolved in the current revision, each with a matching regression test:
Nice work tightening those. Findings1. Shared per-FQN template/buffer can be under-sized if two methods declare the same mutable FQN with different shapes — The template ( Per-session KV/conv FQNs likely have consistent shapes across methods today, so this is latent rather than active, but it's exactly the kind of mismatch this fail-closed layer should catch. A cheap guard in if (d.nbytes > c.template_nbytes[fqn]) {
ET_LOG(Error,
"mutable_state: descriptor '%s' (%zu B) exceeds shared template buffer (%zu B)",
fqn.c_str(), d.nbytes, c.template_nbytes[fqn]);
return Error::Internal;
}2. If a thread has 3. Multi-session is serialized, not concurrent — make sure this is loud enough — Because the AOTI container is shared and Minor / non-blocking
Overall this is solid, defensive, and well-documented. Finding #1 is the only one I'd consider addressing before the consuming engine lands; the rest are confirmations/notes. |
4632963 to
a9c517c
Compare
a9c517c to
0be0912
Compare
|
@claude Review the PR Double check if it addressed previous feedbacks from Copilot and Claude |
|
Claude finished @mergennachin's task in 1m 19s —— View job Re-review: did the PR address prior Copilot & Claude feedback?Checked all 6 changed files at Copilot round 1 (2026-06-12T20:16)
Copilot round 2 (2026-06-12T20:44)
Claude round 1 (2026-06-12T20:38) — Finding #1 (the only one flagged as worth fixing pre-merge)
Findings #2 (stale thread-local active session after destroy) and #3 (multi-session serialization contract) were notes/confirmations, not change requests — both remain correctly fail-closed at the Build wiring
Remaining observations (non-blocking, no prior comment)
Summary: every actionable item from both Copilot rounds and Claude's blocking Finding #1 is addressed in code with corresponding tests. No regressions introduced. The two observations above are follow-ups, not merge blockers. |
| } | ||
| auto& m = mgr(); | ||
| std::lock_guard<std::mutex> g(m.mu); | ||
|
|
There was a problem hiding this comment.
This holds one process-wide lock across the external rebind call and the GPU allocations below. That forces every model in the process to take turns on the hottest part of execute, and it can stall everything if the rebind call blocks or re-enters. Use a per-context lock instead, and copy out the values you need (the pairs, container handle, function pointer) under the lock, then release it before making the rebind call and the allocations.
There was a problem hiding this comment.
It is not a correctness blocker; moving to per-context locks/copy-out safely is a follow-up refactor.
de9d3cd to
cb55702
Compare
|
Thanks for the reviews @shoumikhin and @Gasoonjia @shoumikhin, I addressed all your comments except the per-context mutex. I'll do that in follow-up PR
Yes. The intended behavior is that prefill and decode share the same KV buffer for the same logical session. So if both methods declare the same mutable buffer FQN, then under one active session:
Different sessions still get different buffers, so state does not bleed across requests. I added a unit test.
Yes valid concern. CUDA graph path remains unchanged for normal single-session inference. But I reject cuda_graph during multi session. A future optimization could support one captured graph per session or recapture after rebinding. |
cb55702 to
5fb0034
Compare
| auto* t = reinterpret_cast<SlimTensor*>(it_t->second); | ||
| Desc d; | ||
| d.internal_name = it_name->second; | ||
| d.sizes.assign(t->sizes().begin(), t->sizes().end()); | ||
| d.strides.assign(t->strides().begin(), t->strides().end()); | ||
| d.dtype = t->dtype(); | ||
| d.device = t->device(); | ||
| d.nbytes = t->nbytes(); | ||
| table.emplace(fqn, std::move(d)); | ||
| c.discovered_fqns.insert(fqn); | ||
|
|
||
| if (c.template_ptr.find(fqn) == c.template_ptr.end()) { | ||
| // If a later FQN fails during this build, already captured templates are | ||
| // released by mutable_state_destroy_context(). | ||
| auto device_res = tensor_cuda_device_index(*t); | ||
| ET_CHECK_OK_OR_RETURN_ERROR(device_res.error()); | ||
| const int device = device_res.get(); | ||
| CudaDeviceGuard guard; | ||
| ET_CHECK_OK_OR_RETURN_ERROR(guard.set(device)); | ||
|
|
| err = cudaMemcpy( | ||
| *device_ptr, | ||
| values.data(), | ||
| values.size() * sizeof(float), | ||
| cudaMemcpyHostToDevice); | ||
| if (err != cudaSuccess) { | ||
| ADD_FAILURE() << "cudaMemcpy failed: " << cudaGetErrorString(err); | ||
| cudaFree(*device_ptr); | ||
| *device_ptr = nullptr; | ||
| return nullptr; | ||
| } | ||
| return std::make_unique<slim::SlimTensor>(slim::from_blob( | ||
| *device_ptr, | ||
| {static_cast<int64_t>(values.size())}, | ||
| slimc10::ScalarType::Float, | ||
| slimc10::Device(slimc10::DeviceType::CUDA, 0))); | ||
| } |
| // If a mutable-state session is active on this thread, rebind this | ||
| // container's registered mutable buffers before running. | ||
| ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); |
5fb0034 to
0c6bc51
Compare
0c6bc51 to
35388a3
Compare
| void mutable_state_register_fqns( | ||
| MutableStateContext ctx, | ||
| const std::vector<std::string>& fqns) { | ||
| auto& m = mgr(); | ||
| std::lock_guard<std::mutex> g(m.mu); | ||
| auto it = m.contexts.find(ctx); | ||
| if (it == m.contexts.end()) { | ||
| return; | ||
| } | ||
| it->second.fqns = fqns; | ||
| it->second.fqn_set.clear(); | ||
| it->second.fqn_set.insert(fqns.begin(), fqns.end()); | ||
| } |
| if (!c.symbols_available) { | ||
| ET_LOG( | ||
| Error, "mutable_state: rebinding unavailable; cannot create session"); | ||
| return Error::NotSupported; | ||
| } | ||
| int token = c.next_token++; | ||
| c.sessions.insert(token); | ||
| return token; |
| ActiveSessionScope(MutableStateContext ctx, int token) { | ||
| detail::mutable_state_set_active(ctx, token); | ||
| } | ||
|
|
||
| ~ActiveSessionScope() { | ||
| detail::mutable_state_set_active( | ||
| kInvalidMutableContext, kNoMutableSession); | ||
| } |
| } | ||
| } | ||
|
|
||
| ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); |
Local agent serving needs to host multiple logical conversations on one CUDA-resident model without multiplying the model weights. Loading one AOTI module per conversation is not viable for large local models, while sharing the default mutable state across conversations would let KV/recurrent/conv buffers bleed between users. This adds the CUDA-private foundation for separating those concerns: weights remain owned by the loaded AOTI container, while mutable buffer FQNs can be registered as per-session state and rebound before execution. The path is fail-closed and dormant until a model opts in by creating a mutable-state context and validating coverage, so existing CUDA models keep their current behavior. The branch also wires the new source and unit coverage into both Buck and CMake so the primitive can land independently before any model-specific engine consumes it.
35388a3 to
9f52f6a
Compare
Local agent serving needs to host multiple logical conversations on one CUDA-resident model without multiplying the model weights. Loading one AOTI module per conversation is not viable for large local models, while sharing the default mutable state across conversations would let KV/recurrent/conv buffers bleed between users.
This adds the CUDA-private foundation for separating those concerns: weights remain owned by the loaded AOTI container, while mutable buffer FQNs can be registered as per-session state and rebound before execution. The path is fail-closed and dormant until a model opts in by creating a mutable-state context and validating coverage, so existing CUDA models keep their current behavior.
The branch also wires the new source and fall-closed unit test into both Buck and CMake so the primitive can land independently before any model-specific engine consumes it.
#20001