diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 593e3f8d..905609fc 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -1151,6 +1151,112 @@ fn packed_eq_poly>(eval: &[EF], scalar: EF) -> E EF::ExtensionPacking::from_ext_slice(&buffer) } +/// Tensor-tail variant of [`eval_eq`]: returns the table +/// `out[(hi << log2(tail.len())) | lo] = eq(eval)[hi] * tail[lo]`, +/// i.e. the eq expansion of `eval` tensored with an arbitrary vector `tail` +/// occupying the lowest `log2(tail.len())` variables. `tail.len()` must be a +/// power of two. With `tail = eval_eq(extra)` this equals +/// `eval_eq(concat(eval, extra))` exactly. +pub fn eval_eq_with_tail>>(eval: &[F], tail: &[F]) -> ArenaVec { + let log_tail = log2_strict_usize(tail.len()); + let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() + log_tail)) }; + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= 1 + log_chunks { + eval_eq_tail_kernel(eval, &mut out, F::ONE, tail); + return out; + } + let mut buffer = F::zero_vec(n_chunks); + buffer[0] = F::ONE; + fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer); + let middle = &eval[log_chunks..]; + let out_chunk_size = out.len() / n_chunks; + par_chunks_zip(&mut out, out_chunk_size, &buffer, |out_chunk, &b| { + eval_eq_tail_kernel(middle, out_chunk, b, tail); + }); + out +} + +/// Packed-output variant of [`eval_eq_with_tail`]; same table, packed like +/// [`eval_eq_packed`]. Requires `eval.len() + log2(tail.len()) >= packing_log_width`. +pub fn eval_eq_packed_with_tail>>(eval: &[F], tail: &[F]) -> ArenaVec> { + let w = packing_log_width::(); + let k = log2_strict_usize(tail.len()); + assert!(eval.len() + k >= w); + if k < w { + // Absorb the (w − k) trailing `eval` variables into the tail so the + // packed lanes are fully covered by the (extended) tail. + let absorb = w - k; + let eq_absorb = eval_eq(&eval[eval.len() - absorb..]); + let mut extended_tail = F::zero_vec(1 << w); + for (a, &ea) in eq_absorb.iter().enumerate() { + for (lo, &t) in tail.iter().enumerate() { + extended_tail[(a << k) | lo] = ea * t; + } + } + return eval_eq_packed_with_tail(&eval[..eval.len() - absorb], &extended_tail); + } + let tail_packed: Vec> = pack_extension(tail); + let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() + k - w)) }; + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= 1 + log_chunks { + eval_eq_tail_kernel_packed::(eval, &mut out, F::ONE, &tail_packed); + return out; + } + let mut buffer = F::zero_vec(n_chunks); + buffer[0] = F::ONE; + fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer); + let middle = &eval[log_chunks..]; + let out_chunk_size = out.len() / n_chunks; + par_chunks_zip(&mut out, out_chunk_size, &buffer, |out_chunk, &b| { + eval_eq_tail_kernel_packed::(middle, out_chunk, b, &tail_packed); + }); + out +} + +/// Recursive kernel for [`eval_eq_with_tail`]: standard eq split on `eval`, +/// with the leaf writing `scalar * tail` instead of a single scalar. +#[inline] +fn eval_eq_tail_kernel(eval: &[F], out: &mut [F], scalar: F, tail: &[F]) { + debug_assert_eq!(out.len(), tail.len() << eval.len()); + match eval.split_first() { + None => { + out.iter_mut().zip(tail).for_each(|(o, &t)| *o = t * scalar); + } + Some((&x, rest)) => { + let (low, high) = out.split_at_mut(out.len() / 2); + let s1 = scalar * x; + let s0 = scalar - s1; + eval_eq_tail_kernel(rest, low, s0, tail); + eval_eq_tail_kernel(rest, high, s1, tail); + } + } +} + +/// Recursive kernel for [`eval_eq_packed_with_tail`] (requires the tail to +/// cover at least the packing width, guaranteed by the absorption step). +#[inline] +fn eval_eq_tail_kernel_packed>>( + eval: &[F], + out: &mut [EFPacking], + scalar: F, + tail_packed: &[EFPacking], +) { + debug_assert_eq!(out.len(), tail_packed.len() << eval.len()); + match eval.split_first() { + None => { + let b = EFPacking::::from(scalar); + out.iter_mut().zip(tail_packed).for_each(|(o, &t)| *o = t * b); + } + Some((&x, rest)) => { + let (low, high) = out.split_at_mut(out.len() / 2); + let s1 = scalar * x; + let s0 = scalar - s1; + eval_eq_tail_kernel_packed::(rest, low, s0, tail_packed); + eval_eq_tail_kernel_packed::(rest, high, s1, tail_packed); + } + } +} + #[cfg(test)] mod tests { use std::time::Instant; @@ -1322,4 +1428,55 @@ mod tests { assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars); } } + + #[test] + fn test_eval_eq_with_tail_matches_point_append() { + let mut rng = StdRng::seed_from_u64(7); + for n_vars in [2usize, 5, 9, 12] { + for k in [2usize, 3, 4] { + let eval: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let extra: Vec = (0..k).map(|_| rng.random()).collect(); + let tail = eval_eq(&extra); + let with_tail = eval_eq_with_tail(&eval, &tail); + let appended = eval_eq(&[eval.clone(), extra.clone()].concat()); + assert_eq!(with_tail.as_slice(), appended.as_slice(), "n={n_vars} k={k}"); + } + } + } + + #[test] + fn test_eval_eq_with_tail_random_tail_brute_force() { + let mut rng = StdRng::seed_from_u64(8); + let n_vars = 6; + let k = 3; + let eval: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let tail: Vec = (0..1 << k).map(|_| rng.random()).collect(); + let with_tail = eval_eq_with_tail(&eval, &tail); + let eq_hi = eval_eq(&eval); + for hi in 0..1usize << n_vars { + for lo in 0..1usize << k { + assert_eq!(with_tail[(hi << k) | lo], eq_hi[hi] * tail[lo]); + } + } + } + + #[test] + fn test_eval_eq_packed_with_tail_matches_unpacked() { + let mut rng = StdRng::seed_from_u64(9); + let w = packing_log_width::(); + for n_vars in [2usize, 5, 9, 12] { + // Cover both k < w and k >= w paths regardless of the platform width. + for k in [2usize, 3, 4, 5] { + if n_vars + k < w { + continue; + } + let eval: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let tail: Vec = (0..1 << k).map(|_| rng.random()).collect(); + let unpacked = eval_eq_with_tail(&eval, &tail); + let packed = eval_eq_packed_with_tail(&eval, &tail); + let unpacked_from_packed: Vec = unpack_extension(&packed); + assert_eq!(unpacked.as_slice(), &unpacked_from_packed[..], "n={n_vars} k={k}"); + } + } + } } diff --git a/crates/backend/poly/src/next_mle.rs b/crates/backend/poly/src/next_mle.rs index af8c08e9..8a9fbdcc 100644 --- a/crates/backend/poly/src/next_mle.rs +++ b/crates/backend/poly/src/next_mle.rs @@ -1,7 +1,8 @@ +use ::utils::log2_strict_usize; use field::{ExtensionField, Field, PrimeCharacteristicRing}; use zk_alloc::ArenaVec; -use crate::{PF, eval_eq_scaled}; +use crate::{PF, eval_eq_scaled, eval_eq_with_tail, to_big_endian_in_field}; /// Evaluates the "next" multilinear polynomial at two n-variable points (x, y). /// @@ -54,6 +55,39 @@ where res } +/// Tensor-tail variant of [`next_mle`]: +/// `Σ_x tail[x] · next_mle(concat(prefix, bits(x)), y)`, where `bits(x)` is +/// big-endian over `log2(tail.len())` variables (matching `eval_eq` indexing). +/// Verifier-side: the loop over the `2^k` cube points is intentional. +pub fn next_mle_with_tail(prefix: &[F], tail: &[F], y: &[F]) -> F { + let k = log2_strict_usize(tail.len()); + debug_assert_eq!(prefix.len() + k, y.len()); + let mut sum = F::ZERO; + for (x, &t) in tail.iter().enumerate() { + let mut point = prefix.to_vec(); + point.extend(to_big_endian_in_field::(x, k)); + sum += next_mle(&point, y) * t; + } + sum +} + +/// Tensor-tail variant of [`matrix_next_mle_folded`]: the dense vector +/// `w[y] = Σ_x tail[x] · next_mle(concat(prefix, bits(x)), y)`. +/// +/// Since `next_mle` is multilinear in its first argument, +/// `w[y] = Σ_j v[j] · next_mle(j, y)` with `v = eval_eq_with_tail(prefix, tail)`, +/// and `next_mle(j, y) = 1` iff `y = j + 1`, plus the wrap-around +/// `next_mle(2^n − 1, 2^n − 1) = 1` (see [`next_mle`]). Hence `w` is the +/// shift-by-one of `v`, with `w[last] += v[last]`. +pub fn matrix_next_mle_folded_with_tail>>(prefix: &[F], tail: &[F]) -> ArenaVec { + let v = eval_eq_with_tail(prefix, tail); + let n = v.len(); + let mut res = unsafe { ArenaVec::::zeroed(n) }; + res[1..].copy_from_slice(&v[..n - 1]); + res[n - 1] += v[n - 1]; + res +} + #[cfg(test)] mod tests { use field::PrimeCharacteristicRing; @@ -81,4 +115,68 @@ mod tests { } } } + + #[test] + fn test_next_mle_with_tail_brute_force() { + use koala_bear::QuinticExtensionFieldKB; + use rand::{RngExt, SeedableRng, rngs::StdRng}; + + use crate::next_mle_with_tail; + type EF = QuinticExtensionFieldKB; + + let mut rng = StdRng::seed_from_u64(11); + for k in [2usize, 3] { + let n_prefix = 5 - k; + let prefix: Vec = (0..n_prefix).map(|_| rng.random()).collect(); + let tail: Vec = (0..1 << k).map(|_| rng.random()).collect(); + let y: Vec = (0..5).map(|_| rng.random()).collect(); + let direct = next_mle_with_tail(&prefix, &tail, &y); + let mut brute = EF::ZERO; + for (x, &t) in tail.iter().enumerate() { + let mut point = prefix.clone(); + point.extend(to_big_endian_in_field::(x, k)); + brute += next_mle(&point, &y) * t; + } + assert_eq!(direct, brute); + } + } + + #[test] + fn test_matrix_next_mle_folded_with_tail_matches_sum() { + use koala_bear::QuinticExtensionFieldKB; + use rand::{RngExt, SeedableRng, rngs::StdRng}; + + use crate::{matrix_next_mle_folded_with_tail, next_mle_with_tail}; + type EF = QuinticExtensionFieldKB; + + let mut rng = StdRng::seed_from_u64(12); + for k in [2usize, 3] { + let n_prefix = 5 - k; + let prefix: Vec = (0..n_prefix).map(|_| rng.random()).collect(); + let tail: Vec = (0..1 << k).map(|_| rng.random()).collect(); + + let folded = matrix_next_mle_folded_with_tail(&prefix, &tail); + + // Elementwise against the sum of per-cube-point folded matrices. + let mut expected = EF::zero_vec(1 << 5); + for (x, &t) in tail.iter().enumerate() { + let mut point = prefix.clone(); + point.extend(to_big_endian_in_field::(x, k)); + for (e, &m) in expected.iter_mut().zip(matrix_next_mle_folded(&point).iter()) { + *e += m * t; + } + } + assert_eq!(folded.as_slice(), &expected[..]); + + // Consistency with the pointwise variant: the folded vector's MLE at a + // boolean point y equals next_mle_with_tail(prefix, tail, y). + for y in 0..1usize << 5 { + let y_bools = to_big_endian_in_field::(y, 5); + assert_eq!( + folded.evaluate(&MultilinearPoint(y_bools.clone())), + next_mle_with_tail(&prefix, &tail, &y_bools) + ); + } + } + } } diff --git a/crates/backend/sumcheck/src/lib.rs b/crates/backend/sumcheck/src/lib.rs index 5ec3c06b..830edf65 100644 --- a/crates/backend/sumcheck/src/lib.rs +++ b/crates/backend/sumcheck/src/lib.rs @@ -14,3 +14,6 @@ pub use sc_computation::*; mod product_computation; pub use product_computation::*; + +mod univariate_skip; +pub use univariate_skip::*; diff --git a/crates/backend/sumcheck/src/univariate_skip.rs b/crates/backend/sumcheck/src/univariate_skip.rs new file mode 100644 index 00000000..9b358040 --- /dev/null +++ b/crates/backend/sumcheck/src/univariate_skip.rs @@ -0,0 +1,154 @@ +//! Primitives for a univariate-skip round of a batched sumcheck +//! (Gruen, eprint 2024/108 §5-6). +//! +//! Domain convention: the base-window node for cube point `x ∈ {0,1}^k` +//! (big-endian bit order, matching `eval_eq` indexing) is the small integer +//! `x` itself: `node_x = F::from_usize(x)`. Extended targets continue at +//! `2^k, 2^k + 1, …` ascending. This integer window keeps all skip-kernel +//! evaluation points small (cheap base-field multiplications) and makes the +//! python-verifier / zkDSL mirrors trivial (`bits(i) ↔ node i`), at no cost +//! compared to a multiplicative subgroup: the skipped sum is non-zero here, +//! so no "free zero" evaluations exist either way. + +use field::*; +use poly::{PF, eval_eq, lagrange_basis_evals}; +use zk_alloc::ArenaVec; + +/// The `2^k` base-window nodes, indexed by cube point `x ∈ 0..2^k`. +pub fn skip_domain_points(k: usize) -> Vec { + (0..1usize << k).map(F::from_usize).collect() +} + +/// All evaluation nodes of the skip-round polynomial: the `2^k` window nodes +/// followed by the extended targets, `(2^k − 1)·air_degree + 1` nodes total +/// (enough to determine a polynomial of degree `(2^k − 1)·air_degree`). +pub fn skip_all_nodes(k: usize, air_degree: usize) -> Vec { + let n_nodes = ((1usize << k) - 1) * air_degree + 1; + debug_assert!(n_nodes >= 1 << k); + (0..n_nodes).map(F::from_usize).collect() +} + +/// For each target `z` (typically `skip_all_nodes[2^k..]`), the `2^k` Lagrange +/// basis coefficients `L_x(z)` over the base window. Runs once per prove. +pub fn lagrange_coeffs_for_targets(k: usize, targets: &[F]) -> Vec> { + lagrange_basis_evals(&skip_domain_points::(k), targets) +} + +/// `L_x(r0)` for all `2^k` window nodes, at an extension-field point `r0`. +/// Used for the `2^k → 1` column fold after the skip challenge and as the +/// tensor tail of the WHIR opening weights. +pub fn lagrange_weights_at>(k: usize, r0: EF) -> Vec { + let nodes = skip_domain_points::(k); + let n = nodes.len(); + let den_invs: Vec = (0..n) + .map(|i| { + (0..n) + .filter(|&j| j != i) + .map(|j| nodes[i] - nodes[j]) + .fold(F::ONE, |acc, d| acc * d) + .inverse() + }) + .collect(); + (0..n) + .map(|i| { + let num = (0..n) + .filter(|&j| j != i) + .map(|j| r0 - EF::from(nodes[j])) + .fold(EF::ONE, |acc, d| acc * d); + num * den_invs[i] + }) + .collect() +} + +/// `eq(eq_top, bits(x))` for `x ∈ 0..2^k` — the window values of the eq kernel +/// `ê` (the part of the zerocheck eq factor carried by the skipped variables). +pub fn e_hat_on_window>>(eq_top: &[EF]) -> ArenaVec { + eval_eq(eq_top) +} + +/// `ê(r0) = Σ_x eq(eq_top, bits(x)) · L_x(r0)`: the degree-`(2^k − 1)` +/// univariate extension of the eq kernel over the window, at `r0`. +pub fn e_hat_at>>(eq_top: &[EF], r0: EF) -> EF { + let weights = lagrange_weights_at::, EF>(eq_top.len(), r0); + e_hat_on_window(eq_top) + .iter() + .zip(&weights) + .map(|(&e, &w)| e * w) + .fold(EF::ZERO, |acc, t| acc + t) +} + +#[cfg(test)] +mod tests { + use koala_bear::{KoalaBear, QuinticExtensionFieldKB}; + + use super::*; + + type F = KoalaBear; + type EF = QuinticExtensionFieldKB; + + /// Deterministic scattered field elements (the assertions below are + /// polynomial identities — any distinct values exercise them). + fn test_scalar_f(i: usize) -> F { + F::from_usize(3).exp_u64(7 * i as u64 + 5) + } + fn test_scalar_ef(i: usize) -> EF { + EF::from_basis_coefficients_fn(|j| test_scalar_f(13 * i + j)) + } + + #[test] + fn test_lagrange_weights_delta_on_nodes() { + for k in [3, 4] { + let nodes = skip_domain_points::(k); + for (y, &node_y) in nodes.iter().enumerate() { + let weights = lagrange_weights_at::(k, EF::from(node_y)); + for (x, &w) in weights.iter().enumerate() { + let expected = if x == y { EF::ONE } else { EF::ZERO }; + assert_eq!(w, expected, "k={k} x={x} y={y}"); + } + } + } + } + + #[test] + fn test_lagrange_coeffs_for_targets_reconstruct() { + // A degree-(2^k − 1) polynomial is reconstructed exactly at the targets + // from its window values. + let k = 3; + let coeffs: Vec = (0..1 << k).map(test_scalar_f).collect(); + let poly_eval = |x: F| coeffs.iter().rfold(F::ZERO, |acc, &c| acc * x + c); + let all_nodes = skip_all_nodes::(k, 5); + let window = &all_nodes[..1 << k]; + let targets = &all_nodes[1 << k..]; + assert_eq!(all_nodes.len(), 7 * 5 + 1); + let lags = lagrange_coeffs_for_targets::(k, targets); + for (t, &z) in targets.iter().enumerate() { + let interp = lags[t] + .iter() + .zip(window) + .map(|(&l, &w)| l * poly_eval(w)) + .fold(F::ZERO, |a, b| a + b); + assert_eq!(interp, poly_eval(z)); + } + } + + #[test] + fn test_e_hat_matches_eq_on_window() { + for k in [3, 4] { + let eq_top: Vec = (0..k).map(test_scalar_ef).collect(); + let window = e_hat_on_window(&eq_top); + for x in 0..1usize << k { + let at_node = e_hat_at(&eq_top, EF::from_usize(x)); + assert_eq!(at_node, window[x], "k={k} x={x}"); + } + // And at a random point, ê is consistent with the Lagrange weights. + let r0: EF = test_scalar_ef(40 + k); + let weights = lagrange_weights_at::(k, r0); + let direct = window + .iter() + .zip(&weights) + .map(|(&e, &w)| e * w) + .fold(EF::ZERO, |a, b| a + b); + assert_eq!(e_hat_at(&eq_top, r0), direct); + } + } +} diff --git a/crates/lean_prover/python-verifier/verifier.py b/crates/lean_prover/python-verifier/verifier.py index f412238b..4c73db23 100644 --- a/crates/lean_prover/python-verifier/verifier.py +++ b/crates/lean_prover/python-verifier/verifier.py @@ -26,6 +26,7 @@ MIN_LOG_MEMORY_SIZE, MAX_LOG_MEMORY_SIZE = 16, 26 MIN_LOG_HEIGHT_PER_TABLE, MIN_BYTECODE_LOG_SIZE, MAX_BYTECODE_LOG_SIZE = 8, 8, 22 N_VARS_TO_SEND_GKR_COEFFS = 5 +SKIP_K = 4 # univariate-skip width of the batched AIR sumcheck; must equal sub_protocols::UNIVARIATE_SKIP_K (Rust) N_RUNTIME_COLUMNS, N_INSTRUCTION_COLUMNS = 8, 12 @@ -244,6 +245,13 @@ def next_mle(x: Sequence[EF], y: Sequence[EF]) -> EF: return s + math.prod([*x, *y]) +def next_mle_with_tail(prefix: Sequence[EF], tail: Sequence[EF], y: Sequence[EF]) -> EF: + """Tensor-tail shifted MLE: Σ_x tail[x] · next_mle(prefix ++ bits_be(x), y).""" + k = log2_strict(len(tail)) + bits = lambda x: [ONE if (x >> (k - 1 - j)) & 1 else ZERO for j in range(k)] + return sum(t * next_mle([*prefix, *bits(x)], y) for x, t in enumerate(tail)) + + def eval_multilinear_by_evals(evals: Sequence[Fp | EF], point: Sequence[EF]) -> EF: """Evaluate a multilinear in evaluation form at `point`.""" assert len(evals) == 1 << len(point) @@ -304,10 +312,15 @@ class SparseStatements: point: list[EF] # low-bits variables (suffix), shared by every entry in `values` values: list[tuple[int, EF]] # (selector_index, eval): poly(high bits = selector_index, low bits = point) == eval is_next: bool = False # if set, the low-variable part uses the shifted "next-row" MLE instead of plain eq + tail: list[EF] | None = None # if set, weight = eq(point) ⊗ MLE(tail) with the tail on the LOWEST log2(len(tail)) inner variables + + @property + def inner_num_variables(self) -> int: + return len(self.point) + (log2_strict(len(self.tail)) if self.tail is not None else 0) @property def selector_num_variables(self) -> int: - return self.total_num_variables - len(self.point) # count of high/selector bits that selector_index spans + return self.total_num_variables - self.inner_num_variables # count of high/selector bits that selector_index spans def whir_folding_factor_at_round(round: int) -> int: @@ -350,6 +363,42 @@ def verify_sumcheck(fiat_shamir: FiatShamir, target: EF, n_rounds: int, degree: return point, target +def verify_air_sumcheck_with_skip( + fiat_shamir: FiatShamir, + table_sums: list[EF], + table_n_vars: list[int], + table_degrees: list[int], + eq_top: list[EF], + max_full_degree: int, + n_max: int, +) -> tuple[EF, list[EF], list[EF], EF]: + """Univariate-skip batched AIR sumcheck (Gruen eprint 2024/108 §5-6); mirrors + sub_protocols::verify_batched_air_sumcheck_uniskip. Round 0 binds the SKIP_K lowest + row bits of EVERY table at once: the prover sends ONE combined polynomial + P(X) = Σ_t w_t·v'_t(X) (w_t = 2^(n_max−n_t), front-loaded batching) in coefficient + form over the integer window D = {0..2^SKIP_K−1} (node for cube point x is x itself). + Round-0 identity: Σ_{z∈D} ê(z)·P(z) == Σ_t w_t·s_t, where ê = eq(eq_top, bits(·)) on D. + Then target = ê(r0)·P(r0) and the remaining n_max−SKIP_K rounds are standard.""" + window = 1 << SKIP_K + n_coeffs = (window - 1) * max(table_degrees) + 1 + coeffs = fiat_shamir.next_extension_scalars_vec(n_coeffs) + e_hat = eval_eq(eq_top) + window_sum = sum(e_hat[z] * eval_univariate_polynomial(coeffs, EF(z)) for z in range(window)) + claimed = sum(EF(1 << (n_max - n_t)) * s_t for n_t, s_t in zip(table_n_vars, table_sums)) + assert window_sum == claimed, "AIR skip round: weighted window identity failed" + r0 = fiat_shamir.sample_ef() + # Lagrange basis L_x(r0) over the window nodes + nodes = [EF(x) for x in range(window)] + lagrange_weights = [ + math.prod(r0 - nodes[j] for j in range(window) if j != i) + * math.prod(nodes[i] - nodes[j] for j in range(window) if j != i).inv() + for i in range(window) + ] + target = dot_product(e_hat, lagrange_weights) * eval_univariate_polynomial(coeffs, r0) # = ê(r0)·P(r0) + linear_challenges, final_value = verify_sumcheck(fiat_shamir, target, n_max - SKIP_K, max_full_degree) + return r0, lagrange_weights, linear_challenges, final_value + + def verify_whir( fiat_shamir: FiatShamir, cfg: dict, @@ -419,8 +468,14 @@ def verify_whir( folding_challenges = folding_challenges[whir_folding_factor_at_round(round - 1) :] gamma_power = ONE for smt in smts: - point_suffix = folding_challenges[len(folding_challenges) - len(smt.point) :] # dense part of the point - eval_suffix = next_mle(smt.point, point_suffix) if smt.is_next else eq_poly(smt.point, point_suffix) + point_suffix = folding_challenges[len(folding_challenges) - smt.inner_num_variables :] # dense part of the point + if smt.tail is None: + eval_suffix = next_mle(smt.point, point_suffix) if smt.is_next else eq_poly(smt.point, point_suffix) + elif smt.is_next: + eval_suffix = next_mle_with_tail(smt.point, smt.tail, point_suffix) + else: # weight = eq(point, prefix) · MLE(tail) at the lowest log2(len(tail)) inner coords + prefix, low = point_suffix[: len(smt.point)], point_suffix[len(smt.point) :] + eval_suffix = eq_poly(smt.point, prefix) * eval_multilinear_by_evals(smt.tail, low) sel_n = smt.selector_num_variables for v in smt.values: eval_prefix = eq_at_index(folding_challenges, v[0], sel_n) # sparse part of the point @@ -866,16 +921,28 @@ def verify_execution( alpha = fiat_shamir.sample_ef() alpha_powers = ef_powers(alpha, sum(t.n_constraints for t in TABLES)) - initial_sum, offset = ZERO, 0 + table_sums, offset = [], 0 for table in TABLES: - initial_sum += alpha_powers[offset] * (precompile_nums[table.name] * table.precompile_bus_interaction_sign) - initial_sum += alpha_powers[offset + 1] * (logup_gamma - precompile_dens[table.name]) + table_sums.append( + alpha_powers[offset] * (precompile_nums[table.name] * table.precompile_bus_interaction_sign) + + alpha_powers[offset + 1] * (logup_gamma - precompile_dens[table.name]) + ) offset += table.n_constraints - # 3] verify batched AIR sumcheck - sc_point, sc_value = verify_sumcheck(fiat_shamir, initial_sum, n_max, max(t.air_degree + 1 for t in TABLES)) + # 3] verify batched AIR sumcheck (univariate-skip round 0 + standard linear rounds) + eq_top = gkr_point[-SKIP_K:] # shared by all tables: every eq factor is a suffix of gkr_point + r0, lagrange_weights, linear_challenges, sc_value = verify_air_sumcheck_with_skip( + fiat_shamir, + table_sums, + [table_log_heights[t.name] for t in TABLES], + [t.air_degree for t in TABLES], + eq_top, + max(t.air_degree + 1 for t in TABLES), + n_max, + ) + e_hat_r0 = dot_product(eval_eq(eq_top), lagrange_weights) - committed_column_evals = {t.name: [(gkr_point[-table_log_heights[t.name] :], columns_evals[t.name], {})] for t in TABLES} + committed_column_evals = {t.name: [(gkr_point[-table_log_heights[t.name] :], columns_evals[t.name], {}, None)] for t in TABLES} air_final_value, offset = ZERO, 0 for table in TABLES: log_height = table_log_heights[table.name] @@ -883,11 +950,13 @@ def verify_execution( alphas = alpha_powers[offset : offset + table.n_constraints] offset += table.n_constraints constraint_eval = table.eval_air(col_evals, alphas, logup_beta_eq) - natural_point = list(reversed(sc_point[-log_height:])) - air_final_value += math.prod(sc_point[:-log_height]) * eq_poly(gkr_point[-log_height:], natural_point) * constraint_eval + # Final identity: target == Σ_t ê(r0) · eq(eq_factor_t[..n_t−K], natural_prefix_t) · C_t(col_evals_t); + # the SKIP_K lowest row bits of every table's opening live in the Lagrange tail. + natural_prefix = list(reversed(linear_challenges[: log_height - SKIP_K])) + air_final_value += e_hat_r0 * eq_poly(gkr_point[-log_height:][: log_height - SKIP_K], natural_prefix) * constraint_eval eq_vals = {i: col_evals[i] for i in range(table.n_columns)} next_vals = {j: col_evals[table.n_columns + j] for j in range(table.n_shift)} - committed_column_evals[table.name].append((natural_point, eq_vals, next_vals)) + committed_column_evals[table.name].append((natural_prefix, eq_vals, next_vals, lagrange_weights)) assert air_final_value == sc_value, "AIR sumcheck: claimed value mismatch" public_memory_point = fiat_shamir.sample_many_ef(log2_strict(PUBLIC_INPUT_SIZE)) @@ -917,10 +986,10 @@ def values_at(d: dict[int, EF], col_base: int) -> list[tuple[int, EF]]: offset = table_offsets[table.name] col_base = offset >> log_height pcs_statements.extend(table.boundary_statements(stacked_n_vars, offset, log_height, ending_pc)) - for point, eq_values, next_values in committed_column_evals[table.name]: + for point, eq_values, next_values, tail in committed_column_evals[table.name]: if next_values: - pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(next_values, col_base), True)) - pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(eq_values, col_base))) + pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(next_values, col_base), True, tail)) + pcs_statements.append(SparseStatements(stacked_n_vars, point, values_at(eq_values, col_base), tail=tail)) # 4] Open the PCS verify_whir(fiat_shamir, cfg, parsed_commitment, pcs_statements) diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index dc46d2df..7aeba2c4 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -140,11 +140,12 @@ pub fn prove_execution( let log_n_rows = traces[&table].log_n_rows; committed_statements.insert( table, - vec![( - MultilinearPoint(from_end(gkr_point, log_n_rows).to_vec()), - logup_statements.columns_values[&table].clone(), - BTreeMap::new(), - )], + vec![CommittedClaim { + point: MultilinearPoint(from_end(gkr_point, log_n_rows).to_vec()), + tail: None, + eq_values: logup_statements.columns_values[&table].clone(), + next_values: BTreeMap::new(), + }], ); } @@ -202,27 +203,31 @@ pub fn prove_execution( macro_rules! make_session { ($t:expr) => {{ let session = AirSumcheckSession::new(packed, eq_suffix, bus_final_value, *$t, extra_data, non_padded); - Box::new(session) as Box + '_> + Box::new(session) as Box + '_> }}; } sessions.push(delegate_to_inner!(table => make_session)); alpha_offset += n_constraints; } - let sumcheck_air_point = - info_span!("batched AIR sumcheck").in_scope(|| prove_batched_air_sumcheck(&mut prover_state, &mut sessions)); + let uniskip_point = info_span!("batched AIR sumcheck") + .in_scope(|| prove_batched_air_sumcheck_uniskip(&mut prover_state, &mut sessions, UNIVARIATE_SKIP_K)); for (idx, table) in ALL_TABLES.iter().enumerate() { let col_evals = sessions[idx].final_column_evals(); prover_state.add_extension_scalars(&col_evals); - let natural_ordering_point = - natural_ordering_point_for_session(&sumcheck_air_point.0, traces[table].log_n_rows); + let natural_prefix = natural_prefix_for_session(&uniskip_point, traces[table].log_n_rows); macro_rules! split { - ($t:expr) => {{ columns_evals_flat_and_shift($t, &col_evals, &natural_ordering_point) }}; + ($t:expr) => {{ columns_evals_flat_and_shift($t, &col_evals, &natural_prefix) }}; } - let claim = delegate_to_inner!(table => split); - committed_statements.get_mut(table).unwrap().push(claim); + let (point, eq_values, next_values) = delegate_to_inner!(table => split); + committed_statements.get_mut(table).unwrap().push(CommittedClaim { + point, + tail: Some(uniskip_point.lagrange_weights.clone()), + eq_values, + next_values, + }); } let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(PUBLIC_INPUT_LEN))); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 3145edd3..7d65da8f 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -91,11 +91,12 @@ pub fn verify_execution( let log_n = table_n_vars[&table]; committed_statements.insert( table, - vec![( - MultilinearPoint(from_end(gkr_point, log_n).to_vec()), - logup_statements.columns_values[&table].clone(), - BTreeMap::new(), - )], + vec![CommittedClaim { + point: MultilinearPoint(from_end(gkr_point, log_n).to_vec()), + tail: None, + eq_values: logup_statements.columns_values[&table].clone(), + next_values: BTreeMap::new(), + }], ); } @@ -107,7 +108,7 @@ pub fn verify_execution( extra_data: ExtraDataForBuses, } let mut verify_data: Vec = Vec::new(); - let mut initial_sum = EF::ZERO; + let mut table_sums: Vec = Vec::new(); let mut alpha_offset = 0; for table in ALL_TABLES { @@ -119,8 +120,10 @@ pub fn verify_execution( BusDirection::Pull => EF::NEG_ONE, BusDirection::Push => EF::ONE, }; - initial_sum += air_alpha_powers[alpha_offset] * signed_numerator - + air_alpha_powers[alpha_offset + 1] * (logup_c - bus_denominator_value); + table_sums.push( + air_alpha_powers[alpha_offset] * signed_numerator + + air_alpha_powers[alpha_offset + 1] * (logup_c - bus_denominator_value), + ); let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); verify_data.push(TableVerifyData { @@ -133,14 +136,28 @@ pub fn verify_execution( let max_full_degree = ALL_TABLES.iter().map(|t| t.degree_air() + 1).max().unwrap(); - let n_max = *table_n_vars.values().max().unwrap(); - let Evaluation { - point: sumcheck_air_point, - value: claimed_air_final_value, - } = sumcheck_verify(&mut verifier_state, n_max, max_full_degree, initial_sum, None)?; + // Univariate-skip batched AIR sumcheck (see sub_protocols::air_sumcheck_skip): + // round 0 binds the K lowest row bits of every table with one univariate + // round whose identity is the w_t-weighted window sum of the per-table + // claims; the remaining n_max − K rounds are the legacy combined rounds. + let table_n_vars_ordered: Vec = ALL_TABLES.iter().map(|t| table_n_vars[t]).collect(); + let table_degrees: Vec = ALL_TABLES.iter().map(|t| t.degree_air()).collect(); + let eq_top = from_end(gkr_point, UNIVARIATE_SKIP_K); + let (uniskip_point, claimed_air_final_value) = verify_batched_air_sumcheck_uniskip( + &mut verifier_state, + UNIVARIATE_SKIP_K, + &table_n_vars_ordered, + &table_degrees, + &table_sums, + eq_top, + max_full_degree, + )?; + // Final identity: target == Σ_t ê(r0) · eq(eq_factor_t[..n_t−K], natural_prefix_t) · C_t(col_evals_t). + let e_hat_r0 = e_hat_at(eq_top, uniskip_point.r0); let mut my_air_final_value = EF::ZERO; for vd in &verify_data { + let n_t = table_n_vars[&vd.table]; let n_cols_total = vd.table.n_columns() + vd.table.n_shift_columns(); let col_evals = verifier_state.next_extension_scalars_vec(n_cols_total)?; @@ -149,21 +166,23 @@ pub fn verify_execution( } let constraint_eval = delegate_to_inner!(&vd.table => eval_constraint); - let bus_point = from_end(gkr_point, table_n_vars[&vd.table]); - let natural_ordering_point = natural_ordering_point_for_session(&sumcheck_air_point.0, table_n_vars[&vd.table]); - my_air_final_value += back_loaded_table_contribution( - bus_point, - &sumcheck_air_point.0, - &natural_ordering_point, - constraint_eval, - ); + let bus_point = from_end(gkr_point, n_t); + let natural_prefix = natural_prefix_for_session(&uniskip_point, n_t); + let eq_val = MultilinearPoint(bus_point[..n_t - UNIVARIATE_SKIP_K].to_vec()) + .eq_poly_outside(&MultilinearPoint(natural_prefix.clone())); + my_air_final_value += e_hat_r0 * eq_val * constraint_eval; macro_rules! split { - ($t:expr) => {{ columns_evals_flat_and_shift($t, &col_evals, &natural_ordering_point) }}; + ($t:expr) => {{ columns_evals_flat_and_shift($t, &col_evals, &natural_prefix) }}; } - let claim = delegate_to_inner!(&vd.table => split); + let (point, eq_values, next_values) = delegate_to_inner!(&vd.table => split); - committed_statements.get_mut(&vd.table).unwrap().push(claim); + committed_statements.get_mut(&vd.table).unwrap().push(CommittedClaim { + point, + tail: Some(uniskip_point.lagrange_weights.clone()), + eq_values, + next_values, + }); } if my_air_final_value != claimed_air_final_value { @@ -230,19 +249,3 @@ pub fn verify_execution( verifier_state.into_raw_proof(), )) } - -fn back_loaded_table_contribution>>( - bus_point: &[EF], - sumcheck_air_point: &[EF], - natural_ordering_point: &[EF], - constraint_eval: EF, -) -> EF { - let n_t = bus_point.len(); - let n_max = sumcheck_air_point.len(); - let suffix_start = n_max - n_t; - assert_eq!(natural_ordering_point.len(), n_t); - let eq_val = - MultilinearPoint(bus_point.to_vec()).eq_poly_outside(&MultilinearPoint(natural_ordering_point.to_vec())); - let k_t: EF = sumcheck_air_point[..suffix_start].iter().copied().product(); - k_t * eq_val * constraint_eval -} diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 54d931db..d865423c 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -6,9 +6,19 @@ use std::{any::TypeId, cmp::Reverse, collections::BTreeMap, mem::transmute}; pub type ColIndex = usize; -/// Each entry: (point, eval, eval at 'shifted-down' column). -pub type CommittedStatements = - BTreeMap, BTreeMap, BTreeMap)>>; +/// One batch of column-opening claims at a shared point: `eq_values[c] = f_c(point)` +/// and `next_values[c]` = the shifted-down column at the same point. For +/// univariate-skip openings the weight carries a tensor `tail` on the lowest +/// `log2(tail.len())` row bits: weight = eq(point, ·) ⊗ MLE(tail)(·). +#[derive(Debug, Clone)] +pub struct CommittedClaim { + pub point: MultilinearPoint, + pub tail: Option>, + pub eq_values: BTreeMap, + pub next_values: BTreeMap, +} + +pub type CommittedStatements = BTreeMap>; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BusDirection { diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 7f9bf76e..6ba178b7 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -7,7 +7,7 @@ use lean_prover::{ use lean_vm::*; use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::OnceLock; -use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements}; +use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, UNIVARIATE_SKIP_K, min_stacked_n_vars, total_whir_statements}; use tracing::instrument; use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, XMSS_DIGEST_LEN}; @@ -389,6 +389,45 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree "MAX_AIR_FULL_DEGREE_PLACEHOLDER".to_string(), (ALL_TABLES.iter().map(|t| t.degree_air()).max().unwrap() + 1).to_string(), ); + // Univariate skip (pw13 h1): window {0..2^K−1}, K from sub_protocols::UNIVARIATE_SKIP_K. + replacements.insert("SKIP_K_PLACEHOLDER".to_string(), UNIVARIATE_SKIP_K.to_string()); + let skip_window = 1usize << UNIVARIATE_SKIP_K; + let max_air_degree = ALL_TABLES.iter().map(|t| t.degree_air()).max().unwrap(); + let n_skip_coeffs = (skip_window - 1) * max_air_degree + 1; + // SKIP_Z_POWERS[z][j] = z^j (canonical), so the in-circuit window evals are dot products. + let z_power_rows: Vec = (0..skip_window) + .map(|z| { + let zf = F::from_usize(z); + let mut acc = F::ONE; + let mut row = Vec::with_capacity(n_skip_coeffs); + for _ in 0..n_skip_coeffs { + row.push(acc.as_canonical_u64().to_string()); + acc *= zf; + } + format!("[{}]", row.join(", ")) + }) + .collect(); + replacements.insert( + "SKIP_Z_POWERS_PLACEHOLDER".to_string(), + format!("[{}]", z_power_rows.join(", ")), + ); + // SKIP_LAGRANGE_C[x] = (Π_{y≠x}(x−y))^{-1}: the constant Lagrange denominators, inverted at + // build time so the circuit computes L_x(r0) without any in-circuit inversion. + let lagrange_c: Vec = (0..skip_window) + .map(|x| { + let mut p = F::ONE; + for y in 0..skip_window { + if y != x { + p *= F::from_usize(x) - F::from_usize(y); + } + } + p.inverse().as_canonical_u64().to_string() + }) + .collect(); + replacements.insert( + "SKIP_LAGRANGE_C_PLACEHOLDER".to_string(), + format!("[{}]", lagrange_c.join(", ")), + ); replacements.insert( "N_AIR_COLUMNS_PLACEHOLDER".to_string(), format!("[{}]", n_air_columns.join(", ")), diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 14cc04e4..3f660837 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -29,6 +29,17 @@ ONE_BUSES_ALL_COLS = ONE_BUSES_ALL_COLS_PLACEHOLDER # [[col, ...], _; N_TABLES] — sorted union of cols across all Multiplicity::One buses per table MAX_AIR_FULL_DEGREE = MAX_AIR_FULL_DEGREE_PLACEHOLDER + +# Univariate skip (pw13 h1, Gruen eprint 2024/108 §5-6): round 0 of the batched AIR sumcheck +# binds the SKIP_K lowest row bits of EVERY table via one univariate round over the integer +# window D = {0..2^SKIP_K−1} (node for cube point x is x itself). Must equal +# sub_protocols::UNIVARIATE_SKIP_K and python-verifier SKIP_K. +SKIP_K = SKIP_K_PLACEHOLDER +SKIP_WINDOW = 2**SKIP_K +N_SKIP_COEFFS = (SKIP_WINDOW - 1) * (MAX_AIR_FULL_DEGREE - 1) + 1 +SKIP_Z_POWERS = SKIP_Z_POWERS_PLACEHOLDER # [[z^j for j in 0..N_SKIP_COEFFS]; SKIP_WINDOW] (canonical) +SKIP_LAGRANGE_C = SKIP_LAGRANGE_C_PLACEHOLDER # [(Π_{y≠x}(x−y))^{-1}; SKIP_WINDOW] (canonical) + N_AIR_COLUMNS = N_AIR_COLUMNS_PLACEHOLDER # [_; N_TABLES] N_AIR_SHIFT_COLUMNS = N_AIR_SHIFT_COLUMNS_PLACEHOLDER # [_; N_TABLES] — by convention, shift column j of table t is column j AIR_ALPHA_OFFSETS = AIR_ALPHA_OFFSETS_PLACEHOLDER # [_; N_TABLES], # AIR_ALPHA_OFFSETS[t] = sum(N_AIR_CONSTRAINTS[k] for k in range(t)) @@ -295,12 +306,15 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): # END OF LOGUP - # VERIFY BUS AND AIR — back-loaded batched sumcheck + # VERIFY BUS AND AIR — univariate-skip round 0 + front-loaded batched sumcheck (pw13 h1) fs, air_alpha = fs_sample_ef(fs) air_alpha_powers = powers_const(air_alpha, TOTAL_NUM_AIR_CONSTRAINTS) - initial_sum: Mut = ZERO_VEC_PTR + n_max = log_max_table_height + + # Per-table claims s_t, front-loaded: claimed = Σ_t w_t·s_t with w_t = 2^(n_max−n_t). + claimed_weighted_sum: Mut = ZERO_VEC_PTR for table_index in unroll(0, N_TABLES): alpha_offset = AIR_ALPHA_OFFSETS[table_index] bus_numerator_value = bus_numerators_values + table_index * DIM @@ -317,12 +331,64 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): sub_extension_ret(logup_gamma, bus_denominator_value), ), ) - initial_sum = add_extension_ret(initial_sum, bus_final_value) + w_t = two_exp(n_max - table_log_heights[table_index]) + claimed_weighted_sum = add_extension_ret( + claimed_weighted_sum, mul_base_extension_ret(w_t, bus_final_value) + ) - n_max = log_max_table_height - # Batched AIR sumcheck: - fs, all_challenges, batched_air_final_value = sumcheck_verify_reversed(fs, n_max, initial_sum, MAX_AIR_FULL_DEGREE) + # Round 0 (skip): the prover sends ONE combined polynomial P(X) = Σ_t w_t·v'_t(X) in + # coefficient form (read BEFORE sampling r0). Identity: Σ_{z∈D} ê(z)·P(z) == Σ_t w_t·s_t, + # where ê = eq(eq_top, bits(·)) on D and eq_top = the last SKIP_K coords of the gkr point + # (shared by all tables: every table's eq factor is a suffix of the gkr point). + fs, skip_coeffs = fs_receive_ef_inlined(fs, N_SKIP_COEFFS) + eq_top = point_gkr + (n_vars_logup_gkr - SKIP_K) * DIM + e_hat = compute_eq_mle_extension(eq_top, SKIP_K) + window_evals = Array(SKIP_WINDOW * DIM) + for z in unroll(0, SKIP_WINDOW): + z_pows = Array(N_SKIP_COEFFS) + for j in unroll(0, N_SKIP_COEFFS): + z_pows[j] = SKIP_Z_POWERS[z][j] + dot_product_be(z_pows, skip_coeffs, window_evals + z * DIM, N_SKIP_COEFFS) + window_sum = dot_product_ee_ret(e_hat, window_evals, SKIP_WINDOW) + copy_ef(window_sum, claimed_weighted_sum) # assert the round-0 weighted window identity + + fs, skip_r0 = fs_sample_ef(fs) + + # Lagrange weights L_x(r0) = c_x · Π_{y≠x}(r0 − y), c_x compile-time inverse constants + # (no in-circuit inversion). Also the tail of every AIR opening claim. + skip_diffs = Array(SKIP_WINDOW) + for x in unroll(0, SKIP_WINDOW): + skip_diffs[x] = sub_extension_base_ret(skip_r0, x) + skip_pre = Array(SKIP_WINDOW) + skip_suf = Array(SKIP_WINDOW) + skip_pre[0] = ONE_EF_PTR + skip_suf[SKIP_WINDOW - 1] = ONE_EF_PTR + for x in unroll(1, SKIP_WINDOW): + skip_pre[x] = mul_extension_ret(skip_pre[x - 1], skip_diffs[x - 1]) + rev = SKIP_WINDOW - 1 - x + skip_suf[rev] = mul_extension_ret(skip_suf[rev + 1], skip_diffs[rev + 1]) + lagrange_weights = Array(SKIP_WINDOW * DIM) + for x in unroll(0, SKIP_WINDOW): + mul_extension( + mul_base_extension_ret(SKIP_LAGRANGE_C[x], skip_pre[x]), + skip_suf[x], + lagrange_weights + x * DIM, + ) + e_hat_r0 = dot_product_ee_ret(e_hat, lagrange_weights, SKIP_WINDOW) + # target = ê(r0)·P(r0) + r0_powers = powers_const(skip_r0, N_SKIP_COEFFS) + p_at_r0 = dot_product_ee_ret(skip_coeffs, r0_powers, N_SKIP_COEFFS) + target = mul_extension_ret(e_hat_r0, p_at_r0) + + # Remaining linear rounds (unchanged round mechanics). sumcheck_verify_reversed stores the + # round-r challenge at slot n_linear−1−r, so `all_challenges` is in NATURAL slot order and + # the python spec's natural_prefix_t = reversed(linear_challenges[..log_h−K]) is the + # CONTIGUOUS slice all_challenges[n_linear−(log_h−K) ..] of length log_h−K. + n_linear = n_max - SKIP_K + fs, all_challenges, batched_air_final_value = sumcheck_verify_reversed(fs, n_linear, target, MAX_AIR_FULL_DEGREE) + + pcs_natural_prefixes = Array(N_TABLES) check_sum: Mut = ZERO_VEC_PTR for table_index in unroll(0, N_TABLES): log_n_rows = table_log_heights[table_index] @@ -337,11 +403,12 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): ) bus_point = pcs_inner_points[table_index] - eq_val = poly_eq_extension_dynamic_ret(bus_point, all_challenges, log_n_rows) - - k_t = product_first_n(all_challenges + log_n_rows * DIM, n_max - log_n_rows) + natural_prefix = all_challenges + (n_linear - (log_n_rows - SKIP_K)) * DIM + pcs_natural_prefixes[table_index] = natural_prefix + # Final identity term: ê(r0) · eq(bus_point[..log_h−K], natural_prefix) · C_t(col_evals) + eq_val = poly_eq_extension_dynamic_ret(bus_point, natural_prefix, log_n_rows - SKIP_K) - contribution = mul_extension_ret(k_t, mul_extension_ret(eq_val, air_constraints_eval)) + contribution = mul_extension_ret(e_hat_r0, mul_extension_ret(eq_val, air_constraints_eval)) check_sum = add_extension_ret(check_sum, contribution) # AIR block (i=1): all flat cols 0..n_flat_columns populated; shifts 0..n_shift_columns populated. @@ -523,13 +590,22 @@ def recursion(inner_public_memory, initial_fiat_shamir_cap): curr_randomness += DIM eval_weights = add_extension_ret(eval_weights, mul_extension_ret(logup_acc, eq_factor_logup)) - # AIR + # AIR (uniskip claims: point = natural_prefix, tail = Lagrange weights L_x(r0) on the + # lowest SKIP_K inner coords — weight = eq(prefix) ⊗ MLE(tail), next = shifted variant) + n_air_prefix = log_n_rows - SKIP_K + natural_prefix = pcs_natural_prefixes[table_index] + eq_low_table = compute_eq_mle_extension(inner_folding + n_air_prefix * DIM, SKIP_K) + tail_mle = dot_product_ee_ret(lagrange_weights, eq_low_table, SKIP_WINDOW) if n_shift_columns != 0: - next_factor = next_mle(all_challenges, inner_folding, log_n_rows) + next_factor = next_mle_with_tail( + natural_prefix, lagrange_weights, inner_folding, eq_low_table, n_air_prefix, SKIP_K + ) shift_sum = dot_product_ee_ret(curr_randomness, column_prefixes, n_shift_columns) eval_weights = add_extension_ret(eval_weights, mul_extension_ret(shift_sum, next_factor)) curr_randomness += n_shift_columns * DIM - eq_factor_air = poly_eq_extension_dynamic_ret(all_challenges, inner_folding, log_n_rows) + eq_factor_air = mul_extension_ret( + poly_eq_extension_dynamic_ret(natural_prefix, inner_folding, n_air_prefix), tail_mle + ) air_sum = dot_product_ee_ret(curr_randomness, column_prefixes, N_AIR_COLUMNS[table_index]) eval_weights = add_extension_ret(eval_weights, mul_extension_ret(air_sum, eq_factor_air)) curr_randomness += N_AIR_COLUMNS[table_index] * DIM diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index 2aeb7853..fa3e75bc 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -757,6 +757,29 @@ def next_mle_const(x, y, n: Const): return result +def next_mle_with_tail(prefix, tail, y, eq_low_table, n_prefix, k: Const): + # Scalar evaluation of the tensor-tail SHIFTED weight over n_prefix + k variables: + # Σ_x tail[x] · next_mle(prefix ++ bits_be(x), y) + # where y has n_prefix + k coords and eq_low_table = compute_eq_mle_extension(y[n_prefix..], k) + # (shared with the eq-statement tail evaluation). Mirrors python-verifier + # `next_mle_with_tail` / backend poly::matrix_next_mle_folded_with_tail via the identity + # (shift-by-one of the tail-seeded eq vector, with the saturating "next" corner): + # eq(prefix, y_pref) · Σ_{x≥1} tail[x−1]·eq_low[x] + # + tail[2^k−1] · ( eq_low[0]·next_mle(prefix, y_pref) + # + Πprefix·Πy_pref·(Πy_low − eq_low[0]) ) + shifted_dot = dot_product_ee_ret(tail, eq_low_table + DIM, 2**k - 1) + term1 = mul_extension_ret(poly_eq_extension_dynamic_ret(prefix, y, n_prefix), shifted_dot) + y_low = y + n_prefix * DIM + prod_corner = mul_extension_ret(product_first_n(prefix, n_prefix), product_first_n(y, n_prefix)) + prod_y_low = product_first_n_const(y_low, k) + corner = add_extension_ret( + mul_extension_ret(eq_low_table, next_mle(prefix, y, n_prefix)), + mul_extension_ret(prod_corner, sub_extension_ret(prod_y_low, eq_low_table)), + ) + res = add_extension_ret(term1, mul_extension_ret(tail + (2**k - 1) * DIM, corner)) + return res + + def _verify_log2_ceil(n, log2: Const): # log2 == ceil(log2(n)) <=> 2^(log2-1) < n <= 2^log2 <=> r := n - 2^(log2-1) - 1 is a (log2-1)-bit value # (in [0, 2^(log2-1))), which checked_decompose_bits_small_value_const checks. A wrong log2 makes r too big: diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index eb9ca0a0..9021f4b1 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -46,18 +46,18 @@ pub struct AirSumcheckSession<'a, EF: ExtensionField>, A: Air> where A::ExtraData: AlphaPowers, { - multilinears: MleGroup<'a, EF>, - eq_factor: Vec, // The last element is removed at each round + pub(crate) multilinears: MleGroup<'a, EF>, + pub(crate) eq_factor: Vec, // The last element is removed at each round /// Active element count in the current storage. Always a multiple of /// `2^{P - r}` while r < P (chunk-aligned), then ceil-halves afterward. - current_unpadded_len: usize, - sum: EF, - missing_mul_factor: EF, - computation: A, - extra_data: A::ExtraData, - initial_n_vars: usize, - constraints_eval_at_padding: EF, - rounds_done: usize, + pub(crate) current_unpadded_len: usize, + pub(crate) sum: EF, + pub(crate) missing_mul_factor: EF, + pub(crate) computation: A, + pub(crate) extra_data: A::ExtraData, + pub(crate) initial_n_vars: usize, + pub(crate) constraints_eval_at_padding: EF, + pub(crate) rounds_done: usize, } impl<'a, EF: ExtensionField>, A: Air> AirSumcheckSession<'a, EF, A> @@ -129,7 +129,7 @@ where A: Air + 'static, A::ExtraData: AlphaPowers, { - fn pivot(&self) -> usize { + pub(crate) fn pivot(&self) -> usize { ENDIANNESS_PIVOT_AIR.min(self.initial_n_vars) } @@ -163,7 +163,7 @@ where } } - fn in_phase_1(&self) -> bool { + pub(crate) fn in_phase_1(&self) -> bool { let w = packing_log_width::(); // (a) the variable being bound sits above the lane bits, and // (b) `SplitEq` can still run in packed mode (`n - r - 1 > w`). @@ -180,7 +180,7 @@ where /// `eq_factor` permuted to match our storage convention: entries in /// `[0, n-P)` unchanged, entries in `[n-P, len)` reversed - fn permuted_alphas(&self, len: usize) -> Vec { + pub(crate) fn permuted_alphas(&self, len: usize) -> Vec { let head_len = (self.initial_n_vars - self.pivot()).min(len); let base = &self.eq_factor[..len]; let mut out = Vec::with_capacity(len); diff --git a/crates/sub_protocols/src/air_sumcheck_skip.rs b/crates/sub_protocols/src/air_sumcheck_skip.rs new file mode 100644 index 00000000..dd1fec3d --- /dev/null +++ b/crates/sub_protocols/src/air_sumcheck_skip.rs @@ -0,0 +1,1195 @@ +use std::fmt::Debug; +use std::ops::{Add, Mul}; + +use backend::*; + +use crate::{AirSumcheckSession, OuterSumcheckSession}; + +// --------------------------------------------------------------------------- +// Front-loaded batched orchestration (see plan_spec.md "Protocol spec") +// +// All tables join at round 0. A table with `n_t` variables is embedded in the +// `n_max`-variable combined sum as a function of its FIRST `n_t` variables, +// constant in the trailing `n_max − n_t` ones, so its claim `s_t` enters the +// combined target with the static weight `w_t = 2^{n_max − n_t}`: +// +// target₀ = Σ_t w_t · s_t . +// +// Round 0 is the univariate skip: ONE combined coefficient vector +// P(X) = Σ_t w_t · v'_t(X) of degree ≤ (2^K − 1)·d_max is sent IN FULL — the +// verifier's round-0 identity is the weighted window sum +// +// Σ_{z ∈ D} ê(z) · P(z) == Σ_t w_t · s_t , +// +// not `h(0) + h(1) = target`, so no coefficient can be elided (this full-vector +// convention is the executable spec mirrored by the python verifier and the +// recursion circuit). The next target is ê(r0) · P(r0). +// +// Linear rounds r = 0 .. n_max − K − 1 then bind one variable each with the +// legacy `h(0) + h(1) = target` identity (c0 elided, reconstructed by the +// verifier): +// • ACTIVE table (r < n_t − K): contributes w_t · eq-expanded bare poly, +// exactly the legacy per-round mechanism. +// • FINISHED table (r ≥ n_t − K): its remaining function is the constant +// c_t = session.sum() over the m = n_max − K − r unbound variables; its +// round polynomial is the constant c_t · 2^{m−1} (folded into coeffs[0]), +// since h(0) + h(1) = c_t · 2^m matches its target share, which halves +// each round and ends at exactly c_t. Hence the final combined value is +// +// target_final = Σ_t s_t^final (NO challenge products), +// +// with s_t^final = ê(r0) · eq(eq_factor_t[..n_t−K], natural_prefix_t) · +// C_t(col_evals_t) and natural_prefix_t = reverse(linear_challenges[..n_t−K]). +// --------------------------------------------------------------------------- + +/// The univariate-skip analogue of the batched AIR sumcheck point: the skip +/// challenge `r0` (binding the K lowest row bits of every table), the Lagrange +/// window weights `L_x(r0)` (the tensor tail of the WHIR opening weights), and +/// the linear-round challenges in round order. +#[derive(Debug, Clone)] +pub struct UniskipAirPoint { + pub r0: EF, + /// `L_x(r0)` for `x ∈ 0..2^K`, in row-bit (window-node) order. + pub lagrange_weights: Vec, + /// Challenges of the linear rounds, in round order (round r binds, for each + /// table still active, its highest remaining row bit). + pub linear_challenges: Vec, +} + +impl UniskipAirPoint { + pub fn k(&self) -> usize { + log2_strict_usize(self.lagrange_weights.len()) + } +} + +/// The "natural ordering" opening point of a table's remaining (non-skipped) +/// variables: the reverse of the first `log_n_rows − K` linear challenges +/// (round r binds eq coordinate `n_t − K − 1 − r`, so index-wise pairing with +/// `eq_factor_t[..n_t − K]` requires the reversed prefix). +pub fn natural_prefix_for_session(point: &UniskipAirPoint, log_n_rows: usize) -> Vec { + point.linear_challenges[..log_n_rows - point.k()] + .iter() + .rev() + .copied() + .collect() +} + +/// Front-loaded batched AIR sumcheck with a univariate skip round. +/// `sessions` must all be fresh (`rounds_done == 0`) and share the same last-`k` +/// eq coordinates (suffixes of one gkr point). Returns the binding point. +pub fn prove_batched_air_sumcheck_uniskip<'a, EF: ExtensionField>>( + prover_state: &mut impl FSProver, + sessions: &mut [Box + 'a>], + k: usize, +) -> UniskipAirPoint { + let n_max = sessions.iter().map(|s| s.initial_n_vars()).max().unwrap(); + let max_full_degree = sessions.iter().map(|s| s.bare_degree() + 1).max().unwrap(); + let d_max = sessions.iter().map(|s| s.bare_degree()).max().unwrap(); + let n_skip_coeffs = ((1usize << k) - 1) * d_max + 1; + + let weights: Vec = sessions + .iter() + .map(|s| EF::from_usize(1 << (n_max - s.initial_n_vars()))) + .collect(); + + // The skipped eq coordinates are shared: every session's eq factor is a + // suffix of the same gkr point. + let eq_top = sessions[0].skip_eq_top(k); + for s in sessions.iter().skip(1) { + debug_assert_eq!(s.skip_eq_top(k), eq_top, "sessions must share the skip eq coordinates"); + } + + // Round 0 (skip): combined polynomial P = Σ_t w_t · v'_t, full coefficients. + let skip_polys: Vec> = sessions.iter_mut().map(|s| s.compute_skip_poly(k)).collect(); + let mut combined_skip = EF::zero_vec(n_skip_coeffs); + for (poly, &w_t) in skip_polys.iter().zip(&weights) { + debug_assert!(poly.coeffs.len() <= n_skip_coeffs); + for (acc, &c) in combined_skip.iter_mut().zip(&poly.coeffs) { + *acc += w_t * c; + } + } + prover_state.add_extension_scalars(&combined_skip); + let r0: EF = prover_state.sample(); + + let lagrange_weights = lagrange_weights_at::, EF>(k, r0); + let e_hat_r0 = e_hat_at(&eq_top, r0); + for (session, poly) in sessions.iter_mut().zip(&skip_polys) { + session.process_skip_challenge(k, r0, &lagrange_weights, e_hat_r0, poly); + } + + // Invariant: the next target is ê(r0)·P(r0) = Σ_t w_t · session.sum(). + let mut running_target = e_hat_r0 * DensePolynomial::new(combined_skip).evaluate(r0); + debug_assert_eq!( + running_target, + sessions + .iter() + .zip(&weights) + .map(|(s, &w)| w * s.sum()) + .fold(EF::ZERO, |a, b| a + b) + ); + + // Linear rounds. + let n_linear = n_max - k; + let mut linear_challenges = Vec::with_capacity(n_linear); + for r in 0..n_linear { + let mut combined_coeffs = EF::zero_vec(max_full_degree + 1); + let mut bare_polys: Vec>> = vec![None; sessions.len()]; + + for (idx, session) in sessions.iter_mut().enumerate() { + let n_own = session.initial_n_vars() - k; + if r < n_own { + let bare = session.compute_bare_round_poly(); + let full = expand_bare_to_full(&bare.coeffs, session.eq_alpha()); + for (acc, &c) in combined_coeffs.iter_mut().zip(&full) { + *acc += weights[idx] * c; + } + bare_polys[idx] = Some(bare); + } else { + // Finished: constant c_t over the m = n_linear − r remaining + // variables; round poly = c_t · 2^{m−1} (constant in X). + combined_coeffs[0] += session.sum() * EF::from_usize(1 << (n_linear - r - 1)); + } + } + + // h(0) + h(1) = 2·c0 + Σ_{i≥1} c_i must equal the running target. + debug_assert_eq!( + combined_coeffs[0].double() + combined_coeffs[1..].iter().copied().sum::(), + running_target, + "front-loading bookkeeping broke at linear round {r}" + ); + + prover_state.add_sumcheck_polynomial(&combined_coeffs, None); + let challenge = prover_state.sample(); + linear_challenges.push(challenge); + + for (idx, session) in sessions.iter_mut().enumerate() { + if let Some(bare) = &bare_polys[idx] { + session.process_challenge(challenge, bare); + } + } + running_target = DensePolynomial::new(combined_coeffs).evaluate(challenge); + } + + UniskipAirPoint { + r0, + lagrange_weights, + linear_challenges, + } +} + +/// Verifier half of [`prove_batched_air_sumcheck_uniskip`]. Table arrays are in +/// session order; `table_sums` are the per-table claims `s_t`; `eq_top` is the +/// shared last-`k` slice of the gkr point; `table_degrees` are the bare AIR +/// degrees (`Air::degree`), `max_full_degree = max(degree_air) + 1`. +/// Returns the binding point and the final sumcheck target (the caller checks +/// it against `Σ_t ê(r0) · eq(eq_factor_t[..n_t−k], natural_prefix_t) · C_t`). +#[allow(clippy::too_many_arguments)] +pub fn verify_batched_air_sumcheck_uniskip>>( + verifier_state: &mut impl FSVerifier, + k: usize, + table_n_vars: &[usize], + table_degrees: &[usize], + table_sums: &[EF], + eq_top: &[EF], + max_full_degree: usize, +) -> Result<(UniskipAirPoint, EF), ProofError> { + assert_eq!(table_n_vars.len(), table_sums.len()); + assert_eq!(table_n_vars.len(), table_degrees.len()); + assert_eq!(eq_top.len(), k); + let n_max = *table_n_vars.iter().max().unwrap(); + let d_max = *table_degrees.iter().max().unwrap(); + let n_skip_coeffs = ((1usize << k) - 1) * d_max + 1; + + let coeffs = verifier_state.next_extension_scalars_vec(n_skip_coeffs)?; + let skip_poly = DensePolynomial::new(coeffs); + + // Round-0 identity: Σ_{z ∈ D} ê(z) · P(z) == Σ_t w_t · s_t. + let e_hat_window = e_hat_on_window(eq_top); + let window_sum = (0..1usize << k) + .map(|z| e_hat_window[z] * skip_poly.evaluate(EF::from_usize(z))) + .fold(EF::ZERO, |a, b| a + b); + let claimed: EF = table_n_vars + .iter() + .zip(table_sums) + .map(|(&n_t, &s_t)| EF::from_usize(1 << (n_max - n_t)) * s_t) + .fold(EF::ZERO, |a, b| a + b); + if window_sum != claimed { + return Err(ProofError::InvalidProof); + } + + let r0: EF = verifier_state.sample(); + let target = e_hat_at(eq_top, r0) * skip_poly.evaluate(r0); + + let Evaluation { point, value } = sumcheck_verify(verifier_state, n_max - k, max_full_degree, target, None)?; + + Ok(( + UniskipAirPoint { + r0, + lagrange_weights: lagrange_weights_at::, EF>(k, r0), + linear_challenges: point.0, + }, + value, + )) +} + +// Univariate skip round for the batched AIR sumcheck (Gruen, eprint 2024/108 §5-6). +// +// The first `K` sumcheck rounds — which bind the K lowest row-index bits of every +// table — are replaced by ONE univariate round over the integer window +// D = {0, …, 2^K − 1} (see `backend::univariate_skip` for the domain convention: +// the window node for cube point `x` is the integer `x` itself, in row-bit order, +// so the committed column values at the 2^K rows of a block ARE the polynomial +// values on the window). +// +// Per session (table) the prover computes +// +// v'_t(z) = Σ_{rest} eqᵣ(rest) · C_t( c̃ols(z, rest) ) + padding term, +// +// where `c̃ols(z, ·)` is the degree-(2^K − 1) univariate extension of the 2^K +// window values of each column, and eqᵣ is the eq weight over the remaining +// n_t − K variables (the LAST K entries of the session's `eq_factor` are +// excluded — they form the kernel ê known to the verifier: +// ê(z) interpolates eq(eq_factor[n−K..], bits(x)) over the window). +// +// All constraint evaluations stay in the BASE field (packed SIMD): the window +// values are committed base-field data and the extended-target values are +// base-field Lagrange combinations of them. This replaces rounds 1..K−1 of the +// legacy schedule, whose folded columns are extension-field (~5× the packed +// base-field throughput on this workload). +// +// Identities (see plan_spec.md): +// round 0: Σ_{z ∈ D} ê(z) · v'_t(z) == s_t (the session's claim) +// challenge: sum ← ê(r0) · v'_t(r0), missing_mul_factor ← ê(r0) +// after which the session state is EXACTLY what K legacy rounds would have +// produced (same storage layout, same eq bookkeeping), so the remaining rounds +// run unchanged. +// +// Storage mapping (chunk-bit-reversed columns, see air_sumcheck.rs:10-32): +// rounds 0..K−1 fold storage bits [pivot−K, pivot) inside each 2^pivot chunk, +// so the window value for cube point `x` of rest-position `j = (chunk, o)` +// lives at storage block `bitreverse_K(x)` of that chunk: +// storage_index(x, j) = (chunk << pivot) | (rev_K(x) << (pivot−K)) | o. +// Collapsing the K block bits (in any z-combination) preserves the legacy +// post-round-K layout: chunks of size 2^{pivot−K} in the same `o` order. + +/// Compile-time skip width. The kernels below take `k` as a runtime parameter +/// (so tests can sweep 3..=5); the orchestration layer (T3) uses this constant. +pub const UNIVARIATE_SKIP_K: usize = 4; +pub const SKIP_DOMAIN: usize = 1 << UNIVARIATE_SKIP_K; + +/// Storage block (within a 2^pivot chunk) holding the window values of cube +/// point `x`: the K-bit bit-reversal of `x`. +#[inline] +pub const fn skip_block_of_x(x: usize, k: usize) -> usize { + let mut b = 0; + let mut i = 0; + while i < k { + b |= ((x >> i) & 1) << (k - 1 - i); + i += 1; + } + b +} + +pub trait SkipSession>>: OuterSumcheckSession { + /// The univariate restriction `v'_t` of the session's claim to the skip + /// window (eq over the REST variables only; excludes ê), in coefficient + /// form. Requires a fresh session (`rounds_done == 0`). + fn compute_skip_poly(&mut self, k: usize) -> DensePolynomial { + self.compute_skip_poly_forced(k, false) + } + + /// Test/bench hook: `force_lagrange = true` selects the reference + /// Lagrange-dot extension kernels instead of the default finite-difference + /// ones. Both produce bit-identical polynomials (exact field arithmetic); + /// the toggle exists so the equality tests and the h4 timing gate can + /// compare them. Not part of the protocol surface. + #[doc(hidden)] + fn compute_skip_poly_forced(&mut self, k: usize, force_lagrange: bool) -> DensePolynomial; + + /// Bind the K skipped variables to `r0`: fold every column `2^k → 1` with + /// the Lagrange weights `L_x(r0)`, and fast-forward the session state to + /// `rounds_done = k` (sum, eq factor, missing_mul_factor, padding count). + fn process_skip_challenge( + &mut self, + k: usize, + r0: EF, + lagrange_at_r0: &[EF], + e_hat_r0: EF, + skip_poly: &DensePolynomial, + ); + + /// The eq coordinates carried by the skipped variables (shared ê input): + /// the last `k` entries of the session's eq factor, in `eval_eq` bit order + /// (entry 0 ↔ the MSB of the cube point / window index). + fn skip_eq_top(&self, k: usize) -> Vec; +} + +impl<'a, EF, A> SkipSession for AirSumcheckSession<'a, EF, A> +where + EF: ExtensionField>, + A: Air + Debug + 'static, + A::ExtraData: AlphaPowers + AlphaPowersMut + Debug, +{ + fn compute_skip_poly_forced(&mut self, k: usize, force_lagrange: bool) -> DensePolynomial { + assert_eq!(self.rounds_done, 0, "skip round must come first"); + assert_eq!(self.missing_mul_factor, EF::ONE); + let n = self.initial_n_vars; + let w = packing_log_width::(); + let pivot = self.pivot(); + assert!(k >= 1 && k < n && k <= pivot); + // current_unpadded_len is chunk-aligned at round 0, so whole blocks are + // either fully active or fully padding. + assert_eq!(self.current_unpadded_len % (1 << pivot.min(n)), 0); + + let degree = self.computation.degree(); + let nodes = skip_all_nodes::>(k, degree); + let n_nodes = nodes.len(); + let window = 1usize << k; + + // eq weights over the REST variables, in storage order (same machinery + // as the legacy per-round SplitEq, with the K skipped coordinates and + // no fold coordinate excluded). + let rest_alphas = self.permuted_alphas(n - k); + let split_eq = SplitEq::new(&rest_alphas); + + // Lagrange extension matrix: window -> extended targets (reference + // path only; the default finite-difference path needs no per-node + // coefficients — the nodes are consecutive integers). + let lagrange_targets = if force_lagrange { + lagrange_coeffs_for_targets::>(k, &nodes[window..]) + } else { + Vec::new() + }; + + let active_rest = self.current_unpadded_len >> k; + let total_rest = 1usize << (n - k); + + let raw: Vec = if n - k > w { + let group = self.multilinears.by_ref(); + let cols = group + .as_packed_base() + .expect("skip round expects base-packed columns") + .clone(); + debug_assert!(pivot - k >= w); + match (self.computation.low_degree_air(), force_lagrange) { + (Some((low_degree, low_n_constraints)), false) => compute_skip_evals_degree_split::( + &cols, + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest >> w, + &nodes, + low_degree, + low_n_constraints, + ), + (Some((low_degree, low_n_constraints)), true) => compute_skip_evals_degree_split_lagrange::( + &cols, + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest >> w, + &nodes, + &lagrange_targets, + low_degree, + low_n_constraints, + ), + (None, false) => compute_skip_evals_generic::( + &cols, + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest >> w, + &nodes, + ), + (None, true) => compute_skip_evals_generic_lagrange::( + &cols, + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest >> w, + &nodes, + &lagrange_targets, + ), + } + } else if force_lagrange { + compute_skip_evals_unpacked_lagrange::( + &self.multilinears.by_ref(), + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest, + &nodes, + &lagrange_targets, + ) + } else { + compute_skip_evals_unpacked::( + &self.multilinears.by_ref(), + &self.computation, + &self.extra_data, + &split_eq, + k, + pivot, + active_rest, + &nodes, + ) + }; + + // Padding blocks repeat the (constraint-constant) last row, so their + // contribution is z-independent: C_pad · Σ_{padded rest j} eqᵣ(j). + let padding_contribution = if active_rest < total_rest { + self.constraints_eval_at_padding * mle_of_zeros_then_ones(active_rest, &rest_alphas) + } else { + EF::ZERO + }; + + let values: Vec<(PF, EF)> = nodes + .iter() + .zip(&raw) + .map(|(&node, &v)| (node, v + padding_contribution)) + .collect(); + debug_assert_eq!(values.len(), n_nodes); + DensePolynomial::lagrange_interpolation(&values).unwrap() + } + + fn process_skip_challenge( + &mut self, + k: usize, + r0: EF, + lagrange_at_r0: &[EF], + e_hat_r0: EF, + skip_poly: &DensePolynomial, + ) { + assert_eq!(self.rounds_done, 0); + assert_eq!(lagrange_at_r0.len(), 1 << k); + let n = self.initial_n_vars; + let w = packing_log_width::(); + let pivot = self.pivot(); + + let xb: Vec = (0..1 << k).map(|x| skip_block_of_x(x, k)).collect(); + + if n - k > w { + let group = self.multilinears.by_ref(); + let cols = group.as_packed_base().expect("skip round expects base-packed columns"); + let log_block_packed = pivot - k - w; + let log_chunk_packed = pivot - w; + let block_mask = (1usize << log_block_packed) - 1; + let lw_packed: Vec> = lagrange_at_r0.iter().map(|&l| EFPacking::::from(l)).collect(); + + let mut folded: Vec>> = vec![ArenaVec::new(); cols.len()]; + parallel::par_chunks_mut(&mut folded, 1, |c, slot| { + let src = cols[c]; + let out_len = src.len() >> k; + let mut out: ArenaVec> = unsafe { ArenaVec::uninitialized(out_len) }; + for j_p in 0..out_len { + let chunk = j_p >> log_block_packed; + let o = j_p & block_mask; + let base = (chunk << log_chunk_packed) | o; + let mut acc = lw_packed[0] * src[base | (xb[0] << log_block_packed)]; + for (x, &lw) in lw_packed.iter().enumerate().skip(1) { + acc += lw * src[base | (xb[x] << log_block_packed)]; + } + out[j_p] = acc; + } + slot[0] = out; + }); + self.multilinears = MleGroup::Owned(MleGroupOwned::ExtensionPacked(folded)); + } else { + let group = self.multilinears.by_ref(); + let unpacked = group.unpack(); + let unpacked_ref = unpacked.by_ref(); + let cols = unpacked_ref.as_base().expect("skip round expects base columns"); + let log_block = pivot - k; + let log_chunk = pivot; + let block_mask = (1usize << log_block) - 1; + + let mut folded: Vec> = vec![ArenaVec::new(); cols.len()]; + parallel::par_chunks_mut(&mut folded, 1, |c, slot| { + let src = cols[c]; + let out_len = src.len() >> k; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(out_len) }; + for j in 0..out_len { + let chunk = j >> log_block; + let o = j & block_mask; + let base = (chunk << log_chunk) | o; + let mut acc = lagrange_at_r0[0] * src[base | (xb[0] << log_block)]; + for (x, &lw) in lagrange_at_r0.iter().enumerate().skip(1) { + acc += lw * src[base | (xb[x] << log_block)]; + } + out[j] = acc; + } + slot[0] = out; + }); + self.multilinears = MleGroup::Owned(MleGroupOwned::Extension(folded)); + } + + self.sum = e_hat_r0 * skip_poly.evaluate(r0); + self.missing_mul_factor = e_hat_r0; + self.rounds_done = k; + let new_eq_len = self.eq_factor.len() - k; + self.eq_factor.truncate(new_eq_len); + debug_assert_eq!(self.current_unpadded_len % (1 << k), 0); + self.current_unpadded_len >>= k; + + // Mirror the legacy phase-1 → phase-2 transition: if K legacy rounds + // would have left packed mode, unpack now. + if self.multilinears.by_ref().is_packed() && !self.in_phase_1() { + self.multilinears = self.multilinears.by_ref().unpack().as_owned_or_clone().into(); + } + } + + fn skip_eq_top(&self, k: usize) -> Vec { + self.eq_factor[self.eq_factor.len() - k..].to_vec() + } +} + +// --------------------------------------------------------------------------- +// Finite-difference extension (h4, iteration 2). +// +// All three interpolation sites of the skip kernels — column values +// (degree ≤ 2^k − 1, sampled on the window), the degree-split cached state +// (same degree, same nodes), and the low-part accumulator (degree ≤ +// low_degree·(2^k − 1), sampled on the first n_low nodes) — are polynomials +// sampled on CONSECUTIVE integer nodes and evaluated at the remaining +// CONSECUTIVE integer nodes (see `skip_all_nodes`: 0, 1, 2, …). Newton forward +// differences evaluate such a polynomial at each next node with `n_rows − 1` +// field ADDS per value row, replacing the `n_rows` Montgomery MULS of a +// Lagrange-coefficient dot. Field adds are exact, so the results are +// BIT-IDENTICAL to the Lagrange path (same unique polynomial, same field +// elements) — pinned by `test_fd_extension_matches_lagrange` and by the entire +// iter-1 test suite, which the FD path must satisfy unchanged. +// +// State convention (right-edge anchored, verified in `fd_tests`): +// init: for j in 1..n_rows { for i in 0..n_rows−j { row_i ← row_{i+1} − row_i } } +// after which row_{n_rows−1} = value at the LAST sampled node and +// row_{n_rows−1−j} holds the j-th forward difference Δʲ anchored so +// one advance yields the next node; +// advance: for i in 1..n_rows { row_i += row_{i−1} } — the value row at the +// next consecutive node is then row_{n_rows−1}, readable in place. +// Both passes are forward-sequential over the flattened row-major buffer. +// --------------------------------------------------------------------------- + +/// In-place right-edge forward-difference triangle over `n_rows` rows of +/// `width` values (`rows[i * width + c]` = value row at the i-th consecutive +/// node). Cost: width · n_rows(n_rows−1)/2 subs, once per group. +#[inline] +fn fd_init_in_place(rows: &mut [T], n_rows: usize, width: usize) { + debug_assert!(rows.len() >= n_rows * width); + for j in 1..n_rows { + for idx in 0..(n_rows - j) * width { + rows[idx] = rows[idx + width] - rows[idx]; + } + } +} + +/// Advances the FD state one node: `width · (n_rows − 1)` adds. The value row +/// at the new node is `rows[(n_rows − 1) * width ..]`. +#[inline] +fn fd_advance(rows: &mut [T], n_rows: usize, width: usize) { + debug_assert!(rows.len() >= n_rows * width); + for idx in width..n_rows * width { + let prev = rows[idx - width]; + rows[idx] += prev; + } +} + +/// Gathers, for one packed rest-position `j_p`, the `2^k` window values of all +/// columns into `win` (layout `win[x * n_cols + c]`, contiguous per window +/// node so constraint evals can borrow `&win[x * n_cols..]` directly). +#[inline(always)] +fn gather_window(cols: &[&[T]], win: &mut [T], j_p: usize, xb: &[usize], log_block: usize, log_chunk: usize) { + let n_cols = cols.len(); + let block_mask = (1usize << log_block) - 1; + let chunk = j_p >> log_block; + let o = j_p & block_mask; + let base = (chunk << log_chunk) | o; + for (c, col) in cols.iter().enumerate() { + for (x, &b) in xb.iter().enumerate() { + win[x * n_cols + c] = col[base | (b << log_block)]; + } + } +} + +/// Assembles the column point at extended node `e` (0-based beyond the window): +/// `point[c] = Σ_x L[e][x] · win[x * n_cols + c]`. +#[inline(always)] +fn extend_point>>( + win: &[PFPacking], + point: &mut [PFPacking], + lag_packed: &[PFPacking], + n_cols: usize, +) { + point[..n_cols].fill(PFPacking::::ZERO); + for (x, &lw) in lag_packed.iter().enumerate() { + let row = &win[x * n_cols..(x + 1) * n_cols]; + for (p, &v) in point[..n_cols].iter_mut().zip(row) { + *p += v * lw; + } + } +} + +/// Default generic kernel: finite-difference extension (h4). The gathered +/// window buffer doubles as the FD state after the window evals — zero extra +/// per-thread memory vs the Lagrange path, and the value row at each extended +/// node is read in place as the constraint-eval point. +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_generic( + cols: &[&[PFPacking]], + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_packed: usize, + nodes: &[PF], +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, +{ + let w = packing_log_width::(); + let n_cols = cols.len(); + let window = 1usize << k; + let n_nodes = nodes.len(); + let log_block = pivot - k - w; + let log_chunk = pivot - w; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + + let acc = parallel::map_reduce_with_state( + active_packed, + || vec![PFPacking::::ZERO; window * n_cols], // win, then FD state + || vec![EFPacking::::ZERO; n_nodes], + |win, acc, j_p| { + let partial_eq = split_eq.get_packed(j_p); + gather_window(cols, win, j_p, &xb, log_block, log_chunk); + for x in 0..window { + let v = computation.eval_packed_base(&win[x * n_cols..(x + 1) * n_cols], extra_data); + acc[x] += v * partial_eq; + } + fd_init_in_place(win, window, n_cols); + for node_acc in acc[window..n_nodes].iter_mut() { + fd_advance(win, window, n_cols); + let v = computation.eval_packed_base(&win[(window - 1) * n_cols..window * n_cols], extra_data); + *node_acc += v * partial_eq; + } + }, + |mut a, b| { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a + }, + ); + + acc.into_iter() + .map(|s| EFPacking::::to_ext_iter([s]).sum::()) + .collect() +} + +/// Reference generic kernel: Lagrange-coefficient dots (iter-1 path). Kept for +/// the h4 timing gate and the bit-identity test. +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_generic_lagrange( + cols: &[&[PFPacking]], + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_packed: usize, + nodes: &[PF], + lagrange_targets: &[Vec>], +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, +{ + let w = packing_log_width::(); + let n_cols = cols.len(); + let window = 1usize << k; + let n_nodes = nodes.len(); + let log_block = pivot - k - w; + let log_chunk = pivot - w; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + // Per extended node, the 2^k Lagrange coefficients lifted to packed form. + let lag_packed: Vec>> = lagrange_targets + .iter() + .map(|row| row.iter().map(|&l| PFPacking::::from(l)).collect()) + .collect(); + + let acc = parallel::map_reduce_with_state( + active_packed, + || { + ( + vec![PFPacking::::ZERO; window * n_cols], // win + vec![PFPacking::::ZERO; n_cols], // point + ) + }, + || vec![EFPacking::::ZERO; n_nodes], + |(win, point), acc, j_p| { + let partial_eq = split_eq.get_packed(j_p); + gather_window(cols, win, j_p, &xb, log_block, log_chunk); + for x in 0..window { + let v = computation.eval_packed_base(&win[x * n_cols..(x + 1) * n_cols], extra_data); + acc[x] += v * partial_eq; + } + for (e, lag) in lag_packed.iter().enumerate() { + extend_point::(win, point, lag, n_cols); + let v = computation.eval_packed_base(point, extra_data); + acc[window + e] += v * partial_eq; + } + }, + |mut a, b| { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a + }, + ); + + acc.into_iter() + .map(|s| EFPacking::::to_ext_iter([s]).sum::()) + .collect() +} + +/// Default degree-split kernel: finite-difference extension (h4) at all three +/// interpolation sites — column values (window FD on `win`, in place), the +/// skipped low-block's cached state (degree ≤ 2^k − 1 in z: affine ops on +/// degree-(2^k − 1) column extensions, so its own window-anchored FD cascade +/// advanced in lockstep), and the low-part accumulator (degree ≤ +/// low_degree·(2^k − 1), FD over its first n_low values once captured). +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_degree_split( + cols: &[&[PFPacking]], + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_packed: usize, + nodes: &[PF], + low_degree: usize, + low_n_constraints: usize, +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, + EFPacking: PrimeCharacteristicRing + + Mul, Output = EFPacking> + + Add, Output = EFPacking>, +{ + let w = packing_log_width::(); + let n_cols = cols.len(); + let n_flat = computation.n_columns(); + let window = 1usize << k; + let n_nodes = nodes.len(); + let n_low = low_degree * (window - 1) + 1; + debug_assert!(n_low >= window && n_low <= n_nodes); + let log_block = pivot - k - w; + let log_chunk = pivot - w; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + + let acc = parallel::map_reduce_with_state( + active_packed, + || { + ( + vec![PFPacking::::ZERO; window * n_cols], // win, then column FD state + vec![Vec::>::new(); window], // captured post-block states + Vec::>::new(), // s_fd: flattened state FD cascade + Vec::>::new(), // interpolated-state scratch for the folder + vec![EFPacking::::ZERO; n_low], // low evals + Vec::>::new(), // low_fd: low-part FD cascade (width 1) + ) + }, + || vec![EFPacking::::ZERO; n_nodes], + |(win, states, s_fd, scratch, low_evals, low_fd), acc, j_p| { + let partial_eq = split_eq.get_packed(j_p); + gather_window(cols, win, j_p, &xb, log_block, log_chunk); + + // Full evals at the window nodes; capture the post-block state. + for x in 0..window { + let pt = &win[x * n_cols..(x + 1) * n_cols]; + let mut folder = ConstraintFolderPacked::new(&pt[..n_flat], &pt[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(&mut states[x])); + Air::eval(computation, &mut folder, extra_data); + acc[x] += folder.accumulator * partial_eq; + low_evals[x] = folder.accumulator_low; + states[x] = folder.cached_state.unwrap(); + } + + // FD cascades anchored on the window: columns (in place on `win`) + // and the captured post-block states (advanced in lockstep so the + // anchoring stays consistent; only read beyond n_low). + fd_init_in_place(win, window, n_cols); + let state_len = states[0].len(); + s_fd.clear(); + for st in states.iter() { + debug_assert_eq!(st.len(), state_len); + s_fd.extend_from_slice(st); + } + fd_init_in_place(s_fd, window, state_len); + + // Full evals at the extended nodes that still determine the low part. + for z in window..n_low { + fd_advance(win, window, n_cols); + fd_advance(s_fd, window, state_len); + let pt = &win[(window - 1) * n_cols..window * n_cols]; + let mut folder = ConstraintFolderPacked::new(&pt[..n_flat], &pt[n_flat..], extra_data); + Air::eval(computation, &mut folder, extra_data); + acc[z] += folder.accumulator * partial_eq; + low_evals[z] = folder.accumulator_low; + } + + // Low-part FD cascade over its n_low captured values (width 1). + low_fd.clear(); + low_fd.extend_from_slice(&low_evals[..n_low]); + fd_init_in_place(low_fd, n_low, 1); + + // High-only evals beyond: skip the low block with the FD-advanced + // state, and add the FD-advanced low contribution. + for node_acc in acc[n_low..n_nodes].iter_mut() { + fd_advance(win, window, n_cols); + fd_advance(s_fd, window, state_len); + fd_advance(low_fd, n_low, 1); + let pt = &win[(window - 1) * n_cols..window * n_cols]; + + scratch.clear(); + scratch.extend_from_slice(&s_fd[(window - 1) * state_len..window * state_len]); + + let mut folder = ConstraintFolderPacked::new(&pt[..n_flat], &pt[n_flat..], extra_data); + folder.skip_low = true; + folder.cached_state = Some(std::mem::take(scratch)); + folder.low_ci_count = low_n_constraints; + Air::eval(computation, &mut folder, extra_data); + *scratch = folder.cached_state.unwrap(); + + *node_acc += (folder.accumulator + low_fd[n_low - 1]) * partial_eq; + } + }, + |mut a, b| { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a + }, + ); + + acc.into_iter() + .map(|s| EFPacking::::to_ext_iter([s]).sum::()) + .collect() +} + +/// Reference degree-split kernel: Lagrange-coefficient dots (iter-1 path). +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_degree_split_lagrange( + cols: &[&[PFPacking]], + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_packed: usize, + nodes: &[PF], + lagrange_targets: &[Vec>], + low_degree: usize, + low_n_constraints: usize, +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, + EFPacking: PrimeCharacteristicRing + + Mul, Output = EFPacking> + + Add, Output = EFPacking>, +{ + let w = packing_log_width::(); + let n_cols = cols.len(); + let n_flat = computation.n_columns(); + let window = 1usize << k; + let n_nodes = nodes.len(); + // The low-degree block's constraints have z-degree ≤ low_degree·(2^k − 1), + // determined by the first `n_low` nodes (full evals there); beyond, the + // block is skipped: its post-state — degree ≤ 2^k − 1 in z without the low + // constraints (affine ops on degree-(2^k − 1) column extensions) — is + // interpolated from the 2^k window states, and the low contribution from + // the `n_low` captured values. + let n_low = low_degree * (window - 1) + 1; + debug_assert!(n_low >= window && n_low <= n_nodes); + let log_block = pivot - k - w; + let log_chunk = pivot - w; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + let lag_packed: Vec>> = lagrange_targets + .iter() + .map(|row| row.iter().map(|&l| PFPacking::::from(l)).collect()) + .collect(); + // Lagrange rows for the low part: first n_low nodes -> remaining nodes. + let lag_low_packed: Vec>> = lagrange_basis_evals(&nodes[..n_low], &nodes[n_low..]) + .into_iter() + .map(|row| row.into_iter().map(PFPacking::::from).collect()) + .collect(); + + let acc = parallel::map_reduce_with_state( + active_packed, + || { + ( + vec![PFPacking::::ZERO; window * n_cols], // win + vec![PFPacking::::ZERO; n_cols], // point + vec![Vec::>::new(); window], // captured post-block states + Vec::>::new(), // interpolated state scratch + vec![EFPacking::::ZERO; n_low], // low evals + ) + }, + || vec![EFPacking::::ZERO; n_nodes], + |(win, point, states, scratch, low_evals), acc, j_p| { + let partial_eq = split_eq.get_packed(j_p); + gather_window(cols, win, j_p, &xb, log_block, log_chunk); + + // Full evals at the window nodes; capture the post-block state. + for x in 0..window { + let pt = &win[x * n_cols..(x + 1) * n_cols]; + let mut folder = ConstraintFolderPacked::new(&pt[..n_flat], &pt[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(&mut states[x])); + Air::eval(computation, &mut folder, extra_data); + acc[x] += folder.accumulator * partial_eq; + low_evals[x] = folder.accumulator_low; + states[x] = folder.cached_state.unwrap(); + } + // Full evals at the extended nodes that still determine the low part. + for e in 0..n_low - window { + extend_point::(win, point, &lag_packed[e], n_cols); + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + Air::eval(computation, &mut folder, extra_data); + acc[window + e] += folder.accumulator * partial_eq; + low_evals[window + e] = folder.accumulator_low; + } + // High-only evals beyond: skip the low block with interpolated state, + // and add the Lagrange-extended low contribution. + for z in n_low..n_nodes { + let e = z - window; + extend_point::(win, point, &lag_packed[e], n_cols); + + let lag = &lag_packed[e]; + scratch.clear(); + let state_len = states[0].len(); + for i in 0..state_len { + let mut s = states[0][i] * lag[0]; + for (x, st) in states.iter().enumerate().skip(1) { + s += st[i] * lag[x]; + } + scratch.push(s); + } + + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.skip_low = true; + folder.cached_state = Some(std::mem::take(scratch)); + folder.low_ci_count = low_n_constraints; + Air::eval(computation, &mut folder, extra_data); + *scratch = folder.cached_state.unwrap(); + + let mut low_interpolated = EFPacking::::ZERO; + for (i, &lc) in lag_low_packed[z - n_low].iter().enumerate() { + low_interpolated += low_evals[i] * lc; + } + acc[z] += (folder.accumulator + low_interpolated) * partial_eq; + } + }, + |mut a, b| { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a + }, + ); + + acc.into_iter() + .map(|s| EFPacking::::to_ext_iter([s]).sum::()) + .collect() +} + +/// Scalar fallback for tables too small for the packed kernel +/// (`n − k ≤ packing_log_width`): full evals at every node, no degree split, +/// finite-difference extension. These tables have at most `2^{w + k}` rows — +/// the cost is negligible. +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_unpacked( + group: &MleGroupRef<'_, EF>, + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_rest: usize, + nodes: &[PF], +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, +{ + let window = 1usize << k; + let n_nodes = nodes.len(); + let log_block = pivot - k; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + let block_mask = (1usize << log_block) - 1; + + let unpacked = group.unpack(); + let unpacked_ref = unpacked.by_ref(); + let cols = unpacked_ref.as_base().expect("skip round expects base columns"); + let n_cols = cols.len(); + + let mut acc = vec![EF::ZERO; n_nodes]; + let mut win = vec![PF::::ZERO; window * n_cols]; + for j in 0..active_rest { + let partial_eq = split_eq.get_unpacked(j); + let chunk = j >> log_block; + let o = j & block_mask; + let base = (chunk << pivot) | o; + for (c, col) in cols.iter().enumerate() { + for (x, &b) in xb.iter().enumerate() { + win[x * n_cols + c] = col[base | (b << log_block)]; + } + } + for x in 0..window { + let v = computation.eval_base(&win[x * n_cols..(x + 1) * n_cols], extra_data); + acc[x] += partial_eq * v; + } + fd_init_in_place(&mut win, window, n_cols); + for node_acc in acc[window..n_nodes].iter_mut() { + fd_advance(&mut win, window, n_cols); + let v = computation.eval_base(&win[(window - 1) * n_cols..window * n_cols], extra_data); + *node_acc += partial_eq * v; + } + } + acc +} + +/// Reference scalar fallback: Lagrange-coefficient dots (iter-1 path). +#[allow(clippy::too_many_arguments)] +fn compute_skip_evals_unpacked_lagrange( + group: &MleGroupRef<'_, EF>, + computation: &A, + extra_data: &A::ExtraData, + split_eq: &SplitEq, + k: usize, + pivot: usize, + active_rest: usize, + nodes: &[PF], + lagrange_targets: &[Vec>], +) -> Vec +where + EF: ExtensionField>, + A: Air + 'static, + A::ExtraData: AlphaPowers, +{ + let window = 1usize << k; + let n_nodes = nodes.len(); + let log_block = pivot - k; + let xb: Vec = (0..window).map(|x| skip_block_of_x(x, k)).collect(); + let block_mask = (1usize << log_block) - 1; + + let unpacked = group.unpack(); + let unpacked_ref = unpacked.by_ref(); + let cols = unpacked_ref.as_base().expect("skip round expects base columns"); + let n_cols = cols.len(); + + let mut acc = vec![EF::ZERO; n_nodes]; + let mut win = vec![PF::::ZERO; window * n_cols]; + let mut point = vec![PF::::ZERO; n_cols]; + for j in 0..active_rest { + let partial_eq = split_eq.get_unpacked(j); + let chunk = j >> log_block; + let o = j & block_mask; + let base = (chunk << pivot) | o; + for (c, col) in cols.iter().enumerate() { + for (x, &b) in xb.iter().enumerate() { + win[x * n_cols + c] = col[base | (b << log_block)]; + } + } + for x in 0..window { + let v = computation.eval_base(&win[x * n_cols..(x + 1) * n_cols], extra_data); + acc[x] += partial_eq * v; + } + for (e, lag) in lagrange_targets.iter().enumerate() { + point.fill(PF::::ZERO); + for (x, &lw) in lag.iter().enumerate() { + for (p, &v) in point.iter_mut().zip(&win[x * n_cols..(x + 1) * n_cols]) { + *p += v * lw; + } + } + let v = computation.eval_base(&point, extra_data); + acc[window + e] += partial_eq * v; + } + } + acc +} + +#[cfg(test)] +mod fd_tests { + use super::{fd_advance, fd_init_in_place}; + use backend::*; + + fn horner(coeffs: &[KoalaBear], x: usize) -> KoalaBear { + let xf = KoalaBear::from_usize(x); + let mut acc = KoalaBear::ZERO; + for &c in coeffs.iter().rev() { + acc = acc * xf + c; + } + acc + } + + /// The FD recurrence reproduces every consecutive node value of a degree-d + /// polynomial exactly (the classic cascade-direction off-by-one trap). + #[test] + fn fd_matches_direct_evaluation() { + for d in [1usize, 2, 3, 7, 15, 31, 45] { + let n_rows = d + 1; + let coeffs: Vec = (0..=d).map(|i| KoalaBear::from_usize(7 * i * i + 3 * i + 1)).collect(); + let mut rows: Vec = (0..n_rows).map(|i| horner(&coeffs, i)).collect(); + fd_init_in_place(&mut rows, n_rows, 1); + for next in n_rows..n_rows + 40 { + fd_advance(&mut rows, n_rows, 1); + assert_eq!(rows[n_rows - 1], horner(&coeffs, next), "d={d}, node={next}"); + } + } + } + + /// width > 1: independent polynomials per lane advance in lockstep. + #[test] + fn fd_matches_direct_evaluation_wide() { + let d = 15usize; + let width = 3usize; + let n_rows = d + 1; + let polys: Vec> = (0..width) + .map(|c| { + (0..=d) + .map(|i| KoalaBear::from_usize(11 * c * c + 5 * i * i * i + i + 2)) + .collect() + }) + .collect(); + let mut rows = vec![KoalaBear::ZERO; n_rows * width]; + for i in 0..n_rows { + for (c, p) in polys.iter().enumerate() { + rows[i * width + c] = horner(p, i); + } + } + fd_init_in_place(&mut rows, n_rows, width); + for next in n_rows..n_rows + 25 { + fd_advance(&mut rows, n_rows, width); + for (c, p) in polys.iter().enumerate() { + assert_eq!(rows[(n_rows - 1) * width + c], horner(p, next), "lane {c}, node {next}"); + } + } + } +} diff --git a/crates/sub_protocols/src/lib.rs b/crates/sub_protocols/src/lib.rs index 356bbba8..78009fcb 100644 --- a/crates/sub_protocols/src/lib.rs +++ b/crates/sub_protocols/src/lib.rs @@ -2,6 +2,9 @@ mod air_sumcheck; pub use air_sumcheck::*; +mod air_sumcheck_skip; +pub use air_sumcheck_skip::*; + mod logup; pub use logup::*; diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 07fcabca..33fce2f5 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -74,25 +74,29 @@ pub fn stacked_pcs_global_statements( EF::from_usize(ending_pc), )); } - for (point, eq_values, next_values) in &committed_statements[&table] { - if !next_values.is_empty() { - global_statements.push(SparseStatement::new_next( - stacked_n_vars, - point.clone(), - next_values - .iter() - .map(|(&col_index, &value)| SparseValue::new((offset >> n_vars) + col_index, value)) - .collect(), - )); - } - global_statements.push(SparseStatement::new( - stacked_n_vars, - point.clone(), - eq_values + for claim in &committed_statements[&table] { + if !claim.next_values.is_empty() { + let values = claim + .next_values .iter() .map(|(&col_index, &value)| SparseValue::new((offset >> n_vars) + col_index, value)) - .collect(), - )); + .collect(); + global_statements.push(match &claim.tail { + Some(tail) => { + SparseStatement::new_next_with_tail(stacked_n_vars, claim.point.clone(), tail.clone(), values) + } + None => SparseStatement::new_next(stacked_n_vars, claim.point.clone(), values), + }); + } + let values = claim + .eq_values + .iter() + .map(|(&col_index, &value)| SparseValue::new((offset >> n_vars) + col_index, value)) + .collect(); + global_statements.push(match &claim.tail { + Some(tail) => SparseStatement::new_with_tail(stacked_n_vars, claim.point.clone(), tail.clone(), values), + None => SparseStatement::new(stacked_n_vars, claim.point.clone(), values), + }); } } global_statements diff --git a/crates/sub_protocols/tests/air_sumcheck_skip_kernel.rs b/crates/sub_protocols/tests/air_sumcheck_skip_kernel.rs new file mode 100644 index 00000000..ca879be6 --- /dev/null +++ b/crates/sub_protocols/tests/air_sumcheck_skip_kernel.rs @@ -0,0 +1,435 @@ +//! Correctness + go/no-go timing tests for the univariate skip-round kernel +//! (`SkipSession` on `AirSumcheckSession`). See plan_spec.md (pw13, h1) T2. + +use std::time::Instant; + +use backend::*; +use lean_vm::{ + EF, ExecutionTable, ExtensionOpPrecompile, ExtraDataForBuses, F, LOG_MAX_BUS_WIDTH, Poseidon16Precompile, +}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use sub_protocols::{AirSumcheckSession, OuterSumcheckSession, SkipSession, UNIVARIATE_SKIP_K, skip_block_of_x}; + +fn random_cols(rng: &mut StdRng, n_cols: usize, n_rows: usize) -> Vec> { + (0..n_cols) + .map(|_| ArenaVec::from_iter((0..n_rows).map(|_| rng.random()))) + .collect() +} + +/// Pads rows `>= non_padded` with copies of one fixed row (the production +/// padding shape: identical rows, so the padded blocks are constraint-constant). +fn pad_cols(cols: &mut [ArenaVec], non_padded: usize) { + let n_rows = cols[0].len(); + for col in cols.iter_mut() { + let pad_value = col[non_padded - 1]; + for i in non_padded..n_rows { + col[i] = pad_value; + } + } +} + +fn brute_force_sum(air: &A, extra: &A::ExtraData, cols: &[ArenaVec], eq_factor: &[EF]) -> EF +where + A: Air, + A::ExtraData: AlphaPowers, +{ + let eq = eval_eq(eq_factor); + let n_rows = cols[0].len(); + let mut point = vec![F::ZERO; cols.len()]; + let mut sum = EF::ZERO; + for row in 0..n_rows { + for (c, col) in cols.iter().enumerate() { + point[c] = col[row]; + } + sum += eq[row] * SumcheckComputation::::eval_base(air, &point, extra); + } + sum +} + +/// Runs the full per-x / extended-node / aggregate identity battery for one AIR. +fn check_skip_identities(air: A, n: usize, k: usize, n_cols_total: usize, non_padded: Option, seed: u64) +where + A: Air + Copy + std::fmt::Debug + Air>, +{ + let mut rng = StdRng::seed_from_u64(seed); + let n_rows = 1usize << n; + let mut cols = random_cols(&mut rng, n_cols_total, n_rows); + if let Some(np) = non_padded { + pad_cols(&mut cols, np); + } + let eq_factor: Vec = (0..n).map(|_| rng.random()).collect(); + let alpha: EF = rng.random(); + let alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints()); + let logup_alphas: Vec = (0..LOG_MAX_BUS_WIDTH).map(|_| rng.random()).collect(); + let extra = ExtraDataForBuses::new(&eval_eq(&logup_alphas), alpha_powers.clone()); + + let sum = brute_force_sum(&air, &extra, &cols, &eq_factor); + + let col_refs: Vec<&[F]> = cols.iter().map(|c| c.as_slice()).collect(); + let packed = MleGroupRef::::Base(col_refs.clone()).pack(); + let extra_session = ExtraDataForBuses::new(&eval_eq(&logup_alphas), alpha_powers.clone()); + let mut session = AirSumcheckSession::new( + packed, + eq_factor.clone(), + sum, + air, + extra_session, + non_padded.unwrap_or(n_rows), + ); + + let skip_poly = session.compute_skip_poly(k); + let degree = SumcheckComputation::::degree(&air); + assert!(skip_poly.coeffs.len() <= ((1 << k) - 1) * degree + 1, "degree bound"); + + let window = 1usize << k; + let rest = 1usize << (n - k); + let eq_rest = eval_eq(&eq_factor[..n - k]); + let e_hat = eval_eq(&eq_factor[n - k..]); + + // (a) per-window-node identity: v'(node_x) == Σ_j eq_rest[j] · C(row = (j << k) | x). + let mut point = vec![F::ZERO; n_cols_total]; + let mut aggregate = EF::ZERO; + for x in 0..window { + let mut direct = EF::ZERO; + for j in 0..rest { + let row = (j << k) | x; + for (c, col) in cols.iter().enumerate() { + point[c] = col[row]; + } + direct += eq_rest[j] * SumcheckComputation::::eval_base(&air, &point, &extra); + } + let from_poly = skip_poly.evaluate(EF::from_usize(x)); + assert_eq!(from_poly, direct, "window node x={x}"); + aggregate += e_hat[x] * from_poly; + } + assert_eq!(aggregate, session.sum(), "Σ ê·v' == claim"); + + // extended-node spot checks (validates the Lagrange extension + degree-split path + // independently of the interpolation): nodes window, window+1, and the last one. + let nodes = skip_all_nodes::(k, degree); + let lags = lagrange_coeffs_for_targets::(k, &nodes[window..]); + for &z_idx in &[window, window + 1, nodes.len() - 1] { + let lag = &lags[z_idx - window]; + let mut direct = EF::ZERO; + for j in 0..rest { + for (c, col) in cols.iter().enumerate() { + let mut v = F::ZERO; + for (x, &l) in lag.iter().enumerate() { + v += col[(j << k) | x] * l; + } + point[c] = v; + } + direct += eq_rest[j] * SumcheckComputation::::eval_base(&air, &point, &extra); + } + assert_eq!( + skip_poly.evaluate(EF::from(nodes[z_idx])), + direct, + "extended node {z_idx}" + ); + } + + // (b) bind r0 and compare the post-skip session against brute force. + let r0: EF = rng.random(); + let lagrange_at_r0 = lagrange_weights_at::(k, r0); + let e_hat_r0 = e_hat_at(&session.skip_eq_top(k), r0); + { + // ê(r0) must interpolate the window values of ê. + let direct: EF = e_hat + .iter() + .zip(&lagrange_at_r0) + .map(|(&e, &l)| e * l) + .fold(EF::ZERO, |a, b| a + b); + assert_eq!(e_hat_r0, direct); + } + session.process_skip_challenge(k, r0, &lagrange_at_r0, e_hat_r0, &skip_poly); + + let folded: Vec> = cols + .iter() + .map(|col| { + (0..rest) + .map(|j| { + lagrange_at_r0 + .iter() + .enumerate() + .map(|(x, &l)| l * col[(j << k) | x]) + .fold(EF::ZERO, |a, b| a + b) + }) + .collect() + }) + .collect(); + + // sum invariant: sum == missing · Σ_j eq(eq_factor[..n−k], j) · C(folded(j)). + let mut point_ef = vec![EF::ZERO; n_cols_total]; + let mut expected_sum = EF::ZERO; + for j in 0..rest { + for (c, fc) in folded.iter().enumerate() { + point_ef[c] = fc[j]; + } + expected_sum += eq_rest[j] * SumcheckComputation::::eval_extension(&air, &point_ef, &extra); + } + assert_eq!(session.sum(), e_hat_r0 * expected_sum, "post-skip sum invariant"); + + // next-round bare poly vs brute force (validates layout + eq bookkeeping): + // bare(z) = missing · Σ_{j_hi} eq(eq_factor[..n−k−1], j_hi) · C(lerp(folded(2j_hi), folded(2j_hi+1), z)). + let bare = session.compute_bare_round_poly(); + let eq_hi = eval_eq(&eq_factor[..n - k - 1]); + for z in 0..=degree { + let z_ef = EF::from_usize(z); + let mut direct = EF::ZERO; + for j_hi in 0..rest / 2 { + for (c, fc) in folded.iter().enumerate() { + let v0 = fc[2 * j_hi]; + let v1 = fc[2 * j_hi + 1]; + point_ef[c] = v0 + (v1 - v0) * z_ef; + } + direct += eq_hi[j_hi] * SumcheckComputation::::eval_extension(&air, &point_ef, &extra); + } + assert_eq!( + bare.evaluate(z_ef), + e_hat_r0 * direct, + "post-skip bare round poly at z={z}" + ); + } + + // (c) full pipeline: run all remaining rounds, then check final_column_evals + // against the direct tensor-weighted MLE evaluation of the original columns. + let mut challenges = Vec::new(); + let mut bare_poly = bare; + loop { + let c: EF = rng.random(); + session.process_challenge(c, &bare_poly); + challenges.push(c); + if challenges.len() == n - k { + break; + } + bare_poly = session.compute_bare_round_poly(); + } + let final_evals = session.final_column_evals(); + for (c, col) in cols.iter().enumerate() { + let mut direct = EF::ZERO; + for (row, &v) in col.iter().enumerate() { + let mut weight = lagrange_at_r0[row & (window - 1)]; + for (r, &ch) in challenges.iter().enumerate() { + let bit = (row >> (k + r)) & 1; + weight *= if bit == 1 { ch } else { EF::ONE - ch }; + } + direct += weight * v; + } + assert_eq!(final_evals[c], direct, "final column eval col={c}"); + } +} + +#[test] +fn test_skip_block_map() { + assert_eq!(skip_block_of_x(0b0011, 4), 0b1100); + assert_eq!(skip_block_of_x(0b0001, 4), 0b1000); + assert_eq!(skip_block_of_x(0b101, 3), 0b101); + assert_eq!(skip_block_of_x(0b110, 3), 0b011); + for k in 1..=5 { + for x in 0..1usize << k { + assert_eq!(skip_block_of_x(skip_block_of_x(x, k), k), x); + } + } +} + +#[test] +fn test_skip_poseidon_degree_split() { + // Degree-split path (low_degree_air = (3, 20)), K=3 and K=4. + let air = Poseidon16Precompile::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 11, 4, n_cols, None, 1); + check_skip_identities(air, 11, 3, n_cols, None, 2); +} + +#[test] +fn test_skip_poseidon_with_bus() { + // BUS=true exercises the eval_bus_virtual constraints in the AIR. + let air = Poseidon16Precompile::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 10, 4, n_cols, None, 3); +} + +#[test] +fn test_skip_execution_with_shift_cols() { + // Generic (non-degree-split) path with shift columns, K=3 and K=4. + let air = ExecutionTable::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 10, 4, n_cols, None, 4); + check_skip_identities(air, 10, 3, n_cols, None, 5); +} + +#[test] +fn test_skip_extension_op() { + let air = ExtensionOpPrecompile::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 9, 4, n_cols, None, 6); +} + +#[test] +fn test_skip_with_padding() { + // n > pivot so that padded_n_rows < 2^n and the analytic padding term is active: + // n = 13, pivot = 12, non_padded = 2^12 → half the blocks are pure padding. + let air = Poseidon16Precompile::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 13, 4, n_cols, Some(1 << 12), 7); +} + +#[test] +fn test_skip_unpacked_fallback() { + // n − K ≤ packing_log_width → scalar fallback path. + // AVX-512: w = 4, so n=8, K=4 (boundary) and n=9, K=5 both fall back. + let air = ExtensionOpPrecompile::; + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + check_skip_identities(air, 8, 4, n_cols, None, 8); + check_skip_identities(air, 9, 5, n_cols, None, 9); +} + +/// h4 (iteration 2): the finite-difference extension must produce the SAME +/// polynomial, bit for bit, as the reference Lagrange-dot extension — exact +/// field arithmetic evaluating the same unique degree-bounded polynomial. +/// Covers all three kernels (generic, degree-split, unpacked fallback), the +/// padding path, and K ∈ {3, 4, 5}. +#[test] +fn test_fd_extension_matches_lagrange() { + fn check(air: A, n: usize, k: usize, non_padded: Option, seed: u64) + where + A: Air + Copy + std::fmt::Debug + Air>, + { + let mut rng = StdRng::seed_from_u64(seed); + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + let n_rows = 1usize << n; + let mut cols = random_cols(&mut rng, n_cols, n_rows); + if let Some(np) = non_padded { + pad_cols(&mut cols, np); + } + let eq_factor: Vec = (0..n).map(|_| rng.random()).collect(); + let alpha: EF = rng.random(); + let logup_alphas: Vec = (0..LOG_MAX_BUS_WIDTH).map(|_| rng.random()).collect(); + let col_refs: Vec<&[F]> = cols.iter().map(|c| c.as_slice()).collect(); + let packed = MleGroupRef::::Base(col_refs).pack(); + let extra = ExtraDataForBuses::new( + &eval_eq(&logup_alphas), + alpha.powers().collect_n(Air::n_constraints(&air)), + ); + let mut session = + AirSumcheckSession::new(packed, eq_factor, EF::ZERO, air, extra, non_padded.unwrap_or(n_rows)); + // compute_skip_poly does not advance the session: both paths run on + // identical state. + let lagrange = session.compute_skip_poly_forced(k, true); + let fd = session.compute_skip_poly_forced(k, false); + assert_eq!(fd.coeffs, lagrange.coeffs, "{air:?} n={n} k={k} pad={non_padded:?}"); + } + + // degree-split kernel (poseidon), K ∈ {3, 4, 5}: + check(Poseidon16Precompile::, 11, 3, None, 21); + check(Poseidon16Precompile::, 11, 4, None, 22); + check(Poseidon16Precompile::, 11, 5, None, 23); + check(Poseidon16Precompile::, 10, 4, None, 24); + // generic kernel (execution, with shift cols), K ∈ {3, 4}: + check(ExecutionTable::, 10, 3, None, 25); + check(ExecutionTable::, 10, 4, None, 26); + // extension_op: + check(ExtensionOpPrecompile::, 9, 4, None, 27); + // padding path (n > pivot): + check(Poseidon16Precompile::, 13, 4, Some(1 << 12), 28); + // unpacked fallback (n − k ≤ packing_log_width): + check(ExtensionOpPrecompile::, 8, 4, None, 29); + check(ExtensionOpPrecompile::, 9, 5, None, 30); +} + +/// GO/NO-GO timing gate (plan_spec T2, kill condition a): the skip round must +/// be cheaper than the legacy rounds 0..K−1 it replaces, on production-shaped +/// data. Extended for h4 (plan_spec iteration 2, U1): also times the +/// finite-difference extension against the reference Lagrange-dot extension — +/// FD must win on the combined workload. Run with: +/// cargo test --release -p sub_protocols --test air_sumcheck_skip_kernel -- --ignored --nocapture +#[test] +#[ignore] +fn skip_kernel_timing() { + let k = UNIVARIATE_SKIP_K; + + fn time_table(air: A, n: usize, k: usize, label: &str) -> (f64, f64, f64, f64) + where + A: Air + Copy + std::fmt::Debug + Air>, + { + let mut rng = StdRng::seed_from_u64(42); + let n_cols = Air::n_columns(&air) + Air::n_shift_columns(&air); + let n_rows = 1usize << n; + let cols = random_cols(&mut rng, n_cols, n_rows); + let eq_factor: Vec = (0..n).map(|_| rng.random()).collect(); + let alpha: EF = rng.random(); + let logup_alphas: Vec = (0..LOG_MAX_BUS_WIDTH).map(|_| rng.random()).collect(); + let col_refs: Vec<&[F]> = cols.iter().map(|c| c.as_slice()).collect(); + + fn make<'a, A>( + cols_refs: Vec<&'a [F]>, + air: A, + eq_factor: &[EF], + logup_alphas: &[EF], + alpha: EF, + n_rows: usize, + ) -> AirSumcheckSession<'a, EF, A> + where + A: Air + Copy + std::fmt::Debug + Air>, + { + let packed = MleGroupRef::::Base(cols_refs).pack(); + let extra = ExtraDataForBuses::new( + &eval_eq(logup_alphas), + alpha.powers().collect_n(Air::n_constraints(&air)), + ); + AirSumcheckSession::new(packed, eq_factor.to_vec(), EF::ZERO, air, extra, n_rows) + } + + // h4: FD vs Lagrange extension, compute_skip_poly only (the delta is + // confined to it). Lagrange first (warms the gather paths for FD too). + let mut s_skip = make(col_refs.clone(), air, &eq_factor, &logup_alphas, alpha, n_rows); + let t_lag0 = Instant::now(); + let _lag_poly = s_skip.compute_skip_poly_forced(k, true); + let t_lagrange = t_lag0.elapsed().as_secs_f64() * 1e3; + + let t0 = Instant::now(); + let skip_poly = s_skip.compute_skip_poly_forced(k, false); + let t_fd = t0.elapsed().as_secs_f64() * 1e3; + + // Full skip path (FD poly + fold), for the original skip-vs-legacy gate. + let r0: EF = rng.random(); + let lw = lagrange_weights_at::(k, r0); + let e_hat_r0 = e_hat_at(&s_skip.skip_eq_top(k), r0); + let t_fold0 = Instant::now(); + s_skip.process_skip_challenge(k, r0, &lw, e_hat_r0, &skip_poly); + let t_skip = t_fd + t_fold0.elapsed().as_secs_f64() * 1e3; + + // Legacy path: K rounds. + let mut s_legacy = make(col_refs.clone(), air, &eq_factor, &logup_alphas, alpha, n_rows); + let t1 = Instant::now(); + for round in 0..k { + let poly = s_legacy.compute_bare_round_poly(); + s_legacy.process_challenge(EF::from_usize(5 + round), &poly); + } + let t_legacy = t1.elapsed().as_secs_f64() * 1e3; + + println!( + "{label}: skip-poly FD {t_fd:8.2} ms vs Lagrange {t_lagrange:8.2} ms | skip(FD) {t_skip:8.2} ms vs legacy rounds 0..{k} {t_legacy:8.2} ms" + ); + (t_skip, t_legacy, t_fd, t_lagrange) + } + + let (p_skip, p_legacy, p_fd, p_lag) = time_table(Poseidon16Precompile::, 18, k, "poseidon16 2^18x110"); + let (e_skip, e_legacy, e_fd, e_lag) = time_table(ExecutionTable::, 20, k, "execution 2^20x22 "); + + let fd_total = p_fd + e_fd; + let lag_total = p_lag + e_lag; + println!( + "h4 gate: FD {fd_total:.2} ms vs Lagrange {lag_total:.2} ms -> {}", + if fd_total < lag_total { "PASS" } else { "FAIL" } + ); + + let skip_total = p_skip + e_skip; + let legacy_total = p_legacy + e_legacy; + println!( + "combined: skip {skip_total:.2} ms vs legacy {legacy_total:.2} ms -> {}", + if skip_total < legacy_total { "PASS" } else { "FAIL" } + ); + assert!(fd_total < lag_total, "h4 FD-vs-Lagrange timing gate FAILED"); + assert!(skip_total < legacy_total, "GO/NO-GO timing gate FAILED"); +} diff --git a/crates/sub_protocols/tests/air_sumcheck_uniskip_e2e.rs b/crates/sub_protocols/tests/air_sumcheck_uniskip_e2e.rs new file mode 100644 index 00000000..2bd7e2e3 --- /dev/null +++ b/crates/sub_protocols/tests/air_sumcheck_uniskip_e2e.rs @@ -0,0 +1,326 @@ +//! End-to-end prove → verify roundtrip of the front-loaded batched AIR +//! sumcheck with a univariate skip round (plan_spec.md "Protocol spec"). +//! +//! Three real tables of unequal heights run through a shared transcript: +//! execution n = 13 (2 shift columns, analytic padding: non_padded = 2^12) +//! poseidon16 n = 10 (degree-split AIR) +//! extension_op n = 9 +//! The per-table claims s_t are brute-forced independently (naive loop over +//! every row with `eval_extension`), and the verifier's final identity +//! +//! final_target == Σ_t ê(r0) · eq(eq_factor_t[..n_t−K], natural_prefix_t) +//! · C_t(col_evals_t) +//! +//! is checked in full — this test pins the convention T5 copies into +//! verify_execution.rs. Adversarial variants tamper one transcript coefficient +//! of the skip round (must be rejected at the round-0 window identity) and one +//! linear-round coefficient (must be rejected by the final identity; with c0 +//! elision the per-round checks absorb wire tampering by construction). + +use backend::*; +use lean_vm::{ALL_TABLES, EF, ExtraDataForBuses, F, LOG_MAX_BUS_WIDTH, Table, delegate_to_inner}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use sub_protocols::{ + AirSumcheckSession, SkipSession, UniskipAirPoint, compute_shifted_columns, natural_prefix_for_session, + prove_batched_air_sumcheck_uniskip, verify_batched_air_sumcheck_uniskip, +}; + +struct TableData { + table: Table, + log_n: usize, + non_padded: usize, + /// flat columns followed by shift columns, materialized over all 2^log_n rows + cols: Vec>, +} + +fn build_table_data(table: Table, log_n: usize, non_padded: usize, rng: &mut StdRng) -> TableData { + let n_rows = 1usize << log_n; + let n_flat = table.n_columns(); + let mut flat: Vec> = Vec::with_capacity(n_flat); + for _ in 0..n_flat { + let mut col: Vec = (0..non_padded).map(|_| rng.random()).collect(); + // Padding rows all repeat one fixed row (the analytic-padding model). + let pad: F = rng.random(); + col.resize(n_rows, pad); + flat.push(col.into_iter().collect()); + } + let refs: Vec<&[F]> = flat.iter().map(|c| c.as_slice()).collect(); + let shift = compute_shifted_columns(table.n_shift_columns(), &refs); + let mut cols = flat; + cols.extend(shift); + TableData { + table, + log_n, + non_padded, + cols, + } +} + +/// Independent reference: s_t = Σ_x eq(eq_factor_t, x) · C_t(row x), naive. +fn brute_force_sum(td: &TableData, eq_factor: &[EF], extra: &ExtraDataForBuses) -> EF { + let eq_table = eval_eq(eq_factor); + let mut sum = EF::ZERO; + for x in 0..1usize << td.log_n { + let row: Vec = td.cols.iter().map(|c| EF::from(c[x])).collect(); + macro_rules! eval_row { + ($t:expr) => {{ <_ as SumcheckComputation>::eval_extension($t, &row, extra) }}; + } + let c = delegate_to_inner!(&td.table => eval_row); + sum += eq_table[x] * c; + } + sum +} + +/// What the cheating prover perturbs on the wire. +#[derive(Clone, Copy)] +enum Tamper { + None, + /// Add 1 to skip-poly coefficient `i` before sending. + SkipCoeff(usize), + /// Add 1 to combined coefficient `i` (i ≥ 1; c0 is elided) of linear round `r`. + LinearCoeff(usize, usize), +} + +/// Builds sessions, runs the (optionally tampered) prover, then verifies the +/// transcript including the full final identity. Returns Ok(()) iff accepted. +fn roundtrip(k: usize, tamper: Tamper) -> Result<(), String> { + let mut rng = StdRng::seed_from_u64(7 * k as u64 + 1); + let tables: Vec = vec![ + build_table_data(Table::execution(), 13, 1 << 12, &mut rng), + build_table_data(Table::extension_op(), 9, 1 << 9, &mut rng), + build_table_data(Table::poseidon16(), 10, 1 << 10, &mut rng), + ]; + assert_eq!( + tables.iter().map(|t| t.table).collect::>(), + ALL_TABLES.to_vec(), + "session order convention" + ); + let n_max = tables.iter().map(|t| t.log_n).max().unwrap(); + let total_constraints: usize = tables.iter().map(|t| t.table.n_constraints()).sum(); + let max_full_degree = tables.iter().map(|t| t.table.degree_air() + 1).max().unwrap(); + + // ---------------- prover ---------------- + let mut prover_state = ProverState::::new(get_poseidon16().clone(), Default::default()); + prover_state.duplex(); // fresh state has a stale rate; production absorbs a commitment first + let air_alpha: EF = prover_state.sample(); + let air_alpha_powers: Vec = air_alpha.powers().collect_n(total_constraints); + prover_state.duplex(); + let logup_alphas: Vec = prover_state.sample_vec(LOG_MAX_BUS_WIDTH); + let logup_alphas_eq_poly = eval_eq(&logup_alphas); + prover_state.duplex(); + let gkr_point: Vec = prover_state.sample_vec(n_max); + let eq_top = gkr_point[n_max - k..].to_vec(); + + // Per-table eq factors (suffixes of the shared gkr point), claims, extra data. + let mut sums = Vec::new(); + let mut alpha_offset = 0; + let mut extras_for_final = Vec::new(); + let mut sessions: Vec + '_>> = Vec::new(); + for td in &tables { + let eq_factor = gkr_point[n_max - td.log_n..].to_vec(); + let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + td.table.n_constraints()].to_vec(); + alpha_offset += td.table.n_constraints(); + let extra = ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice.clone()); + let s_t = brute_force_sum(td, &eq_factor, &extra); + sums.push(s_t); + extras_for_final.push(ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice)); + + let col_refs: Vec<&[F]> = td.cols.iter().map(|c| c.as_slice()).collect(); + let packed = MleGroupRef::::Base(col_refs).pack(); + macro_rules! make_session { + ($t:expr) => {{ + let s = AirSumcheckSession::new(packed, eq_factor.clone(), s_t, *$t, extra, td.non_padded); + Box::new(s) as Box + '_> + }}; + } + sessions.push(delegate_to_inner!(&td.table => make_session)); + } + + let point = match tamper { + Tamper::None => prove_batched_air_sumcheck_uniskip(&mut prover_state, &mut sessions, k), + _ => prove_tampered(&mut prover_state, &mut sessions, k, tamper), + }; + + // Per-table column openings (transcript-sent, as in prove_execution.rs). + for session in &sessions { + prover_state.add_extension_scalars(&session.final_column_evals()); + } + + // Prover-side sanity on the honest path: the final identity holds locally. + if matches!(tamper, Tamper::None) { + let final_sum: EF = sessions.iter().map(|s| s.sum()).fold(EF::ZERO, |a, b| a + b); + let e_hat_r0 = e_hat_at(&eq_top, point.r0); + let mut check = EF::ZERO; + for (idx, (td, session)) in tables.iter().zip(&sessions).enumerate() { + let eq_factor = &gkr_point[n_max - td.log_n..]; + let prefix = natural_prefix_for_session(&point, td.log_n); + let eq_val = + MultilinearPoint(eq_factor[..td.log_n - k].to_vec()).eq_poly_outside(&MultilinearPoint(prefix.clone())); + let col_evals = session.final_column_evals(); + macro_rules! eval_c { + ($t:expr) => {{ <_ as SumcheckComputation>::eval_extension($t, &col_evals, &extras_for_final[idx]) }}; + } + let c_eval = delegate_to_inner!(&td.table => eval_c); + check += e_hat_r0 * eq_val * c_eval; + } + assert_eq!(check, final_sum, "prover-side final identity"); + } + + // ---------------- verifier ---------------- + let mut verifier_state = + VerifierState::::new(prover_state.into_proof(), get_poseidon16().clone(), Default::default()) + .map_err(|e| format!("{e:?}"))?; + verifier_state.duplex(); + let air_alpha_v: EF = verifier_state.sample(); + let air_alpha_powers_v: Vec = air_alpha_v.powers().collect_n(total_constraints); + verifier_state.duplex(); + let logup_alphas_v: Vec = verifier_state.sample_vec(LOG_MAX_BUS_WIDTH); + let logup_alphas_eq_poly_v = eval_eq(&logup_alphas_v); + verifier_state.duplex(); + let gkr_point_v: Vec = verifier_state.sample_vec(n_max); + assert_eq!(gkr_point_v, gkr_point); + let eq_top_v = gkr_point_v[n_max - k..].to_vec(); + + let table_n_vars: Vec = tables.iter().map(|t| t.log_n).collect(); + let table_degrees: Vec = tables.iter().map(|t| t.table.degree_air()).collect(); + + let (point_v, final_target): (UniskipAirPoint, EF) = verify_batched_air_sumcheck_uniskip( + &mut verifier_state, + k, + &table_n_vars, + &table_degrees, + &sums, + &eq_top_v, + max_full_degree, + ) + .map_err(|e| format!("uniskip verify: {e:?}"))?; + + // Final identity (the formula T5 installs in verify_execution.rs). + let e_hat_r0 = e_hat_at(&eq_top_v, point_v.r0); + let mut alpha_offset = 0; + let mut my_final = EF::ZERO; + for td in &tables { + let n_cols_total = td.table.n_columns() + td.table.n_shift_columns(); + let col_evals = verifier_state + .next_extension_scalars_vec(n_cols_total) + .map_err(|e| format!("{e:?}"))?; + let alpha_slice = air_alpha_powers_v[alpha_offset..alpha_offset + td.table.n_constraints()].to_vec(); + alpha_offset += td.table.n_constraints(); + let extra = ExtraDataForBuses::new(&logup_alphas_eq_poly_v, alpha_slice); + macro_rules! eval_c { + ($t:expr) => {{ <_ as SumcheckComputation>::eval_extension($t, &col_evals, &extra) }}; + } + let c_eval = delegate_to_inner!(&td.table => eval_c); + + let eq_factor = &gkr_point_v[n_max - td.log_n..]; + let prefix = natural_prefix_for_session(&point_v, td.log_n); + let eq_val = MultilinearPoint(eq_factor[..td.log_n - k].to_vec()).eq_poly_outside(&MultilinearPoint(prefix)); + my_final += e_hat_r0 * eq_val * c_eval; + } + if my_final != final_target { + return Err("final identity mismatch".to_string()); + } + Ok(()) +} + +/// Cheating-prover replica of `prove_batched_air_sumcheck_uniskip`: identical +/// schedule, but perturbs one wire coefficient. Sessions consume the challenges +/// of the tampered transcript (the natural cheating model). +fn prove_tampered<'a>( + prover_state: &mut impl FSProver, + sessions: &mut [Box + 'a>], + k: usize, + tamper: Tamper, +) -> UniskipAirPoint { + let n_max = sessions.iter().map(|s| s.initial_n_vars()).max().unwrap(); + let max_full_degree = sessions.iter().map(|s| s.bare_degree() + 1).max().unwrap(); + let d_max = sessions.iter().map(|s| s.bare_degree()).max().unwrap(); + let n_skip_coeffs = ((1usize << k) - 1) * d_max + 1; + let weights: Vec = sessions + .iter() + .map(|s| EF::from_usize(1 << (n_max - s.initial_n_vars()))) + .collect(); + let eq_top = sessions[0].skip_eq_top(k); + + let skip_polys: Vec> = sessions.iter_mut().map(|s| s.compute_skip_poly(k)).collect(); + let mut combined_skip = EF::zero_vec(n_skip_coeffs); + for (poly, &w_t) in skip_polys.iter().zip(&weights) { + for (acc, &c) in combined_skip.iter_mut().zip(&poly.coeffs) { + *acc += w_t * c; + } + } + if let Tamper::SkipCoeff(i) = tamper { + combined_skip[i] += EF::ONE; + } + prover_state.add_extension_scalars(&combined_skip); + let r0: EF = prover_state.sample(); + let lagrange_weights = lagrange_weights_at::(k, r0); + let e_hat_r0 = e_hat_at(&eq_top, r0); + for (session, poly) in sessions.iter_mut().zip(&skip_polys) { + session.process_skip_challenge(k, r0, &lagrange_weights, e_hat_r0, poly); + } + + let n_linear = n_max - k; + let mut linear_challenges = Vec::with_capacity(n_linear); + for r in 0..n_linear { + let mut combined_coeffs = EF::zero_vec(max_full_degree + 1); + let mut bare_polys: Vec>> = vec![None; sessions.len()]; + for (idx, session) in sessions.iter_mut().enumerate() { + let n_own = session.initial_n_vars() - k; + if r < n_own { + let bare = session.compute_bare_round_poly(); + let full = expand_bare_to_full(&bare.coeffs, session.eq_alpha()); + for (acc, &c) in combined_coeffs.iter_mut().zip(&full) { + *acc += weights[idx] * c; + } + bare_polys[idx] = Some(bare); + } else { + combined_coeffs[0] += session.sum() * EF::from_usize(1 << (n_linear - r - 1)); + } + } + if let Tamper::LinearCoeff(tr, i) = tamper + && tr == r + { + assert!(i >= 1, "c0 is elided; tamper a sent coefficient"); + combined_coeffs[i] += EF::ONE; + } + prover_state.add_sumcheck_polynomial(&combined_coeffs, None); + let challenge = prover_state.sample(); + linear_challenges.push(challenge); + for (idx, session) in sessions.iter_mut().enumerate() { + if let Some(bare) = &bare_polys[idx] { + session.process_challenge(challenge, bare); + } + } + } + UniskipAirPoint { + r0, + lagrange_weights, + linear_challenges, + } +} + +#[test] +fn test_uniskip_e2e_roundtrip() { + for k in [3, 4] { + roundtrip(k, Tamper::None).unwrap_or_else(|e| panic!("k={k}: {e}")); + } +} + +#[test] +fn test_uniskip_e2e_rejects_tampered_skip_coeff() { + for i in [0, 7] { + let err = roundtrip(4, Tamper::SkipCoeff(i)).expect_err("tampered skip coeff must be rejected"); + assert!(err.contains("uniskip verify"), "expected round-0 rejection, got: {err}"); + } +} + +#[test] +fn test_uniskip_e2e_rejects_tampered_linear_coeff() { + // With c0 elision the per-round identity absorbs wire perturbations; the + // corruption must surface at the final identity. + for (r, i) in [(0, 1), (4, 3)] { + let err = roundtrip(4, Tamper::LinearCoeff(r, i)).expect_err("tampered linear coeff must be rejected"); + assert!(!err.is_empty()); + } +} diff --git a/crates/whir/src/lib.rs b/crates/whir/src/lib.rs index a3e84f9c..2ec6073f 100644 --- a/crates/whir/src/lib.rs +++ b/crates/whir/src/lib.rs @@ -34,6 +34,11 @@ pub struct SparseStatement { pub values: Vec>, /// When true, the weight polynomial is `next_mle(point, .)` instead of `eq(point, .)`. pub is_next: bool, + /// Optional tensor tail occupying the lowest `log2(tail.len())` inner + /// variables: the inner weight becomes `eq(point, hi) * MLE(tail)(lo)` + /// where `lo` indexes the fastest-varying coordinates (for `is_next`, + /// the shift-by-one of that weight). `tail.len()` must be a power of two. + pub tail: Option>, } impl SparseStatement { @@ -49,6 +54,7 @@ impl SparseStatement { point, values, is_next: false, + tail: None, } } @@ -64,15 +70,52 @@ impl SparseStatement { point, values, is_next: true, + tail: None, } } + pub fn new_with_tail( + total_num_variables: usize, + point: MultilinearPoint, + tail: Vec, + values: Vec>, + ) -> Self { + let mut smt = Self::new(total_num_variables, point, values); + smt.set_tail(tail); + smt + } + + pub fn new_next_with_tail( + total_num_variables: usize, + point: MultilinearPoint, + tail: Vec, + values: Vec>, + ) -> Self { + let mut smt = Self::new_next(total_num_variables, point, values); + smt.set_tail(tail); + smt + } + + fn set_tail(&mut self, tail: Vec) { + assert!(tail.len().is_power_of_two(), "tail length must be a power of two"); + let tail_log = tail.len().trailing_zeros() as usize; + assert!( + self.total_num_variables >= self.point.len() + tail_log, + "total_num_variables ({}) must be >= point.len() ({}) + tail_log ({})", + self.total_num_variables, + self.point.len(), + tail_log + ); + self.tail = Some(tail); + } + pub fn unique_value(total_num_variables: usize, index: usize, value: EF) -> Self { Self { total_num_variables, point: MultilinearPoint(vec![]), values: vec![SparseValue { selector: index, value }], is_next: false, + tail: None, } } @@ -82,6 +125,7 @@ impl SparseStatement { point, values: vec![SparseValue { selector: 0, value }], is_next: false, + tail: None, } } @@ -91,8 +135,12 @@ impl SparseStatement { .expect("invariant violated: total_num_variables < point.len()") } + pub fn tail_num_variables(&self) -> usize { + self.tail.as_ref().map_or(0, |t| t.len().trailing_zeros() as usize) + } + pub fn inner_num_variables(&self) -> usize { - self.point.len() + self.point.len() + self.tail_num_variables() } } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 358d2381..eb305839 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -524,7 +524,11 @@ where let out_len = 1 << (num_variables - packing_log_width::()); let is_full = |s: &SparseStatement| { - !s.is_next && s.values.len() == 1 && s.values[0].selector == 0 && s.inner_num_variables() == num_variables + s.tail.is_none() + && !s.is_next + && s.values.len() == 1 + && s.values[0].selector == 0 + && s.inner_num_variables() == num_variables }; let mut combined_weights: ArenaVec>; @@ -556,18 +560,21 @@ where }; for smt in &statements[start_idx..] { - if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { + if smt.tail.is_none() + && !smt.is_next + && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) + { for evaluation in &smt.values { compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow); combined_sum += evaluation.value * gamma_pow; gamma_pow *= gamma; } } else { - let inner_poly: ArenaVec> = if smt.is_next { - let next = matrix_next_mle_folded(&smt.point.0); - pack_extension(&next) - } else { - eval_eq_packed(&smt.point) + let inner_poly: ArenaVec> = match (&smt.tail, smt.is_next) { + (Some(tail), true) => pack_extension(&matrix_next_mle_folded_with_tail(&smt.point.0, tail)), + (Some(tail), false) => eval_eq_packed_with_tail(&smt.point.0, tail), + (None, true) => pack_extension(&matrix_next_mle_folded(&smt.point.0)), + (None, false) => eval_eq_packed(&smt.point), }; let shift = smt.inner_num_variables() - packing_log_width::(); let mut indexed_smt_values = smt.values.iter().enumerate().collect::>(); diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 9ec784ac..14bffda0 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -362,10 +362,15 @@ where let mut i = 0; for smt in constraints { let inner_point = &point[point.len() - smt.inner_num_variables()..]; - let common_weight = if smt.is_next { - next_mle(&smt.point.0, inner_point) - } else { - smt.point.eq_poly_outside(&MultilinearPoint(inner_point.to_vec())) + let common_weight = match (&smt.tail, smt.is_next) { + (Some(tail), true) => next_mle_with_tail(&smt.point.0, tail, inner_point), + (Some(tail), false) => { + let (prefix, low) = inner_point.split_at(smt.point.len()); + smt.point.eq_poly_outside(&MultilinearPoint(prefix.to_vec())) + * tail.evaluate(&MultilinearPoint(low.to_vec())) + } + (None, true) => next_mle(&smt.point.0, inner_point), + (None, false) => smt.point.eq_poly_outside(&MultilinearPoint(inner_point.to_vec())), }; for e in &smt.values { let eval = (0..smt.selector_num_variables()) @@ -389,6 +394,10 @@ where } fn verify_constraint_coeffs(constraint: &SparseStatement, coeffs: &[EF]) -> bool { + debug_assert!( + constraint.tail.is_none(), + "tailed statements are never STIR constraints" + ); assert_eq!(constraint.selector_num_variables(), 0); let alpha = constraint.point[0]; // Verify the point is expand_from_univariate(alpha, n): [alpha, alpha^2, alpha^4, ...] diff --git a/crates/whir/tests/tensor_tail.rs b/crates/whir/tests/tensor_tail.rs new file mode 100644 index 00000000..4c605170 --- /dev/null +++ b/crates/whir/tests/tensor_tail.rs @@ -0,0 +1,209 @@ +//! T4 (pw13 h1): tensor-tail SparseStatements — weight = eq(point) ⊗ MLE(tail) +//! on the lowest log2(tail.len()) inner variables (next variant: shift-by-one). + +use fiat_shamir::{ProverState, VerifierState}; +use field::{PrimeCharacteristicRing, TwoAdicField}; +use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; +use poly::*; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use whir::*; +use zk_alloc::ArenaVec; + +type F = KoalaBear; +type EF = QuinticExtensionFieldKB; + +const NUM_VARIABLES: usize = 18; + +fn whir_config() -> WhirConfig { + let params = WhirConfigBuilder { + security_level: 124, + max_num_variables_to_send_coeffs: 9, + pow_bits: 16, + folding_factor: FoldingFactor::new(7, 4), + soundness_type: SecurityAssumption::JohnsonBound, + starting_log_inv_rate: 2, + rs_domain_initial_reduction_factor: 5, + }; + WhirConfig::new(¶ms, NUM_VARIABLES) +} + +fn random_poly(seed: u64) -> Vec { + let mut rng = StdRng::seed_from_u64(seed); + (0..1 << NUM_VARIABLES).map(|_| rng.random::()).collect() +} + +/// Full commit+open+verify roundtrip; returns Ok(()) iff the verifier accepts +/// `verify_statements` against a proof generated for `prove_statements`. +fn roundtrip( + polynomial: &[F], + prove_statements: Vec>, + verify_statements: Vec>, +) -> Result<(), fiat_shamir::ProofError> { + let poseidon16 = default_koalabear_poseidon1_16(); + let params = whir_config(); + precompute_dft_twiddles::(1 << F::TWO_ADICITY); + + let mut prover_state = ProverState::new(poseidon16.clone(), Default::default()); + let mle: MleOwned = MleOwned::Base(ArenaVec::from_iter(polynomial.to_vec())); + let witness = params.commit(&mut prover_state, &mle, 1 << NUM_VARIABLES); + params.prove(&mut prover_state, prove_statements, witness, &mle.by_ref()); + + let mut verifier_state = + VerifierState::::new(prover_state.into_proof(), poseidon16, Default::default()).unwrap(); + let parsed_commitment = params.parse_commitment::(&mut verifier_state)?; + params + .verify::(&mut verifier_state, &parsed_commitment, verify_statements) + .map(|_| ()) +} + +/// Naive reference: Σ_{hi,lo} eq(point)[hi]·tail[lo]·f[(selector << inner) | (hi << k) | lo]. +fn naive_tailed_value(polynomial: &[F], selector: usize, point: &[EF], tail: &[EF]) -> EF { + let k = tail.len().trailing_zeros() as usize; + let inner = point.len() + k; + let chunk = &polynomial[selector << inner..][..1 << inner]; + let eq = eval_eq(point); + let mut sum = EF::ZERO; + for (hi, &w_hi) in eq.iter().enumerate() { + for (lo, &w_lo) in tail.iter().enumerate() { + sum += w_hi * w_lo * chunk[(hi << k) | lo]; + } + } + sum +} + +#[test] +fn test_tail_equivalent_to_point_append() { + let polynomial = random_poly(1); + let mut rng = StdRng::seed_from_u64(2); + let p: Vec = (0..12).map(|_| rng.random()).collect(); + let c: Vec = (0..4).map(|_| rng.random()).collect(); + let selector = 3usize; // 2 selector variables + let appended = MultilinearPoint([p.clone(), c.clone()].concat()); + let value = polynomial.evaluate_sparse(selector, &appended); + + let tail: Vec = eval_eq(&c).to_vec(); + assert_eq!( + naive_tailed_value(&polynomial, selector, &p, &tail), + value, + "naive reference disagrees with evaluate_sparse" + ); + + let plain = SparseStatement::new(NUM_VARIABLES, appended, vec![SparseValue::new(selector, value)]); + let tailed = SparseStatement::new_with_tail( + NUM_VARIABLES, + MultilinearPoint(p), + tail, + vec![SparseValue::new(selector, value)], + ); + assert_eq!(plain.inner_num_variables(), tailed.inner_num_variables()); + assert_eq!(plain.selector_num_variables(), tailed.selector_num_variables()); + + let statements = vec![plain, tailed]; + roundtrip(&polynomial, statements.clone(), statements).expect("equivalence roundtrip must accept"); +} + +#[test] +fn test_random_tail_cube_sum() { + let polynomial = random_poly(3); + let mut rng = StdRng::seed_from_u64(4); + let p: Vec = (0..13).map(|_| rng.random()).collect(); + let tail: Vec = (0..8).map(|_| rng.random()).collect(); // k = 3, NOT an eq expansion + let selector = 2usize; // 2 selector variables + let value = naive_tailed_value(&polynomial, selector, &p, &tail); + + let make = |v: EF| { + vec![SparseStatement::new_with_tail( + NUM_VARIABLES, + MultilinearPoint(p.clone()), + tail.clone(), + vec![SparseValue::new(selector, v)], + )] + }; + roundtrip(&polynomial, make(value), make(value)).expect("random-tail roundtrip must accept"); + assert!( + roundtrip(&polynomial, make(value), make(value + EF::ONE)).is_err(), + "corrupted tailed value must be rejected" + ); +} + +#[test] +fn test_next_tail_equivalent_to_point_append() { + let polynomial = random_poly(5); + let mut rng = StdRng::seed_from_u64(6); + let p: Vec = (0..12).map(|_| rng.random()).collect(); + let c: Vec = (0..4).map(|_| rng.random()).collect(); + let selector = 1usize; // 2 selector variables + let appended = [p.clone(), c.clone()].concat(); + let inner = appended.len(); + + // Reference: dot(matrix_next_mle_folded(point), chunk). + let chunk = &polynomial[selector << inner..][..1 << inner]; + let weights = matrix_next_mle_folded(&appended); + let value: EF = weights.iter().zip(chunk).map(|(&w, &f)| w * f).sum(); + + // T1 identity cross-check on the tailed weights. + let tail: Vec = eval_eq(&c).to_vec(); + let tailed_weights = matrix_next_mle_folded_with_tail(&p, &tail); + let tailed_value: EF = tailed_weights.iter().zip(chunk).map(|(&w, &f)| w * f).sum(); + assert_eq!(tailed_value, value, "next-with-tail weights disagree with point-append"); + + let plain = SparseStatement::new_next( + NUM_VARIABLES, + MultilinearPoint(appended), + vec![SparseValue::new(selector, value)], + ); + let tailed = SparseStatement::new_next_with_tail( + NUM_VARIABLES, + MultilinearPoint(p), + tail, + vec![SparseValue::new(selector, value)], + ); + let statements = vec![plain, tailed]; + roundtrip(&polynomial, statements.clone(), statements).expect("next equivalence roundtrip must accept"); +} + +#[test] +fn test_tailed_mixed_with_ordinary_statements() { + let polynomial = random_poly(7); + let mut rng = StdRng::seed_from_u64(8); + + // Dense full-width statement (exercises the is_full fast path alongside tails). + let full_point: Vec = (0..NUM_VARIABLES).map(|_| rng.random()).collect(); + let dense = SparseStatement::dense( + MultilinearPoint(full_point.clone()), + polynomial.evaluate(&MultilinearPoint(full_point)), + ); + + // Plain sparse statement. + let sp: Vec = (0..15).map(|_| rng.random()).collect(); + let plain = SparseStatement::new( + NUM_VARIABLES, + MultilinearPoint(sp.clone()), + vec![SparseValue::new( + 5, + polynomial.evaluate_sparse(5, &MultilinearPoint(sp)), + )], + ); + + // Tailed eq statement (k = 4). + let p1: Vec = (0..10).map(|_| rng.random()).collect(); + let t1: Vec = (0..16).map(|_| rng.random()).collect(); + let tailed_eq = SparseStatement::new_with_tail( + NUM_VARIABLES, + MultilinearPoint(p1.clone()), + t1.clone(), + vec![SparseValue::new(2, naive_tailed_value(&polynomial, 2, &p1, &t1))], + ); + + // Tailed next statement (k = 3). + let p2: Vec = (0..13).map(|_| rng.random()).collect(); + let t2: Vec = (0..8).map(|_| rng.random()).collect(); + let w2 = matrix_next_mle_folded_with_tail(&p2, &t2); + let chunk2 = &polynomial[1 << 16..][..1 << 16]; + let v2: EF = w2.iter().zip(chunk2).map(|(&w, &f)| w * f).sum(); + let tailed_next = + SparseStatement::new_next_with_tail(NUM_VARIABLES, MultilinearPoint(p2), t2, vec![SparseValue::new(1, v2)]); + + let statements = vec![dense, plain, tailed_eq, tailed_next]; + roundtrip(&polynomial, statements.clone(), statements).expect("mixed roundtrip must accept"); +}