diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 0ce48d85e92..4668e48b91e 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -184,7 +184,9 @@ install( ) # CUDA backend implementation -set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp + runtime/cuda_mutable_state.cpp +) if(_cuda_is_msvc_toolchain) # MSVC links aoti_cuda_backend into portable_lib without relying on C++ # symbols exported from aoti_cuda_shims.dll. @@ -236,3 +238,13 @@ install( EXPORT ExecuTorchTargets DESTINATION lib ) + +if(BUILD_TESTING) + include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + + et_cxx_test( + test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp + EXTRA_LIBS aoti_cuda_backend + ) + target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1) +endif() diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index f62780b29c2..1cdd430a020 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -1,4 +1,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") oncall("executorch") @@ -105,9 +107,11 @@ runtime.cxx_library( name = "cuda_backend", srcs = [ "cuda_backend.cpp", + "cuda_mutable_state.cpp", ], headers = [ "cuda_delegate_handle.h", + "cuda_mutable_state.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, @@ -135,3 +139,26 @@ runtime.cxx_library( ("cuda", None, "cuda-lazy"), ], ) + +cpp_unittest( + name = "test_cuda_mutable_state", + srcs = [ + "test/test_cuda_mutable_state.cpp", + ], + deps = [ + ":cuda_backend", + "//executorch/backends/aoti:aoti_common_slim", + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/factory:from_blob", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + preprocessor_flags = ["-DCUDA_AVAILABLE=1"], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), +) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index b0a06c8e8a0..f7d095540ad 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -44,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -436,6 +437,8 @@ class ET_EXPERIMENTAL CudaBackend final kCudaGraphWarmupSteps); } + mutable_state_note_handle(handle); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -539,6 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final } } + ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle)); + // --------------------------------------------------------------- // CUDA graph REPLAY path — skip all tensor setup and just replay // --------------------------------------------------------------- @@ -826,6 +831,8 @@ class ET_EXPERIMENTAL CudaBackend final } cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_; + mutable_state_forget_handle(handle); + // The CUDA stream is managed by shared_ptr in the handle. // It will be automatically destroyed when the last handle using it // is destroyed. Just reset our reference. @@ -899,11 +906,12 @@ class ET_EXPERIMENTAL CudaBackend final // * Constants are assumed to be IMMUTABLE (parameters or read-only // buffers). The AOTI shim today does not expose a mutability bit // through GetConstantOriginalFQN, so we cannot detect or refuse - // to share mutable buffers (e.g. a per-method KV cache). If a - // future model exports the same FQN as a mutable buffer in - // multiple methods, mutations from one method WILL be visible to - // the other through the shared GPU memory. Callers that need - // per-method mutable state must currently use distinct FQNs. + // to share mutable buffers (for example, runtime caches). If a + // model exports the same FQN as a mutable buffer in multiple + // methods, mutations from one method WILL be visible to the other + // through the shared GPU memory. Callers that need isolated mutable + // state for shared FQNs must opt into cuda_mutable_state or use + // distinct FQNs. // TODO: when AOTInductor exposes a constant-type / mutability // query, refuse to share entries that are not PARAMETER or // non-mutable BUFFER. diff --git a/backends/cuda/runtime/cuda_mutable_state.cpp b/backends/cuda/runtime/cuda_mutable_state.cpp new file mode 100644 index 00000000000..3438bd5b453 --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.cpp @@ -0,0 +1,721 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace aoti = ::executorch::backends::aoti; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::backends::aoti::slim::from_blob; +using ::executorch::backends::aoti::slim::SlimTensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +// AOTI internal constant names are per-handle; the exported FQN is the stable +// identity across methods. +struct Desc { + std::string internal_name; + std::vector sizes; + std::vector strides; + slimc10::ScalarType dtype{slimc10::ScalarType::Float}; + slimc10::Device device{slimc10::DeviceType::CUDA, 0}; + size_t nbytes{0}; +}; + +// Cached user-managed pairs for a (handle, session). +struct Bound { + std::vector> tensors; + std::vector pairs; +}; + +struct Context { + std::vector fqns; + std::unordered_set fqn_set; + + bool symbols_checked{false}; + bool symbols_available{false}; + bool handles_associated{false}; + + std::unordered_map template_ptr; + std::unordered_map template_nbytes; + std::unordered_map template_device; + int64_t total_bytes{0}; + + std::unordered_map> + desc; + std::unordered_set discovered_fqns; + Error build_error{Error::Ok}; + + std::unordered_set sessions; + int next_token{0}; + std::unordered_map> session_buf; + std::unordered_map> bound; + // A managed handle must not execute without an active session after sessions + // exist or after a previous rebind left session pointers installed. + std::unordered_set rebound_handles; +}; + +struct Manager { + std::mutex mu; + std::unordered_map contexts; + std::unordered_map handle_ctx; + MutableStateContext next_ctx{1}; +}; + +Manager& mgr() { + static Manager m; + return m; +} + +// Load scopes associate handles with a context; active scopes select a session +// for execute on the current thread. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +bool handle_has_symbols(CudaDelegateHandle* h) { + return h->get_num_constants && h->get_constant_name && + h->get_constant_original_fqn && h->extract_constants_map && + h->update_user_managed_constant_buffer_pairs; +} + +struct CudaDeviceGuard { + int prev_device{0}; + bool restore{false}; + + Error set(int device) { + if (device < 0) { + return Error::Ok; + } + cudaError_t err = cudaGetDevice(&prev_device); + if (err != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaGetDevice failed"); + return Error::Internal; + } + if (prev_device == device) { + return Error::Ok; + } + err = cudaSetDevice(device); + if (err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaSetDevice(%d) failed: %s", + device, + cudaGetErrorString(err)); + return Error::Internal; + } + restore = true; + return Error::Ok; + } + + ~CudaDeviceGuard() { + if (restore) { + cudaSetDevice(prev_device); + } + } +}; + +Result tensor_cuda_device_index(const SlimTensor& t) { + const slimc10::Device device = t.device(); + ET_CHECK_OR_RETURN_ERROR( + device.is_cuda(), + InvalidArgument, + "mutable_state: mutable buffer template must be on CUDA, got %s", + device.str().c_str()); + if (device.index() >= 0) { + return static_cast(device.index()); + } + cudaPointerAttributes attr{}; + const cudaError_t err = cudaPointerGetAttributes(&attr, t.data_ptr()); + if (err != cudaSuccess) { + cudaGetLastError(); + ET_LOG( + Error, + "mutable_state: cudaPointerGetAttributes failed for template pointer"); + return Error::Internal; + } + return attr.device; +} + +void cuda_free_on_pointer_device(void* ptr, bool synchronize) { + if (!ptr) { + return; + } + int device = -1; + cudaPointerAttributes attr{}; + const cudaError_t attr_err = cudaPointerGetAttributes(&attr, ptr); + if (attr_err == cudaSuccess) { + device = attr.device; + } else { + cudaGetLastError(); + } + + CudaDeviceGuard guard; + if (device >= 0 && guard.set(device) != Error::Ok) { + ET_LOG( + Error, + "mutable_state: freeing pointer %p without switching to device %d", + ptr, + device); + } + if (synchronize) { + const cudaError_t sync_err = cudaDeviceSynchronize(); + if (sync_err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaDeviceSynchronize before free failed: %s", + cudaGetErrorString(sync_err)); + } + } + const cudaError_t free_err = cudaFree(ptr); + if (free_err != cudaSuccess) { + ET_LOG( + Error, + "mutable_state: cudaFree(%p) failed: %s", + ptr, + cudaGetErrorString(free_err)); + } +} + +bool validate_descriptors(const Context& c) { + bool ok = true; + std::unordered_map first_desc; + for (const auto& handle_descs : c.desc) { + for (const auto& fd : handle_descs.second) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end()) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has no captured template", + fqn.c_str()); + ok = false; + continue; + } + if (d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it->second); + ok = false; + } + + auto inserted = first_desc.emplace(fqn, &d); + if (!inserted.second) { + const Desc& base = *inserted.first->second; + if (d.dtype != base.dtype || d.device != base.device) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' has incompatible dtype/device " + "across loaded methods", + fqn.c_str()); + ok = false; + } + } + } + } + return ok; +} + +Error validate_coverage_locked(Context& c) { + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + return Error::NotSupported; + } + if (c.fqns.empty()) { + ET_LOG(Error, "mutable_state: no mutable-buffer FQNs registered"); + c.build_error = Error::InvalidState; + return Error::InvalidState; + } + + bool ok = true; + for (const auto& fqn : c.fqns) { + if (c.discovered_fqns.find(fqn) == c.discovered_fqns.end()) { + ET_LOG( + Error, + "mutable_state: declared mutable buffer '%s' not found in any loaded " + "method's constants (FQN mismatch?)", + fqn.c_str()); + ok = false; + } + } + ok = validate_descriptors(c) && ok; + if (!ok) { + c.build_error = Error::InvalidProgram; + return Error::InvalidProgram; + } + return Error::Ok; +} + +// Captures descriptors and initial templates while the container still owns its +// default mutable buffers. +Error build_descriptors(Context& c, CudaDelegateHandle* h) { + auto container = h->container_handle; + + size_t n = 0; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_num_constants(container, &n), + "mutable_state: get_num_constants failed"); + std::unordered_map fqn_to_internal; + for (size_t i = 0; i < n; ++i) { + const char* internal = nullptr; + const char* fqn = nullptr; + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_name(container, i, &internal), + "mutable_state: get_constant_name failed"); + ET_CHECK_OK_OR_RETURN_ERROR( + h->get_constant_original_fqn(container, i, &fqn), + "mutable_state: get_constant_original_fqn failed"); + // Empty names are method-scoped constants; skip them. + if (internal && internal[0] != '\0' && fqn && fqn[0] != '\0') { + fqn_to_internal[fqn] = internal; + } + } + + std::unordered_map extracted; + ET_CHECK_OK_OR_RETURN_ERROR( + h->extract_constants_map( + container, + reinterpret_cast(&extracted), + /*use_inactive=*/false), + "mutable_state: extract_constants_map failed"); + + auto& table = c.desc[h]; + for (const auto& fqn : c.fqns) { + auto it_name = fqn_to_internal.find(fqn); + auto it_t = extracted.find(fqn); + if (it_name == fqn_to_internal.end() || it_t == extracted.end()) { + continue; + } + auto* t = reinterpret_cast(it_t->second); + auto device_res = tensor_cuda_device_index(*t); + ET_CHECK_OK_OR_RETURN_ERROR(device_res.error()); + const int device = device_res.get(); + + Desc d; + d.internal_name = it_name->second; + d.sizes.assign(t->sizes().begin(), t->sizes().end()); + d.strides.assign(t->strides().begin(), t->strides().end()); + d.dtype = t->dtype(); + d.device = slimc10::Device(slimc10::DeviceType::CUDA, device); + d.nbytes = t->nbytes(); + table.emplace(fqn, std::move(d)); + c.discovered_fqns.insert(fqn); + + if (c.template_ptr.find(fqn) == c.template_ptr.end()) { + CudaDeviceGuard guard; + ET_CHECK_OK_OR_RETURN_ERROR(guard.set(device)); + + void* tpl = nullptr; + if (cudaMalloc(&tpl, t->nbytes()) != cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMalloc template '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy( + tpl, t->data_ptr(), t->nbytes(), cudaMemcpyDeviceToDevice) != + cudaSuccess) { + ET_LOG(Error, "mutable_state: cudaMemcpy template '%s'", fqn.c_str()); + cuda_free_on_pointer_device(tpl, /*synchronize=*/false); + return Error::Internal; + } + c.template_ptr[fqn] = tpl; + c.template_nbytes[fqn] = t->nbytes(); + c.template_device[fqn] = device; + c.total_bytes += static_cast(t->nbytes()); + } + } + return Error::Ok; +} + +// Allocates any missing per-FQN session buffers from the captured templates. +Error ensure_session_buffers(Context& c, int token) { + auto& buf = c.session_buf[token]; + for (const auto& kv : c.template_ptr) { + const std::string& fqn = kv.first; + if (buf.find(fqn) != buf.end()) { + continue; + } + void* tpl = kv.second; + size_t nbytes = c.template_nbytes[fqn]; + auto device_it = c.template_device.find(fqn); + if (device_it == c.template_device.end()) { + ET_LOG(Error, "mutable_state: no template device for '%s'", fqn.c_str()); + return Error::Internal; + } + CudaDeviceGuard guard; + ET_CHECK_OK_OR_RETURN_ERROR(guard.set(device_it->second)); + + void* p = nullptr; + if (cudaMalloc(&p, nbytes) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMalloc session buffer '%s'", fqn.c_str()); + return Error::Internal; + } + if (cudaMemcpy(p, tpl, nbytes, cudaMemcpyDeviceToDevice) != cudaSuccess) { + ET_LOG( + Error, "mutable_state: cudaMemcpy session buffer '%s'", fqn.c_str()); + cuda_free_on_pointer_device(p, /*synchronize=*/false); + return Error::Internal; + } + buf[fqn] = p; + } + return Error::Ok; +} + +Error ensure_bound(Context& c, CudaDelegateHandle* h, int token) { + if (c.bound[h].find(token) != c.bound[h].end()) { + return Error::Ok; + } + Bound b; + auto& buf = c.session_buf[token]; + for (const auto& fd : c.desc[h]) { + const std::string& fqn = fd.first; + const Desc& d = fd.second; + auto buf_it = buf.find(fqn); + if (buf_it == buf.end() || buf_it->second == nullptr) { + ET_LOG(Error, "mutable_state: no session buffer for '%s'", fqn.c_str()); + return Error::Internal; + } + auto template_it = c.template_nbytes.find(fqn); + if (template_it == c.template_nbytes.end() || + d.nbytes > template_it->second) { + ET_LOG( + Error, + "mutable_state: descriptor '%s' (%zu B) exceeds shared template " + "buffer (%zu B)", + fqn.c_str(), + d.nbytes, + template_it == c.template_nbytes.end() ? 0 : template_it->second); + return Error::Internal; + } + void* ptr = buf_it->second; + auto st = std::make_unique(from_blob( + ptr, + ::executorch::runtime::makeArrayRef(d.sizes.data(), d.sizes.size()), + ::executorch::runtime::makeArrayRef(d.strides.data(), d.strides.size()), + d.dtype, + d.device)); + aoti::AOTInductorConstantMapEntry entry; + entry.name = d.internal_name.c_str(); + entry.handle = reinterpret_cast(st.get()); + b.pairs.push_back(entry); + b.tensors.push_back(std::move(st)); + } + c.bound[h].emplace(token, std::move(b)); + return Error::Ok; +} + +void free_session_buffers(Context& c, int token) { + auto it = c.session_buf.find(token); + if (it != c.session_buf.end()) { + for (auto& kv : it->second) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + c.session_buf.erase(it); + } + for (auto& hb : c.bound) { + hb.second.erase(token); + } + c.sessions.erase(token); +} + +} // namespace + +namespace detail { + +MutableStateContext mutable_state_create_context() { + auto& m = mgr(); + std::lock_guard g(m.mu); + MutableStateContext id = m.next_ctx++; + m.contexts.emplace(id, Context{}); + return id; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + for (auto& kv : c.template_ptr) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + for (auto& sb : c.session_buf) { + for (auto& kv : sb.second) { + if (kv.second) { + cuda_free_on_pointer_device(kv.second, /*synchronize=*/true); + } + } + } + for (auto hit = m.handle_ctx.begin(); hit != m.handle_ctx.end();) { + hit = (hit->second == ctx) ? m.handle_ctx.erase(hit) : std::next(hit); + } + m.contexts.erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + if (tl_loading_ctx != kInvalidMutableContext) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto active = m.contexts.find(tl_loading_ctx); + if (active != m.contexts.end()) { + active->second.build_error = Error::InvalidState; + } + auto nested = m.contexts.find(ctx); + if (nested != m.contexts.end()) { + nested->second.build_error = Error::InvalidState; + } + ET_LOG(Error, "mutable_state: nested load scopes are not supported"); + tl_loading_ctx = kInvalidMutableContext; + return; + } + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + if (c.handles_associated || !c.sessions.empty()) { + ET_LOG( + Error, + "mutable_state: mutable-buffer FQNs must be registered before load"); + c.build_error = Error::InvalidState; + return; + } + c.fqns = fqns; + c.fqn_set.clear(); + c.fqn_set.insert(fqns.begin(), fqns.end()); +} + +bool mutable_state_available(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it != m.contexts.end() && it->second.build_error == Error::Ok && + it->second.symbols_available; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + return it == m.contexts.end() ? 0 : it->second.total_bytes; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + return validate_coverage_locked(it->second); +} + +Result mutable_state_create_session(MutableStateContext ctx) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return Error::InvalidArgument; + } + Context& c = it->second; + ET_CHECK_OK_OR_RETURN_ERROR(validate_coverage_locked(c)); + int token = c.next_token++; + c.sessions.insert(token); + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + free_session_buffers(it->second, token); +} + +} // namespace detail + +void mutable_state_note_handle(CudaDelegateHandle* handle) { + MutableStateContext ctx = tl_loading_ctx; + if (ctx == kInvalidMutableContext) { + return; + } + auto& m = mgr(); + std::lock_guard g(m.mu); + auto it = m.contexts.find(ctx); + if (it == m.contexts.end()) { + return; + } + Context& c = it->second; + c.handles_associated = true; + m.handle_ctx[handle] = ctx; + bool ok = handle_has_symbols(handle); + c.symbols_available = c.symbols_checked ? (c.symbols_available && ok) : ok; + c.symbols_checked = true; + if (ok && !c.fqns.empty() && c.desc.find(handle) == c.desc.end()) { + Error e = build_descriptors(c, handle); + if (e != Error::Ok) { + c.build_error = e; + } + } +} + +void mutable_state_forget_handle(CudaDelegateHandle* handle) { + auto& m = mgr(); + std::lock_guard g(m.mu); + auto hit = m.handle_ctx.find(handle); + if (hit == m.handle_ctx.end()) { + return; + } + auto cit = m.contexts.find(hit->second); + if (cit != m.contexts.end()) { + cit->second.desc.erase(handle); + cit->second.bound.erase(handle); + cit->second.rebound_handles.erase(handle); + } + m.handle_ctx.erase(hit); +} + +Error mutable_state_rebind_for_execute(CudaDelegateHandle* handle) { + auto& m = mgr(); + std::lock_guard g(m.mu); + + auto hit = m.handle_ctx.find(handle); + if (tl_active_token == kNoMutableSession) { + if (hit == m.handle_ctx.end()) { + return Error::Ok; + } + auto cit = m.contexts.find(hit->second); + if (cit != m.contexts.end() && + (!cit->second.sessions.empty() || + cit->second.rebound_handles.find(handle) != + cit->second.rebound_handles.end())) { + ET_LOG( + Error, "mutable_state: active session is required for this handle"); + return Error::InvalidState; + } + return Error::Ok; + } + if (hit == m.handle_ctx.end()) { + ET_LOG( + Error, + "mutable_state: active session set but handle has no context (load " + "scope missed?)"); + return Error::Internal; + } + MutableStateContext ctx = hit->second; + if (ctx != tl_active_ctx) { + ET_LOG( + Error, + "mutable_state: active context mismatch (caller set a different context " + "active than the one executing)"); + return Error::Internal; + } + auto cit = m.contexts.find(ctx); + if (cit == m.contexts.end()) { + return Error::Internal; + } + Context& c = cit->second; + if (c.build_error != Error::Ok) { + return c.build_error; + } + if (!c.symbols_available) { + ET_LOG( + Error, "mutable_state: active session set but rebinding unavailable"); + return Error::NotSupported; + } + const int token = tl_active_token; + if (c.sessions.find(token) == c.sessions.end()) { + ET_LOG(Error, "mutable_state: active session token was not created"); + return Error::InvalidArgument; + } + if (handle->cuda_graph_state.phase != CudaGraphPhase::Disabled) { + ET_LOG( + Error, + "mutable_state: per-session rebinding is not supported with CUDA graph"); + return Error::NotSupported; + } + if (c.desc.find(handle) == c.desc.end()) { + ET_LOG( + Error, + "mutable_state: no descriptors for handle (note_handle missed?)"); + return Error::Internal; + } + ET_CHECK_OK_OR_RETURN_ERROR(ensure_session_buffers(c, token)); + ET_CHECK_OK_OR_RETURN_ERROR(ensure_bound(c, handle, token)); + + const Bound& b = c.bound[handle][token]; + if (b.pairs.empty()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + b.pairs.data(), + b.pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false), + "mutable_state: update_user_managed_constant_buffer_pairs failed"); + c.rebound_handles.insert(handle); + return Error::Ok; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/cuda_mutable_state.h b/backends/cuda/runtime/cuda_mutable_state.h new file mode 100644 index 00000000000..07b9d3c898c --- /dev/null +++ b/backends/cuda/runtime/cuda_mutable_state.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +// CUDA-private support for running one loaded CUDA program with multiple +// isolated instances of its mutable buffers. Callers register mutable-buffer +// FQNs, create sessions, and execute with one active session selected. + +namespace executorch { +namespace backends { +namespace cuda { + +struct CudaDelegateHandle; + +// Opaque per-loaded-program context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Sentinel for execution without per-session rebinding. +constexpr int kNoMutableSession = -1; + +// Implementation entry points. Callers should use MutableStateContextOwner. +namespace detail { + +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); +void mutable_state_register_fqns( + MutableStateContext ctx, + const std::vector& fqns); +bool mutable_state_available(MutableStateContext ctx); +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); +void mutable_state_set_active(MutableStateContext ctx, int token); + +} // namespace detail + +// Caller-facing owner for one mutable-state context. +class ET_EXPERIMENTAL MutableStateContextOwner final { + class LoadScope final { + public: + explicit LoadScope(MutableStateContext ctx) { + detail::mutable_state_begin_load(ctx); + } + + ~LoadScope() { + detail::mutable_state_end_load(); + } + + LoadScope(const LoadScope&) = delete; + LoadScope& operator=(const LoadScope&) = delete; + }; + + class ActiveSessionScope final { + public: + ActiveSessionScope(MutableStateContext ctx, int token) { + detail::mutable_state_set_active(ctx, token); + } + + ~ActiveSessionScope() { + detail::mutable_state_set_active( + kInvalidMutableContext, kNoMutableSession); + } + + ActiveSessionScope(const ActiveSessionScope&) = delete; + ActiveSessionScope& operator=(const ActiveSessionScope&) = delete; + }; + + public: + MutableStateContextOwner() : ctx_(detail::mutable_state_create_context()) {} + + ~MutableStateContextOwner() { + destroy(); + } + + MutableStateContextOwner(const MutableStateContextOwner&) = delete; + MutableStateContextOwner& operator=(const MutableStateContextOwner&) = delete; + + MutableStateContextOwner(MutableStateContextOwner&& other) noexcept + : ctx_(std::exchange(other.ctx_, kInvalidMutableContext)) {} + + MutableStateContextOwner& operator=( + MutableStateContextOwner&& other) noexcept { + if (this != &other) { + destroy(); + ctx_ = std::exchange(other.ctx_, kInvalidMutableContext); + } + return *this; + } + + MutableStateContext get() const { + return ctx_; + } + + explicit operator bool() const { + return ctx_ != kInvalidMutableContext; + } + + void register_fqns(const std::vector& fqns) const { + detail::mutable_state_register_fqns(ctx_, fqns); + } + + // Associates delegate handles created by `fn` with this context. Register + // FQNs before entering the load scope. + template + auto with_load_scope(Fn&& fn) const -> decltype(std::forward(fn)()) { + LoadScope scope(ctx_); + return std::forward(fn)(); + } + + // Selects this context/session while `fn` executes. The caller is responsible + // for serializing execution that touches the same loaded program. + template + auto with_active_session(int token, Fn&& fn) const + -> decltype(std::forward(fn)()) { + ActiveSessionScope scope(ctx_, token); + return std::forward(fn)(); + } + + bool available() const { + return detail::mutable_state_available(ctx_); + } + + int64_t bytes_per_session() const { + return detail::mutable_state_bytes_per_session(ctx_); + } + + ::executorch::runtime::Error validate_coverage() const { + return detail::mutable_state_validate_coverage(ctx_); + } + + ::executorch::runtime::Result create_session() const { + return detail::mutable_state_create_session(ctx_); + } + + void destroy_session(int token) const { + detail::mutable_state_destroy_session(ctx_, token); + } + + private: + void destroy() { + if (ctx_ != kInvalidMutableContext) { + detail::mutable_state_destroy_context(ctx_); + ctx_ = kInvalidMutableContext; + } + } + + MutableStateContext ctx_ = kInvalidMutableContext; +}; + +// --- CudaBackend hooks ------------------------------------------------------- + +void mutable_state_note_handle(CudaDelegateHandle* handle); + +void mutable_state_forget_handle(CudaDelegateHandle* handle); + +::executorch::runtime::Error mutable_state_rebind_for_execute( + CudaDelegateHandle* handle); + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/test/test_cuda_mutable_state.cpp b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp new file mode 100644 index 00000000000..c7392e5dd55 --- /dev/null +++ b/backends/cuda/runtime/test/test_cuda_mutable_state.cpp @@ -0,0 +1,827 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cu = ::executorch::backends::cuda; +namespace aoti = ::executorch::backends::aoti; +namespace slim = ::executorch::backends::aoti::slim; +namespace slimc10 = ::executorch::backends::aoti::slim::c10; +using ::executorch::runtime::Error; + +namespace { + +Error fake_get_num_constants( + aoti::AOTInductorModelContainerHandle, + size_t* num_constants) { + *num_constants = 0; + return Error::Ok; +} + +Error fake_get_constant_name( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle, + size_t, + const char**) { + return Error::Ok; +} + +Error fake_extract_constants_map( + aoti::AOTInductorModelContainerHandle, + aoti::AOTInductorConstantMapHandle, + bool) { + return Error::Ok; +} + +Error fake_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle, + const aoti::AOTInductorConstantMapEntry*, + size_t, + bool, + bool) { + return Error::Ok; +} + +struct FakeContainer { + std::vector internal_names; + std::vector fqns; + std::unordered_map extracted; + size_t update_calls = 0; + size_t last_num_pairs = 0; + std::string last_name; + void* last_bound_data = nullptr; + size_t last_bound_nbytes = 0; + int last_bound_device_index = -1; + std::unordered_map bound_data_by_name; + std::unordered_map bound_device_by_name; +}; + +Error fake_container_get_num_constants( + aoti::AOTInductorModelContainerHandle container, + size_t* num_constants) { + auto* c = reinterpret_cast(container); + *num_constants = c->internal_names.size(); + return Error::Ok; +} + +Error fake_container_get_constant_name( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** name) { + auto* c = reinterpret_cast(container); + *name = + idx < c->internal_names.size() ? c->internal_names[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_get_constant_original_fqn( + aoti::AOTInductorModelContainerHandle container, + size_t idx, + const char** fqn) { + auto* c = reinterpret_cast(container); + *fqn = idx < c->fqns.size() ? c->fqns[idx].c_str() : nullptr; + return Error::Ok; +} + +Error fake_container_extract_constants_map( + aoti::AOTInductorModelContainerHandle container, + aoti::AOTInductorConstantMapHandle map_handle, + bool) { + auto* c = reinterpret_cast(container); + auto* out = reinterpret_cast< + std::unordered_map*>(map_handle); + *out = c->extracted; + return Error::Ok; +} + +Error fake_container_update_user_managed_pairs( + aoti::AOTInductorModelContainerHandle container, + const aoti::AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool, + bool) { + auto* c = reinterpret_cast(container); + c->update_calls++; + c->last_num_pairs = num_pairs; + if (num_pairs > 0) { + c->last_name = pairs[0].name; + auto* t = reinterpret_cast(pairs[0].handle); + c->last_bound_data = t->data_ptr(); + c->last_bound_nbytes = t->nbytes(); + c->last_bound_device_index = t->device().index(); + } + for (size_t i = 0; i < num_pairs; ++i) { + auto* t = reinterpret_cast(pairs[i].handle); + c->bound_data_by_name[pairs[i].name] = t->data_ptr(); + c->bound_device_by_name[pairs[i].name] = t->device().index(); + } + return Error::Ok; +} + +cu::CudaDelegateHandle fake_symbol_handle() { + cu::CudaDelegateHandle handle{}; + handle.get_num_constants = fake_get_num_constants; + handle.get_constant_name = fake_get_constant_name; + handle.get_constant_original_fqn = fake_get_constant_original_fqn; + handle.extract_constants_map = fake_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_update_user_managed_pairs; + return handle; +} + +cu::CudaDelegateHandle fake_container_handle(FakeContainer* container) { + cu::CudaDelegateHandle handle{}; + handle.container_handle = + reinterpret_cast(container); + handle.get_num_constants = fake_container_get_num_constants; + handle.get_constant_name = fake_container_get_constant_name; + handle.get_constant_original_fqn = fake_container_get_constant_original_fqn; + handle.extract_constants_map = fake_container_extract_constants_map; + handle.update_user_managed_constant_buffer_pairs = + fake_container_update_user_managed_pairs; + return handle; +} + +bool cuda_device_available() { + int device_count = 0; + const cudaError_t err = cudaGetDeviceCount(&device_count); + return err == cudaSuccess && device_count > 0; +} + +std::unique_ptr make_device_tensor( + const std::vector& values, + void** device_ptr, + int device_index = 0) { + *device_ptr = nullptr; + cudaError_t err = cudaMalloc(device_ptr, values.size() * sizeof(float)); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMalloc failed: " << cudaGetErrorString(err); + return nullptr; + } + err = cudaMemcpy( + *device_ptr, + values.data(), + values.size() * sizeof(float), + cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + ADD_FAILURE() << "cudaMemcpy failed: " << cudaGetErrorString(err); + cudaFree(*device_ptr); + *device_ptr = nullptr; + return nullptr; + } + return std::make_unique(slim::from_blob( + *device_ptr, + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CUDA, device_index))); +} + +std::unique_ptr make_cpu_tensor(std::vector& values) { + return std::make_unique(slim::from_blob( + values.data(), + {static_cast(values.size())}, + slimc10::ScalarType::Float, + slimc10::Device(slimc10::DeviceType::CPU, 0))); +} + +} // namespace + +TEST(CudaMutableStateTest, FallClosedDefaults) { + const cu::MutableStateContext bad = 999999; + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + + EXPECT_GT(c2.get(), c1.get()); + EXPECT_TRUE(c1); + EXPECT_FALSE(c1.available()); + EXPECT_EQ(c1.bytes_per_session(), 0); + EXPECT_EQ(cu::detail::mutable_state_bytes_per_session(bad), 0); + EXPECT_EQ( + cu::detail::mutable_state_validate_coverage(bad), Error::InvalidArgument); + EXPECT_EQ(c1.validate_coverage(), Error::NotSupported); + + c1.register_fqns({"a.b", "c.d"}); + EXPECT_EQ(c1.validate_coverage(), Error::NotSupported); + EXPECT_EQ( + cu::detail::mutable_state_create_session(bad).error(), + Error::InvalidArgument); + EXPECT_EQ(c1.create_session().error(), Error::NotSupported); + + cu::detail::mutable_state_destroy_session(bad, 0); + cu::detail::mutable_state_destroy_context(bad); +} + +TEST(CudaMutableStateTest, ForgetHandleDropsAssociation) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle{}; + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + c.with_active_session(0, [&] { + EXPECT_EQ( + cu::mutable_state_rebind_for_execute(&handle), Error::NotSupported); + + cu::mutable_state_forget_handle(&handle); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::Internal); + }); +} + +TEST(CudaMutableStateTest, CreateSessionRejectsEmptyFqns) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidState); + EXPECT_EQ(c.validate_coverage(), Error::InvalidState); + EXPECT_FALSE(c.available()); +} + +TEST(CudaMutableStateTest, CreateSessionValidatesCoverageBeforeIssuingToken) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.register_fqns({"missing.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidProgram); + EXPECT_EQ(c.validate_coverage(), Error::InvalidProgram); + EXPECT_FALSE(c.available()); +} + +TEST(CudaMutableStateTest, RegisterFqnsAfterLoadFailsClosed) { + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + c.register_fqns({"late.state"}); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidState); + EXPECT_EQ(c.create_session().error(), Error::InvalidState); +} + +TEST(CudaMutableStateTest, NestedBeginLoadFailsClosed) { + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + cu::CudaDelegateHandle handle = fake_symbol_handle(); + + c1.with_load_scope([&] { + c2.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + }); + + EXPECT_EQ(c1.validate_coverage(), Error::InvalidState); + EXPECT_EQ(c2.validate_coverage(), Error::InvalidState); + EXPECT_FALSE(c1.available()); + EXPECT_FALSE(c2.available()); + EXPECT_EQ(c1.create_session().error(), Error::InvalidState); + EXPECT_EQ(c2.create_session().error(), Error::InvalidState); +} + +TEST(CudaMutableStateTest, OwnerLoadScopeClearsThreadLocalLoadState) { + cu::MutableStateContextOwner c1; + cu::MutableStateContextOwner c2; + cu::CudaDelegateHandle h1 = fake_symbol_handle(); + cu::CudaDelegateHandle h2 = fake_symbol_handle(); + + c1.with_load_scope([&] { cu::mutable_state_note_handle(&h1); }); + c2.with_load_scope([&] { cu::mutable_state_note_handle(&h2); }); + + EXPECT_TRUE(c1.available()); + EXPECT_TRUE(c2.available()); +} + +TEST(CudaMutableStateTest, RebindRejectsCudaGraphHandle) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = make_device_tensor({1.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_state"}; + container.fqns = {"model.state"}; + container.extracted["model.state"] = + reinterpret_cast(source_tensor.get()); + cu::MutableStateContextOwner c; + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + c.register_fqns({"model.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + + handle.cuda_graph_state.phase = cu::CudaGraphPhase::Warmup; + EXPECT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::NotSupported); + + c.destroy_session(token.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, CapturesClonesAndRebindsDeviceBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = + make_device_tensor({1.0f, 2.0f, 3.0f, 4.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_state"}; + container.fqns = {"model.state"}; + container.extracted["model.state"] = + reinterpret_cast(source_tensor.get()); + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.bytes_per_session(), 4 * sizeof(float)); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + EXPECT_EQ( + c.with_active_session( + 123, [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::InvalidArgument); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::InvalidState); + + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + + EXPECT_EQ(container.update_calls, 1u); + EXPECT_EQ(container.last_num_pairs, 1u); + EXPECT_EQ(container.last_name, "internal_state"); + ASSERT_NE(container.last_bound_data, nullptr); + EXPECT_NE(container.last_bound_data, source_ptr); + EXPECT_EQ(container.last_bound_nbytes, 4 * sizeof(float)); + + std::vector cloned(4); + EXPECT_EQ( + cudaMemcpy( + cloned.data(), + container.last_bound_data, + cloned.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(cloned, (std::vector{1.0f, 2.0f, 3.0f, 4.0f})); + + EXPECT_EQ(cu::mutable_state_rebind_for_execute(&handle), Error::InvalidState); + + EXPECT_EQ( + c.with_active_session( + token.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + + c.destroy_session(token.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, SharedFqnAcrossHandlesUsesSameSessionBuffer) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* prefill_ptr = nullptr; + void* decode_ptr = nullptr; + auto prefill_tensor = make_device_tensor({1.0f, 2.0f}, &prefill_ptr); + auto decode_tensor = make_device_tensor({9.0f, 8.0f}, &decode_ptr); + ASSERT_NE(prefill_tensor, nullptr); + ASSERT_NE(decode_tensor, nullptr); + ASSERT_NE(prefill_ptr, nullptr); + ASSERT_NE(decode_ptr, nullptr); + + FakeContainer prefill_container; + prefill_container.internal_names = {"prefill_internal_kv"}; + prefill_container.fqns = {"model.kv"}; + prefill_container.extracted["model.kv"] = + reinterpret_cast(prefill_tensor.get()); + cu::CudaDelegateHandle prefill_handle = + fake_container_handle(&prefill_container); + + FakeContainer decode_container; + decode_container.internal_names = {"decode_internal_kv"}; + decode_container.fqns = {"model.kv"}; + decode_container.extracted["model.kv"] = + reinterpret_cast(decode_tensor.get()); + cu::CudaDelegateHandle decode_handle = + fake_container_handle(&decode_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&prefill_handle); + cu::mutable_state_note_handle(&decode_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&prefill_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&decode_handle); + }), + Error::Ok); + + ASSERT_NE(prefill_container.last_bound_data, nullptr); + ASSERT_NE(decode_container.last_bound_data, nullptr); + EXPECT_EQ(prefill_container.last_name, "prefill_internal_kv"); + EXPECT_EQ(decode_container.last_name, "decode_internal_kv"); + EXPECT_EQ( + prefill_container.last_bound_data, decode_container.last_bound_data); + EXPECT_NE(prefill_container.last_bound_data, prefill_ptr); + EXPECT_NE(decode_container.last_bound_data, decode_ptr); + + c.destroy_session(token.get()); + cudaFree(prefill_ptr); + cudaFree(decode_ptr); +} + +TEST(CudaMutableStateTest, SessionsStayIsolatedForSameHandleAndFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* source_ptr = nullptr; + auto source_tensor = make_device_tensor({0.0f, 0.0f}, &source_ptr); + ASSERT_NE(source_tensor, nullptr); + ASSERT_NE(source_ptr, nullptr); + + FakeContainer container; + container.internal_names = {"internal_kv"}; + container.fqns = {"model.kv"}; + container.extracted["model.kv"] = + reinterpret_cast(source_tensor.get()); + cu::CudaDelegateHandle handle = fake_container_handle(&container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { cu::mutable_state_note_handle(&handle); }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto session_a = c.create_session(); + auto session_b = c.create_session(); + ASSERT_TRUE(session_a.ok()); + ASSERT_TRUE(session_b.ok()); + + void* a_ptr = nullptr; + const std::vector a_values = {1.0f, 2.0f}; + ASSERT_EQ( + c.with_active_session( + session_a.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + a_ptr = container.bound_data_by_name["internal_kv"]; + ASSERT_NE(a_ptr, nullptr); + ASSERT_EQ( + cudaMemcpy( + a_ptr, + a_values.data(), + a_values.size() * sizeof(float), + cudaMemcpyHostToDevice), + cudaSuccess); + + void* b_ptr = nullptr; + const std::vector b_values = {9.0f, 8.0f}; + ASSERT_EQ( + c.with_active_session( + session_b.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + b_ptr = container.bound_data_by_name["internal_kv"]; + ASSERT_NE(b_ptr, nullptr); + EXPECT_NE(a_ptr, b_ptr); + ASSERT_EQ( + cudaMemcpy( + b_ptr, + b_values.data(), + b_values.size() * sizeof(float), + cudaMemcpyHostToDevice), + cudaSuccess); + + ASSERT_EQ( + c.with_active_session( + session_a.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + EXPECT_EQ(container.bound_data_by_name["internal_kv"], a_ptr); + + std::vector read_a(2); + ASSERT_EQ( + cudaMemcpy( + read_a.data(), + a_ptr, + read_a.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(read_a, a_values); + + ASSERT_EQ( + c.with_active_session( + session_b.get(), + [&] { return cu::mutable_state_rebind_for_execute(&handle); }), + Error::Ok); + EXPECT_EQ(container.bound_data_by_name["internal_kv"], b_ptr); + + std::vector read_b(2); + ASSERT_EQ( + cudaMemcpy( + read_b.data(), + b_ptr, + read_b.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + EXPECT_EQ(read_b, b_values); + + c.destroy_session(session_a.get()); + c.destroy_session(session_b.get()); + cudaFree(source_ptr); +} + +TEST(CudaMutableStateTest, EmptyInternalNameIsSkipped) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* skipped_ptr = nullptr; + void* valid_ptr = nullptr; + auto skipped_tensor = make_device_tensor({9.0f, 8.0f}, &skipped_ptr); + auto valid_tensor = make_device_tensor({1.0f, 2.0f}, &valid_ptr); + ASSERT_NE(skipped_tensor, nullptr); + ASSERT_NE(valid_tensor, nullptr); + ASSERT_NE(skipped_ptr, nullptr); + ASSERT_NE(valid_ptr, nullptr); + + FakeContainer skipped_container; + skipped_container.internal_names = {""}; + skipped_container.fqns = {"model.kv"}; + skipped_container.extracted["model.kv"] = + reinterpret_cast(skipped_tensor.get()); + cu::CudaDelegateHandle skipped_handle = + fake_container_handle(&skipped_container); + + FakeContainer valid_container; + valid_container.internal_names = {"valid_internal_kv"}; + valid_container.fqns = {"model.kv"}; + valid_container.extracted["model.kv"] = + reinterpret_cast(valid_tensor.get()); + cu::CudaDelegateHandle valid_handle = fake_container_handle(&valid_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.kv"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&skipped_handle); + cu::mutable_state_note_handle(&valid_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&skipped_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&valid_handle); + }), + Error::Ok); + EXPECT_EQ(skipped_container.update_calls, 0u); + + EXPECT_EQ(valid_container.update_calls, 1u); + EXPECT_EQ(valid_container.last_name, "valid_internal_kv"); + ASSERT_NE(valid_container.last_bound_data, nullptr); + EXPECT_NE(valid_container.last_bound_data, valid_ptr); + EXPECT_NE(valid_container.last_bound_data, skipped_ptr); + + c.destroy_session(token.get()); + cudaFree(skipped_ptr); + cudaFree(valid_ptr); +} + +TEST( + CudaMutableStateTest, + ValidateCoverageRejectsLargerDescriptorForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* small_ptr = nullptr; + void* large_ptr = nullptr; + auto small_tensor = make_device_tensor({1.0f}, &small_ptr); + auto large_tensor = make_device_tensor({1.0f, 2.0f}, &large_ptr); + ASSERT_NE(small_tensor, nullptr); + ASSERT_NE(large_tensor, nullptr); + ASSERT_NE(small_ptr, nullptr); + ASSERT_NE(large_ptr, nullptr); + + FakeContainer small_container; + small_container.internal_names = {"small_internal"}; + small_container.fqns = {"model.state"}; + small_container.extracted["model.state"] = + reinterpret_cast(small_tensor.get()); + cu::CudaDelegateHandle small_handle = fake_container_handle(&small_container); + + FakeContainer large_container; + large_container.internal_names = {"large_internal"}; + large_container.fqns = {"model.state"}; + large_container.extracted["model.state"] = + reinterpret_cast(large_tensor.get()); + cu::CudaDelegateHandle large_handle = fake_container_handle(&large_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&small_handle); + cu::mutable_state_note_handle(&large_handle); + }); + + ASSERT_TRUE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidProgram); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidProgram); + EXPECT_EQ(large_container.update_calls, 0u); + + cudaFree(small_ptr); + cudaFree(large_ptr); +} + +TEST( + CudaMutableStateTest, + ValidateCoverageNormalizesUnspecifiedCudaDeviceForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + int current_device = 0; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + + void* unspecified_ptr = nullptr; + void* explicit_ptr = nullptr; + auto unspecified_tensor = + make_device_tensor({1.0f}, &unspecified_ptr, /*device_index=*/-1); + auto explicit_tensor = + make_device_tensor({2.0f}, &explicit_ptr, current_device); + ASSERT_NE(unspecified_tensor, nullptr); + ASSERT_NE(explicit_tensor, nullptr); + ASSERT_NE(unspecified_ptr, nullptr); + ASSERT_NE(explicit_ptr, nullptr); + + FakeContainer unspecified_container; + unspecified_container.internal_names = {"unspecified_internal"}; + unspecified_container.fqns = {"model.state"}; + unspecified_container.extracted["model.state"] = + reinterpret_cast(unspecified_tensor.get()); + cu::CudaDelegateHandle unspecified_handle = + fake_container_handle(&unspecified_container); + + FakeContainer explicit_container; + explicit_container.internal_names = {"explicit_internal"}; + explicit_container.fqns = {"model.state"}; + explicit_container.extracted["model.state"] = + reinterpret_cast(explicit_tensor.get()); + cu::CudaDelegateHandle explicit_handle = + fake_container_handle(&explicit_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&unspecified_handle); + cu::mutable_state_note_handle(&explicit_handle); + }); + + ASSERT_TRUE(c.available()); + ASSERT_EQ(c.validate_coverage(), Error::Ok); + + auto token = c.create_session(); + ASSERT_TRUE(token.ok()); + ASSERT_EQ( + c.with_active_session( + token.get(), + [&] { + Error e = cu::mutable_state_rebind_for_execute(&unspecified_handle); + if (e != Error::Ok) { + return e; + } + return cu::mutable_state_rebind_for_execute(&explicit_handle); + }), + Error::Ok); + + EXPECT_EQ(unspecified_container.last_bound_device_index, current_device); + EXPECT_EQ(explicit_container.last_bound_device_index, current_device); + EXPECT_EQ( + unspecified_container.bound_device_by_name["unspecified_internal"], + current_device); + EXPECT_EQ( + explicit_container.bound_device_by_name["explicit_internal"], + current_device); + + c.destroy_session(token.get()); + cudaFree(unspecified_ptr); + cudaFree(explicit_ptr); +} + +TEST(CudaMutableStateTest, BuildRejectsNonCudaDescriptorForSharedFqn) { + if (!cuda_device_available()) { + GTEST_SKIP() << "CUDA device unavailable"; + } + + void* cuda_ptr = nullptr; + auto cuda_tensor = make_device_tensor({1.0f}, &cuda_ptr); + ASSERT_NE(cuda_tensor, nullptr); + ASSERT_NE(cuda_ptr, nullptr); + + std::vector cpu_values = {1.0f}; + auto cpu_tensor = make_cpu_tensor(cpu_values); + ASSERT_NE(cpu_tensor, nullptr); + + FakeContainer cuda_container; + cuda_container.internal_names = {"cuda_internal"}; + cuda_container.fqns = {"model.state"}; + cuda_container.extracted["model.state"] = + reinterpret_cast(cuda_tensor.get()); + cu::CudaDelegateHandle cuda_handle = fake_container_handle(&cuda_container); + + FakeContainer cpu_container; + cpu_container.internal_names = {"cpu_internal"}; + cpu_container.fqns = {"model.state"}; + cpu_container.extracted["model.state"] = + reinterpret_cast(cpu_tensor.get()); + cu::CudaDelegateHandle cpu_handle = fake_container_handle(&cpu_container); + + cu::MutableStateContextOwner c; + c.register_fqns({"model.state"}); + c.with_load_scope([&] { + cu::mutable_state_note_handle(&cuda_handle); + cu::mutable_state_note_handle(&cpu_handle); + }); + + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.validate_coverage(), Error::InvalidArgument); + EXPECT_FALSE(c.available()); + EXPECT_EQ(c.create_session().error(), Error::InvalidArgument); + EXPECT_EQ(cpu_container.update_calls, 0u); + + cudaFree(cuda_ptr); +}