From 9f52f6a527fee5d3ce0eb5d8963313e1ffc961c6 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 12 Jun 2026 12:55:41 -0700 Subject: [PATCH] cuda: add per-session mutable state rebinding Local agent serving needs to host multiple logical conversations on one CUDA-resident model without multiplying the model weights. Loading one AOTI module per conversation is not viable for large local models, while sharing the default mutable state across conversations would let KV/recurrent/conv buffers bleed between users. This adds the CUDA-private foundation for separating those concerns: weights remain owned by the loaded AOTI container, while mutable buffer FQNs can be registered as per-session state and rebound before execution. The path is fail-closed and dormant until a model opts in by creating a mutable-state context and validating coverage, so existing CUDA models keep their current behavior. The branch also wires the new source and unit coverage into both Buck and CMake so the primitive can land independently before any model-specific engine consumes it. --- backends/cuda/CMakeLists.txt | 14 +- backends/cuda/runtime/TARGETS | 27 + backends/cuda/runtime/cuda_backend.cpp | 18 +- backends/cuda/runtime/cuda_mutable_state.cpp | 721 +++++++++++++++ backends/cuda/runtime/cuda_mutable_state.h | 182 ++++ .../runtime/test/test_cuda_mutable_state.cpp | 827 ++++++++++++++++++ 6 files changed, 1783 insertions(+), 6 deletions(-) create mode 100644 backends/cuda/runtime/cuda_mutable_state.cpp create mode 100644 backends/cuda/runtime/cuda_mutable_state.h create mode 100644 backends/cuda/runtime/test/test_cuda_mutable_state.cpp diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 0ce48d85e92..4668e48b91e 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -184,7 +184,9 @@ install( ) # CUDA backend implementation -set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp + runtime/cuda_mutable_state.cpp +) if(_cuda_is_msvc_toolchain) # MSVC links aoti_cuda_backend into portable_lib without relying on C++ # symbols exported from aoti_cuda_shims.dll. @@ -236,3 +238,13 @@ install( EXPORT ExecuTorchTargets DESTINATION lib ) + +if(BUILD_TESTING) + include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + + et_cxx_test( + test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp + EXTRA_LIBS aoti_cuda_backend + ) + target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1) +endif() diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index f62780b29c2..1cdd430a020 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -1,4 +1,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") oncall("executorch") @@ -105,9 +107,11 @@ runtime.cxx_library( name = "cuda_backend", srcs = [ "cuda_backend.cpp", + "cuda_mutable_state.cpp", ], headers = [ "cuda_delegate_handle.h", + "cuda_mutable_state.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, @@ -135,3 +139,26 @@ runtime.cxx_library( ("cuda", None, "cuda-lazy"), ], ) + +cpp_unittest( + name = "test_cuda_mutable_state", + srcs = [ + "test/test_cuda_mutable_state.cpp", + ], + deps = [ + ":cuda_backend", + "//executorch/backends/aoti:aoti_common_slim", + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/factory:from_blob", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + preprocessor_flags = ["-DCUDA_AVAILABLE=1"], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), +) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index b0a06c8e8a0..f7d095540ad 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -44,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -436,6 +437,8 @@ class ET_EXPERIMENTAL CudaBackend final kCudaGraphWarmupSteps); } + mutable_state_note_handle(handle); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -539,6 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final } } + ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); + // --------------------------------------------------------------- // CUDA graph REPLAY path — skip all tensor setup and just replay // --------------------------------------------------------------- @@ -826,6 +831,8 @@ class ET_EXPERIMENTAL CudaBackend final } cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_; + mutable_state_forget_handle(handle); + // The CUDA stream is managed by shared_ptr in the handle. // It will be automatically destroyed when the last handle using it // is destroyed. Just reset our reference. @@ -899,11 +906,12 @@ class ET_EXPERIMENTAL CudaBackend final // * Constants are assumed to be IMMUTABLE (parameters or read-only // buffers). The AOTI shim today does not expose a mutability bit // through GetConstantOriginalFQN, so we cannot detect or refuse - // to share mutable buffers (e.g. a per-method KV cache). If a - // future model exports the same FQN as a mutable buffer in - // multiple methods, mutations from one method WILL be visible to - // the other through the shared GPU memory. Callers that need - // per-method mutable state must currently use distinct FQNs. + // to share mutable buffers (for example, runtime caches). If a + // model exports the same FQN as a mutable buffer in multiple + // methods, mutations from one method WILL be visible to the other + // through the shared GPU memory. Callers that need isolated mutable + // state for shared FQNs must opt into cuda_mutable_state or use + // distinct FQNs. // TODO: when AOTInductor exposes a constant-type / mutability // query, refuse to share entries that are not PARAMETER or // non-mutable BUFFER. diff --git a/backends/cuda/runtime/cuda_mutable_state.cpp b/backends/cuda/runtime/cuda_mutable_state.cpp new file mode 100644 index 00000000000..3438bd5b453 --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.cpp @@ -0,0 +1,721 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace aoti = ::executorch::backends::aoti; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::backends::aoti::slim::from_blob; +using ::executorch::backends::aoti::slim::SlimTensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +// AOTI internal constant names are per-handle; the exported FQN is the stable +// identity across methods. +struct Desc { + std::string internal_name; + std::vector sizes; + std::vector strides; + slimc10::ScalarType dtype{slimc10::ScalarType::Float}; + slimc10::Device device{slimc10::DeviceType::CUDA, 0}; + size_t nbytes{0}; +}; + +// Cached user-managed pairs for a (handle, session). +struct Bound { + std::vector> tensors; + std::vector pairs; +}; + +struct Context { + std::vector fqns; + std::unordered_set fqn_set; + + bool symbols_checked{false}; + bool symbols_available{false}; + bool handles_associated{false}; + + std::unordered_map template_ptr; + std::unordered_map template_nbytes; + std::unordered_map template_device; + int64_t total_bytes{0}; + + std::unordered_map> + desc; + std::unordered_set discovered_fqns; + Error build_error{Error::Ok}; + + std::unordered_set sessions; + int next_token{0}; + std::unordered_map> session_buf; + std::unordered_map> bound; + // A managed handle must not execute without an active session after sessions + // exist or after a previous rebind left session pointers installed. + std::unordered_set rebound_handles; +}; + +struct Manager { + std::mutex mu; + std::unordered_map contexts; + std::unordered_map handle_ctx; + MutableStateContext next_ctx{1}; +}; + +Manager& mgr() { + static Manager m; + return m; +} + +// Load scopes associate handles with a context; active scopes select a session +// for execute on the current thread. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +bool handle_has_symbols(CudaDelegateHandle* h) { + return h->get_num_constants && h->get_constant_name && + h->get_constant_original_fqn && h->extract_constants_map && + h->update_user_managed_constant_buffer_pairs; +} + +struct CudaDeviceGuard { + int prev_device{0}; + bool restore{false}; + + Error set(int device) { + if (device < 0) { + return Error::Ok; + } + cudaError_t err = cudaGetDevice(&prev_device); + if (err != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaGetDevice failed"); + return Error::Internal; + } + if (prev_device == device) { + return Error::Ok; + } + err = cudaSetDevice(device); + if (err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaSetDevice(%d) failed: %s", + device, + cudaGetErrorString(err)); + return Error::Internal; + } + restore = true; + return Error::Ok; + } + + ~CudaDeviceGuard() { + if (restore) { + cudaSetDevice(prev_device); + } + } +}; + +Result tensor_cuda_device_index(const SlimTensor& t) { + const slimc10::Device device = t.device(); + ET_CHECK_OR_RETURN_ERROR( + device.is_cuda(), + InvalidArgument, + "mutable_state: mutable buffer template must be on CUDA, got %s", + device.str().c_str()); + if (device.index() >= 0) { + return static_cast(device.index()); + } + cudaPointerAttributes attr{}; + const cudaError_t err = cudaPointerGetAttributes(&attr, t.data_ptr()); + if (err != cudaSuccess) { + cudaGetLastError(); + ET_LOG( + Error, + "mutable_state: cudaPointerGetAttributes failed for template pointer"); + return Error::Internal; + } + return attr.device; +} + +void cuda_free_on_pointer_device(void* ptr, bool synchronize) { + if (!ptr) { + return; + } + int device = -1; + cudaPointerAttributes attr{}; + const cudaError_t attr_err = cudaPointerGetAttributes(&attr, ptr); + if (attr_err == cudaSuccess) { + device = attr.device; + } else { + cudaGetLastError(); + } + + CudaDeviceGuard guard; + if (device >= 0 && guard.set(device) != Error::Ok) { + ET_LOG( + Error, + "mutable_state: freeing pointer %p without switching to device %d", + ptr, + device); + } + if (synchronize) { + const cudaError_t sync_err = cudaDeviceSynchronize(); + if (sync_err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaDeviceSynchronize before free failed: %s", + cudaGetErrorString(sync_err)); + } + } + const cudaError_t free_err = cudaFree(ptr); + if (free_err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaFree(%p) failed: %s", + ptr, + cudaGetErrorString(free_err)); + } +} + +bool validate_descriptors(const Context& c) { + bool ok = true; + std::unordered_map first_desc; + for (const auto& handle_descs : c.desc) { + for (const auto& fd : handle_descs.second) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end()) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has no captured template", + fqn.c_str()); + ok = false; + continue; + } + if (d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it->second); + ok = false; + } + + auto inserted = first_desc.emplace(fqn, &d); + if (!inserted.second) { + const Desc& base = *inserted.first->second; + if (d.dtype != base.dtype || d.device != base.device) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has incompatible dtype/device " + "across loaded methods", + fqn.c_str()); + ok = false; + } + } + } + } + return ok; +} + +Error validate_coverage_locked(Context& c) { + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + return Error::NotSupported; + } + if (c.fqns.empty()) { + ET_LOG(Error, "mutable_state: no mutable-buffer FQNs registered"); + c.build_error = Error::InvalidState; + return Error::InvalidState; + } + + bool ok = true; + for (const auto& fqn : c.fqns) { + if (c.discovered_fqns.find(fqn) == c.discovered_fqns.end()) { + ET_LOG( + Error, + "mutable_state: declared mutable buffer '%s' not found in any loaded " + "method's constants (FQN mismatch?)", + fqn.c_str()); + ok = false; + } + } + ok = validate_descriptors(c) && ok; + if (!ok) { + c.build_error = Error::InvalidProgram; + return Error::InvalidProgram; + } + return Error::Ok; +} + +// Captures descriptors and initial templates while the container still owns its +// default mutable buffers. +Error build_descriptors(Context& c, CudaDelegateHandle* h) { + auto container = h->container_handle; + + size_t n = 0; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_num_constants(container, &n), + "mutable_state: get_num_constants failed"); + std::unordered_map fqn_to_internal; + for (size_t i = 0; i < n; ++i) { + const char* internal = nullptr; + const char* fqn = nullptr; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_name(container, i, &internal), + "mutable_state: get_constant_name failed"); + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_original_fqn(container, i, &fqn), + "mutable_state: get_constant_original_fqn failed"); + // Empty names are method-scoped constants; skip them. + if (internal && internal[0] != '\0' && fqn && fqn[0] != '\0') { + fqn_to_internal[fqn] = internal; + } + } + + std::unordered_map extracted; + ET_CHECK_OK_OR_RETURN_ERROR( + h->extract_constants_map( + container, + reinterpret_cast(&extracted), + /*use_inactive=*/false), + "mutable_state: extract_constants_map failed"); + + auto& table = c.desc[h]; + for (const auto& fqn : c.fqns) { + auto it_name = fqn_to_internal.find(fqn); + auto it_t = extracted.find(fqn); + if (it_name == fqn_to_internal.end() || it_t == extracted.end()) { + continue; + } + auto* t = reinterpret_cast(it_t->second); + auto device_res = tensor_cuda_device_index(*t); + ET_CHECK_OK_OR_RETURN_ERROR(device_res.error()); + const int device = device_res.get(); + + Desc d; + d.internal_name = it_name->second; + d.sizes.assign(t->sizes().begin(), t->sizes().end()); + d.strides.assign(t->strides().begin(), t->strides().end()); + d.dtype = t->dtype(); + d.device = slimc10::Device(slimc10::DeviceType::CUDA, device); + d.nbytes = t->nbytes(); + table.emplace(fqn, std::move(d)); + c.discovered_fqns.insert(fqn); + + if (c.template_ptr.find(fqn) == c.template_ptr.end()) { + CudaDeviceGuard guard; + ET_CHECK_OK_OR_RETURN_ERROR(guard.set(device)); + + void* tpl = nullptr; + if (cudaMalloc(&tpl, t->nbytes()) != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMalloc template '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy( + tpl, t->data_ptr(), t->nbytes(), cudaMemcpyDeviceToDevice) != + cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMemcpy template '%s'", fqn.c_str()); + cuda_free_on_pointer_device(tpl, /*synchronize=*/false); + return Error::Internal; + } + c.template_ptr[fqn] = tpl; + c.template_nbytes[fqn] = t->nbytes(); + c.template_device[fqn] = device; + c.total_bytes += static_cast(t->nbytes()); + } + } + return Error::Ok; +} + +// Allocates any missing per-FQN session buffers from the captured templates. +Error ensure_session_buffers(Context& c, int token) { + auto& buf = c.session_buf[token]; + for (const auto& kv : c.template_ptr) { + const std::string& fqn = kv.first; + if (buf.find(fqn) != buf.end()) { + continue; + } + void* tpl = kv.second; + size_t nbytes = c.template_nbytes[fqn]; + auto device_it = c.template_device.find(fqn); + if (device_it == c.template_device.end()) { + ET_LOG(Error, "mutable_state: no template device for '%s'", fqn.c_str()); + return Error::Internal; + } + CudaDeviceGuard guard; + ET_CHECK_OK_OR_RETURN_ERROR(guard.set(device_it->second)); + + void* p = nullptr; + if (cudaMalloc(&p, nbytes) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMalloc session buffer '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy(p, tpl, nbytes, cudaMemcpyDeviceToDevice) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMemcpy session buffer '%s'", fqn.c_str()); + cuda_free_on_pointer_device(p, /*synchronize=*/false); + return Error::Internal; + } + buf[fqn] = p; + } + return Error::Ok; +} + +Error ensure_bound(Context& c, CudaDelegateHandle* h, int token) { + if (c.bound[h].find(token) != c.bound[h].end()) { + return Error::Ok; + } + Bound b; + auto& buf = c.session_buf[token]; + for (const auto& fd : c.desc[h]) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto buf_it = buf.find(fqn); + if (buf_it == buf.end() || buf_it->second == nullptr) { + ET_LOG(Error, "mutable_state: no session buffer for '%s'", fqn.c_str()); + return Error::Internal; + } + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end() || + d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it == c.template_nbytes.end() ? 0 : template_it->second); + return Error::Internal; + } + void* ptr = buf_it->second; + auto st = std::make_unique(from_blob( + ptr, + ::executorch::runtime::makeArrayRef(d.sizes.data(), d.sizes.size()), + ::executorch::runtime::makeArrayRef(d.strides.data(), d.strides.size()), + d.dtype, + d.device)); + aoti::AOTInductorConstantMapEntry entry; + entry.name = d.internal_name.c_str(); + entry.handle = reinterpret_cast(st.get()); + b.pairs.push_back(entry); + b.tensors.push_back(std::move(st)); + } + c.bound[h].emplace(token, std::move(b)); + return Error::Ok; +} + +void free_session_buffers(Context& c, int token) { + auto it = c.session_buf.find(token); + if (it != c.session_buf.end()) { + for (auto& kv : it->second) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + c.session_buf.erase(it); + } + for (auto& hb : c.bound) { + hb.second.erase(token); + } + c.sessions.erase(token); +} + +} // namespace + +namespace detail { + +MutableStateContext mutable_state_create_context() { + auto& m = mgr(); + std::lock_guard g(m.mu); + MutableStateContext id = m.next_ctx++; + m.contexts.emplace(id, Context{}); + return id; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + for (auto& kv : c.template_ptr) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + for (auto& sb : c.session_buf) { + for (auto& kv : sb.second) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + } + for (auto hit = m.handle_ctx.begin(); hit != m.handle_ctx.end();) { + hit = (hit->second == ctx) ? m.handle_ctx.erase(hit) : std::next(hit); + } + m.contexts.erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + if (tl_loading_ctx != kInvalidMutableContext) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto active = m.contexts.find(tl_loading_ctx); + if (active != m.contexts.end()) { + active->second.build_error = Error::InvalidState; + } + auto nested = m.contexts.find(ctx); + if (nested != m.contexts.end()) { + nested->second.build_error = Error::InvalidState; + } + ET_LOG(Error, "mutable_state: nested load scopes are not supported"); + tl_loading_ctx = kInvalidMutableContext; + return; + } + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + if (c.handles_associated || !c.sessions.empty()) { + ET_LOG( + Error, + "mutable_state: mutable-buffer FQNs must be registered before load"); + c.build_error = Error::InvalidState; + return; + } + c.fqns = fqns; + c.fqn_set.clear(); + c.fqn_set.insert(fqns.begin(), fqns.end()); +} + +bool mutable_state_available(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it != m.contexts.end() && it->second.build_error == Error::Ok && + it->second.symbols_available; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it == m.contexts.end() ? 0 : it->second.total_bytes; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + return validate_coverage_locked(it->second); +} + +Result mutable_state_create_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + Context& c = it->second; + ET_CHECK_OK_OR_RETURN_ERROR(validate_coverage_locked(c)); + int token = c.next_token++; + c.sessions.insert(token); + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + free_session_buffers(it->second, token); +} + +} // namespace detail + +void mutable_state_note_handle(CudaDelegateHandle* handle) { + MutableStateContext ctx = tl_loading_ctx; + if (ctx == kInvalidMutableContext) { + return; + } + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + c.handles_associated = true; + m.handle_ctx[handle] = ctx; + bool ok = handle_has_symbols(handle); + c.symbols_available = c.symbols_checked ? (c.symbols_available && ok) : ok; + c.symbols_checked = true; + if (ok && !c.fqns.empty() && c.desc.find(handle) == c.desc.end()) { + Error e = build_descriptors(c, handle); + if (e != Error::Ok) { + c.build_error = e; + } + } +} + +void mutable_state_forget_handle(CudaDelegateHandle* handle) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto hit = m.handle_ctx.find(handle); + if (hit == m.handle_ctx.end()) { + return; + } + auto cit = m.contexts.find(hit->second); + if (cit != m.contexts.end()) { + cit->second.desc.erase(handle); + cit->second.bound.erase(handle); + cit->second.rebound_handles.erase(handle); + } + m.handle_ctx.erase(hit); +} + +Error mutable_state_rebind_for_execute(CudaDelegateHandle* handle) { + auto& m = mgr(); + std::lock_guard g(m.mu); + + auto hit = m.handle_ctx.find(handle); + if (tl_active_token == kNoMutableSession) { + if (hit == m.handle_ctx.end()) { + return Error::Ok; + } + auto cit = m.contexts.find(hit->second); + if (cit != m.contexts.end() && + (!cit->second.sessions.empty() || + cit->second.rebound_handles.find(handle) != + cit->second.rebound_handles.end())) { + ET_LOG( + Error, "mutable_state: active session is required for this handle"); + return Error::InvalidState; + } + return Error::Ok; + } + if (hit == m.handle_ctx.end()) { + ET_LOG( + Error, + "mutable_state: active session set but handle has no context (load " + "scope missed?)"); + return Error::Internal; + } + MutableStateContext ctx = hit->second; + if (ctx != tl_active_ctx) { + ET_LOG( + Error, + "mutable_state: active context mismatch (caller set a different context " + "active than the one executing)"); + return Error::Internal; + } + auto cit = m.contexts.find(ctx); + if (cit == m.contexts.end()) { + return Error::Internal; + } + Context& c = cit->second; + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + ET_LOG( + Error, "mutable_state: active session set but rebinding unavailable"); + return Error::NotSupported; + } + const int token = tl_active_token; + if (c.sessions.find(token) == c.sessions.end()) { + ET_LOG(Error, "mutable_state: active session token was not created"); + return Error::InvalidArgument; + } + if (handle->cuda_graph_state.phase != CudaGraphPhase::Disabled) { + ET_LOG( + Error, + "mutable_state: per-session rebinding is not supported with CUDA graph"); + return Error::NotSupported; + } + if (c.desc.find(handle) == c.desc.end()) { + ET_LOG( + Error, + "mutable_state: no descriptors for handle (note_handle missed?)"); + return Error::Internal; + } + ET_CHECK_OK_OR_RETURN_ERROR(ensure_session_buffers(c, token)); + ET_CHECK_OK_OR_RETURN_ERROR(ensure_bound(c, handle, token)); + + const Bound& b = c.bound[handle][token]; + if (b.pairs.empty()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + b.pairs.data(), + b.pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false), + "mutable_state: update_user_managed_constant_buffer_pairs failed"); + c.rebound_handles.insert(handle); + return Error::Ok; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/cuda_mutable_state.h b/backends/cuda/runtime/cuda_mutable_state.h new file mode 100644 index 00000000000..07b9d3c898c --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +// CUDA-private support for running one loaded CUDA program with multiple +// isolated instances of its mutable buffers. Callers register mutable-buffer +// FQNs, create sessions, and execute with one active session selected. + +namespace executorch { +namespace backends { +namespace cuda { + +struct CudaDelegateHandle; + +// Opaque per-loaded-program context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Sentinel for execution without per-session rebinding. +constexpr int kNoMutableSession = -1; + +// Implementation entry points. Callers should use MutableStateContextOwner. +namespace detail { + +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns); +bool mutable_state_available(MutableStateContext ctx); +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); +void mutable_state_set_active(MutableStateContext ctx, int token); + +} // namespace detail + +// Caller-facing owner for one mutable-state context. +class ET_EXPERIMENTAL MutableStateContextOwner final { + class LoadScope final { + public: + explicit LoadScope(MutableStateContext ctx) { + detail::mutable_state_begin_load(ctx); + } + + ~LoadScope() { + detail::mutable_state_end_load(); + } + + LoadScope(const LoadScope&) = delete; + LoadScope& operator=(const LoadScope&) = delete; + }; + + class ActiveSessionScope final { + public: + ActiveSessionScope(MutableStateContext ctx, int token) { + detail::mutable_state_set_active(ctx, token); + } + + ~ActiveSessionScope() { + detail::mutable_state_set_active( + kInvalidMutableContext, kNoMutableSession); + } + + ActiveSessionScope(const ActiveSessionScope&) = delete; + ActiveSessionScope& operator=(const ActiveSessionScope&) = delete; + }; + + public: + MutableStateContextOwner() : ctx_(detail::mutable_state_create_context()) {} + + ~MutableStateContextOwner() { + destroy(); + } + + MutableStateContextOwner(const MutableStateContextOwner&) = delete; + MutableStateContextOwner& operator=(const MutableStateContextOwner&) = delete; + + MutableStateContextOwner(MutableStateContextOwner&& other) noexcept + : ctx_(std::exchange(other.ctx_, kInvalidMutableContext)) {} + + MutableStateContextOwner& operator=( + MutableStateContextOwner&& other) noexcept { + if (this != &other) { + destroy(); + ctx_ = std::exchange(other.ctx_, kInvalidMutableContext); + } + return *this; + } + + MutableStateContext get() const { + return ctx_; + } + + explicit operator bool() const { + return ctx_ != kInvalidMutableContext; + } + + void register_fqns(const std::vector& fqns) const { + detail::mutable_state_register_fqns(ctx_, fqns); + } + + // Associates delegate handles created by `fn` with this context. Register + // FQNs before entering the load scope. + template + auto with_load_scope(Fn&& fn) const -> decltype(std::forward(fn)()) { + LoadScope scope(ctx_); + return std::forward(fn)(); + } + + // Selects this context/session while `fn` executes. The caller is responsible + // for serializing execution that touches the same loaded program. + template + auto with_active_session(int token, Fn&& fn) const + -> decltype(std::forward(fn)()) { + ActiveSessionScope scope(ctx_, token); + return std::forward(fn)(); + } + + bool available() const { + return detail::mutable_state_available(ctx_); + } + + int64_t bytes_per_session() const { + return detail::mutable_state_bytes_per_session(ctx_); + } + + ::executorch::runtime::Error validate_coverage() const { + return detail::mutable_state_validate_coverage(ctx_); + } + + ::executorch::runtime::Result create_session() const { + return detail::mutable_state_create_session(ctx_); + } + + void destroy_session(int token) const { + detail::mutable_state_destroy_session(ctx_, token); + } + + private: + void destroy() { + if (ctx_ != kInvalidMutableContext) { + detail::mutable_state_destroy_context(ctx_); + ctx_ = kInvalidMutableContext; + } + } + + MutableStateContext ctx_ = kInvalidMutableContext; +}; + +// --- CudaBackend hooks ------------------------------------------------------- + +void mutable_state_note_handle(CudaDelegateHandle* handle); + +void mutable_state_forget_handle(CudaDelegateHandle* handle); + +::executorch::runtime::Error mutable_state_rebind_for_execute( + CudaDelegateHandle* handle); + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/test/test_cuda_mutable_state.cpp b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp new file mode 100644 index 00000000000..c7392e5dd55 --- /dev/null +++ b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp @@ -0,0 +1,827 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cu = ::executorch::backends::cuda; +namespace aoti = ::executorch::backends::aoti; +namespace slim = ::executorch::backends::aoti::slim; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::runtime::Error; + +namespace { + +Error fake_get_num_constants( + aoti::AOTInductorModelContainerHandle, + size_t* num_constants) { + *num_constants = 0; + return Error::Ok; +} + +Error fake_get_constant_name( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_extract_constants_map( + aoti::AOTInductorModelContainerHandle, + aoti::AOTInductorConstantMapHandle, + bool) { + return Error::Ok; +} + +Error fake_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle, + const aoti::AOTInductorConstantMapEntry*, + size_t, + bool, + bool) { + return Error::Ok; +} + +struct FakeContainer { + std::vector internal_names; + std::vector fqns; + std::unordered_map extracted; + size_t update_calls = 0; + size_t last_num_pairs = 0; + std::string last_name; + void* last_bound_data = nullptr; + size_t last_bound_nbytes = 0; + int last_bound_device_index = -1; + std::unordered_map bound_data_by_name; + std::unordered_map bound_device_by_name; +}; + +Error fake_container_get_num_constants( + aoti::AOTInductorModelContainerHandle container, + size_t* num_constants) { + auto* c = reinterpret_cast(container); + *num_constants = c->internal_names.size(); + return Error::Ok; +} + +Error fake_container_get_constant_name( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** name) { + auto* c = reinterpret_cast(container); + *name = + idx < c->internal_names.size() ? c->internal_names[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** fqn) { + auto* c = reinterpret_cast(container); + *fqn = idx < c->fqns.size() ? c->fqns[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_extract_constants_map( + aoti::AOTInductorModelContainerHandle container, + aoti::AOTInductorConstantMapHandle map_handle, + bool) { + auto* c = reinterpret_cast(container); + auto* out = reinterpret_cast< + std::unordered_map*>(map_handle); + *out = c->extracted; + return Error::Ok; +} + +Error fake_container_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle container, + const aoti::AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool, + bool) { + auto* c = reinterpret_cast(container); + c->update_calls++; + c->last_num_pairs = num_pairs; + if (num_pairs > 0) { + c->last_name = pairs[0].name; + auto* t = reinterpret_cast(pairs[0].handle); + c->last_bound_data = t->data_ptr(); + c->last_bound_nbytes = t->nbytes(); + c->last_bound_device_index = t->device().index(); + } + for (size_t i = 0; i < num_pairs; ++i) { + auto* t = reinterpret_cast(pairs[i].handle); + c->bound_data_by_name[pairs[i].name] = t->data_ptr(); + c->bound_device_by_name[pairs[i].name] = t->device().index(); + } + return Error::Ok; +} + +cu::CudaDelegateHandle fake_symbol_handle() { + cu::CudaDelegateHandle handle{}; + handle.get_num_constants = fake_get_num_constants; + handle.get_constant_name = fake_get_constant_name; + handle.get_constant_original_fqn = fake_get_constant_original_fqn; + handle.extract_constants_map = fake_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_update_user_managed_pairs; + return handle; +} + +cu::CudaDelegateHandle fake_container_handle(FakeContainer* container) { + cu::CudaDelegateHandle handle{}; + handle.container_handle = + reinterpret_cast(container); + handle.get_num_constants = fake_container_get_num_constants; + handle.get_constant_name = fake_container_get_constant_name; + handle.get_constant_original_fqn = fake_container_get_constant_original_fqn; + handle.extract_constants_map = fake_container_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_container_update_user_managed_pairs; + return handle; +} + +bool cuda_device_available() { + int device_count = 0; + const cudaError_t err = cudaGetDeviceCount(&device_count); + return err == cudaSuccess && device_count > 0; +} + +std::unique_ptr make_device_tensor( + const std::vector& values, + void** device_ptr, + int device_index = 0) { + *device_ptr = nullptr; + cudaError_t err = cudaMalloc(device_ptr, values.size() * sizeof(float)); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMalloc failed: " << cudaGetErrorString(err); + return nullptr; + } + err = cudaMemcpy( + *device_ptr, + values.data(), + values.size() * sizeof(float), + cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMemcpy failed: " << cudaGetErrorString(err); + cudaFree(*device_ptr); + *device_ptr = nullptr; + return nullptr; + } + return std::make_unique(slim::from_blob( + *device_ptr, + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CUDA, device_index))); +} + +std::unique_ptr make_cpu_tensor(std::vector& values) { + return std::make_unique(slim::from_blob( + values.data(), + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CPU, 0))); +} + +} // namespace + +TEST(CudaMutableStateTest, FallClosedDefaults) { + const cu::MutableStateContext bad = 999999; + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + + EXPECT_GT(c2.get(), c1.get()); + EXPECT_TRUE(c1); + EXPECT_FALSE(c1.available()); + EXPECT_EQ(c1.bytes_per_session(), 0); + EXPECT_EQ(cu::detail::mutable_state_bytes_per_session(bad), 0); + EXPECT_EQ( + cu::detail::mutable_state_validate_coverage(bad), Error::InvalidArgument); + EXPECT_EQ(c1.validate_coverage(), Error::NotSupported); + + c1.register_fqns({"a.b", "c.d"}); + EXPECT_EQ(c1.validate_coverage(), Error::NotSupported); + EXPECT_EQ( + cu::detail::mutable_state_create_session(bad).error(), + Error::InvalidArgument); + EXPECT_EQ(c1.create_session().error(), Error::NotSupported); + + cu::detail::mutable_state_destroy_session(bad, 0); + cu::detail::mutable_state_destroy_context(bad); +} + +TEST(CudaMutableStateTest, ForgetHandleDropsAssociation) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle{}; + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + c.with_active_session(0, [&] { + EXPECT_EQ( + cu::mutable_state_rebind_for_execute(&handle), Error::NotSupported); + + cu::mutable_state_forget_handle(&handle); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::Internal); + }); +} + +TEST(CudaMutableStateTest, CreateSessionRejectsEmptyFqns) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidState); + EXPECT_EQ(c.validate_coverage(), Error::InvalidState); + EXPECT_FALSE(c.available()); +} + +TEST(CudaMutableStateTest, CreateSessionValidatesCoverageBeforeIssuingToken) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.register_fqns({"missing.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidProgram); + EXPECT_EQ(c.validate_coverage(), Error::InvalidProgram); + EXPECT_FALSE(c.available()); +} + +TEST(CudaMutableStateTest, RegisterFqnsAfterLoadFailsClosed) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + c.register_fqns({"late.state"}); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidState); + EXPECT_EQ(c.create_session().error(), Error::InvalidState); +} + +TEST(CudaMutableStateTest, NestedBeginLoadFailsClosed) { + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c1.with_load_scope([&] { + c2.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + }); + + EXPECT_EQ(c1.validate_coverage(), Error::InvalidState); + EXPECT_EQ(c2.validate_coverage(), Error::InvalidState); + EXPECT_FALSE(c1.available()); + EXPECT_FALSE(c2.available()); + EXPECT_EQ(c1.create_session().error(), Error::InvalidState); + EXPECT_EQ(c2.create_session().error(), Error::InvalidState); +} + +TEST(CudaMutableStateTest, OwnerLoadScopeClearsThreadLocalLoadState) { + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + cu::CudaDelegateHandle h1 = fake_symbol_handle(); + cu::CudaDelegateHandle h2 = fake_symbol_handle(); + + c1.with_load_scope([&] { cu::mutable_state_note_handle(&h1); }); + c2.with_load_scope([&] { cu::mutable_state_note_handle(&h2); }); + + EXPECT_TRUE(c1.available()); + EXPECT_TRUE(c2.available()); +} + +TEST(CudaMutableStateTest, RebindRejectsCudaGraphHandle) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = make_device_tensor({1.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_state"}; + container.fqns = {"model.state"}; + container.extracted["model.state"] = + reinterpret_cast(source_tensor.get()); + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + c.register_fqns({"model.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + + handle.cuda_graph_state.phase = cu::CudaGraphPhase::Warmup; + EXPECT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::NotSupported); + + c.destroy_session(token.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, CapturesClonesAndRebindsDeviceBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = + make_device_tensor({1.0f, 2.0f, 3.0f, 4.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_state"}; + container.fqns = {"model.state"}; + container.extracted["model.state"] = + reinterpret_cast(source_tensor.get()); + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.bytes_per_session(), 4 * sizeof(float)); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + EXPECT_EQ( + c.with_active_session( + 123, [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::InvalidArgument); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::InvalidState); + + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + + EXPECT_EQ(container.update_calls, 1u); + EXPECT_EQ(container.last_num_pairs, 1u); + EXPECT_EQ(container.last_name, "internal_state"); + ASSERT_NE(container.last_bound_data, nullptr); + EXPECT_NE(container.last_bound_data, source_ptr); + EXPECT_EQ(container.last_bound_nbytes, 4 * sizeof(float)); + + std::vector cloned(4); + EXPECT_EQ( + cudaMemcpy( + cloned.data(), + container.last_bound_data, + cloned.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(cloned, (std::vector{1.0f, 2.0f, 3.0f, 4.0f})); + + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::InvalidState); + + EXPECT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + + c.destroy_session(token.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, SharedFqnAcrossHandlesUsesSameSessionBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* prefill_ptr = nullptr; + void* decode_ptr = nullptr; + auto prefill_tensor = make_device_tensor({1.0f, 2.0f}, &prefill_ptr); + auto decode_tensor = make_device_tensor({9.0f, 8.0f}, &decode_ptr); + ASSERT_NE(prefill_tensor, nullptr); + ASSERT_NE(decode_tensor, nullptr); + ASSERT_NE(prefill_ptr, nullptr); + ASSERT_NE(decode_ptr, nullptr); + + FakeContainer prefill_container; + prefill_container.internal_names = {"prefill_internal_kv"}; + prefill_container.fqns = {"model.kv"}; + prefill_container.extracted["model.kv"] = + reinterpret_cast(prefill_tensor.get()); + cu::CudaDelegateHandle prefill_handle = + fake_container_handle(&prefill_container); + + FakeContainer decode_container; + decode_container.internal_names = {"decode_internal_kv"}; + decode_container.fqns = {"model.kv"}; + decode_container.extracted["model.kv"] = + reinterpret_cast(decode_tensor.get()); + cu::CudaDelegateHandle decode_handle = + fake_container_handle(&decode_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&prefill_handle); + cu::mutable_state_note_handle(&decode_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&prefill_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&decode_handle); + }), + Error::Ok); + + ASSERT_NE(prefill_container.last_bound_data, nullptr); + ASSERT_NE(decode_container.last_bound_data, nullptr); + EXPECT_EQ(prefill_container.last_name, "prefill_internal_kv"); + EXPECT_EQ(decode_container.last_name, "decode_internal_kv"); + EXPECT_EQ( + prefill_container.last_bound_data, decode_container.last_bound_data); + EXPECT_NE(prefill_container.last_bound_data, prefill_ptr); + EXPECT_NE(decode_container.last_bound_data, decode_ptr); + + c.destroy_session(token.get()); + cudaFree(prefill_ptr); + cudaFree(decode_ptr); +} + +TEST(CudaMutableStateTest, SessionsStayIsolatedForSameHandleAndFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = make_device_tensor({0.0f, 0.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_kv"}; + container.fqns = {"model.kv"}; + container.extracted["model.kv"] = + reinterpret_cast(source_tensor.get()); + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto session_a = c.create_session(); + auto session_b = c.create_session(); + ASSERT_TRUE(session_a.ok()); + ASSERT_TRUE(session_b.ok()); + + void* a_ptr = nullptr; + const std::vector a_values = {1.0f, 2.0f}; + ASSERT_EQ( + c.with_active_session( + session_a.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + a_ptr = container.bound_data_by_name["internal_kv"]; + ASSERT_NE(a_ptr, nullptr); + ASSERT_EQ( + cudaMemcpy( + a_ptr, + a_values.data(), + a_values.size() * sizeof(float), + cudaMemcpyHostToDevice), + cudaSuccess); + + void* b_ptr = nullptr; + const std::vector b_values = {9.0f, 8.0f}; + ASSERT_EQ( + c.with_active_session( + session_b.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + b_ptr = container.bound_data_by_name["internal_kv"]; + ASSERT_NE(b_ptr, nullptr); + EXPECT_NE(a_ptr, b_ptr); + ASSERT_EQ( + cudaMemcpy( + b_ptr, + b_values.data(), + b_values.size() * sizeof(float), + cudaMemcpyHostToDevice), + cudaSuccess); + + ASSERT_EQ( + c.with_active_session( + session_a.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + EXPECT_EQ(container.bound_data_by_name["internal_kv"], a_ptr); + + std::vector read_a(2); + ASSERT_EQ( + cudaMemcpy( + read_a.data(), + a_ptr, + read_a.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(read_a, a_values); + + ASSERT_EQ( + c.with_active_session( + session_b.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + EXPECT_EQ(container.bound_data_by_name["internal_kv"], b_ptr); + + std::vector read_b(2); + ASSERT_EQ( + cudaMemcpy( + read_b.data(), + b_ptr, + read_b.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(read_b, b_values); + + c.destroy_session(session_a.get()); + c.destroy_session(session_b.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, EmptyInternalNameIsSkipped) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* skipped_ptr = nullptr; + void* valid_ptr = nullptr; + auto skipped_tensor = make_device_tensor({9.0f, 8.0f}, &skipped_ptr); + auto valid_tensor = make_device_tensor({1.0f, 2.0f}, &valid_ptr); + ASSERT_NE(skipped_tensor, nullptr); + ASSERT_NE(valid_tensor, nullptr); + ASSERT_NE(skipped_ptr, nullptr); + ASSERT_NE(valid_ptr, nullptr); + + FakeContainer skipped_container; + skipped_container.internal_names = {""}; + skipped_container.fqns = {"model.kv"}; + skipped_container.extracted["model.kv"] = + reinterpret_cast(skipped_tensor.get()); + cu::CudaDelegateHandle skipped_handle = + fake_container_handle(&skipped_container); + + FakeContainer valid_container; + valid_container.internal_names = {"valid_internal_kv"}; + valid_container.fqns = {"model.kv"}; + valid_container.extracted["model.kv"] = + reinterpret_cast(valid_tensor.get()); + cu::CudaDelegateHandle valid_handle = fake_container_handle(&valid_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&skipped_handle); + cu::mutable_state_note_handle(&valid_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&skipped_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&valid_handle); + }), + Error::Ok); + EXPECT_EQ(skipped_container.update_calls, 0u); + + EXPECT_EQ(valid_container.update_calls, 1u); + EXPECT_EQ(valid_container.last_name, "valid_internal_kv"); + ASSERT_NE(valid_container.last_bound_data, nullptr); + EXPECT_NE(valid_container.last_bound_data, valid_ptr); + EXPECT_NE(valid_container.last_bound_data, skipped_ptr); + + c.destroy_session(token.get()); + cudaFree(skipped_ptr); + cudaFree(valid_ptr); +} + +TEST( + CudaMutableStateTest, + ValidateCoverageRejectsLargerDescriptorForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* small_ptr = nullptr; + void* large_ptr = nullptr; + auto small_tensor = make_device_tensor({1.0f}, &small_ptr); + auto large_tensor = make_device_tensor({1.0f, 2.0f}, &large_ptr); + ASSERT_NE(small_tensor, nullptr); + ASSERT_NE(large_tensor, nullptr); + ASSERT_NE(small_ptr, nullptr); + ASSERT_NE(large_ptr, nullptr); + + FakeContainer small_container; + small_container.internal_names = {"small_internal"}; + small_container.fqns = {"model.state"}; + small_container.extracted["model.state"] = + reinterpret_cast(small_tensor.get()); + cu::CudaDelegateHandle small_handle = fake_container_handle(&small_container); + + FakeContainer large_container; + large_container.internal_names = {"large_internal"}; + large_container.fqns = {"model.state"}; + large_container.extracted["model.state"] = + reinterpret_cast(large_tensor.get()); + cu::CudaDelegateHandle large_handle = fake_container_handle(&large_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&small_handle); + cu::mutable_state_note_handle(&large_handle); + }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidProgram); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidProgram); + EXPECT_EQ(large_container.update_calls, 0u); + + cudaFree(small_ptr); + cudaFree(large_ptr); +} + +TEST( + CudaMutableStateTest, + ValidateCoverageNormalizesUnspecifiedCudaDeviceForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + int current_device = 0; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + + void* unspecified_ptr = nullptr; + void* explicit_ptr = nullptr; + auto unspecified_tensor = + make_device_tensor({1.0f}, &unspecified_ptr, /*device_index=*/-1); + auto explicit_tensor = + make_device_tensor({2.0f}, &explicit_ptr, current_device); + ASSERT_NE(unspecified_tensor, nullptr); + ASSERT_NE(explicit_tensor, nullptr); + ASSERT_NE(unspecified_ptr, nullptr); + ASSERT_NE(explicit_ptr, nullptr); + + FakeContainer unspecified_container; + unspecified_container.internal_names = {"unspecified_internal"}; + unspecified_container.fqns = {"model.state"}; + unspecified_container.extracted["model.state"] = + reinterpret_cast(unspecified_tensor.get()); + cu::CudaDelegateHandle unspecified_handle = + fake_container_handle(&unspecified_container); + + FakeContainer explicit_container; + explicit_container.internal_names = {"explicit_internal"}; + explicit_container.fqns = {"model.state"}; + explicit_container.extracted["model.state"] = + reinterpret_cast(explicit_tensor.get()); + cu::CudaDelegateHandle explicit_handle = + fake_container_handle(&explicit_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&unspecified_handle); + cu::mutable_state_note_handle(&explicit_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&unspecified_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&explicit_handle); + }), + Error::Ok); + + EXPECT_EQ(unspecified_container.last_bound_device_index, current_device); + EXPECT_EQ(explicit_container.last_bound_device_index, current_device); + EXPECT_EQ( + unspecified_container.bound_device_by_name["unspecified_internal"], + current_device); + EXPECT_EQ( + explicit_container.bound_device_by_name["explicit_internal"], + current_device); + + c.destroy_session(token.get()); + cudaFree(unspecified_ptr); + cudaFree(explicit_ptr); +} + +TEST(CudaMutableStateTest, BuildRejectsNonCudaDescriptorForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* cuda_ptr = nullptr; + auto cuda_tensor = make_device_tensor({1.0f}, &cuda_ptr); + ASSERT_NE(cuda_tensor, nullptr); + ASSERT_NE(cuda_ptr, nullptr); + + std::vector cpu_values = {1.0f}; + auto cpu_tensor = make_cpu_tensor(cpu_values); + ASSERT_NE(cpu_tensor, nullptr); + + FakeContainer cuda_container; + cuda_container.internal_names = {"cuda_internal"}; + cuda_container.fqns = {"model.state"}; + cuda_container.extracted["model.state"] = + reinterpret_cast(cuda_tensor.get()); + cu::CudaDelegateHandle cuda_handle = fake_container_handle(&cuda_container); + + FakeContainer cpu_container; + cpu_container.internal_names = {"cpu_internal"}; + cpu_container.fqns = {"model.state"}; + cpu_container.extracted["model.state"] = + reinterpret_cast(cpu_tensor.get()); + cu::CudaDelegateHandle cpu_handle = fake_container_handle(&cpu_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&cuda_handle); + cu::mutable_state_note_handle(&cpu_handle); + }); + + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidArgument); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidArgument); + EXPECT_EQ(cpu_container.update_calls, 0u); + + cudaFree(cuda_ptr); +}