diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index a88ba102..d99b8b7c 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -328,6 +328,35 @@ impl LogUpSingleKeyTable { assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let v_table = logup_poly_val(builder, &table_combined, query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } } pub struct LogUpRangeProofTable { @@ -455,6 +484,25 @@ impl LogUpRangeProofTable { ); assert_eq_rational(builder, &v_table, &v_query); } + + pub fn final_check_with_query_count>( + &mut self, + builder: &mut B, + query_count: &[Variable], + ) { + let alpha = builder.get_random_value(); + + let v_table = logup_poly_val(builder, &self.table_keys, query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } } pub fn query_count_hint(inputs: &[F], outputs: &mut [F]) -> Result<(), Error> { diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 40081f72..a929ffac 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -888,6 +888,17 @@ impl>> Context { ContextState::WitnessDone, "Please finish computation graph and witness solving before exporting device memories." ); + self.export_device_memories_impl() + } + + /// Export device memories without checking the context state. + /// Use this when you need to export memories outside the normal workflow, + /// e.g., for memory optimization where you want to export and then drop the context. + pub fn export_device_memories_unchecked(&self) -> Vec>> { + self.export_device_memories_impl() + } + + fn export_device_memories_impl(&self) -> Vec>> { self.device_memories .iter() .map(|dm| { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 7d7fed98..6a559fa1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -73,3 +73,15 @@ where wait_async(ClientHttpHelper::request_exit()) } } + +impl ExpanderNoOverSubscribe +where + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, +{ + /// Lightweight prove that doesn't require computation_graph or prover_setup. + /// Use this after setup() to allow releasing those large data structures before proving. + pub fn prove_lightweight(device_memories: Vec>>) { + client_send_witness_and_prove::(device_memories); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..64b39c03 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -112,7 +112,11 @@ where let mpi_size = if allow_oversubscribe { max_parallel_count } else { - let num_cpus = prev_power_of_two(num_cpus::get_physical()); + let num_cpus = std::env::var("ZKML_NUM_CPUS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or_else(num_cpus::get_physical); + let num_cpus = prev_power_of_two(num_cpus); if max_parallel_count > num_cpus { num_cpus } else { @@ -136,7 +140,11 @@ where setup_timer.stop(); - SharedMemoryEngine::read_pcs_setup_from_shared_memory() + // Prover setup not needed on client side (server does the proving). + // Verifier setup is required for verification, so read it from shared memory. + let (_prover_setup, verifier_setup) = + SharedMemoryEngine::read_pcs_setup_from_shared_memory::(); + (ExpanderProverSetup::default(), verifier_setup) } pub fn client_send_witness_and_prove( @@ -148,8 +156,39 @@ where { let timer = Timer::new("prove", true); + // Reset ack signal, then write witness + SharedMemoryEngine::reset_witness_ack(); SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); - wait_async(ClientHttpHelper::request_prove()); + + #[cfg(all(target_os = "linux", target_env = "gnu"))] + { + extern "C" { + fn malloc_trim(pad: usize) -> i32; + } + unsafe { + malloc_trim(0); + } + } + + // Async: send prove request + poll for witness ack to release shared memory early + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let prove_handle = tokio::spawn(async { + ClientHttpHelper::request_prove().await; + }); + + // Poll witness_ack; once server confirms read, release witness shared memory + tokio::task::spawn_blocking(|| { + SharedMemoryEngine::wait_for_witness_read_complete(); + unsafe { + super::shared_memory_utils::SHARED_MEMORY.witness = None; + } + }) + .await + .expect("Witness cleanup task failed"); + + prove_handle.await.expect("Prove task failed"); + }); let proof = SharedMemoryEngine::read_proof_from_shared_memory(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 27919a50..f51dd509 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -149,6 +149,9 @@ where let mut witness_win = state.wt_shared_memory_win.lock().await; S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win); + // Signal client: witness has been read, shared memory can be released + SharedMemoryEngine::signal_witness_read_complete(); + let prover_setup_guard = state.prover_setup.lock().await; let computation_graph = state.computation_graph.lock().await; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 648f33a8..b03aa639 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -18,12 +18,15 @@ pub struct SharedMemory { pub pcs_setup: Option, pub witness: Option, pub proof: Option, + /// 1-byte signal: 0 = witness not read, 1 = server finished reading witness + pub witness_ack: Option, } pub static mut SHARED_MEMORY: SharedMemory = SharedMemory { pcs_setup: None, witness: None, proof: None, + witness_ack: None, }; pub struct SharedMemoryEngine {} @@ -106,6 +109,56 @@ impl SharedMemoryEngine { Self::read_object_from_shared_memory("pcs_setup", 0) } + /// Client: reset witness_ack to 0 (call before writing witness) + pub fn reset_witness_ack() { + unsafe { + Self::allocate_shared_memory_if_necessary( + &mut SHARED_MEMORY.witness_ack, + "witness_ack", + 1, + ); + let ptr = SHARED_MEMORY.witness_ack.as_mut().unwrap().as_ptr(); + std::ptr::write_volatile(ptr, 0u8); + } + } + + /// Server: set witness_ack to 1 (call after reading witness) + pub fn signal_witness_read_complete() { + let shmem = ShmemConf::new() + .flink("witness_ack") + .open() + .expect("Failed to open witness_ack shared memory"); + unsafe { + std::ptr::write_volatile(shmem.as_ptr(), 1u8); + } + } + + /// Client: poll until witness_ack becomes 1, with a timeout to avoid hanging + /// if the server crashes. + pub fn wait_for_witness_read_complete() { + const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); + let start = std::time::Instant::now(); + unsafe { + let ptr = SHARED_MEMORY + .witness_ack + .as_ref() + .expect("witness_ack not initialized, call reset_witness_ack first") + .as_ptr() as *const u8; + loop { + if std::ptr::read_volatile(ptr) != 0 { + break; + } + if start.elapsed() > TIMEOUT { + panic!( + "Timed out waiting for server to read witness ({}s)", + TIMEOUT.as_secs() + ); + } + std::thread::sleep(std::time::Duration::from_millis(10)); + } + } + } + pub fn write_witness_to_shared_memory(values: Vec>) { let total_size = std::mem::size_of::() + values