Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/lean_compiler/tests/test_data/program_179.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ def main():
output = Array(5)
input[0] = 1
input[4] = 5
copy_5(input, output)
copy_ef(input, output)
assert output[0] == 1
assert output[4] == 5
return


@inline
def copy_5(a, b):
def copy_ef(a, b):
dot_product_ee(a, ONE_EF_PTR, b)
return

Expand Down
2 changes: 1 addition & 1 deletion crates/lean_vm/src/tables/extension_op/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn solve_unknowns(
let c = memory.get_ef_element(addr_res);

if op == ExtensionOp::DotProduct && !flag_be {
// detect "copy_5"
// detect "copy_ef" (single EF-element copy: dot_product_ee against EF::ONE)
if b == Ok(EF::ONE) {
memory.make_slices_equal_and_defined(ptr_a.to_usize(), ptr_res.to_usize(), DIMENSION)?;
return Ok(());
Expand Down
22 changes: 6 additions & 16 deletions crates/rec_aggregation/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree
let mut one_buses_data_offsets = vec![];
let mut one_buses_new_cols = vec![];
let mut num_cols_air = vec![];
let mut air_degrees = vec![];
let mut n_air_columns = vec![];
let mut n_air_shift_columns = vec![];
let mut n_air_constraints = vec![];
Expand Down Expand Up @@ -334,7 +333,6 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree
));

num_cols_air.push(table.n_columns().to_string());
air_degrees.push(table.degree_air().to_string());
n_air_columns.push(table.n_columns().to_string());
n_air_shift_columns.push(table.n_shift_columns().to_string());
n_air_constraints.push(table.n_constraints().to_string());
Expand Down Expand Up @@ -387,10 +385,6 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree
"AIR_ALPHA_OFFSETS_PLACEHOLDER".to_string(),
format!("[{}]", air_alpha_offsets.join(", ")),
);
replacements.insert(
"AIR_DEGREES_PLACEHOLDER".to_string(),
format!("[{}]", air_degrees.join(", ")),
);
replacements.insert(
"MAX_AIR_FULL_DEGREE_PLACEHOLDER".to_string(),
(ALL_TABLES.iter().map(|t| t.degree_air()).max().unwrap() + 1).to_string(),
Expand All @@ -411,10 +405,6 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree
"N_INSTRUCTION_COLUMNS_PLACEHOLDER".to_string(),
N_INSTRUCTION_COLUMNS.to_string(),
);
replacements.insert(
"N_COMMITTED_EXEC_COLUMNS_PLACEHOLDER".to_string(),
N_RUNTIME_COLUMNS.to_string(),
);
replacements.insert(
"TOTAL_WHIR_STATEMENTS_PLACEHOLDER".to_string(),
total_whir_statements().to_string(),
Expand Down Expand Up @@ -532,7 +522,7 @@ where
let mut ctx = AirCodegenCtx::new();

let mut res = format!(
"def evaluate_air_constraints_table_{}({}, air_alpha_powers, logup_alphas_eq_poly):\n",
"def evaluate_air_constraints_table_{}({}, air_alpha_powers, logup_beta_eq_poly):\n",
table.table().index(),
AIR_INNER_VALUES_VAR
);
Expand All @@ -549,17 +539,17 @@ where
res += &format!("\n buff = Array(DIM * {})", bus_real_data.len());
for (i, data) in bus_real_data.iter().enumerate() {
let data_str = eval_air_constraint(*data, None, &mut ctx, &mut res);
res += &format!("\n copy_5({}, buff + DIM * {})", data_str, i);
res += &format!("\n copy_ef({}, buff + DIM * {})", data_str, i);
}
let domainsep_str = eval_air_constraint(*bus_domainsep, None, &mut ctx, &mut res);
// bus_res = sum(buff[i] * logup_alphas_eq_poly[i]) + disc * logup_alphas_eq_poly.last()
// bus_res = sum(buff[i] * logup_beta_eq_poly[i]) + disc * logup_beta_eq_poly.last()
res += "\n bus_res_init = Array(DIM)";
res += &format!(
"\n dot_product_ee(buff, logup_alphas_eq_poly, bus_res_init, {})",
"\n dot_product_ee(buff, logup_beta_eq_poly, bus_res_init, {})",
bus_real_data.len()
);
res += &format!(
"\n bus_res: Mut = add_extension_ret(mul_extension_ret({}, logup_alphas_eq_poly + {} * DIM), bus_res_init)",
"\n bus_res: Mut = add_extension_ret(mul_extension_ret({}, logup_beta_eq_poly + {} * DIM), bus_res_init)",
domainsep_str,
(1 << LOG_MAX_BUS_WIDTH) - 1
);
Expand Down Expand Up @@ -618,7 +608,7 @@ fn eval_air_constraint(
if let Some(d) = dest
&& v != d
{
res.push_str(&format!("\n copy_5({}, {})", v, d));
res.push_str(&format!("\n copy_ef({}, {})", v, d));
}
v
}
Expand Down
17 changes: 0 additions & 17 deletions crates/rec_aggregation/zkdsl_implem/fiat_shamir.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,6 @@ def fs_sample_many_ef(fs, n):
return new_fs, sampled


@inline
def fs_hint(fs, n):
# Hint = read `n` cells from the transcript without absorbing them. Just advance the
# transcript pointer; the sponge state is unchanged.
new_fs = Array(17)
copy_8(new_fs, fs)
copy_8(new_fs + 8, fs + 8)
new_fs[16] = fs[16] + n
return new_fs, fs[16]


def fs_receive_chunks(fs, n_chunks: Const):
# Read n_chunks * 8 cells from the transcript and absorb them. Returns the new fs
# and a pointer to the just-consumed transcript region.
Expand Down Expand Up @@ -190,12 +179,6 @@ def fs_receive_ef(fs, n: Const):
return new_fs, ef_ptr


def fs_print_state(fs_state):
for i in unroll(0, 17):
print(i, fs_state[i])
return


@inline
def fs_sample_queries(fs, n_samples):
# Sample `n_samples` query bit-strings. Each chunk yields 8 base field elements that
Expand Down
106 changes: 3 additions & 103 deletions crates/rec_aggregation/zkdsl_implem/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
DIGEST_LEN = 8

# memory layout: [public_input (PUBLIC_INPUT_LEN)] [preamble_memory (PREAMBLE_MEMORY_LEN)] [runtime ...]
# `preamble_memory` is a region that is filled by the guest program, with usefull constants [0000...][1000...]...
# `preamble_memory` is a region that is filled by the guest program, with useful constants [0000...][1000...]...
PUBLIC_INPUT_LEN = DIGEST_LEN
PARTIAL_UNROLL_BATCH = 64
ZERO_VEC_PTR = PUBLIC_INPUT_LEN
Expand All @@ -17,41 +17,6 @@
PREAMBLE_MEMORY_LEN = PREAMBLE_MEMORY_END - PUBLIC_INPUT_LEN


def batch_hash_slice_rtl_with_iv(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks):
if num_chunks == DIM * 2:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2)
return
if num_chunks == 16:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 16)
return
if num_chunks == 8:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 8)
return
if num_chunks == 20:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 20)
return
if num_chunks == 1:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 1)
return
if num_chunks == 4:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 4)
return
if num_chunks == 5:
batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 5)
return
print(num_chunks)
assert False, "batch_hash_slice called with unsupported len"


def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const):
iv = build_iv(num_chunks * DIGEST_LEN)
for i in range(0, num_queries):
data = all_data_to_hash[i]
res = slice_hash_rtl(data, num_chunks, iv)
all_resulting_hashes[i] = res
return


# IV for the sponge: [slice length in field elements, 0, 0, ..., 0]
@inline
def build_iv(length):
Expand Down Expand Up @@ -129,7 +94,7 @@ def slice_hash_continue(running, data, num_chunks):


@inline
def euclidian_div_runtime(a, b):
def euclidean_div_runtime(a, b):
# Returns (q, r) with q = floor(a / b) and r = a mod b.
# Requires:
# 1 <= b < 2^14
Expand Down Expand Up @@ -166,7 +131,7 @@ def slice_hash_runtime(data, num_chunks):
poseidon16_permute_half(iv, data, states)
n_iters = num_chunks - 2

n_chunks_outer, remainder = euclidian_div_runtime(n_iters, PARTIAL_UNROLL_BATCH)
n_chunks_outer, remainder = euclidean_div_runtime(n_iters, PARTIAL_UNROLL_BATCH)
carry = Array((n_chunks_outer + 1) * 2)
carry[0] = states
carry[1] = data + DIGEST_LEN
Expand Down Expand Up @@ -282,68 +247,3 @@ def whir_do_1_merkle_level(b, state_in, path_chunk, state_out):
else:
poseidon16_compress_half(path_chunk, state_in, state_out)
return


def merkle_verif_batch(merkle_paths, leaves_digests, leave_positions, root, height, num_queries):
match_range(
height,
range(10, 26),
lambda h: merkle_verif_batch_const(
num_queries,
merkle_paths,
leaves_digests,
leave_positions,
root,
h,
),
)
return


def merkle_verif_batch_const(n_paths, merkle_paths, leaves_digests, leave_positions, root, height: Const):
# n_paths: F
# leaves_digests: pointer to a slice of n_paths pointers, each pointing to 1 chunk of 8 field elements
# leave_positions: pointer to a slice of n_paths field elements (each < 2^height)
# root: pointer to 1 chunk of 8 field elements
# height: F

for i in range(0, n_paths):
merkle_verify(
leaves_digests[i],
merkle_paths + (i * height) * DIGEST_LEN,
leave_positions[i],
root,
height,
)

return


def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Const):
states = Array(height * DIGEST_LEN)

# First merkle round
match leaf_position_bits[0]:
case 0:
poseidon16_compress_half(leaf_digest, merkle_path, states)
case 1:
poseidon16_compress_half(merkle_path, leaf_digest, states)

# Remaining merkle rounds
for j in unroll(1, height):
# Warning: this works only if leaf_position_bits[i] is known to be boolean:
match leaf_position_bits[j]:
case 0:
poseidon16_compress_half(
states + (j - 1) * DIGEST_LEN,
merkle_path + j * DIGEST_LEN,
states + j * DIGEST_LEN,
)
case 1:
poseidon16_compress_half(
merkle_path + j * DIGEST_LEN,
states + (j - 1) * DIGEST_LEN,
states + j * DIGEST_LEN,
)
copy_8(states + (height - 1) * DIGEST_LEN, root)
return
6 changes: 3 additions & 3 deletions crates/rec_aggregation/zkdsl_implem/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def main():

running_hash: Mut = build_iv(n_sub * PUB_KEY_SIZE)
n_first = n_sub - 1
n_chunks, remainder = euclidian_div_runtime(n_first, PARTIAL_UNROLL_BATCH)
n_chunks, remainder = euclidean_div_runtime(n_first, PARTIAL_UNROLL_BATCH)
pubkey_idx: Mut = 0
inner_carry = Array((n_chunks + 1) * 3)
inner_carry[0] = counter
Expand Down Expand Up @@ -299,7 +299,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou
all_values = Array(n_bytecode_claims * DIM)
for i in range(0, n_bytecode_claims):
claim_ptr = bytecode_claims[i]
copy_5(claim_ptr + BYTECODE_POINT_N_VARS * DIM, all_values + i * DIM)
copy_ef(claim_ptr + BYTECODE_POINT_N_VARS * DIM, all_values + i * DIM)

claimed_sum = Array(DIM)
dot_product_ee_dynamic(all_values, alpha_powers, claimed_sum, n_bytecode_claims)
Expand All @@ -316,7 +316,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou
bytecode_value_at_r = div_extension_ret(final_eval, w_r)

copy_many_ef(challenges, bytecode_claim_output, BYTECODE_POINT_N_VARS)
copy_5(bytecode_value_at_r, bytecode_claim_output + BYTECODE_POINT_N_VARS * DIM)
copy_ef(bytecode_value_at_r, bytecode_claim_output + BYTECODE_POINT_N_VARS * DIM)
return


Expand Down
Loading
Loading