diff --git a/.github/workflows/hol_light.yml b/.github/workflows/hol_light.yml index c60985c5a..fcc7df43d 100644 --- a/.github/workflows/hol_light.yml +++ b/.github/workflows/hol_light.yml @@ -12,6 +12,7 @@ on: - 'proofs/hol_light/aarch64/Makefile' - 'proofs/hol_light/aarch64/**/*.S' - 'proofs/hol_light/aarch64/**/*.ml' + - 'proofs/hol_light/common/**/*.ml' - 'proofs/hol_light/x86_64/Makefile' - 'proofs/hol_light/x86_64/**/*.S' - 'proofs/hol_light/x86_64/**/*.ml' @@ -26,6 +27,7 @@ on: - 'proofs/hol_light/aarch64/Makefile' - 'proofs/hol_light/aarch64/**/*.S' - 'proofs/hol_light/aarch64/**/*.ml' + - 'proofs/hol_light/common/**/*.ml' - 'proofs/hol_light/x86_64/Makefile' - 'proofs/hol_light/x86_64/**/*.S' - 'proofs/hol_light/x86_64/**/*.ml' @@ -81,6 +83,16 @@ jobs: # Dependencies on {name}.{S,ml} are implicit - name: mldsa_poly_caddq needs: ["aarch64_utils.ml"] + - name: mldsa_poly_chknorm + needs: ["aarch64_utils.ml"] + - name: mldsa_polyz_unpack_17 + needs: ["aarch64_utils.ml", "mldsa_polyz_unpack_consts.ml", "mldsa_specs.ml"] + - name: mldsa_polyz_unpack_19 + needs: ["aarch64_utils.ml", "mldsa_polyz_unpack_consts.ml", "mldsa_specs.ml"] + - name: mldsa_poly_decompose_32 + needs: ["aarch64_utils.ml", "mldsa_specs.ml"] + - name: mldsa_poly_decompose_88 + needs: ["aarch64_utils.ml", "mldsa_specs.ml"] name: AArch64 HOL Light proof for ${{ matrix.proof.name }}.S runs-on: pqcp-arm64 if: github.repository_owner == 'pq-code-package' && !github.event.pull_request.head.repo.fork diff --git a/dev/aarch64_clean/src/arith_native_aarch64.h b/dev/aarch64_clean/src/arith_native_aarch64.h index f78a1487d..e46b989dd 100644 --- a/dev/aarch64_clean/src/arith_native_aarch64.h +++ b/dev/aarch64_clean/src/arith_native_aarch64.h @@ -71,10 +71,36 @@ uint64_t mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf, unsigned buflen, const uint8_t *table); #define mld_poly_decompose_32_asm MLD_NAMESPACE(poly_decompose_32_asm) -void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 16 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 32)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 16)) + /* check-magic: 261889 == (MLDSA_Q - 1) / 32 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 261889)) +); #define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm) -void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 44 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 88)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 44)) + /* check-magic: 95233 == (MLDSA_Q - 1) / 88 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 95233)) +); #define mld_poly_caddq_asm MLD_NAMESPACE(poly_caddq_asm) void mld_poly_caddq_asm(int32_t *a) @@ -95,15 +121,42 @@ void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) MLD_MUST_CHECK_RETURN_VALUE -int mld_poly_chknorm_asm(const int32_t *a, int32_t B); +int mld_poly_chknorm_asm(const int32_t *a, int32_t B) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + /* HOL Light precondition: abs(ival(x i)) < 2^31, i.e., a[i] != INT32_MIN */ + requires(forall(k0, 0, MLDSA_N, a[k0] > INT32_MIN)) + ensures(return_value == 0 || return_value == 1) + ensures((return_value == 0) == array_abs_bound(a, 0, MLDSA_N, B)) +); #define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm) void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 576)) + requires(indices == mld_polyz_unpack_17_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 17) - 1), (1 << 17) + 1)) +); #define mld_polyz_unpack_19_asm MLD_NAMESPACE(polyz_unpack_19_asm) void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 640)) + requires(indices == mld_polyz_unpack_19_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 19) - 1), (1 << 19) + 1)) +); #define mld_poly_pointwise_montgomery_asm \ MLD_NAMESPACE(poly_pointwise_montgomery_asm) diff --git a/dev/aarch64_opt/src/arith_native_aarch64.h b/dev/aarch64_opt/src/arith_native_aarch64.h index f78a1487d..e46b989dd 100644 --- a/dev/aarch64_opt/src/arith_native_aarch64.h +++ b/dev/aarch64_opt/src/arith_native_aarch64.h @@ -71,10 +71,36 @@ uint64_t mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf, unsigned buflen, const uint8_t *table); #define mld_poly_decompose_32_asm MLD_NAMESPACE(poly_decompose_32_asm) -void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 16 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 32)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 16)) + /* check-magic: 261889 == (MLDSA_Q - 1) / 32 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 261889)) +); #define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm) -void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 44 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 88)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 44)) + /* check-magic: 95233 == (MLDSA_Q - 1) / 88 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 95233)) +); #define mld_poly_caddq_asm MLD_NAMESPACE(poly_caddq_asm) void mld_poly_caddq_asm(int32_t *a) @@ -95,15 +121,42 @@ void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) MLD_MUST_CHECK_RETURN_VALUE -int mld_poly_chknorm_asm(const int32_t *a, int32_t B); +int mld_poly_chknorm_asm(const int32_t *a, int32_t B) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + /* HOL Light precondition: abs(ival(x i)) < 2^31, i.e., a[i] != INT32_MIN */ + requires(forall(k0, 0, MLDSA_N, a[k0] > INT32_MIN)) + ensures(return_value == 0 || return_value == 1) + ensures((return_value == 0) == array_abs_bound(a, 0, MLDSA_N, B)) +); #define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm) void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 576)) + requires(indices == mld_polyz_unpack_17_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 17) - 1), (1 << 17) + 1)) +); #define mld_polyz_unpack_19_asm MLD_NAMESPACE(polyz_unpack_19_asm) void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 640)) + requires(indices == mld_polyz_unpack_19_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 19) - 1), (1 << 19) + 1)) +); #define mld_poly_pointwise_montgomery_asm \ MLD_NAMESPACE(poly_pointwise_montgomery_asm) diff --git a/dev/aarch64_opt/src/poly_chknorm_asm.S b/dev/aarch64_opt/src/poly_chknorm_asm.S index 064d50f8e..5832526dd 100644 --- a/dev/aarch64_opt/src/poly_chknorm_asm.S +++ b/dev/aarch64_opt/src/poly_chknorm_asm.S @@ -30,7 +30,7 @@ MLD_ASM_FN_SYMBOL(poly_chknorm_asm) // Load constants dup bound.4s, B - movi flags.4s, 0 + eor flags.16b, flags.16b, flags.16b mov count, #(64/4) diff --git a/mldsa/src/native/aarch64/src/arith_native_aarch64.h b/mldsa/src/native/aarch64/src/arith_native_aarch64.h index f78a1487d..e46b989dd 100644 --- a/mldsa/src/native/aarch64/src/arith_native_aarch64.h +++ b/mldsa/src/native/aarch64/src/arith_native_aarch64.h @@ -71,10 +71,36 @@ uint64_t mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf, unsigned buflen, const uint8_t *table); #define mld_poly_decompose_32_asm MLD_NAMESPACE(poly_decompose_32_asm) -void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 16 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 32)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 16)) + /* check-magic: 261889 == (MLDSA_Q - 1) / 32 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 261889)) +); #define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm) -void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0); +void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml */ +__contract__( + requires(memory_no_alias(a1, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a0, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a0, 0, MLDSA_N, 0, MLDSA_Q)) + assigns(memory_slice(a1, sizeof(int32_t) * MLDSA_N)) + assigns(memory_slice(a0, sizeof(int32_t) * MLDSA_N)) + /* check-magic: 44 == (MLDSA_Q - 1) / (2 * ((MLDSA_Q - 1) / 88)) */ + ensures(array_bound(a1, 0, MLDSA_N, 0, 44)) + /* check-magic: 95233 == (MLDSA_Q - 1) / 88 + 1 */ + ensures(array_abs_bound(a0, 0, MLDSA_N, 95233)) +); #define mld_poly_caddq_asm MLD_NAMESPACE(poly_caddq_asm) void mld_poly_caddq_asm(int32_t *a) @@ -95,15 +121,42 @@ void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) MLD_MUST_CHECK_RETURN_VALUE -int mld_poly_chknorm_asm(const int32_t *a, int32_t B); +int mld_poly_chknorm_asm(const int32_t *a, int32_t B) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + /* HOL Light precondition: abs(ival(x i)) < 2^31, i.e., a[i] != INT32_MIN */ + requires(forall(k0, 0, MLDSA_N, a[k0] > INT32_MIN)) + ensures(return_value == 0 || return_value == 1) + ensures((return_value == 0) == array_abs_bound(a, 0, MLDSA_N, B)) +); #define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm) void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 576)) + requires(indices == mld_polyz_unpack_17_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 17) - 1), (1 << 17) + 1)) +); #define mld_polyz_unpack_19_asm MLD_NAMESPACE(polyz_unpack_19_asm) void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf, - const uint8_t *indices); + const uint8_t *indices) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml */ +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, 640)) + requires(indices == mld_polyz_unpack_19_indices) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(r, 0, MLDSA_N, -((1 << 19) - 1), (1 << 19) + 1)) +); #define mld_poly_pointwise_montgomery_asm \ MLD_NAMESPACE(poly_pointwise_montgomery_asm) diff --git a/mldsa/src/native/aarch64/src/poly_chknorm_asm.S b/mldsa/src/native/aarch64/src/poly_chknorm_asm.S index 8d5d0d267..aa0952894 100644 --- a/mldsa/src/native/aarch64/src/poly_chknorm_asm.S +++ b/mldsa/src/native/aarch64/src/poly_chknorm_asm.S @@ -22,7 +22,7 @@ MLD_ASM_FN_SYMBOL(poly_chknorm_asm) .cfi_startproc dup v20.4s, w1 - movi v21.4s, #0x0 + eor v21.16b, v21.16b, v21.16b mov x2, #0x10 // =16 Lpoly_chknorm_loop: diff --git a/mldsa/src/native/api.h b/mldsa/src/native/api.h index 409337fcc..140d67f9e 100644 --- a/mldsa/src/native/api.h +++ b/mldsa/src/native/api.h @@ -440,7 +440,7 @@ __contract__( requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) requires(0 <= B && B <= MLDSA_Q - REDUCE32_RANGE_MAX) requires(array_bound(a, 0, MLDSA_N, -REDUCE32_RANGE_MAX, REDUCE32_RANGE_MAX)) - ensures(return_value == MLD_NATIVE_FUNC_FALLBACK || return_value == MLD_NATIVE_FUNC_SUCCESS) + ensures(return_value == MLD_NATIVE_FUNC_FALLBACK || return_value == 0 || return_value == 1) ensures((return_value == 0) == array_abs_bound(a, 0, MLDSA_N, B)) ); #endif /* MLD_USE_NATIVE_POLY_CHKNORM */ diff --git a/mldsa/src/poly.c b/mldsa/src/poly.c index 0026e145d..ac6bbbe5b 100644 --- a/mldsa/src/poly.c +++ b/mldsa/src/poly.c @@ -956,7 +956,7 @@ uint32_t mld_poly_chknorm(const mld_poly *a, int32_t B) if (success) { /* Convert 0 / 1 to 0 / 0xFFFFFFFF here */ - return 0U - (uint32_t)ret; + return mld_ct_cmask_nonzero_u32((uint32_t)ret); } #endif /* MLD_USE_NATIVE_POLY_CHKNORM */ return mld_poly_chknorm_c(a, B); diff --git a/nix/s2n_bignum/default.nix b/nix/s2n_bignum/default.nix index 0aed575aa..3a6df170d 100644 --- a/nix/s2n_bignum/default.nix +++ b/nix/s2n_bignum/default.nix @@ -4,12 +4,12 @@ { stdenv, fetchFromGitHub, writeText, ... }: stdenv.mkDerivation rec { pname = "s2n_bignum"; - version = "113a146ab49c19281388881b3650b63a6a67e8dc"; + version = "f9ebd40af24087c503ecaca008be22edd166afab"; src = fetchFromGitHub { - owner = "awslabs"; + owner = "mkannwischer"; repo = "s2n-bignum"; rev = "${version}"; - hash = "sha256-Ub+Nrlo8DEmz3H5SdgcH9iLbNJnZmLvGk3dGgWar2kY="; + hash = "sha256-gXwGRS4TLFESnSbLRiFyMKqCiBA0hCIQibWzuDfpLeM="; }; setupHook = writeText "setup-hook.sh" '' export S2N_BIGNUM_DIR="$1" diff --git a/proofs/cbmc/poly_chknorm_native/Makefile b/proofs/cbmc/poly_chknorm_native/Makefile index 9b6516cc8..e8f5f4e27 100644 --- a/proofs/cbmc/poly_chknorm_native/Makefile +++ b/proofs/cbmc/poly_chknorm_native/Makefile @@ -22,6 +22,7 @@ PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly.c CHECK_FUNCTION_CONTRACTS=mld_poly_chknorm USE_FUNCTION_CONTRACTS=mld_poly_chknorm_native USE_FUNCTION_CONTRACTS+=mld_poly_chknorm_c +USE_FUNCTION_CONTRACTS+=mld_ct_cmask_nonzero_u32 APPLY_LOOP_CONTRACTS=on USE_DYNAMIC_FRAMES=1 diff --git a/proofs/cbmc/poly_chknorm_native_aarch64/Makefile b/proofs/cbmc/poly_chknorm_native_aarch64/Makefile new file mode 100644 index 000000000..df4a902f8 --- /dev/null +++ b/proofs/cbmc/poly_chknorm_native_aarch64/Makefile @@ -0,0 +1,58 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = poly_chknorm_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = poly_chknorm_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly.c + +CHECK_FUNCTION_CONTRACTS=mld_poly_chknorm_native +USE_FUNCTION_CONTRACTS = mld_poly_chknorm_asm +USE_FUNCTION_CONTRACTS += mld_sys_check_capability +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = poly_chknorm_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +# If you require access to a file-local ("static") function or object to conduct +# your proof, set the following (and do not include the original source file +# ("mldsa/src/poly.c") in PROJECT_SOURCES). +# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i +# include ../Makefile.common +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/src/poly.c +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz +# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must +# be set before including Makefile.common, but any use of variables on the +# left-hand side requires those variables to be defined. Hence, _SOURCE, +# _FUNCTIONS, _OBJECTS is set after including Makefile.common. + +include ../Makefile.common diff --git a/proofs/cbmc/poly_chknorm_native_aarch64/poly_chknorm_native_aarch64_harness.c b/proofs/cbmc/poly_chknorm_native_aarch64/poly_chknorm_native_aarch64_harness.c new file mode 100644 index 000000000..e7f99663d --- /dev/null +++ b/proofs/cbmc/poly_chknorm_native_aarch64/poly_chknorm_native_aarch64_harness.c @@ -0,0 +1,16 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +int mld_poly_chknorm_native(const int32_t a[MLDSA_N], int32_t B); + +void harness(void) +{ + const int32_t *a; + int32_t B; + int t; + t = mld_poly_chknorm_native(a, B); +} diff --git a/proofs/cbmc/poly_decompose_32_native_aarch64/Makefile b/proofs/cbmc/poly_decompose_32_native_aarch64/Makefile new file mode 100644 index 000000000..2487cb1cd --- /dev/null +++ b/proofs/cbmc/poly_decompose_32_native_aarch64/Makefile @@ -0,0 +1,66 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = poly_decompose_32_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = poly_decompose_32_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c + +# poly_decompose_32 is only used with ML-DSA-65 and ML-DSA-87 +ifeq ($(MLD_CONFIG_PARAMETER_SET),65) +CHECK_FUNCTION_CONTRACTS=mld_poly_decompose_32_native +USE_FUNCTION_CONTRACTS = mld_poly_decompose_32_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),87) +CHECK_FUNCTION_CONTRACTS=mld_poly_decompose_32_native +USE_FUNCTION_CONTRACTS = mld_poly_decompose_32_asm +else +CHECK_FUNCTION_CONTRACTS= +USE_FUNCTION_CONTRACTS = +endif +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = poly_decompose_32_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +# If you require access to a file-local ("static") function or object to conduct +# your proof, set the following (and do not include the original source file +# ("mldsa/poly_kl.c") in PROJECT_SOURCES). +# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i +# include ../Makefile.common +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/src/poly_kl.c +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz +# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must +# be set before including Makefile.common, but any use of variables on the +# left-hand side requires those variables to be defined. Hence, _SOURCE, +# _FUNCTIONS, _OBJECTS is set after including Makefile.common. + +include ../Makefile.common diff --git a/proofs/cbmc/poly_decompose_32_native_aarch64/poly_decompose_32_native_aarch64_harness.c b/proofs/cbmc/poly_decompose_32_native_aarch64/poly_decompose_32_native_aarch64_harness.c new file mode 100644 index 000000000..bfbe8f589 --- /dev/null +++ b/proofs/cbmc/poly_decompose_32_native_aarch64/poly_decompose_32_native_aarch64_harness.c @@ -0,0 +1,19 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +int mld_poly_decompose_32_native(int32_t *a1, int32_t *a0); + +void harness(void) +{ + /* mld_poly_decompose_32_native is only defined for ML-DSA-65 and ML-DSA-87 */ +#if MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87 + int32_t *a1; + int32_t *a0; + int t; + t = mld_poly_decompose_32_native(a1, a0); +#endif /* MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87 */ +} diff --git a/proofs/cbmc/poly_decompose_88_native_aarch64/Makefile b/proofs/cbmc/poly_decompose_88_native_aarch64/Makefile new file mode 100644 index 000000000..31ee47924 --- /dev/null +++ b/proofs/cbmc/poly_decompose_88_native_aarch64/Makefile @@ -0,0 +1,63 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = poly_decompose_88_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = poly_decompose_88_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c + +# poly_decompose_88 is only used with ML-DSA-44 +ifeq ($(MLD_CONFIG_PARAMETER_SET),44) +CHECK_FUNCTION_CONTRACTS=mld_poly_decompose_88_native +USE_FUNCTION_CONTRACTS = mld_poly_decompose_88_asm +else +CHECK_FUNCTION_CONTRACTS= +USE_FUNCTION_CONTRACTS = +endif +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = poly_decompose_88_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +# If you require access to a file-local ("static") function or object to conduct +# your proof, set the following (and do not include the original source file +# ("mldsa/poly_kl.c") in PROJECT_SOURCES). +# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i +# include ../Makefile.common +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/src/poly_kl.c +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz +# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must +# be set before including Makefile.common, but any use of variables on the +# left-hand side requires those variables to be defined. Hence, _SOURCE, +# _FUNCTIONS, _OBJECTS is set after including Makefile.common. + +include ../Makefile.common diff --git a/proofs/cbmc/poly_decompose_88_native_aarch64/poly_decompose_88_native_aarch64_harness.c b/proofs/cbmc/poly_decompose_88_native_aarch64/poly_decompose_88_native_aarch64_harness.c new file mode 100644 index 000000000..36951d079 --- /dev/null +++ b/proofs/cbmc/poly_decompose_88_native_aarch64/poly_decompose_88_native_aarch64_harness.c @@ -0,0 +1,19 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +int mld_poly_decompose_88_native(int32_t *a1, int32_t *a0); + +void harness(void) +{ + /* mld_poly_decompose_88_native is only defined for ML-DSA-44 */ +#if MLD_CONFIG_PARAMETER_SET == 44 + int32_t *a1; + int32_t *a0; + int t; + t = mld_poly_decompose_88_native(a1, a0); +#endif /* MLD_CONFIG_PARAMETER_SET == 44 */ +} diff --git a/proofs/cbmc/polyz_unpack_17_native_aarch64/Makefile b/proofs/cbmc/polyz_unpack_17_native_aarch64/Makefile new file mode 100644 index 000000000..f6ac6ca51 --- /dev/null +++ b/proofs/cbmc/polyz_unpack_17_native_aarch64/Makefile @@ -0,0 +1,63 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = polyz_unpack_17_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = polyz_unpack_17_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c $(SRCDIR)/mldsa/src/native/aarch64/src/polyz_unpack_table.c + +# polyz_unpack_17 is only used with ML-DSA-44 +ifeq ($(MLD_CONFIG_PARAMETER_SET),44) +CHECK_FUNCTION_CONTRACTS=mld_polyz_unpack_17_native +USE_FUNCTION_CONTRACTS = mld_polyz_unpack_17_asm +else +CHECK_FUNCTION_CONTRACTS= +USE_FUNCTION_CONTRACTS = +endif +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = polyz_unpack_17_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +# If you require access to a file-local ("static") function or object to conduct +# your proof, set the following (and do not include the original source file +# ("mldsa/poly_kl.c") in PROJECT_SOURCES). +# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i +# include ../Makefile.common +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/src/poly_kl.c +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz +# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must +# be set before including Makefile.common, but any use of variables on the +# left-hand side requires those variables to be defined. Hence, _SOURCE, +# _FUNCTIONS, _OBJECTS is set after including Makefile.common. + +include ../Makefile.common diff --git a/proofs/cbmc/polyz_unpack_17_native_aarch64/polyz_unpack_17_native_aarch64_harness.c b/proofs/cbmc/polyz_unpack_17_native_aarch64/polyz_unpack_17_native_aarch64_harness.c new file mode 100644 index 000000000..ff2983c44 --- /dev/null +++ b/proofs/cbmc/polyz_unpack_17_native_aarch64/polyz_unpack_17_native_aarch64_harness.c @@ -0,0 +1,19 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a); + +void harness(void) +{ + /* mld_polyz_unpack_17_native is only defined for ML-DSA-44 */ +#if MLD_CONFIG_PARAMETER_SET == 44 + int32_t *r; + const uint8_t *a; + int t; + t = mld_polyz_unpack_17_native(r, a); +#endif /* MLD_CONFIG_PARAMETER_SET == 44 */ +} diff --git a/proofs/cbmc/polyz_unpack_19_native_aarch64/Makefile b/proofs/cbmc/polyz_unpack_19_native_aarch64/Makefile new file mode 100644 index 000000000..1a5e73d64 --- /dev/null +++ b/proofs/cbmc/polyz_unpack_19_native_aarch64/Makefile @@ -0,0 +1,66 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = polyz_unpack_19_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = polyz_unpack_19_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c $(SRCDIR)/mldsa/src/native/aarch64/src/polyz_unpack_table.c + +# polyz_unpack_19 is only used with ML-DSA-65 and ML-DSA-87 +ifeq ($(MLD_CONFIG_PARAMETER_SET),65) +CHECK_FUNCTION_CONTRACTS=mld_polyz_unpack_19_native +USE_FUNCTION_CONTRACTS = mld_polyz_unpack_19_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),87) +CHECK_FUNCTION_CONTRACTS=mld_polyz_unpack_19_native +USE_FUNCTION_CONTRACTS = mld_polyz_unpack_19_asm +else +CHECK_FUNCTION_CONTRACTS= +USE_FUNCTION_CONTRACTS = +endif +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = polyz_unpack_19_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +# If you require access to a file-local ("static") function or object to conduct +# your proof, set the following (and do not include the original source file +# ("mldsa/poly_kl.c") in PROJECT_SOURCES). +# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i +# include ../Makefile.common +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/src/poly_kl.c +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar +# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz +# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must +# be set before including Makefile.common, but any use of variables on the +# left-hand side requires those variables to be defined. Hence, _SOURCE, +# _FUNCTIONS, _OBJECTS is set after including Makefile.common. + +include ../Makefile.common diff --git a/proofs/cbmc/polyz_unpack_19_native_aarch64/polyz_unpack_19_native_aarch64_harness.c b/proofs/cbmc/polyz_unpack_19_native_aarch64/polyz_unpack_19_native_aarch64_harness.c new file mode 100644 index 000000000..f4cfecced --- /dev/null +++ b/proofs/cbmc/polyz_unpack_19_native_aarch64/polyz_unpack_19_native_aarch64_harness.c @@ -0,0 +1,19 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a); + +void harness(void) +{ + /* mld_polyz_unpack_19_native is only defined for ML-DSA-65 and ML-DSA-87 */ +#if MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87 + int32_t *r; + const uint8_t *a; + int t; + t = mld_polyz_unpack_19_native(r, a); +#endif /* MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87 */ +} diff --git a/proofs/hol_light/README.md b/proofs/hol_light/README.md index e7df5f400..0fdee7898 100644 --- a/proofs/hol_light/README.md +++ b/proofs/hol_light/README.md @@ -52,5 +52,6 @@ echo '1+1;;' | nc -w 5 127.0.0.1 2012 ## What is covered? - AArch64 poly_caddq: [mldsa_poly_caddq.S](aarch64/mldsa/mldsa_poly_caddq.S) +- AArch64 poly_chknorm: [mldsa_poly_chknorm.S](aarch64/mldsa/mldsa_poly_chknorm.S) - x86_64 forward NTT: [mldsa_ntt.S](x86_64/mldsa/mldsa_ntt.S) - x86_64 inverse NTT: [mldsa_intt.S](x86_64/mldsa/mldsa_intt.S) diff --git a/proofs/hol_light/aarch64/Makefile b/proofs/hol_light/aarch64/Makefile index 91c8bc96d..ad6e7ee9f 100644 --- a/proofs/hol_light/aarch64/Makefile +++ b/proofs/hol_light/aarch64/Makefile @@ -53,7 +53,12 @@ endif SPLIT=tr ';' '\n' -OBJ = mldsa/mldsa_poly_caddq.o +OBJ = mldsa/mldsa_poly_caddq.o \ + mldsa/mldsa_poly_chknorm.o \ + mldsa/mldsa_poly_decompose_32.o \ + mldsa/mldsa_poly_decompose_88.o \ + mldsa/mldsa_polyz_unpack_17.o \ + mldsa/mldsa_polyz_unpack_19.o # According to # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms, diff --git a/proofs/hol_light/aarch64/mldsa/mldsa_poly_chknorm.S b/proofs/hol_light/aarch64/mldsa/mldsa_poly_chknorm.S new file mode 100644 index 000000000..dfe8a0aab --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/mldsa_poly_chknorm.S @@ -0,0 +1,54 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/poly_chknorm_asm.S using scripts/simpasm. Do not modify it directly. + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",@progbits +#endif + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_poly_chknorm_asm +_PQCP_MLDSA_NATIVE_MLDSA44_poly_chknorm_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_poly_chknorm_asm +PQCP_MLDSA_NATIVE_MLDSA44_poly_chknorm_asm: +#endif + + .cfi_startproc + dup v20.4s, w1 + eor v21.16b, v21.16b, v21.16b + mov x2, #0x10 // =16 + +Lpoly_chknorm_loop: + ldr q1, [x0, #0x10] + ldr q2, [x0, #0x20] + ldr q3, [x0, #0x30] + ldr q0, [x0], #0x40 + abs v1.4s, v1.4s + cmge v1.4s, v1.4s, v20.4s + orr v21.16b, v21.16b, v1.16b + abs v2.4s, v2.4s + cmge v2.4s, v2.4s, v20.4s + orr v21.16b, v21.16b, v2.16b + abs v3.4s, v3.4s + cmge v3.4s, v3.4s, v20.4s + orr v21.16b, v21.16b, v3.16b + abs v0.4s, v0.4s + cmge v0.4s, v0.4s, v20.4s + orr v21.16b, v21.16b, v0.16b + subs x2, x2, #0x1 + b.ne Lpoly_chknorm_loop + umaxv s21, v21.4s + fmov w0, s21 + and w0, w0, #0x1 + ret + .cfi_endproc diff --git a/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_32.S b/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_32.S new file mode 100644 index 000000000..f60e253c4 --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_32.S @@ -0,0 +1,81 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/poly_decompose_32_asm.S using scripts/simpasm. Do not modify it directly. + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",@progbits +#endif + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_32_asm +_PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_32_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_32_asm +PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_32_asm: +#endif + + .cfi_startproc + mov w4, #0xe001 // =57345 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + mov w5, #0xe100 // =57600 + movk w5, #0x7b, lsl #16 + dup v21.4s, w5 + mov w7, #0xfe00 // =65024 + movk w7, #0x7, lsl #16 + dup v22.4s, w7 + mov w11, #0x401 // =1025 + movk w11, #0x4010, lsl #16 + dup v23.4s, w11 + mov x3, #0x10 // =16 + +Lpoly_decompose_32_loop: + ldr q0, [x1] + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + sqdmulh v5.4s, v1.4s, v23.4s + srshr v5.4s, v5.4s, #0x12 + cmgt v24.4s, v1.4s, v21.4s + mls v1.4s, v5.4s, v22.4s + bic v5.16b, v5.16b, v24.16b + add v1.4s, v1.4s, v24.4s + sqdmulh v6.4s, v2.4s, v23.4s + srshr v6.4s, v6.4s, #0x12 + cmgt v24.4s, v2.4s, v21.4s + mls v2.4s, v6.4s, v22.4s + bic v6.16b, v6.16b, v24.16b + add v2.4s, v2.4s, v24.4s + sqdmulh v7.4s, v3.4s, v23.4s + srshr v7.4s, v7.4s, #0x12 + cmgt v24.4s, v3.4s, v21.4s + mls v3.4s, v7.4s, v22.4s + bic v7.16b, v7.16b, v24.16b + add v3.4s, v3.4s, v24.4s + sqdmulh v4.4s, v0.4s, v23.4s + srshr v4.4s, v4.4s, #0x12 + cmgt v24.4s, v0.4s, v21.4s + mls v0.4s, v4.4s, v22.4s + bic v4.16b, v4.16b, v24.16b + add v0.4s, v0.4s, v24.4s + str q5, [x0, #0x10] + str q6, [x0, #0x20] + str q7, [x0, #0x30] + str q4, [x0], #0x40 + str q1, [x1, #0x10] + str q2, [x1, #0x20] + str q3, [x1, #0x30] + str q0, [x1], #0x40 + subs x3, x3, #0x1 + b.ne Lpoly_decompose_32_loop + ret + .cfi_endproc diff --git a/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_88.S b/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_88.S new file mode 100644 index 000000000..aef78985c --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/mldsa_poly_decompose_88.S @@ -0,0 +1,81 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/poly_decompose_88_asm.S using scripts/simpasm. Do not modify it directly. + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",@progbits +#endif + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_88_asm +_PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_88_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_88_asm +PQCP_MLDSA_NATIVE_MLDSA44_poly_decompose_88_asm: +#endif + + .cfi_startproc + mov w4, #0xe001 // =57345 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + mov w5, #0x6c00 // =27648 + movk w5, #0x7e, lsl #16 + dup v21.4s, w5 + mov w7, #0xe800 // =59392 + movk w7, #0x2, lsl #16 + dup v22.4s, w7 + mov w11, #0x581 // =1409 + movk w11, #0x5816, lsl #16 + dup v23.4s, w11 + mov x3, #0x10 // =16 + +Lpoly_decompose_88_loop: + ldr q0, [x1] + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + sqdmulh v5.4s, v1.4s, v23.4s + srshr v5.4s, v5.4s, #0x11 + cmgt v24.4s, v1.4s, v21.4s + mls v1.4s, v5.4s, v22.4s + bic v5.16b, v5.16b, v24.16b + add v1.4s, v1.4s, v24.4s + sqdmulh v6.4s, v2.4s, v23.4s + srshr v6.4s, v6.4s, #0x11 + cmgt v24.4s, v2.4s, v21.4s + mls v2.4s, v6.4s, v22.4s + bic v6.16b, v6.16b, v24.16b + add v2.4s, v2.4s, v24.4s + sqdmulh v7.4s, v3.4s, v23.4s + srshr v7.4s, v7.4s, #0x11 + cmgt v24.4s, v3.4s, v21.4s + mls v3.4s, v7.4s, v22.4s + bic v7.16b, v7.16b, v24.16b + add v3.4s, v3.4s, v24.4s + sqdmulh v4.4s, v0.4s, v23.4s + srshr v4.4s, v4.4s, #0x11 + cmgt v24.4s, v0.4s, v21.4s + mls v0.4s, v4.4s, v22.4s + bic v4.16b, v4.16b, v24.16b + add v0.4s, v0.4s, v24.4s + str q5, [x0, #0x10] + str q6, [x0, #0x20] + str q7, [x0, #0x30] + str q4, [x0], #0x40 + str q1, [x1, #0x10] + str q2, [x1, #0x20] + str q3, [x1, #0x30] + str q0, [x1], #0x40 + subs x3, x3, #0x1 + b.ne Lpoly_decompose_88_loop + ret + .cfi_endproc diff --git a/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_17.S b/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_17.S new file mode 100644 index 000000000..df793b22a --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_17.S @@ -0,0 +1,67 @@ +/* + * Copyright (c) The mldsa-native project authors + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/polyz_unpack_17_asm.S using scripts/simpasm. Do not modify it directly. + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",@progbits +#endif + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_17_asm +_PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_17_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_17_asm +PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_17_asm: +#endif + + .cfi_startproc + ldr q24, [x2] + ldr q25, [x2, #0x10] + ldr q26, [x2, #0x20] + ldr q27, [x2, #0x30] + mov x3, #0xfe00000000 // =1090921693184 + mov v28.d[0], x3 + mov x3, #0xfc // =252 + movk x3, #0xfa, lsl #32 + mov v28.d[1], x3 + movi v29.4s, #0x3, msl #16 + movi v30.4s, #0x2, lsl #16 + mov x9, #0x10 // =16 + +Lpolyz_unpack_17_loop: + ld1 { v0.4s, v1.4s }, [x1], #32 + ldr s2, [x1], #0x4 + tbl v4.16b, { v0.16b }, v24.16b + tbl v5.16b, { v0.16b, v1.16b }, v25.16b + tbl v6.16b, { v1.16b }, v26.16b + tbl v7.16b, { v1.16b, v2.16b }, v27.16b + ushl v4.4s, v4.4s, v28.4s + and v4.16b, v4.16b, v29.16b + sub v4.4s, v30.4s, v4.4s + ushl v5.4s, v5.4s, v28.4s + and v5.16b, v5.16b, v29.16b + sub v5.4s, v30.4s, v5.4s + ushl v6.4s, v6.4s, v28.4s + and v6.16b, v6.16b, v29.16b + sub v6.4s, v30.4s, v6.4s + ushl v7.4s, v7.4s, v28.4s + and v7.16b, v7.16b, v29.16b + sub v7.4s, v30.4s, v7.4s + str q5, [x0, #0x10] + str q6, [x0, #0x20] + str q7, [x0, #0x30] + str q4, [x0], #0x40 + subs x9, x9, #0x1 + b.ne Lpolyz_unpack_17_loop + ret + .cfi_endproc diff --git a/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_19.S b/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_19.S new file mode 100644 index 000000000..d83423d26 --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/mldsa_polyz_unpack_19.S @@ -0,0 +1,64 @@ +/* + * Copyright (c) The mldsa-native project authors + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/polyz_unpack_19_asm.S using scripts/simpasm. Do not modify it directly. + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",@progbits +#endif + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_19_asm +_PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_19_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_19_asm +PQCP_MLDSA_NATIVE_MLDSA44_polyz_unpack_19_asm: +#endif + + .cfi_startproc + ldr q24, [x2] + ldr q25, [x2, #0x10] + ldr q26, [x2, #0x20] + ldr q27, [x2, #0x30] + mov x3, #0xfc00000000 // =1082331758592 + dup v28.2d, x3 + movi v29.4s, #0xf, msl #16 + movi v30.4s, #0x8, lsl #16 + mov x9, #0x10 // =16 + +Lpolyz_unpack_19_loop: + ld1 { v0.2d, v1.2d }, [x1], #32 + ldr d2, [x1], #0x8 + tbl v4.16b, { v0.16b }, v24.16b + tbl v5.16b, { v0.16b, v1.16b }, v25.16b + tbl v6.16b, { v1.16b }, v26.16b + tbl v7.16b, { v1.16b, v2.16b }, v27.16b + ushl v4.4s, v4.4s, v28.4s + and v4.16b, v4.16b, v29.16b + sub v4.4s, v30.4s, v4.4s + ushl v5.4s, v5.4s, v28.4s + and v5.16b, v5.16b, v29.16b + sub v5.4s, v30.4s, v5.4s + ushl v6.4s, v6.4s, v28.4s + and v6.16b, v6.16b, v29.16b + sub v6.4s, v30.4s, v6.4s + ushl v7.4s, v7.4s, v28.4s + and v7.16b, v7.16b, v29.16b + sub v7.4s, v30.4s, v7.4s + str q5, [x0, #0x10] + str q6, [x0, #0x20] + str q7, [x0, #0x30] + str q4, [x0], #0x40 + subs x9, x9, #0x1 + b.ne Lpolyz_unpack_19_loop + ret + .cfi_endproc diff --git a/proofs/hol_light/aarch64/proofs/aarch64_utils.ml b/proofs/hol_light/aarch64/proofs/aarch64_utils.ml index e63fa0409..37890af1f 100644 --- a/proofs/hol_light/aarch64/proofs/aarch64_utils.ml +++ b/proofs/hol_light/aarch64/proofs/aarch64_utils.ml @@ -3,7 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT *) -(* Utility for executing until target PC is reached *) +needs "common/mldsa_specs.ml";; + +(* ------------------------------------------------------------------------- *) +(* Symbolic execution until target PC is reached. *) +(* ------------------------------------------------------------------------- *) + let MAP_UNTIL_TARGET_PC f n = fun (asl, w) -> let is_pc_condition = can (term_match [] `read PC some_state = some_value`) in let extract_target_pc_from_goal goal = @@ -22,3 +27,153 @@ let MAP_UNTIL_TARGET_PC f n = fun (asl, w) -> let rec core n (asl, w) = (TARGET_PC_REACHED_TAC target_pc ORELSE (f n THEN core (n + 1))) (asl, w) in core n (asl, w);; + +(* ========================================================================= *) +(* SIMD simplification: subword extraction + numeric reduction + folding. *) +(* ========================================================================= *) + +let SIMD_SIMPLIFY_CONV unfold_defs = + TOP_DEPTH_CONV + (REWR_CONV WORD_SUBWORD_AND ORELSEC WORD_SIMPLE_SUBWORD_CONV) THENC + DEPTH_CONV WORD_NUM_RED_CONV THENC + REWRITE_CONV (map GSYM unfold_defs);; + +let SIMD_SIMPLIFY_TAC unfold_defs = + let simdable = can (term_match [] `read X (s:armstate):int128 = whatever`) in + TRY(FIRST_X_ASSUM + (ASSUME_TAC o + CONV_RULE(RAND_CONV (SIMD_SIMPLIFY_CONV unfold_defs)) o + check (simdable o concl)));; + +(* ========================================================================= *) +(* Parametric infrastructure for d-bit packed coefficients (SIMD). *) +(* Supports d=18 (GAMMA1=2^17) and d=20 (GAMMA1=2^19). *) +(* ========================================================================= *) + +(* Convert MOD/DIV expressions to word_subword of (16*d)-bit word *) +let mk_base_simps d = + let total = 16 * d in + let rem = total - 256 in + let total_ty = mk_finty (Num.num_of_int total) in + let rem_ty = mk_finty (Num.num_of_int rem) in + let mod_128 = CONV_RULE NUM_REDUCE_CONV (prove( + inst [total_ty, `:N`] + `word (t MOD 2 EXP 128) : 128 word = + word_subword (word t : N word) (0, 128)`, + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_SUBWORD; VAL_WORD; DIMINDEX_128] THEN + REWRITE_TAC[EXP; DIV_1; MOD_MOD_REFL; MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + MP_TAC (SPECL [`t:num`; `2`; mk_small_numeral total; `128`] MOD_MOD_EXP_MIN) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN (SUBST1_TAC o SYM) THEN REFL_TAC)) in + let div_128_mod_128 = CONV_RULE NUM_REDUCE_CONV (prove( + inst [total_ty, `:N`] + `word ((t DIV 2 EXP 128) MOD 2 EXP 128) : 128 word = + word_subword (word t : N word) (128, 128)`, + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_SUBWORD; VAL_WORD; DIMINDEX_128] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN + REWRITE_TAC[ARITH_RULE `MIN 128 128 = 128`; MOD_MOD_REFL] THEN + REWRITE_TAC[DIV_MOD; GSYM EXP_ADD; MOD_MOD_EXP_MIN] THEN + CONV_TAC NUM_REDUCE_CONV)) in + let div_256 = CONV_RULE NUM_REDUCE_CONV (prove( + inst [total_ty, `:N`; rem_ty, `:M`] + `word (t DIV 2 EXP 256) : M word = + word_subword (word t : N word) (256, dimindex(:M))`, + REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_SUBWORD; VAL_WORD] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[DIV_MOD; GSYM EXP_ADD; MOD_MOD_EXP_MIN] THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[MOD_MOD_REFL])) in + [mod_128; div_128_mod_128; div_256];; + +(* Split ncoeffs d-bit coefficients into chunks of chunk_size *) +let mk_split_theorem d ncoeffs chunk_size = + let total = d * chunk_size in + let nchunks = ncoeffs / chunk_size in + let d_ty = mk_finty (Num.num_of_int d) in + let total_ty = mk_finty (Num.num_of_int total) in + prove( + subst [mk_small_numeral ncoeffs, `ncoeffs:num`; + mk_small_numeral chunk_size, `cs:num`; + mk_small_numeral nchunks, `nc:num`] + (inst [d_ty, `:D`; total_ty, `:T`] + `!(l: (D word) list). LENGTH l = ncoeffs ==> + num_of_wordlist l = num_of_wordlist (MAP ((word:num->T word) o num_of_wordlist) + (list_of_seq (\i. SUB_LIST (cs * i, cs) l) nc))`), + REPEAT STRIP_TAC THEN + UNDISCH_THEN (subst [mk_small_numeral ncoeffs, `n:num`] + (inst [d_ty, `:D`] `LENGTH (l : (D word) list) = n`)) (fun th -> + GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) + [MATCH_MP (CONV_RULE NUM_REDUCE_CONV + (ISPECL [mk_small_numeral chunk_size; mk_small_numeral nchunks; + `l:'a list`] SUBLIST_PARTITION)) th] + THEN ASSUME_TAC th) THEN + IMP_REWRITE_TAC [CONV_RULE (ONCE_DEPTH_CONV DIMINDEX_CONV THENC NUM_REDUCE_CONV) + (ISPECL [inst [d_ty, `:D`] `ll: ((D word) list) list`; + mk_small_numeral chunk_size] + (INST_TYPE [d_ty, `:N`; total_ty, `:M`] NUM_OF_WORDLIST_FLATTEN))] THEN + CONV_TAC(ONCE_DEPTH_CONV LIST_OF_SEQ_CONV) THEN + ASM_REWRITE_TAC[ALL; LENGTH_SUB_LIST] THEN + ARITH_TAC);; + +(* Extract individual d-bit coefficients from (d*chunk_size)-bit word *) +let mk_subword_cases d chunk_size = + let total = d * chunk_size in + let d_ty = mk_finty (Num.num_of_int d) in + let total_ty = mk_finty (Num.num_of_int total) in + let arith_simp = + let lhs = mk_eq(mk_small_numeral total, + mk_comb(mk_comb(`( * ):num->num->num`, + mk_small_numeral d), `n:num`)) in + let rhs = mk_eq(`n:num`, mk_small_numeral chunk_size) in + ARITH_RULE (mk_eq(lhs, rhs)) in + let meson_simp = + let n_eq = mk_eq(`n:num`, mk_small_numeral chunk_size) in + let k_lt_n = mk_comb(mk_comb(`(<):num->num->bool`, `k:num`), `n:num`) in + let k_lt_cs = mk_comb(mk_comb(`(<):num->num->bool`, `k:num`), + mk_small_numeral chunk_size) in + MESON[] (mk_eq(mk_conj(n_eq, k_lt_n), mk_conj(n_eq, k_lt_cs))) in + let base = + let th = INST_TYPE [total_ty, `:KL`; d_ty, `:L`] WORD_SUBWORD_NUM_OF_WORDLIST in + let th = CONV_RULE(DEPTH_CONV DIMINDEX_CONV) th in + REWRITE_RULE[arith_simp; meson_simp] th in + let mk k = + let th = SPEC (mk_small_numeral k) + (SPEC (inst [d_ty, `:L`] `ls:(L word)list`) base) in + CONV_RULE NUM_REDUCE_CONV (REWRITE_RULE[ARITH] th) in + map mk (0 -- (chunk_size - 1));; + +(* ========================================================================= *) +(* zunpack lane conversion for TBL + USHL + AND + SUB pipeline. *) +(* ========================================================================= *) + +let ZUNPACK_LANE_CONV d i tm = + let gamma1 = 1 lsl (d - 1) in + let word_bits = 16 * d in + match find_word_subterm_n word_bits tm with + | Some t_var -> + let d_ty = mk_finty (Num.num_of_int d) in + let t_ty = mk_finty (Num.num_of_int word_bits) in + let goal = mk_eq(tm, + subst [mk_small_numeral (d*i), `pos:num`; + mk_small_numeral d, `bw:num`; + mk_small_numeral gamma1, `g:num`; + t_var, mk_var("t", mk_type("word",[t_ty]))] + (inst [d_ty, `:B`; t_ty, `:T`] + `word_sub (word g : 32 word) + (word_zx (word_subword (t : T word) (pos,bw) : B word))`)) in + WORD_BLAST goal + | None -> failwith ("no " ^ string_of_int word_bits ^ "-bit word found");; + +let ZUNPACK_128_CONV d tm = + tryfind (fun base_i -> + RAND_CONV (BINOP_CONV_N 2 (fun j -> ZUNPACK_LANE_CONV d (base_i + j))) tm + ) [0; 4; 8; 12];; + +let SIMP_ZUNPACK_TAC d zunpack_correct = + let zunpack_const = + fst(strip_comb(rhs(snd(strip_forall(concl zunpack_correct))))) in + let already_processed tm = + can (find_term ((=) zunpack_const)) tm in + RULE_ASSUM_TAC (fun th -> + if already_processed (concl th) then th + else CONV_RULE (TRY_CONV (ZUNPACK_128_CONV d) THENC + TRY_CONV (ONCE_REWRITE_CONV [zunpack_correct])) th);; diff --git a/proofs/hol_light/aarch64/proofs/dump_bytecode.ml b/proofs/hol_light/aarch64/proofs/dump_bytecode.ml index d64561e32..1236001f0 100644 --- a/proofs/hol_light/aarch64/proofs/dump_bytecode.ml +++ b/proofs/hol_light/aarch64/proofs/dump_bytecode.ml @@ -9,3 +9,23 @@ needs "arm/proofs/base.ml";; print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_caddq.o ===\n";; print_literal_from_elf "aarch64/mldsa/mldsa_poly_caddq.o";; print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_chknorm.o ===\n";; +print_literal_from_elf "aarch64/mldsa/mldsa_poly_chknorm.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_decompose_32.o ===\n";; +print_literal_from_elf "aarch64/mldsa/mldsa_poly_decompose_32.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_decompose_88.o ===\n";; +print_literal_from_elf "aarch64/mldsa/mldsa_poly_decompose_88.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/mldsa_polyz_unpack_17.o ===\n";; +print_literal_from_elf "aarch64/mldsa/mldsa_polyz_unpack_17.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/mldsa_polyz_unpack_19.o ===\n";; +print_literal_from_elf "aarch64/mldsa/mldsa_polyz_unpack_19.o";; +print_string "==== bytecode end =====================================\n\n";; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml b/proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml new file mode 100644 index 000000000..e57abc17a --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_poly_chknorm.ml @@ -0,0 +1,213 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Functional correctness of poly_chknorm: *) +(* Check if any polynomial coefficient has absolute value >= bound *) +(* Returns 1 if norm check fails (|coeff| >= bound), 0 otherwise *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; + +(**** print_literal_from_elf "aarch64/mldsa/mldsa_poly_chknorm.o";; + ****) + +let mldsa_poly_chknorm_mc = define_assert_from_elf "mldsa_poly_chknorm_mc" "aarch64/mldsa/mldsa_poly_chknorm.o" +(*** BYTECODE START ***) +[ + 0x4e040c34; (* arm_DUP_GEN Q20 X1 32 128 *) + 0x6e351eb5; (* arm_EOR_VEC Q21 Q21 Q21 128 *) + 0xd2800202; (* arm_MOV X2 (rvalue (word 16)) *) + 0x3dc00401; (* arm_LDR Q1 X0 (Immediate_Offset (word 16)) *) + 0x3dc00802; (* arm_LDR Q2 X0 (Immediate_Offset (word 32)) *) + 0x3dc00c03; (* arm_LDR Q3 X0 (Immediate_Offset (word 48)) *) + 0x3cc40400; (* arm_LDR Q0 X0 (Postimmediate_Offset (word 64)) *) + 0x4ea0b821; (* arm_ABS_VEC Q1 Q1 32 128 *) + 0x4eb43c21; (* arm_CMGE_VEC Q1 Q1 Q20 32 128 *) + 0x4ea11eb5; (* arm_ORR_VEC Q21 Q21 Q1 128 *) + 0x4ea0b842; (* arm_ABS_VEC Q2 Q2 32 128 *) + 0x4eb43c42; (* arm_CMGE_VEC Q2 Q2 Q20 32 128 *) + 0x4ea21eb5; (* arm_ORR_VEC Q21 Q21 Q2 128 *) + 0x4ea0b863; (* arm_ABS_VEC Q3 Q3 32 128 *) + 0x4eb43c63; (* arm_CMGE_VEC Q3 Q3 Q20 32 128 *) + 0x4ea31eb5; (* arm_ORR_VEC Q21 Q21 Q3 128 *) + 0x4ea0b800; (* arm_ABS_VEC Q0 Q0 32 128 *) + 0x4eb43c00; (* arm_CMGE_VEC Q0 Q0 Q20 32 128 *) + 0x4ea01eb5; (* arm_ORR_VEC Q21 Q21 Q0 128 *) + 0xf1000442; (* arm_SUBS X2 X2 (rvalue (word 1)) *) + 0x54fffde1; (* arm_BNE (word 2097084) *) + 0x6eb0aab5; (* arm_UMAXV Q21 Q21 4 32 *) + 0x1e2602a0; (* arm_FMOV_FtoI W0 Q21 0 32 *) + 0x12000000; (* arm_AND W0 W0 (rvalue (word 1)) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_POLY_CHKNORM_EXEC = ARM_MK_EXEC_RULE mldsa_poly_chknorm_mc;; + +(* ------------------------------------------------------------------------- *) +(* Code length constants *) +(* ------------------------------------------------------------------------- *) + +let LENGTH_MLDSA_POLY_CHKNORM_MC = + REWRITE_CONV[mldsa_poly_chknorm_mc] `LENGTH mldsa_poly_chknorm_mc` + |> CONV_RULE (RAND_CONV LENGTH_CONV);; + +let MLDSA_POLY_CHKNORM_PREAMBLE_LENGTH = new_definition + `MLDSA_POLY_CHKNORM_PREAMBLE_LENGTH = 0`;; + +let MLDSA_POLY_CHKNORM_POSTAMBLE_LENGTH = new_definition + `MLDSA_POLY_CHKNORM_POSTAMBLE_LENGTH = 4`;; + +let MLDSA_POLY_CHKNORM_CORE_START = new_definition + `MLDSA_POLY_CHKNORM_CORE_START = MLDSA_POLY_CHKNORM_PREAMBLE_LENGTH`;; + +let MLDSA_POLY_CHKNORM_CORE_END = new_definition + `MLDSA_POLY_CHKNORM_CORE_END = LENGTH mldsa_poly_chknorm_mc - MLDSA_POLY_CHKNORM_POSTAMBLE_LENGTH`;; + +let CHKNORM_LENGTH_SIMPLIFY_CONV = + REWRITE_CONV[LENGTH_MLDSA_POLY_CHKNORM_MC; + MLDSA_POLY_CHKNORM_CORE_START; MLDSA_POLY_CHKNORM_CORE_END; + MLDSA_POLY_CHKNORM_PREAMBLE_LENGTH; MLDSA_POLY_CHKNORM_POSTAMBLE_LENGTH] THENC + NUM_REDUCE_CONV THENC REWRITE_CONV [ADD_0];; + +(* ------------------------------------------------------------------------- *) +(* Helper lemmas *) +(* ------------------------------------------------------------------------- *) + +(* Expression emerging from the AVX2 code converting bit to 32-bit mask *) +let bit_to_mask32 = new_definition `bit_to_mask32 (b : bool) : 32 word = word_neg (word (bitval b) : 32 word)`;; + +(* Expression used for bounds check itself *) +let bd = new_definition `bd (v : int32) (b: int32) : bool = + (ival (iword (abs (ival v)) : 32 word) >= ival (word_zx (word_zx b : 64 word) : 32 word))`;; + +let MAX_VAL_BIT_TO_MASK32 = prove( + `MAX (val (bit_to_mask32 b0)) (val (bit_to_mask32 b1)) = val (bit_to_mask32 (b0 \/ b1))`, + REWRITE_TAC[bit_to_mask32] THEN + BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN + REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN ARITH_TAC);; + +let BD_SIMP = prove( + `abs(ival(x : int32)) < &2 pow 31 ==> (bd x b <=> abs (ival x) >= ival b)`, + REWRITE_TAC[bd] THEN DISCH_TAC THEN + SUBGOAL_THEN `ival(iword(abs(ival(x:32 word))) : 32 word) = abs(ival x)` SUBST1_TAC THENL + [MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + FIRST_X_ASSUM MP_TAC THEN REWRITE_TAC[INT_ABS_POS] THEN INT_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(word_zx:64 word -> 32 word) ((word_zx:32 word -> 64 word) b) = b` SUBST1_TAC THENL + [MATCH_MP_TAC WORD_ZX_ZX THEN REWRITE_TAC[DIMINDEX_32; DIMINDEX_64] THEN ARITH_TAC; + REFL_TAC]);; + +let BIT_TO_MASK32_OR = prove( + `word_or (bit_to_mask32 b0) (bit_to_mask32 b1) = bit_to_mask32 (b0 \/ b1)`, + REWRITE_TAC[bit_to_mask32] THEN + BOOL_CASES_TAC `b0:bool` THEN BOOL_CASES_TAC `b1:bool` THEN + REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV);; + +let MASK32_TO_BIT = prove( + `(word_zx:32 word -> 64 word) (word_and ((word_zx:64 word -> 32 word) + ((word_zx:32 word -> 64 word) (word_subword (word (val (bit_to_mask32 b)) : 128 word) (0,32)))) + (word 1)) = word (bitval b) : 64 word`, + REWRITE_TAC[bit_to_mask32] THEN + BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[BITVAL_CLAUSES] THEN + CONV_TAC WORD_REDUCE_CONV);; + +let WORD_JOIN_OR_TYBIT0 = prove( + `word_or (word_join (a:N word) (b:N word) : (N tybit0) word) (word_join (c:N word) (d:N word)) = + word_join (word_or a c) (word_or b d)`, + REWRITE_TAC[WORD_EQ_BITS_ALT; BIT_WORD_OR; BIT_WORD_JOIN; DIMINDEX_TYBIT0] THEN + X_GEN_TAC `i:num` THEN + ASM_CASES_TAC `i < 2 * dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `i < dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN + MATCH_MP_TAC(TAUT `p ==> (q <=> p /\ q)`) THEN ASM_ARITH_TAC);; + +(* ------------------------------------------------------------------------- *) +(* Core correctness theorem *) +(* ------------------------------------------------------------------------- *) + +let MLDSA_POLY_CHKNORM_CORRECT = prove( + `!a (x:num->int32) (bound:int32) pc. + nonoverlapping (word pc, LENGTH mldsa_poly_chknorm_mc) (a, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_chknorm_mc /\ + read PC s = word(pc + MLDSA_POLY_CHKNORM_CORE_START) /\ + C_ARGUMENTS [a; word_zx bound] s /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> abs(ival(x i)) < &2 pow 31)) + (\s. read PC s = word(pc + MLDSA_POLY_CHKNORM_CORE_END) /\ + read X0 s = word(bitval(?i. i < 256 /\ abs(ival(x i)) >= ival bound))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI)`, + CONV_TAC CHKNORM_LENGTH_SIMPLIFY_CONV THEN + MAP_EVERY X_GEN_TAC [`a:int64`; `x:num->int32`; `bound:int32`; `pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; + NONOVERLAPPING_CLAUSES] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + (* Expand bounded foralls in precondition to 256 explicit cases *) + ENSURES_INIT_TAC "s0" THEN + UNDISCH_TAC `forall i. i < 256 ==> read (memory :> bytes32 (word_add a (word (4 * i)))) s0 = x i` THEN + CONV_TAC(ONCE_DEPTH_CONV (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)) THEN REPEAT STRIP_TAC THEN + (* Merge bytes32 reads into bytes128 reads (64 merges for 256 coefficients) *) + MP_TAC(end_itlist CONJ (map (fun n -> READ_MEMORY_MERGE_CONV 2 + (subst[mk_small_numeral(16*n),`n:num`] + `read (memory :> bytes128(word_add a (word n))) s0`)) + (0--63))) THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + STRIP_TAC THEN + (* Symbolically execute all instructions until target PC *) + MAP_UNTIL_TARGET_PC (fun n -> + ARM_STEPS_TAC MLDSA_POLY_CHKNORM_EXEC [n] THEN + RULE_ASSUM_TAC(CONV_RULE(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV)) THEN + RULE_ASSUM_TAC(REWRITE_RULE[WORD_SUBWORD_OR; GSYM bit_to_mask32; WORD_JOIN_OR_TYBIT0; SYM (SPEC_ALL bd); BIT_TO_MASK32_OR; + MAX_VAL_BIT_TO_MASK32; MASK32_TO_BIT])) 1 THEN + + (* Close the state relation *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read t s = x`] THEN + + RULE_ASSUM_TAC (CONV_RULE (ONCE_DEPTH_CONV EXPAND_CASES_CONV)) THEN + REPEAT(FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC)) THEN + IMP_REWRITE_TAC [BD_SIMP] THEN + POP_ASSUM_LIST (K ALL_TAC) THEN + + (* Convert to ! instead of ? and split *) + GEN_REWRITE_TAC (BINOP_CONV o ONCE_DEPTH_CONV) [prove (`b = ~ (~ (b : bool))`, REWRITE_TAC[])] THEN + GEN_REWRITE_TAC TOP_SWEEP_CONV [MESON[] `~(?i. i < n /\ P i) <=> (!i. i < n ==> ~P i)`; DE_MORGAN_THM] THEN + CONV_TAC (ONCE_DEPTH_CONV EXPAND_CASES_CONV) THEN + REPEAT AP_TERM_TAC THEN EQ_TAC THEN SIMP_TAC[]);; + +(* ------------------------------------------------------------------------- *) +(* Subroutine correctness theorem (includes return) *) +(* ------------------------------------------------------------------------- *) + +(* NOTE: This must be kept in sync with the CBMC specification + * in mldsa/src/native/aarch64/src/arith_native_aarch64.h *) + +let MLDSA_POLY_CHKNORM_SUBROUTINE_CORRECT = prove( + `!a (x:num->int32) (bound:int32) pc returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_chknorm_mc) (a, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_chknorm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [a; word_zx bound] s /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> abs(ival(x i)) < &2 pow 31)) + (\s. read PC s = returnaddress /\ + read X0 s = word(bitval(?i. i < 256 /\ abs(ival(x i)) >= ival bound))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI)`, + CONV_TAC CHKNORM_LENGTH_SIMPLIFY_CONV THEN + let TWEAK_CONV = + ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC + ONCE_DEPTH_CONV NUM_MULT_CONV THENC + PURE_REWRITE_CONV [WORD_ADD_0] in + CONV_TAC TWEAK_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC MLDSA_POLY_CHKNORM_EXEC + (CONV_RULE TWEAK_CONV + (CONV_RULE CHKNORM_LENGTH_SIMPLIFY_CONV MLDSA_POLY_CHKNORM_CORRECT)));; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml b/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml new file mode 100644 index 000000000..4696b208e --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_32.ml @@ -0,0 +1,611 @@ +(* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + *) + +(* ========================================================================= *) +(* Functional correctness of poly_decompose_32: *) +(* Decompose polynomial coefficients into (a1, a0) where a = a1*2*GAMMA2+a0 *) +(* for GAMMA2 = (Q-1)/32 = 261888 (ML-DSA-65/87) *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; +needs "common/mldsa_specs.ml";; + +(**** print_literal_from_elf "aarch64/mldsa/mldsa_poly_decompose_32.o";; + ****) + +let mldsa_poly_decompose_32_mc = define_assert_from_elf "mldsa_poly_decompose_32_mc" "aarch64/mldsa/mldsa_poly_decompose_32.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x529c2005; (* arm_MOV W5 (rvalue (word 57600)) *) + 0x72a00f65; (* arm_MOVK W5 (word 123) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529fc007; (* arm_MOV W7 (rvalue (word 65024)) *) + 0x72a000e7; (* arm_MOVK W7 (word 7) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280802b; (* arm_MOV W11 (rvalue (word 1025)) *) + 0x72a8020b; (* arm_MOVK W11 (word 16400) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00020; (* arm_LDR Q0 X1 (Immediate_Offset (word 0)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x4eb7b425; (* arm_SQDMULH_VEC Q5 Q1 Q23 32 128 *) + 0x4f2e24a5; (* arm_SRSHR_VEC Q5 Q5 18 32 128 *) + 0x4eb53438; (* arm_CMGT_VEC Q24 Q1 Q21 32 128 *) + 0x6eb694a1; (* arm_MLS_VEC Q1 Q5 Q22 32 128 *) + 0x4e781ca5; (* arm_BIC_VEC Q5 Q5 Q24 128 *) + 0x4eb88421; (* arm_ADD_VEC Q1 Q1 Q24 32 128 *) + 0x4eb7b446; (* arm_SQDMULH_VEC Q6 Q2 Q23 32 128 *) + 0x4f2e24c6; (* arm_SRSHR_VEC Q6 Q6 18 32 128 *) + 0x4eb53458; (* arm_CMGT_VEC Q24 Q2 Q21 32 128 *) + 0x6eb694c2; (* arm_MLS_VEC Q2 Q6 Q22 32 128 *) + 0x4e781cc6; (* arm_BIC_VEC Q6 Q6 Q24 128 *) + 0x4eb88442; (* arm_ADD_VEC Q2 Q2 Q24 32 128 *) + 0x4eb7b467; (* arm_SQDMULH_VEC Q7 Q3 Q23 32 128 *) + 0x4f2e24e7; (* arm_SRSHR_VEC Q7 Q7 18 32 128 *) + 0x4eb53478; (* arm_CMGT_VEC Q24 Q3 Q21 32 128 *) + 0x6eb694e3; (* arm_MLS_VEC Q3 Q7 Q22 32 128 *) + 0x4e781ce7; (* arm_BIC_VEC Q7 Q7 Q24 128 *) + 0x4eb88463; (* arm_ADD_VEC Q3 Q3 Q24 32 128 *) + 0x4eb7b404; (* arm_SQDMULH_VEC Q4 Q0 Q23 32 128 *) + 0x4f2e2484; (* arm_SRSHR_VEC Q4 Q4 18 32 128 *) + 0x4eb53418; (* arm_CMGT_VEC Q24 Q0 Q21 32 128 *) + 0x6eb69480; (* arm_MLS_VEC Q0 Q4 Q22 32 128 *) + 0x4e781c84; (* arm_BIC_VEC Q4 Q4 Q24 128 *) + 0x4eb88400; (* arm_ADD_VEC Q0 Q0 Q24 32 128 *) + 0x3d800405; (* arm_STR Q5 X0 (Immediate_Offset (word 16)) *) + 0x3d800806; (* arm_STR Q6 X0 (Immediate_Offset (word 32)) *) + 0x3d800c07; (* arm_STR Q7 X0 (Immediate_Offset (word 48)) *) + 0x3c840404; (* arm_STR Q4 X0 (Postimmediate_Offset (word 64)) *) + 0x3d800421; (* arm_STR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3d800822; (* arm_STR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3d800c23; (* arm_STR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3c840420; (* arm_STR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fffb61; (* arm_BNE (word 2097004) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_POLY_DECOMPOSE_32_EXEC = ARM_MK_EXEC_RULE mldsa_poly_decompose_32_mc;; + +(* ========================================================================= *) +(* Constants *) +(* ========================================================================= *) + +let LENGTH_MLDSA_POLY_DECOMPOSE_32_MC = + REWRITE_CONV[mldsa_poly_decompose_32_mc] `LENGTH mldsa_poly_decompose_32_mc` + |> CONV_RULE (RAND_CONV LENGTH_CONV);; + +let MLDSA_POLY_DECOMPOSE_32_CORE_START = new_definition + `MLDSA_POLY_DECOMPOSE_32_CORE_START = 0`;; + +let MLDSA_POLY_DECOMPOSE_32_POSTAMBLE_LENGTH = new_definition + `MLDSA_POLY_DECOMPOSE_32_POSTAMBLE_LENGTH = 4`;; + +let MLDSA_POLY_DECOMPOSE_32_CORE_END = new_definition + `MLDSA_POLY_DECOMPOSE_32_CORE_END = + LENGTH mldsa_poly_decompose_32_mc - MLDSA_POLY_DECOMPOSE_32_POSTAMBLE_LENGTH`;; + +let LENGTH_SIMPLIFY_CONV = + REWRITE_CONV[LENGTH_MLDSA_POLY_DECOMPOSE_32_MC; + MLDSA_POLY_DECOMPOSE_32_CORE_START; MLDSA_POLY_DECOMPOSE_32_CORE_END; + MLDSA_POLY_DECOMPOSE_32_POSTAMBLE_LENGTH] THENC + NUM_REDUCE_CONV THENC REWRITE_CONV [ADD_0];; + +(* ========================================================================= *) +(* Word-level helper functions *) +(* Per-lane operations matching the assembly's SQDMULH+SRSHR, BIC, MLS+ADD *) +(* ========================================================================= *) + +(* h32: quotient = srshr(sqdmulh(x, magic), 18) ≈ round(x / 523776) *) +let h32 = define + `h32 (x:int32) : int32 = + iword((ival((iword_saturate:int->int32) + ((&2 * ival x * &1074791425) div &4294967296)) + + &131072) div &262144)`;; + +(* decompose32_a1: a1 = h AND (NOT mask) where mask = -1 if x > threshold *) +let decompose32_a1 = define + `decompose32_a1 (x:int32) : int32 = + word_and (h32 x) + (word_not(word_neg(word(bitval(ival x > &8118528)))))`;; + +(* decompose32_a0: a0 = (x - h*2*GAMMA2) + mask *) +let decompose32_a0 = define + `decompose32_a0 (x:int32) : int32 = + word_add (word_sub x (word_mul (h32 x) (word 523776))) + (word_neg(word(bitval(ival x > &8118528))))`;; + +(* ========================================================================= *) +(* Distribution lemmas for word_and/word_not over word_join *) +(* Needed because BIC (arm_BIC_VEC) operates at 128-bit level *) +(* ========================================================================= *) + +let WORD_AND_JOIN_64 = WORD_BLAST + `!a b c d : int32. + word_and ((word_join:int32->int32->int64) a b) + ((word_join:int32->int32->int64) c d) = + word_join (word_and a c) (word_and b d)`;; + +let WORD_AND_JOIN_128 = WORD_BLAST + `!a b c d : int64. + word_and ((word_join:int64->int64->int128) a b) + ((word_join:int64->int64->int128) c d) = + word_join (word_and a c) (word_and b d)`;; + +let WORD_NOT_JOIN_64 = WORD_BLAST + `!a b : int32. + word_not ((word_join:int32->int32->int64) a b) = + word_join (word_not a) (word_not b)`;; + +let WORD_NOT_JOIN_128 = WORD_BLAST + `!a b : int64. + word_not ((word_join:int64->int64->int128) a b) = + word_join (word_not a) (word_not b)`;; + +(* ========================================================================= *) +(* Mathematical correctness lemmas *) +(* Connect word-level decompose32_a1/a0 to spec-level decompose32 *) +(* ========================================================================= *) + +(* Case split: a1 is either h32 or 0 depending on the threshold *) +let DECOMPOSE32_A1_CASES = prove( + `!x:int32. decompose32_a1 x = + if ival x > &8118528 then word 0 else h32 x`, + REWRITE_TAC[decompose32_a1] THEN BITBLAST_TAC);; + +(* Case split: a0 subtracts 1 in the special case *) +let DECOMPOSE32_A0_CASES = prove( + `!x:int32. decompose32_a0 x = + if ival x > &8118528 + then word_sub (word_sub x (word_mul (h32 x) (word 523776))) (word 1) + else word_sub x (word_mul (h32 x) (word 523776))`, + GEN_TAC THEN REWRITE_TAC[decompose32_a0] THEN + COND_CASES_TAC THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_RULE);; + +(* ival equals val for values in the positive int32 range *) +let IVAL_EQ_VAL = prove( + `!x:int32. val x < 2 EXP 31 ==> ival x = &(val x)`, + GEN_TAC THEN REWRITE_TAC[IVAL_VAL; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + DISCH_TAC THEN + SUBGOAL_THEN `bit (32 - 1) (x:int32) = F` ASSUME_TAC THENL [ + REWRITE_TAC[BIT_VAL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_ARITH_TAC; + ASM_REWRITE_TAC[bitval] THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Barrett reduction correctness for h32 *) +(* Shows that SQDMULH+SRSHR computes round(x / 523776) correctly *) +(* ========================================================================= *) + +(* Algebraic expansion: n*K + q*E = q*D*P + r*K + where K=2149582850, M=523776, D=262144, P=4294967296, E=1024 *) +let BARRETT32_EXPAND = prove( + `!n. n * 2149582850 + (n DIV 523776) * 1024 = + (n DIV 523776) * 262144 * 4294967296 + (n MOD 523776) * 2149582850`, + GEN_TAC THEN + MP_TAC(SPECL [`n:num`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + DISCH_THEN(CONJUNCTS_THEN2 (fun th -> GEN_REWRITE_TAC (LAND_CONV o LAND_CONV o LAND_CONV) [th]) ASSUME_TAC) THEN + CONV_TAC NUM_REDUCE_CONV THEN + CONV_TAC NUM_RING);; + +(* DIV_BOUNDS_EQ: if q*d <= b < (q+1)*d then b DIV d = q *) +let DIV_BOUNDS_EQ = prove( + `!b d q. ~(d = 0) /\ q * d <= b /\ b < (q + 1) * d ==> b DIV d = q`, + REPEAT STRIP_TAC THEN MATCH_MP_TAC(ARITH_RULE `q <= r /\ r < q + 1 ==> r = q`) THEN + CONJ_TAC THENL [ + ASM_SIMP_TAC[LE_RDIV_EQ] THEN ASM_ARITH_TAC; + ASM_SIMP_TAC[RDIV_LT_EQ] THEN ASM_ARITH_TAC]);; + +(* Barrett reduction: (n*K) DIV P with rounding = round(n / M) *) +let BARRETT32_CORRECT = prove( + `!n. n < 8380417 ==> + ((n * 2149582850) DIV 4294967296 + 131072) DIV 262144 = + (if n MOD 523776 * 2 <= 523776 + then n DIV 523776 + else n DIV 523776 + 1)`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `n = 8380416` THENL [ + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV; ALL_TAC] THEN + ABBREV_TAC `q = n DIV 523776` THEN + ABBREV_TAC `r = n MOD 523776` THEN + SUBGOAL_THEN `n = q * 523776 + r` ASSUME_TAC THENL [ + EXPAND_TAC "q" THEN EXPAND_TAC "r" THEN + MESON_TAC[DIVISION; ARITH_RULE `~(523776 = 0)`]; ALL_TAC] THEN + SUBGOAL_THEN `r < 523776` ASSUME_TAC THENL [ + EXPAND_TAC "r" THEN SIMP_TAC[MOD_LT_EQ] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q <= 15` ASSUME_TAC THENL [ + EXPAND_TAC "q" THEN ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(523776 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `n:num` BARRETT32_EXPAND) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + COND_CASES_TAC THENL [ + (* Round-down case: r * 2 <= 523776, so r <= 261888 *) + ABBREV_TAC `d = ((q * 523776 + r) * 2149582850) DIV 4294967296` THEN + MP_TAC(SPECL [`(q * 523776 + r) * 2149582850`; `4294967296`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + SUBGOAL_THEN `d * 4294967296 + q * 1024 <= q * 262144 * 4294967296 + r * 2149582850` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 262144 * 4294967296 + r * 2149582850 < (d + 1) * 4294967296 + q * 1024` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `r * 2149582850 <= 261888 * 2149582850` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN CONJ_TAC THENL [ + MP_TAC(ARITH_RULE `261888 * 2149582850 < 131072 * 4294967296`) THEN ASM_ARITH_TAC; + MP_TAC(ARITH_RULE `261888 * 2149582850 < 131072 * 4294967296`) THEN ASM_ARITH_TAC]; + (* Round-up case: ~(r * 2 <= 523776), so r >= 261889 *) + ABBREV_TAC `d = ((q * 523776 + r) * 2149582850) DIV 4294967296` THEN + MP_TAC(SPECL [`(q * 523776 + r) * 2149582850`; `4294967296`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + SUBGOAL_THEN `d * 4294967296 + q * 1024 <= q * 262144 * 4294967296 + r * 2149582850` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 262144 * 4294967296 + r * 2149582850 < (d + 1) * 4294967296 + q * 1024` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `261889 * 2149582850 <= r * 2149582850` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `r * 2149582850 < 523776 * 2149582850` ASSUME_TAC THENL [ + REWRITE_TAC[LT_MULT_RCANCEL] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 1024 <= 15 * 1024` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN CONJ_TAC THENL [ + MP_TAC(ARITH_RULE `131072 * 4294967296 + 15 * 1024 <= 261889 * 2149582850`) THEN + ASM_ARITH_TAC; + MP_TAC(ARITH_RULE `523776 * 2149582850 < 262144 * 4294967296`) THEN + ASM_ARITH_TAC]]);; + +(* ========================================================================= *) +(* Word-level to natural number connection for h32 *) +(* ========================================================================= *) + +(* h32 computes the correct rounding quotient: round(val x / 523776) + Eliminates iword_saturate by inlining its definition and using BOUNDER_TAC + to discharge the impossible saturation cases (consistent with mlkem-native). *) +let H32_CORRECT = prove( + `!x:int32. val x < 8380417 ==> + val(h32 x) = (if val x MOD 523776 * 2 <= 523776 + then val x DIV 523776 + else val x DIV 523776 + 1)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[h32; iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REPEAT(COND_CASES_TAC THENL + [FIRST_X_ASSUM(MATCH_MP_TAC o MATCH_MP (MESON[] `p ==> ~p ==> q`)) THEN + REWRITE_TAC[INT_GT; INT_NOT_LT] THEN BOUNDER_TAC[]; + ASM_REWRITE_TAC[]]) THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[INT_OF_NUM_MUL] THEN + SUBGOAL_THEN `2 * val(x:int32) * 1074791425 = val x * 2149582850` SUBST1_TAC THENL [ + ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV] THEN + SUBGOAL_THEN `(val(x:int32) * 2149582850) DIV 4294967296 < 2147483648` ASSUME_TAC THENL [ + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(4294967296 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `ival(iword(&((val(x:int32) * 2149582850) DIV 4294967296)):int32) = + &((val x * 2149582850) DIV 4294967296)` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + REWRITE_TAC[INT_OF_NUM_LT; INT_LE_NEG2; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_ADD; INT_OF_NUM_DIV] THEN + SUBGOAL_THEN `((val(x:int32) * 2149582850) DIV 4294967296 + 131072) DIV 262144 < 2147483648` ASSUME_TAC THENL [ + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(262144 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[GSYM WORD_IWORD; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `n < 2147483648 ==> n < 4294967296`] THEN + MATCH_MP_TAC BARRETT32_CORRECT THEN ASM_REWRITE_TAC[]);; + +(* Special case: rounding quotient = 16 when val x > 8118528 *) +let ROUND32_SPECIAL = prove( + `!n. 8118528 < n /\ n < 8380417 ==> + (if n MOD 523776 * 2 <= 523776 then n DIV 523776 else n DIV 523776 + 1) = 16`, + REPEAT STRIP_TAC THEN + ASM_CASES_TAC `n < 8380416` THENL [ + SUBGOAL_THEN `n DIV 523776 = 15` ASSUME_TAC THENL [ + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THENL [ + MP_TAC(SPECL [`n:num`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[]]; + SUBGOAL_THEN `n = 8380416` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV]);; + +(* ========================================================================= *) +(* Main correctness lemmas: connect word-level to spec-level *) +(* ========================================================================= *) + +(* decompose32_a1 computes FST(decompose32(val x)) *) +let DECOMPOSE32_A1_CORRECT = prove( + `!x:int32. val x < 8380417 + ==> val(decompose32_a1 x) = FST(decompose32(val x))`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE32_A1_CASES; DECOMPOSE32_EXPAND; LET_DEF; LET_END_DEF; FST] THEN + COND_CASES_TAC THENL [ + (* ival x > &8118528: a1 = word 0, h = 16, FST = 0 *) + REWRITE_TAC[VAL_WORD_0; FST] THEN + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `&(val(x:int32)):int > &8118528` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_GT; GT] THEN DISCH_TAC THEN + MP_TAC(SPEC `val(x:int32)` ROUND32_SPECIAL) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[FST]; + (* ~(ival x > &8118528): a1 = h32 x, h < 16 *) + MP_TAC(SPEC `x:int32` H32_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + COND_CASES_TAC THENL [ + (* Round-down case *) + SUBGOAL_THEN `val(x:int32) <= 8118528` ASSUME_TAC THENL [ + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `~(&(val(x:int32)):int > &8118528)` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_GT; INT_NOT_LT; INT_OF_NUM_LE]; ALL_TAC] THEN + SUBGOAL_THEN `~(val(x:int32) DIV 523776 = 16)` ASSUME_TAC THENL [ + DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[FST]; + (* Round-up case *) + SUBGOAL_THEN `val(x:int32) <= 8118528` ASSUME_TAC THENL [ + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `~(&(val(x:int32)):int > &8118528)` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_GT; INT_NOT_LT; INT_OF_NUM_LE]; ALL_TAC] THEN + SUBGOAL_THEN `~(val(x:int32) DIV 523776 + 1 = 16)` ASSUME_TAC THENL [ + REWRITE_TAC[ARITH_RULE `n + 1 = 16 <=> n = 15`] THEN DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[FST]]]);; + +(* cmod n 523776 is bounded by [-261888, 261888], well within int32 range *) +let CMOD_ABS_BOUND_523776 = prove( + `!n. abs(cmod n 523776) <= &261888`, + GEN_TAC THEN REWRITE_TAC[cmod] THEN + SUBGOAL_THEN `n MOD 523776 < 523776` MP_TAC THENL [ + SIMP_TAC[MOD_LT_EQ; ARITH_RULE `~(523776 = 0)`]; ALL_TAC] THEN + SPEC_TAC(`n MOD 523776`, `m:num`) THEN GEN_TAC THEN DISCH_TAC THEN + COND_CASES_TAC THEN + REWRITE_TAC[INT_ABS; INT_POS; INT_OF_NUM_LE; + INT_OF_NUM_SUB; INT_SUB_LE; INT_NEG_SUB] THEN + ASM_ARITH_TAC);; + +(* decompose32_a0 computes SND(decompose32(val x)) *) +let DECOMPOSE32_A0_CORRECT = prove( + `!x:int32. val x < 8380417 + ==> ival(decompose32_a0 x) = SND(decompose32(val x))`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE32_A0_CASES; DECOMPOSE32_EXPAND; LET_DEF; LET_END_DEF; SND] THEN + (* Express word_sub x (word_mul (h32 x) (word 523776)) as iword(...) *) + SUBGOAL_THEN `word_sub x (word_mul (h32 x) (word 523776)) : int32 = + iword(ival x - ival(h32 x) * &523776)` SUBST1_TAC THENL [ + CONV_TAC WORD_RULE; ALL_TAC] THEN + (* Convert ival x and ival(h32 x) to val-based expressions *) + SUBGOAL_THEN `ival(x:int32) = &(val x)` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_EQ_VAL THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `ival(h32 x:int32) = &(val(h32 x))` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_EQ_VAL THEN + MP_TAC(SPEC `x:int32` H32_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(523776 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + (* Substitute h32 value using H32_CORRECT *) + MP_TAC(SPEC `x:int32` H32_CORRECT) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[INT_OF_NUM_GT] THEN + ABBREV_TAC `h = (if val(x:int32) MOD 523776 * 2 <= 523776 + then val x DIV 523776 else val x DIV 523776 + 1)` THEN + (* Establish DIVISION identity in int form *) + SUBGOAL_THEN `&(val(x:int32)):int = + &(val x DIV 523776) * &523776 + &(val x MOD 523776)` ASSUME_TAC THENL [ + MP_TAC(SPECL [`val(x:int32)`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + DISCH_THEN(MP_TAC o AP_TERM `int_of_num` o CONJUNCT1) THEN + REWRITE_TAC[INT_OF_NUM_MUL; INT_OF_NUM_ADD]; ALL_TAC] THEN + (* Prove key identity: val x - h * 523776 = cmod(val x) 523776 *) + SUBGOAL_THEN `&(val(x:int32)) - &h * &523776 = cmod (val x) 523776` + ASSUME_TAC THENL [ + REWRITE_TAC[cmod] THEN + FIRST_X_ASSUM(MP_TAC o SYM o check (fun th -> + fst(dest_cond(fst(dest_eq(concl th)))) = + `val (x:int32) MOD 523776 * 2 <= 523776`)) THEN + COND_CASES_TAC THENL [ + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN INT_ARITH_TAC; + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[GSYM INT_OF_NUM_ADD; + GSYM INT_OF_NUM_MUL] THEN INT_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + (* Case split on val x > 8118528 *) + COND_CASES_TAC THENL [ + (* Special case: val x > 8118528, h = 16 *) + SUBGOAL_THEN `h = 16` SUBST1_TAC THENL [ + MP_TAC(SPEC `val(x:int32)` ROUND32_SPECIAL) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; + ALL_TAC] THEN + REWRITE_TAC[SND] THEN + SUBGOAL_THEN `word_sub (iword(cmod (val(x:int32)) 523776)) (word 1) : int32 = + iword(cmod (val x) 523776 - &1)` SUBST1_TAC THENL [ + REWRITE_TAC[GSYM IWORD_INT_SUB; WORD_IWORD]; ALL_TAC] THEN + MATCH_MP_TAC(INST_TYPE [`:32`,`:N`] IVAL_IWORD) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `val(x:int32)` CMOD_ABS_BOUND_523776) THEN INT_ARITH_TAC; + (* Normal case: ~(val x > 8118528), h != 16 *) + SUBGOAL_THEN `~(h = 16)` ASSUME_TAC THENL [ + DISCH_TAC THEN + SUBGOAL_THEN `val(x:int32) <= 8118528` ASSUME_TAC THENL [ + ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(if val(x:int32) MOD 523776 * 2 <= 523776 + then val x DIV 523776 else val x DIV 523776 + 1) = 16` MP_TAC THENL [ + ASM_MESON_TAC[]; ALL_TAC] THEN + COND_CASES_TAC THENL [ + DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[ARITH_RULE `n + 1 = 16 <=> n = 15`] THEN DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[SND] THEN + MATCH_MP_TAC(INST_TYPE [`:32`,`:N`] IVAL_IWORD) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `val(x:int32)` CMOD_ABS_BOUND_523776) THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Specification *) +(* ========================================================================= *) + +let MLDSA_POLY_DECOMPOSE_32_CORRECT = prove( + `!pc a r1 x. + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_32_mc) + (r1, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_32_mc) + (a, 1024) /\ + nonoverlapping (r1, 1024) (a, 1024) + ==> (!i. i < 256 ==> val(x i:int32) < 8380417) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_decompose_32_mc /\ + read PC s = word(pc + MLDSA_POLY_DECOMPOSE_32_CORE_START) /\ + C_ARGUMENTS [r1; a] s /\ + (!i. i < 256 + ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = + x i)) + (\s. read PC s = word(pc + MLDSA_POLY_DECOMPOSE_32_CORE_END) /\ + (!i. i < 256 + ==> val(read(memory :> bytes32 + (word_add r1 (word(4 * i)))) s) = + FST(decompose32(val(x i)))) /\ + (!i. i < 256 + ==> ival(read(memory :> bytes32 + (word_add a (word(4 * i)))) s) = + SND(decompose32(val(x i))))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r1, 1024)] ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + MAP_EVERY X_GEN_TAC [`pc:num`; `a:int64`; `r1:int64`; `x:num->int32`] THEN + REWRITE_TAC[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + fst MLDSA_POLY_DECOMPOSE_32_EXEC; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] THEN + STRIP_TAC THEN STRIP_TAC THEN + + (* Expand the quantified input condition to individual coefficients *) + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV + (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)))) THEN + + ENSURES_INIT_TAC "s0" THEN + + (* Merge 4x32-bit coefficient reads into 128-bit vector reads *) + MP_TAC(end_itlist CONJ (map (fun n -> READ_MEMORY_MERGE_CONV 2 + (subst[mk_small_numeral(16*n),`n:num`] + `read (memory :> bytes128(word_add a (word n))) s0`)) + (0--63))) THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + STRIP_TAC THEN + + RULE_ASSUM_TAC(REWRITE_RULE[ADD_CLAUSES]) THEN + + (* Symbolic execution with folding to decompose32_a1/a0 *) + MAP_UNTIL_TARGET_PC (fun n -> + ARM_STEPS_TAC MLDSA_POLY_DECOMPOSE_32_EXEC [n] THEN + RULE_ASSUM_TAC(CONV_RULE( + TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV THENC + ONCE_REWRITE_CONV [GSYM h32] THENC + REWRITE_CONV [WORD_NOT_JOIN_128; WORD_NOT_JOIN_64; + WORD_AND_JOIN_128; WORD_AND_JOIN_64] THENC + ONCE_REWRITE_CONV [GSYM decompose32_a1] THENC + ONCE_REWRITE_CONV [GSYM decompose32_a0]))) 1 THEN + + (* Establish postcondition *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (* Split bytes128 results back into bytes32 *) + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + RULE_ASSUM_TAC(CONV_RULE(RAND_CONV( + TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV))) THEN + + CONV_TAC(ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC + ONCE_DEPTH_CONV NUM_MULT_CONV) THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN + + (* Apply mathematical correctness lemmas *) + REPEAT CONJ_TAC THEN + (MATCH_MP_TAC DECOMPOSE32_A1_CORRECT ORELSE + MATCH_MP_TAC DECOMPOSE32_A0_CORRECT) THEN + FIRST_ASSUM MATCH_MP_TAC THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* ========================================================================= *) +(* Subroutine form: wraps CORRECT with RET handling *) +(* ========================================================================= *) + +let MLDSA_POLY_DECOMPOSE_32_SUBROUTINE_CORRECT = prove( + `!pc a r1 x returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_32_mc) + (r1, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_32_mc) + (a, 1024) /\ + nonoverlapping (r1, 1024) (a, 1024) + ==> (!i. i < 256 ==> val(x i:int32) < 8380417) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_decompose_32_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [r1; a] s /\ + (!i. i < 256 + ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = + x i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 + ==> val(read(memory :> bytes32 + (word_add r1 (word(4 * i)))) s) = + FST(decompose32(val(x i)))) /\ + (!i. i < 256 + ==> ival(read(memory :> bytes32 + (word_add a (word(4 * i)))) s) = + SND(decompose32(val(x i)))) /\ + (!i. i < 256 + ==> FST(decompose32(val(x i))) <= 15) /\ + (!i. i < 256 + ==> --(&261888) <= SND(decompose32(val(x i))) /\ + SND(decompose32(val(x i))) <= &261888)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r1, 1024)] ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + REWRITE_TAC[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + fst MLDSA_POLY_DECOMPOSE_32_EXEC; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC(!simulation_precanon_thms) THEN + ENSURES_INIT_TAC "s0" THEN + MP_TAC(REWRITE_RULE[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] + (SPECL [`pc:num`; `a:int64`; `r1:int64`; `x:num->int32`] + (CONV_RULE LENGTH_SIMPLIFY_CONV MLDSA_POLY_DECOMPOSE_32_CORRECT))) THEN + ANTS_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ANTS_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ARM_BIGSTEP_TAC MLDSA_POLY_DECOMPOSE_32_EXEC "s1" THEN + ARM_STEPS_TAC MLDSA_POLY_DECOMPOSE_32_EXEC [2] THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONJ_TAC THENL [ + REPEAT STRIP_TAC THEN MATCH_MP_TAC DECOMPOSE32_A1_BOUND THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]; + GEN_TAC THEN DISCH_TAC THEN MATCH_MP_TAC DECOMPOSE32_A0_BOUND THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml b/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml new file mode 100644 index 000000000..d278dea70 --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_poly_decompose_88.ml @@ -0,0 +1,611 @@ +(* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + *) + +(* ========================================================================= *) +(* Functional correctness of poly_decompose_88: *) +(* Decompose polynomial coefficients into (a1, a0) where a = a1*2*GAMMA2+a0 *) +(* for GAMMA2 = (Q-1)/88 = 95232 (ML-DSA-44) *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; +needs "common/mldsa_specs.ml";; + +(**** print_literal_from_elf "aarch64/mldsa/mldsa_poly_decompose_88.o";; + ****) + +let mldsa_poly_decompose_88_mc = define_assert_from_elf "mldsa_poly_decompose_88_mc" "aarch64/mldsa/mldsa_poly_decompose_88.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x528d8005; (* arm_MOV W5 (rvalue (word 27648)) *) + 0x72a00fc5; (* arm_MOVK W5 (word 126) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529d0007; (* arm_MOV W7 (rvalue (word 59392)) *) + 0x72a00047; (* arm_MOVK W7 (word 2) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280b02b; (* arm_MOV W11 (rvalue (word 1409)) *) + 0x72ab02cb; (* arm_MOVK W11 (word 22550) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00020; (* arm_LDR Q0 X1 (Immediate_Offset (word 0)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x4eb7b425; (* arm_SQDMULH_VEC Q5 Q1 Q23 32 128 *) + 0x4f2f24a5; (* arm_SRSHR_VEC Q5 Q5 17 32 128 *) + 0x4eb53438; (* arm_CMGT_VEC Q24 Q1 Q21 32 128 *) + 0x6eb694a1; (* arm_MLS_VEC Q1 Q5 Q22 32 128 *) + 0x4e781ca5; (* arm_BIC_VEC Q5 Q5 Q24 128 *) + 0x4eb88421; (* arm_ADD_VEC Q1 Q1 Q24 32 128 *) + 0x4eb7b446; (* arm_SQDMULH_VEC Q6 Q2 Q23 32 128 *) + 0x4f2f24c6; (* arm_SRSHR_VEC Q6 Q6 17 32 128 *) + 0x4eb53458; (* arm_CMGT_VEC Q24 Q2 Q21 32 128 *) + 0x6eb694c2; (* arm_MLS_VEC Q2 Q6 Q22 32 128 *) + 0x4e781cc6; (* arm_BIC_VEC Q6 Q6 Q24 128 *) + 0x4eb88442; (* arm_ADD_VEC Q2 Q2 Q24 32 128 *) + 0x4eb7b467; (* arm_SQDMULH_VEC Q7 Q3 Q23 32 128 *) + 0x4f2f24e7; (* arm_SRSHR_VEC Q7 Q7 17 32 128 *) + 0x4eb53478; (* arm_CMGT_VEC Q24 Q3 Q21 32 128 *) + 0x6eb694e3; (* arm_MLS_VEC Q3 Q7 Q22 32 128 *) + 0x4e781ce7; (* arm_BIC_VEC Q7 Q7 Q24 128 *) + 0x4eb88463; (* arm_ADD_VEC Q3 Q3 Q24 32 128 *) + 0x4eb7b404; (* arm_SQDMULH_VEC Q4 Q0 Q23 32 128 *) + 0x4f2f2484; (* arm_SRSHR_VEC Q4 Q4 17 32 128 *) + 0x4eb53418; (* arm_CMGT_VEC Q24 Q0 Q21 32 128 *) + 0x6eb69480; (* arm_MLS_VEC Q0 Q4 Q22 32 128 *) + 0x4e781c84; (* arm_BIC_VEC Q4 Q4 Q24 128 *) + 0x4eb88400; (* arm_ADD_VEC Q0 Q0 Q24 32 128 *) + 0x3d800405; (* arm_STR Q5 X0 (Immediate_Offset (word 16)) *) + 0x3d800806; (* arm_STR Q6 X0 (Immediate_Offset (word 32)) *) + 0x3d800c07; (* arm_STR Q7 X0 (Immediate_Offset (word 48)) *) + 0x3c840404; (* arm_STR Q4 X0 (Postimmediate_Offset (word 64)) *) + 0x3d800421; (* arm_STR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3d800822; (* arm_STR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3d800c23; (* arm_STR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3c840420; (* arm_STR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fffb61; (* arm_BNE (word 2097004) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_POLY_DECOMPOSE_88_EXEC = ARM_MK_EXEC_RULE mldsa_poly_decompose_88_mc;; + +(* ========================================================================= *) +(* Constants *) +(* ========================================================================= *) + +let LENGTH_MLDSA_POLY_DECOMPOSE_88_MC = + REWRITE_CONV[mldsa_poly_decompose_88_mc] `LENGTH mldsa_poly_decompose_88_mc` + |> CONV_RULE (RAND_CONV LENGTH_CONV);; + +let MLDSA_POLY_DECOMPOSE_88_CORE_START = new_definition + `MLDSA_POLY_DECOMPOSE_88_CORE_START = 0`;; + +let MLDSA_POLY_DECOMPOSE_88_POSTAMBLE_LENGTH = new_definition + `MLDSA_POLY_DECOMPOSE_88_POSTAMBLE_LENGTH = 4`;; + +let MLDSA_POLY_DECOMPOSE_88_CORE_END = new_definition + `MLDSA_POLY_DECOMPOSE_88_CORE_END = + LENGTH mldsa_poly_decompose_88_mc - MLDSA_POLY_DECOMPOSE_88_POSTAMBLE_LENGTH`;; + +let LENGTH_SIMPLIFY_CONV = + REWRITE_CONV[LENGTH_MLDSA_POLY_DECOMPOSE_88_MC; + MLDSA_POLY_DECOMPOSE_88_CORE_START; MLDSA_POLY_DECOMPOSE_88_CORE_END; + MLDSA_POLY_DECOMPOSE_88_POSTAMBLE_LENGTH] THENC + NUM_REDUCE_CONV THENC REWRITE_CONV [ADD_0];; + +(* ========================================================================= *) +(* Word-level helper functions *) +(* Per-lane operations matching the assembly's SQDMULH+SRSHR, BIC, MLS+ADD *) +(* ========================================================================= *) + +(* h88: quotient = srshr(sqdmulh(x, magic), 17) ≈ round(x / 190464) *) +let h88 = define + `h88 (x:int32) : int32 = + iword((ival((iword_saturate:int->int32) + ((&2 * ival x * &1477838209) div &4294967296)) + + &65536) div &131072)`;; + +(* decompose88_a1: a1 = h AND (NOT mask) where mask = -1 if x > threshold *) +let decompose88_a1 = define + `decompose88_a1 (x:int32) : int32 = + word_and (h88 x) + (word_not(word_neg(word(bitval(ival x > &8285184)))))`;; + +(* decompose88_a0: a0 = (x - h*2*GAMMA2) + mask *) +let decompose88_a0 = define + `decompose88_a0 (x:int32) : int32 = + word_add (word_sub x (word_mul (h88 x) (word 190464))) + (word_neg(word(bitval(ival x > &8285184))))`;; + +(* ========================================================================= *) +(* Distribution lemmas for word_and/word_not over word_join *) +(* Needed because BIC (arm_BIC_VEC) operates at 128-bit level *) +(* ========================================================================= *) + +let WORD_AND_JOIN_64 = WORD_BLAST + `!a b c d : int32. + word_and ((word_join:int32->int32->int64) a b) + ((word_join:int32->int32->int64) c d) = + word_join (word_and a c) (word_and b d)`;; + +let WORD_AND_JOIN_128 = WORD_BLAST + `!a b c d : int64. + word_and ((word_join:int64->int64->int128) a b) + ((word_join:int64->int64->int128) c d) = + word_join (word_and a c) (word_and b d)`;; + +let WORD_NOT_JOIN_64 = WORD_BLAST + `!a b : int32. + word_not ((word_join:int32->int32->int64) a b) = + word_join (word_not a) (word_not b)`;; + +let WORD_NOT_JOIN_128 = WORD_BLAST + `!a b : int64. + word_not ((word_join:int64->int64->int128) a b) = + word_join (word_not a) (word_not b)`;; + +(* ========================================================================= *) +(* Mathematical correctness lemmas *) +(* Connect word-level decompose88_a1/a0 to spec-level decompose88 *) +(* ========================================================================= *) + +(* Case split: a1 is either h88 or 0 depending on the threshold *) +let DECOMPOSE88_A1_CASES = prove( + `!x:int32. decompose88_a1 x = + if ival x > &8285184 then word 0 else h88 x`, + REWRITE_TAC[decompose88_a1] THEN BITBLAST_TAC);; + +(* Case split: a0 subtracts 1 in the special case *) +let DECOMPOSE88_A0_CASES = prove( + `!x:int32. decompose88_a0 x = + if ival x > &8285184 + then word_sub (word_sub x (word_mul (h88 x) (word 190464))) (word 1) + else word_sub x (word_mul (h88 x) (word 190464))`, + GEN_TAC THEN REWRITE_TAC[decompose88_a0] THEN + COND_CASES_TAC THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_RULE);; + +(* ival equals val for values in the positive int32 range *) +let IVAL_EQ_VAL = prove( + `!x:int32. val x < 2 EXP 31 ==> ival x = &(val x)`, + GEN_TAC THEN REWRITE_TAC[IVAL_VAL; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + DISCH_TAC THEN + SUBGOAL_THEN `bit (32 - 1) (x:int32) = F` ASSUME_TAC THENL [ + REWRITE_TAC[BIT_VAL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_ARITH_TAC; + ASM_REWRITE_TAC[bitval] THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Barrett reduction correctness for h88 *) +(* Shows that SQDMULH+SRSHR computes round(x / 190464) correctly *) +(* ========================================================================= *) + +(* Algebraic expansion: n*K + q*E = q*D*P + r*K + where q = n DIV M_asm, r = n MOD M_asm *) +let BARRETT88_EXPAND = prove( + `!n. n * 2955676418 + (n DIV 190464) * 143360 = + (n DIV 190464) * 131072 * 4294967296 + (n MOD 190464) * 2955676418`, + GEN_TAC THEN + MP_TAC(SPECL [`n:num`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + DISCH_THEN(CONJUNCTS_THEN2 (fun th -> GEN_REWRITE_TAC (LAND_CONV o LAND_CONV o LAND_CONV) [th]) ASSUME_TAC) THEN + CONV_TAC NUM_REDUCE_CONV THEN + CONV_TAC NUM_RING);; + +(* DIV_BOUNDS_EQ: if q*d <= b < (q+1)*d then b DIV d = q *) +let DIV_BOUNDS_EQ = prove( + `!b d q. ~(d = 0) /\ q * d <= b /\ b < (q + 1) * d ==> b DIV d = q`, + REPEAT STRIP_TAC THEN MATCH_MP_TAC(ARITH_RULE `q <= r /\ r < q + 1 ==> r = q`) THEN + CONJ_TAC THENL [ + ASM_SIMP_TAC[LE_RDIV_EQ] THEN ASM_ARITH_TAC; + ASM_SIMP_TAC[RDIV_LT_EQ] THEN ASM_ARITH_TAC]);; + +(* Barrett reduction: (n*K) DIV P with rounding = round(n / M_asm) *) +let BARRETT88_CORRECT = prove( + `!n. n < 8380417 ==> + ((n * 2955676418) DIV 4294967296 + 65536) DIV 131072 = + (if n MOD 190464 * 2 <= 190464 + then n DIV 190464 + else n DIV 190464 + 1)`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `n = 8380416` THENL [ + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV; ALL_TAC] THEN + ABBREV_TAC `q = n DIV 190464` THEN + ABBREV_TAC `r = n MOD 190464` THEN + SUBGOAL_THEN `n = q * 190464 + r` ASSUME_TAC THENL [ + EXPAND_TAC "q" THEN EXPAND_TAC "r" THEN + MESON_TAC[DIVISION; ARITH_RULE `~(190464 = 0)`]; ALL_TAC] THEN + SUBGOAL_THEN `r < 190464` ASSUME_TAC THENL [ + EXPAND_TAC "r" THEN SIMP_TAC[MOD_LT_EQ] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q <= 43` ASSUME_TAC THENL [ + EXPAND_TAC "q" THEN ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(190464 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `n:num` BARRETT88_EXPAND) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + COND_CASES_TAC THENL [ + (* Round-down case: r * 2 <= 190464, so r <= 95232 *) + ABBREV_TAC `d = ((q * 190464 + r) * 2955676418) DIV 4294967296` THEN + MP_TAC(SPECL [`(q * 190464 + r) * 2955676418`; `4294967296`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + SUBGOAL_THEN `d * 4294967296 + q * 143360 <= q * 131072 * 4294967296 + r * 2955676418` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 131072 * 4294967296 + r * 2955676418 < (d + 1) * 4294967296 + q * 143360` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `r * 2955676418 <= 95232 * 2955676418` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN CONJ_TAC THENL [ + MP_TAC(ARITH_RULE `95232 * 2955676418 < 65536 * 4294967296`) THEN ASM_ARITH_TAC; + MP_TAC(ARITH_RULE `95232 * 2955676418 < 65536 * 4294967296`) THEN ASM_ARITH_TAC]; + (* Round-up case: ~(r * 2 <= 190464), so r >= 95233 *) + ABBREV_TAC `d = ((q * 190464 + r) * 2955676418) DIV 4294967296` THEN + MP_TAC(SPECL [`(q * 190464 + r) * 2955676418`; `4294967296`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN + SUBGOAL_THEN `d * 4294967296 + q * 143360 <= q * 131072 * 4294967296 + r * 2955676418` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 131072 * 4294967296 + r * 2955676418 < (d + 1) * 4294967296 + q * 143360` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `95233 * 2955676418 <= r * 2955676418` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `r * 2955676418 < 190464 * 2955676418` ASSUME_TAC THENL [ + REWRITE_TAC[LT_MULT_RCANCEL] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `q * 143360 <= 43 * 143360` ASSUME_TAC THENL [ + MATCH_MP_TAC LE_MULT2 THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN CONJ_TAC THENL [ + MP_TAC(ARITH_RULE `65536 * 4294967296 + 43 * 143360 <= 95233 * 2955676418`) THEN + ASM_ARITH_TAC; + MP_TAC(ARITH_RULE `190464 * 2955676418 < 131072 * 4294967296`) THEN + ASM_ARITH_TAC]]);; + +(* ========================================================================= *) +(* Word-level to natural number connection for h88 *) +(* ========================================================================= *) + +(* h88 computes the correct rounding quotient: round(val x / 190464) + Eliminates iword_saturate by inlining its definition and using BOUNDER_TAC + to discharge the impossible saturation cases (consistent with mlkem-native). *) +let H88_CORRECT = prove( + `!x:int32. val x < 8380417 ==> + val(h88 x) = (if val x MOD 190464 * 2 <= 190464 + then val x DIV 190464 + else val x DIV 190464 + 1)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[h88; iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REPEAT(COND_CASES_TAC THENL + [FIRST_X_ASSUM(MATCH_MP_TAC o MATCH_MP (MESON[] `p ==> ~p ==> q`)) THEN + REWRITE_TAC[INT_GT; INT_NOT_LT] THEN BOUNDER_TAC[]; + ASM_REWRITE_TAC[]]) THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[INT_OF_NUM_MUL] THEN + SUBGOAL_THEN `2 * val(x:int32) * 1477838209 = val x * 2955676418` SUBST1_TAC THENL [ + ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV] THEN + SUBGOAL_THEN `(val(x:int32) * 2955676418) DIV 4294967296 < 2147483648` ASSUME_TAC THENL [ + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(4294967296 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `ival(iword(&((val(x:int32) * 2955676418) DIV 4294967296)):int32) = + &((val x * 2955676418) DIV 4294967296)` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + REWRITE_TAC[INT_OF_NUM_LT; INT_LE_NEG2; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_ADD; INT_OF_NUM_DIV] THEN + SUBGOAL_THEN `((val(x:int32) * 2955676418) DIV 4294967296 + 65536) DIV 131072 < 2147483648` ASSUME_TAC THENL [ + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(131072 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[GSYM WORD_IWORD; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_EXP_CONV) THEN + ASM_SIMP_TAC[MOD_LT; ARITH_RULE `n < 2147483648 ==> n < 4294967296`] THEN + MATCH_MP_TAC BARRETT88_CORRECT THEN ASM_REWRITE_TAC[]);; + +(* Special case: rounding quotient = 44 when val x > 8285184 *) +let ROUND88_SPECIAL = prove( + `!n. 8285184 < n /\ n < 8380417 ==> + (if n MOD 190464 * 2 <= 190464 then n DIV 190464 else n DIV 190464 + 1) = 44`, + REPEAT STRIP_TAC THEN + ASM_CASES_TAC `n < 8380416` THENL [ + SUBGOAL_THEN `n DIV 190464 = 43` ASSUME_TAC THENL [ + MATCH_MP_TAC DIV_BOUNDS_EQ THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THENL [ + MP_TAC(SPECL [`n:num`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[]]; + SUBGOAL_THEN `n = 8380416` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV]);; + +(* ========================================================================= *) +(* Main correctness lemmas: connect word-level to spec-level *) +(* ========================================================================= *) + +(* decompose88_a1 computes FST(decompose88(val x)) *) +let DECOMPOSE88_A1_CORRECT = prove( + `!x:int32. val x < 8380417 + ==> val(decompose88_a1 x) = FST(decompose88(val x))`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE88_A1_CASES; DECOMPOSE88_EXPAND; LET_DEF; LET_END_DEF; FST] THEN + COND_CASES_TAC THENL [ + (* ival x > &8285184: a1 = word 0, h = 44, FST = 0 *) + REWRITE_TAC[VAL_WORD_0; FST] THEN + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `&(val(x:int32)):int > &8285184` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_GT; GT] THEN DISCH_TAC THEN + MP_TAC(SPEC `val(x:int32)` ROUND88_SPECIAL) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[FST]; + (* ~(ival x > &8285184): a1 = h88 x, h < 44 *) + MP_TAC(SPEC `x:int32` H88_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + COND_CASES_TAC THENL [ + (* Round-down case *) + SUBGOAL_THEN `val(x:int32) <= 8285184` ASSUME_TAC THENL [ + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `~(&(val(x:int32)):int > &8285184)` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_GT; INT_NOT_LT; INT_OF_NUM_LE]; ALL_TAC] THEN + SUBGOAL_THEN `~(val(x:int32) DIV 190464 = 44)` ASSUME_TAC THENL [ + DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[FST]; + (* Round-up case *) + SUBGOAL_THEN `val(x:int32) <= 8285184` ASSUME_TAC THENL [ + SUBGOAL_THEN `val(x:int32) < 2 EXP 31` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `x:int32` IVAL_EQ_VAL) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `~(&(val(x:int32)):int > &8285184)` MP_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + REWRITE_TAC[INT_GT; INT_NOT_LT; INT_OF_NUM_LE]; ALL_TAC] THEN + SUBGOAL_THEN `~(val(x:int32) DIV 190464 + 1 = 44)` ASSUME_TAC THENL [ + REWRITE_TAC[ARITH_RULE `n + 1 = 44 <=> n = 43`] THEN DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[FST]]]);; + +(* cmod n 190464 is bounded by [-95232, 95232], well within int32 range *) +let CMOD_ABS_BOUND_190464 = prove( + `!n. abs(cmod n 190464) <= &95232`, + GEN_TAC THEN REWRITE_TAC[cmod] THEN + SUBGOAL_THEN `n MOD 190464 < 190464` MP_TAC THENL [ + SIMP_TAC[MOD_LT_EQ; ARITH_RULE `~(190464 = 0)`]; ALL_TAC] THEN + SPEC_TAC(`n MOD 190464`, `m:num`) THEN GEN_TAC THEN DISCH_TAC THEN + COND_CASES_TAC THEN + REWRITE_TAC[INT_ABS; INT_POS; INT_OF_NUM_LE; + INT_OF_NUM_SUB; INT_SUB_LE; INT_NEG_SUB] THEN + ASM_ARITH_TAC);; + +(* decompose88_a0 computes SND(decompose88(val x)) *) +let DECOMPOSE88_A0_CORRECT = prove( + `!x:int32. val x < 8380417 + ==> ival(decompose88_a0 x) = SND(decompose88(val x))`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE88_A0_CASES; DECOMPOSE88_EXPAND; LET_DEF; LET_END_DEF; SND] THEN + (* Express word_sub x (word_mul (h88 x) (word 190464)) as iword(...) *) + SUBGOAL_THEN `word_sub x (word_mul (h88 x) (word 190464)) : int32 = + iword(ival x - ival(h88 x) * &190464)` SUBST1_TAC THENL [ + CONV_TAC WORD_RULE; ALL_TAC] THEN + (* Convert ival x and ival(h88 x) to val-based expressions *) + SUBGOAL_THEN `ival(x:int32) = &(val x)` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_EQ_VAL THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `ival(h88 x:int32) = &(val(h88 x))` SUBST1_TAC THENL [ + MATCH_MP_TAC IVAL_EQ_VAL THEN + MP_TAC(SPEC `x:int32` H88_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(190464 = 0)`] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + (* Substitute h88 value using H88_CORRECT *) + MP_TAC(SPEC `x:int32` H88_CORRECT) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[INT_OF_NUM_GT] THEN + ABBREV_TAC `h = (if val(x:int32) MOD 190464 * 2 <= 190464 + then val x DIV 190464 else val x DIV 190464 + 1)` THEN + (* Establish DIVISION identity in int form *) + SUBGOAL_THEN `&(val(x:int32)):int = + &(val x DIV 190464) * &190464 + &(val x MOD 190464)` ASSUME_TAC THENL [ + MP_TAC(SPECL [`val(x:int32)`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + DISCH_THEN(MP_TAC o AP_TERM `int_of_num` o CONJUNCT1) THEN + REWRITE_TAC[INT_OF_NUM_MUL; INT_OF_NUM_ADD]; ALL_TAC] THEN + (* Prove key identity: val x - h * 190464 = cmod(val x) 190464 *) + SUBGOAL_THEN `&(val(x:int32)) - &h * &190464 = cmod (val x) 190464` + ASSUME_TAC THENL [ + REWRITE_TAC[cmod] THEN + FIRST_X_ASSUM(MP_TAC o SYM o check (fun th -> + fst(dest_cond(fst(dest_eq(concl th)))) = + `val (x:int32) MOD 190464 * 2 <= 190464`)) THEN + COND_CASES_TAC THENL [ + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN INT_ARITH_TAC; + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[GSYM INT_OF_NUM_ADD; + GSYM INT_OF_NUM_MUL] THEN INT_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + (* Case split on val x > 8285184 *) + COND_CASES_TAC THENL [ + (* Special case: val x > 8285184, h = 44 *) + SUBGOAL_THEN `h = 44` SUBST1_TAC THENL [ + MP_TAC(SPEC `val(x:int32)` ROUND88_SPECIAL) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; + ALL_TAC] THEN + REWRITE_TAC[SND] THEN + SUBGOAL_THEN `word_sub (iword(cmod (val(x:int32)) 190464)) (word 1) : int32 = + iword(cmod (val x) 190464 - &1)` SUBST1_TAC THENL [ + REWRITE_TAC[GSYM IWORD_INT_SUB; WORD_IWORD]; ALL_TAC] THEN + MATCH_MP_TAC(INST_TYPE [`:32`,`:N`] IVAL_IWORD) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `val(x:int32)` CMOD_ABS_BOUND_190464) THEN INT_ARITH_TAC; + (* Normal case: ~(val x > 8285184), h != 44 *) + SUBGOAL_THEN `~(h = 44)` ASSUME_TAC THENL [ + DISCH_TAC THEN + SUBGOAL_THEN `val(x:int32) <= 8285184` ASSUME_TAC THENL [ + ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(if val(x:int32) MOD 190464 * 2 <= 190464 + then val x DIV 190464 else val x DIV 190464 + 1) = 44` MP_TAC THENL [ + ASM_MESON_TAC[]; ALL_TAC] THEN + COND_CASES_TAC THENL [ + DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[ARITH_RULE `n + 1 = 44 <=> n = 43`] THEN DISCH_TAC THEN + MP_TAC(SPECL [`val(x:int32)`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + STRIP_TAC THEN ASM_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[SND] THEN + MATCH_MP_TAC(INST_TYPE [`:32`,`:N`] IVAL_IWORD) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `val(x:int32)` CMOD_ABS_BOUND_190464) THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Specification *) +(* ========================================================================= *) + +let MLDSA_POLY_DECOMPOSE_88_CORRECT = prove( + `!pc a r1 x. + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_88_mc) + (r1, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_88_mc) + (a, 1024) /\ + nonoverlapping (r1, 1024) (a, 1024) + ==> (!i. i < 256 ==> val(x i:int32) < 8380417) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_decompose_88_mc /\ + read PC s = word(pc + MLDSA_POLY_DECOMPOSE_88_CORE_START) /\ + C_ARGUMENTS [r1; a] s /\ + (!i. i < 256 + ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = + x i)) + (\s. read PC s = word(pc + MLDSA_POLY_DECOMPOSE_88_CORE_END) /\ + (!i. i < 256 + ==> val(read(memory :> bytes32 + (word_add r1 (word(4 * i)))) s) = + FST(decompose88(val(x i)))) /\ + (!i. i < 256 + ==> ival(read(memory :> bytes32 + (word_add a (word(4 * i)))) s) = + SND(decompose88(val(x i))))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r1, 1024)] ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + MAP_EVERY X_GEN_TAC [`pc:num`; `a:int64`; `r1:int64`; `x:num->int32`] THEN + REWRITE_TAC[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + fst MLDSA_POLY_DECOMPOSE_88_EXEC; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] THEN + STRIP_TAC THEN STRIP_TAC THEN + + (* Expand the quantified input condition to individual coefficients *) + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV + (EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV)))) THEN + + ENSURES_INIT_TAC "s0" THEN + + (* Merge 4x32-bit coefficient reads into 128-bit vector reads *) + MP_TAC(end_itlist CONJ (map (fun n -> READ_MEMORY_MERGE_CONV 2 + (subst[mk_small_numeral(16*n),`n:num`] + `read (memory :> bytes128(word_add a (word n))) s0`)) + (0--63))) THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + STRIP_TAC THEN + + RULE_ASSUM_TAC(REWRITE_RULE[ADD_CLAUSES]) THEN + + (* Symbolic execution with folding to decompose88_a1/a0 *) + MAP_UNTIL_TARGET_PC (fun n -> + ARM_STEPS_TAC MLDSA_POLY_DECOMPOSE_88_EXEC [n] THEN + RULE_ASSUM_TAC(CONV_RULE( + TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV THENC + ONCE_REWRITE_CONV [GSYM h88] THENC + REWRITE_CONV [WORD_NOT_JOIN_128; WORD_NOT_JOIN_64; + WORD_AND_JOIN_128; WORD_AND_JOIN_64] THENC + ONCE_REWRITE_CONV [GSYM decompose88_a1] THENC + ONCE_REWRITE_CONV [GSYM decompose88_a0]))) 1 THEN + + (* Establish postcondition *) + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (* Split bytes128 results back into bytes32 *) + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + RULE_ASSUM_TAC(CONV_RULE(RAND_CONV( + TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV))) THEN + + CONV_TAC(ONCE_DEPTH_CONV EXPAND_CASES_CONV THENC + ONCE_DEPTH_CONV NUM_MULT_CONV) THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN + + (* Apply mathematical correctness lemmas *) + REPEAT CONJ_TAC THEN + (MATCH_MP_TAC DECOMPOSE88_A1_CORRECT ORELSE + MATCH_MP_TAC DECOMPOSE88_A0_CORRECT) THEN + FIRST_ASSUM MATCH_MP_TAC THEN + CONV_TAC NUM_REDUCE_CONV);; + +(* ========================================================================= *) +(* Subroutine form: wraps CORRECT with RET handling *) +(* ========================================================================= *) + +let MLDSA_POLY_DECOMPOSE_88_SUBROUTINE_CORRECT = prove( + `!pc a r1 x returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_88_mc) + (r1, 1024) /\ + nonoverlapping (word pc, LENGTH mldsa_poly_decompose_88_mc) + (a, 1024) /\ + nonoverlapping (r1, 1024) (a, 1024) + ==> (!i. i < 256 ==> val(x i:int32) < 8380417) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_decompose_88_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [r1; a] s /\ + (!i. i < 256 + ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = + x i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 + ==> val(read(memory :> bytes32 + (word_add r1 (word(4 * i)))) s) = + FST(decompose88(val(x i)))) /\ + (!i. i < 256 + ==> ival(read(memory :> bytes32 + (word_add a (word(4 * i)))) s) = + SND(decompose88(val(x i)))) /\ + (!i. i < 256 + ==> FST(decompose88(val(x i))) <= 43) /\ + (!i. i < 256 + ==> --(&95232) <= SND(decompose88(val(x i))) /\ + SND(decompose88(val(x i))) <= &95232)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r1, 1024)] ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + REWRITE_TAC[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + fst MLDSA_POLY_DECOMPOSE_88_EXEC; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC(!simulation_precanon_thms) THEN + ENSURES_INIT_TAC "s0" THEN + MP_TAC(REWRITE_RULE[NONOVERLAPPING_CLAUSES; C_ARGUMENTS; SOME_FLAGS; + MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI] + (SPECL [`pc:num`; `a:int64`; `r1:int64`; `x:num->int32`] + (CONV_RULE LENGTH_SIMPLIFY_CONV MLDSA_POLY_DECOMPOSE_88_CORRECT))) THEN + ANTS_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ANTS_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + ARM_BIGSTEP_TAC MLDSA_POLY_DECOMPOSE_88_EXEC "s1" THEN + ARM_STEPS_TAC MLDSA_POLY_DECOMPOSE_88_EXEC [2] THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + CONJ_TAC THENL [ + REPEAT STRIP_TAC THEN MATCH_MP_TAC DECOMPOSE88_A1_BOUND THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]; + GEN_TAC THEN DISCH_TAC THEN MATCH_MP_TAC DECOMPOSE88_A0_BOUND THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml new file mode 100644 index 000000000..289cfe4ca --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_17.ml @@ -0,0 +1,257 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Functional correctness of polyz_unpack_17: *) +(* Unpack polynomial z with 18-bit packed coefficients (GAMMA1 = 2^17) *) +(* Maps packed [0, 2^18-1] to signed [-(2^17-1), 2^17] via GAMMA1 - x *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; +needs "aarch64/proofs/mldsa_polyz_unpack_consts.ml";; + +(**** print_literal_from_elf "aarch64/mldsa/mldsa_polyz_unpack_17.o";; + ****) + +let mldsa_polyz_unpack_17_mc = define_assert_from_elf + "mldsa_polyz_unpack_17_mc" "aarch64/mldsa/mldsa_polyz_unpack_17.o" +(*** BYTECODE START ***) +[ + 0x3dc00058; (* arm_LDR Q24 X2 (Immediate_Offset (word 0)) *) + 0x3dc00459; (* arm_LDR Q25 X2 (Immediate_Offset (word 16)) *) + 0x3dc0085a; (* arm_LDR Q26 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c5b; (* arm_LDR Q27 X2 (Immediate_Offset (word 48)) *) + 0xd2c01fc3; (* arm_MOVZ X3 (word 254) 32 *) + 0x4e081c7c; (* arm_INS_GEN Q28 X3 0 64 *) + 0xd2801f83; (* arm_MOV X3 (rvalue (word 252)) *) + 0xf2c01f43; (* arm_MOVK X3 (word 250) 32 *) + 0x4e181c7c; (* arm_INS_GEN Q28 X3 64 64 *) + 0x4f00d47d; (* arm_MOVI Q29 (word 1125895612137471) *) + 0x4f00445e; (* arm_MOVI Q30 (word 562949953552384) *) + 0xd2800209; (* arm_MOV X9 (rvalue (word 16)) *) + 0x4cdfa820; (* arm_LDP Q0 Q1 X1 (Postimmediate_Offset (word 32)) *) + 0xbc404422; (* arm_LDR S2 X1 (Postimmediate_Offset (word 4)) *) + 0x4e180004; (* arm_TBL Q4 [Q0] Q24 128 *) + 0x4e192005; (* arm_TBL2 Q5 Q0 Q1 Q25 128 *) + 0x4e1a0026; (* arm_TBL Q6 [Q1] Q26 128 *) + 0x4e1b2027; (* arm_TBL2 Q7 Q1 Q2 Q27 128 *) + 0x6ebc4484; (* arm_USHL_VEC Q4 Q4 Q28 32 128 *) + 0x4e3d1c84; (* arm_AND_VEC Q4 Q4 Q29 128 *) + 0x6ea487c4; (* arm_SUB_VEC Q4 Q30 Q4 32 128 *) + 0x6ebc44a5; (* arm_USHL_VEC Q5 Q5 Q28 32 128 *) + 0x4e3d1ca5; (* arm_AND_VEC Q5 Q5 Q29 128 *) + 0x6ea587c5; (* arm_SUB_VEC Q5 Q30 Q5 32 128 *) + 0x6ebc44c6; (* arm_USHL_VEC Q6 Q6 Q28 32 128 *) + 0x4e3d1cc6; (* arm_AND_VEC Q6 Q6 Q29 128 *) + 0x6ea687c6; (* arm_SUB_VEC Q6 Q30 Q6 32 128 *) + 0x6ebc44e7; (* arm_USHL_VEC Q7 Q7 Q28 32 128 *) + 0x4e3d1ce7; (* arm_AND_VEC Q7 Q7 Q29 128 *) + 0x6ea787c7; (* arm_SUB_VEC Q7 Q30 Q7 32 128 *) + 0x3d800405; (* arm_STR Q5 X0 (Immediate_Offset (word 16)) *) + 0x3d800806; (* arm_STR Q6 X0 (Immediate_Offset (word 32)) *) + 0x3d800c07; (* arm_STR Q7 X0 (Immediate_Offset (word 48)) *) + 0x3c840404; (* arm_STR Q4 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000529; (* arm_SUBS X9 X9 (rvalue (word 1)) *) + 0x54fffd21; (* arm_BNE (word 2097060) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_POLYZ_UNPACK_17_EXEC = ARM_MK_EXEC_RULE mldsa_polyz_unpack_17_mc;; + +(* ------------------------------------------------------------------------- *) +(* Code length constants *) +(* ------------------------------------------------------------------------- *) + +let LENGTH_MLDSA_POLYZ_UNPACK_17_MC = + REWRITE_CONV[mldsa_polyz_unpack_17_mc] `LENGTH mldsa_polyz_unpack_17_mc` + |> CONV_RULE (RAND_CONV LENGTH_CONV);; + +let MLDSA_POLYZ_UNPACK_17_PREAMBLE_LENGTH = new_definition + `MLDSA_POLYZ_UNPACK_17_PREAMBLE_LENGTH = 0`;; + +let MLDSA_POLYZ_UNPACK_17_POSTAMBLE_LENGTH = new_definition + `MLDSA_POLYZ_UNPACK_17_POSTAMBLE_LENGTH = 4`;; + +let MLDSA_POLYZ_UNPACK_17_CORE_START = new_definition + `MLDSA_POLYZ_UNPACK_17_CORE_START = MLDSA_POLYZ_UNPACK_17_PREAMBLE_LENGTH`;; + +let MLDSA_POLYZ_UNPACK_17_CORE_END = new_definition + `MLDSA_POLYZ_UNPACK_17_CORE_END = + LENGTH mldsa_polyz_unpack_17_mc - MLDSA_POLYZ_UNPACK_17_POSTAMBLE_LENGTH`;; + +let LENGTH_SIMPLIFY_CONV = + REWRITE_CONV[LENGTH_MLDSA_POLYZ_UNPACK_17_MC; + MLDSA_POLYZ_UNPACK_17_CORE_START; MLDSA_POLYZ_UNPACK_17_CORE_END; + MLDSA_POLYZ_UNPACK_17_PREAMBLE_LENGTH; + MLDSA_POLYZ_UNPACK_17_POSTAMBLE_LENGTH] THENC + NUM_REDUCE_CONV THENC REWRITE_CONV [ADD_0];; + +(* ------------------------------------------------------------------------- *) +(* D=18 instantiations for SIMD infrastructure *) +(* ------------------------------------------------------------------------- *) + +let BASE_SIMPS_D18 = mk_base_simps 18;; +let NUM_OF_WORDLIST_SPLIT_18_256 = mk_split_theorem 18 256 16;; +let READ_MEMORY_WBYTES_SPLIT_128_128_32 = prove + (`t < 2 EXP 288 + ==> (read (memory :> wbytes a) (s:armstate) = (word t : 288 word) <=> + read (memory :> bytes128 a) s = (word (t MOD 2 EXP 128) : int128) /\ + read (memory :> bytes128 (word_add a (word 16))) s = + (word ((t DIV 2 EXP 128) MOD 2 EXP 128) : int128) /\ + read (memory :> bytes32 (word_add a (word 32))) s = + (word (t DIV 2 EXP 256) : int32))`, + let split_16_20 = CONV_RULE (ONCE_DEPTH_CONV NUM_ADD_CONV THENC + DEPTH_CONV NUM_MULT_CONV) + (INST [`16`,`k:num`; `20`,`l:num`] READ_BYTES_SPLIT_ANY) in + let split_16_4 = CONV_RULE (ONCE_DEPTH_CONV NUM_ADD_CONV THENC + DEPTH_CONV NUM_MULT_CONV) + (INST [`16`,`k:num`; `4`,`l:num`] READ_BYTES_SPLIT_ANY) in + STRIP_TAC THEN + REWRITE_TAC[BYTES128_WBYTES; BYTES32_WBYTES; GSYM VAL_EQ; + VAL_READ_WBYTES; READ_COMPONENT_COMPOSE] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[split_16_20] THEN REWRITE_TAC[split_16_4] THEN + REWRITE_TAC[WORD_ADD_ASSOC_CONSTS] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[DIV_DIV; GSYM EXP_ADD] THEN CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[VAL_WORD_EXACT] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN ASM_ARITH_TAC);; +let WORD_SUBWORD_NUM_OF_WORDLIST_CASES_D18 = mk_subword_cases 18 16;; + +(* ------------------------------------------------------------------------- *) +(* Core correctness theorem *) +(* ------------------------------------------------------------------------- *) + +let MLDSA_POLYZ_UNPACK_17_CORRECT = prove + (`!r b t l pc. + LENGTH l = 256 /\ + nonoverlapping (word pc,LENGTH mldsa_polyz_unpack_17_mc) (r,1024) /\ + nonoverlapping (b,576) (r,1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_polyz_unpack_17_mc /\ + read PC s = word (pc + MLDSA_POLYZ_UNPACK_17_CORE_START) /\ + C_ARGUMENTS [r; b; t] s /\ + read(memory :> bytes(t,64)) s = + num_of_wordlist mldsa_polyz_unpack_17_indices /\ + read(memory :> bytes(b,576)) s = num_of_wordlist l) + (\s. read PC s = word(pc + MLDSA_POLYZ_UNPACK_17_CORE_END) /\ + read(memory :> bytes(r,1024)) s = + num_of_wordlist (MAP zunpack17 l)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r,1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + MAP_EVERY X_GEN_TAC [`r:int64`; `b:int64`; `t:int64`; + `l:(18 word) list`; `pc:num`] THEN + REWRITE_TAC[C_ARGUMENTS; MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; + NONOVERLAPPING_CLAUSES] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + + (*** Ghost Q28: constructed via two INS_GEN partial writes ***) + GHOST_INTRO_TAC `q28_init:int128` `read Q28` THEN + + ENSURES_INIT_TAC "s0" THEN + + (*** Expand table precondition into 4 x bytes128 reads ***) + FIRST_X_ASSUM(MP_TAC o check (can (term_match [] + `read(memory :> bytes(t:int64,64)) s = x`) o concl)) THEN + REWRITE_TAC[mldsa_polyz_unpack_17_indices] THEN + REPLICATE_TAC 4 + (GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) + [GSYM NUM_OF_PAIR_WORDLIST]) THEN + REWRITE_TAC[pair_wordlist] THEN + CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(LAND_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + REWRITE_TAC[GSYM BYTES128_WBYTES] THEN + STRIP_TAC THEN + + (*** Split 256 18-bit coefficients into 16 chunks of 16 as 288-bit words ***) + UNDISCH_TAC `read(memory :> bytes(b,576)) s0 = num_of_wordlist(l:(18 word) list)` THEN + IMP_REWRITE_TAC [NUM_OF_WORDLIST_SPLIT_18_256] THEN + CONV_TAC (ONCE_DEPTH_CONV LIST_OF_SEQ_CONV) THEN + REWRITE_TAC [MAP; o_DEF] THEN + CONV_TAC(LAND_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + + (*** Split each 288-bit wbytes into bytes128 + bytes128 + bytes32 ***) + IMP_REWRITE_TAC [READ_MEMORY_WBYTES_SPLIT_128_128_32] THEN + MAP_EVERY (fun n -> SUBGOAL_THEN (subst[mk_small_numeral n,`k:num`] + `num_of_wordlist (SUB_LIST (16 * k,16) (l : (18 word) list)) < 2 EXP 288`) + (fun th -> REWRITE_TAC[th]) THENL [ + TRANS_TAC LTE_TRANS (subst[mk_small_numeral n,`k:num`] + `2 EXP (dimindex(:18) * LENGTH(SUB_LIST(16*k,16) (l : (18 word) list)))`) THEN + REWRITE_TAC[NUM_OF_WORDLIST_BOUND] THEN + REWRITE_TAC[LENGTH_SUB_LIST; DIMINDEX_CONV `dimindex (:18)`] THEN + ASM_SIMP_TAC [] THEN NUM_REDUCE_TAC; + ALL_TAC]) (0--15) THEN + REWRITE_TAC [WORD_ADD_ASSOC_CONSTS] THEN CONV_TAC (TOP_SWEEP_CONV NUM_ADD_CONV) THEN + STRIP_TAC THEN + + (*** Gather LENGTH assumptions for sublists ***) + MAP_EVERY (fun i -> SUBGOAL_THEN + (subst [mk_small_numeral (16 * i), `i: num`] + `LENGTH (SUB_LIST (i, 16) (l : (18 word) list)) = 16`) ASSUME_TAC + THENL [ASM_REWRITE_TAC [LENGTH_SUB_LIST] THEN NUM_REDUCE_TAC; ALL_TAC]) + (0 -- 15) THEN + + (*** Symbolic execution with per-step simplification ***) + MAP_UNTIL_TARGET_PC (fun n -> + ARM_STEPS_TAC MLDSA_POLYZ_UNPACK_17_EXEC [n] THEN + SIMD_SIMPLIFY_TAC (map GSYM BASE_SIMPS_D18) THEN + SIMP_ZUNPACK_TAC 18 ZUNPACK17_CORRECT) 1 THEN + + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (*** Fold output back to MAP zunpack17 l ***) + REPEAT (FIRST_X_ASSUM(MP_TAC o check + (can (term_match [] `read (memory :> bytes128 r) s0 = xxx`) o concl))) THEN + TRY (IMP_REWRITE_TAC WORD_SUBWORD_NUM_OF_WORDLIST_CASES_D18) THEN + UNDISCH_THEN `LENGTH (l : (18 word) list) = 256` + (fun th -> CONV_TAC (TOP_SWEEP_CONV (EL_SUB_LIST_CONV th)) THEN ASSUME_TAC th) THEN + REPEAT DISCH_TAC THEN + GEN_REWRITE_TAC (RAND_CONV o RAND_CONV o RAND_CONV) [GSYM LIST_OF_SEQ_EQ_SELF] THEN + ASM_REWRITE_TAC[LENGTH_MAP] THEN + CONV_TAC (TOP_SWEEP_CONV LIST_OF_SEQ_CONV) THEN + ASM_REWRITE_TAC [MAP] THEN + REPLICATE_TAC 2 (CONV_TAC (ONCE_REWRITE_CONV [GSYM NUM_OF_PAIR_WORDLIST])) THEN + REWRITE_TAC[pair_wordlist] THEN + CONV_TAC (ONCE_DEPTH_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + ASM_REWRITE_TAC[GSYM BYTES128_WBYTES]);; + +(* ------------------------------------------------------------------------- *) +(* Subroutine correctness *) +(* ------------------------------------------------------------------------- *) + +(* NOTE: This must be kept in sync with the CBMC specification + * in mldsa/src/native/aarch64/src/arith_native_aarch64.h *) + +let MLDSA_POLYZ_UNPACK_17_SUBROUTINE_CORRECT = prove + (`!r b t l pc returnaddress. + LENGTH l = 256 /\ + nonoverlapping (word pc,LENGTH mldsa_polyz_unpack_17_mc) (r,1024) /\ + nonoverlapping (b,576) (r,1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_polyz_unpack_17_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [r; b; t] s /\ + read(memory :> bytes(t,64)) s = + num_of_wordlist mldsa_polyz_unpack_17_indices /\ + read(memory :> bytes(b,576)) s = num_of_wordlist l) + (\s. read PC s = returnaddress /\ + read(memory :> bytes(r,1024)) s = + num_of_wordlist (MAP zunpack17 l) /\ + (!i. i < 256 ==> + --(&(2 EXP 17) - &1) <= ival(EL i (MAP zunpack17 l)) /\ + ival(EL i (MAP zunpack17 l)) <= &(2 EXP 17))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r,1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC MLDSA_POLYZ_UNPACK_17_EXEC + (CONV_RULE LENGTH_SIMPLIFY_CONV MLDSA_POLYZ_UNPACK_17_CORRECT) THEN + REPEAT STRIP_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (ISPECL [`l:(18 word) list`; `i:num`] ZUNPACK17_MAP_BOUND)) THEN + ASM_REWRITE_TAC[] THEN SIMP_TAC[]);; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml new file mode 100644 index 000000000..cef19dce8 --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_19.ml @@ -0,0 +1,251 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Functional correctness of polyz_unpack_19: *) +(* Unpack polynomial z with 20-bit packed coefficients (GAMMA1 = 2^19) *) +(* Maps packed [0, 2^20-1] to signed [-(2^19-1), 2^19] via GAMMA1 - x *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; +needs "aarch64/proofs/mldsa_polyz_unpack_consts.ml";; + +(**** print_literal_from_elf "aarch64/mldsa/mldsa_polyz_unpack_19.o";; + ****) + +let mldsa_polyz_unpack_19_mc = define_assert_from_elf + "mldsa_polyz_unpack_19_mc" "aarch64/mldsa/mldsa_polyz_unpack_19.o" +(*** BYTECODE START ***) +[ + 0x3dc00058; (* arm_LDR Q24 X2 (Immediate_Offset (word 0)) *) + 0x3dc00459; (* arm_LDR Q25 X2 (Immediate_Offset (word 16)) *) + 0x3dc0085a; (* arm_LDR Q26 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c5b; (* arm_LDR Q27 X2 (Immediate_Offset (word 48)) *) + 0xd2c01f83; (* arm_MOVZ X3 (word 252) 32 *) + 0x4e080c7c; (* arm_DUP_GEN Q28 X3 64 128 *) + 0x4f00d5fd; (* arm_MOVI Q29 (word 4503595333451775) *) + 0x4f00451e; (* arm_MOVI Q30 (word 2251799814209536) *) + 0xd2800209; (* arm_MOV X9 (rvalue (word 16)) *) + 0x4cdfac20; (* arm_LDP Q0 Q1 X1 (Postimmediate_Offset (word 32)) *) + 0xfc408422; (* arm_LDR D2 X1 (Postimmediate_Offset (word 8)) *) + 0x4e180004; (* arm_TBL Q4 [Q0] Q24 128 *) + 0x4e192005; (* arm_TBL2 Q5 Q0 Q1 Q25 128 *) + 0x4e1a0026; (* arm_TBL Q6 [Q1] Q26 128 *) + 0x4e1b2027; (* arm_TBL2 Q7 Q1 Q2 Q27 128 *) + 0x6ebc4484; (* arm_USHL_VEC Q4 Q4 Q28 32 128 *) + 0x4e3d1c84; (* arm_AND_VEC Q4 Q4 Q29 128 *) + 0x6ea487c4; (* arm_SUB_VEC Q4 Q30 Q4 32 128 *) + 0x6ebc44a5; (* arm_USHL_VEC Q5 Q5 Q28 32 128 *) + 0x4e3d1ca5; (* arm_AND_VEC Q5 Q5 Q29 128 *) + 0x6ea587c5; (* arm_SUB_VEC Q5 Q30 Q5 32 128 *) + 0x6ebc44c6; (* arm_USHL_VEC Q6 Q6 Q28 32 128 *) + 0x4e3d1cc6; (* arm_AND_VEC Q6 Q6 Q29 128 *) + 0x6ea687c6; (* arm_SUB_VEC Q6 Q30 Q6 32 128 *) + 0x6ebc44e7; (* arm_USHL_VEC Q7 Q7 Q28 32 128 *) + 0x4e3d1ce7; (* arm_AND_VEC Q7 Q7 Q29 128 *) + 0x6ea787c7; (* arm_SUB_VEC Q7 Q30 Q7 32 128 *) + 0x3d800405; (* arm_STR Q5 X0 (Immediate_Offset (word 16)) *) + 0x3d800806; (* arm_STR Q6 X0 (Immediate_Offset (word 32)) *) + 0x3d800c07; (* arm_STR Q7 X0 (Immediate_Offset (word 48)) *) + 0x3c840404; (* arm_STR Q4 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000529; (* arm_SUBS X9 X9 (rvalue (word 1)) *) + 0x54fffd21; (* arm_BNE (word 2097060) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_POLYZ_UNPACK_19_EXEC = ARM_MK_EXEC_RULE mldsa_polyz_unpack_19_mc;; + +(* ------------------------------------------------------------------------- *) +(* Code length constants *) +(* ------------------------------------------------------------------------- *) + +let LENGTH_MLDSA_POLYZ_UNPACK_19_MC = + REWRITE_CONV[mldsa_polyz_unpack_19_mc] `LENGTH mldsa_polyz_unpack_19_mc` + |> CONV_RULE (RAND_CONV LENGTH_CONV);; + +let MLDSA_POLYZ_UNPACK_19_PREAMBLE_LENGTH = new_definition + `MLDSA_POLYZ_UNPACK_19_PREAMBLE_LENGTH = 0`;; + +let MLDSA_POLYZ_UNPACK_19_POSTAMBLE_LENGTH = new_definition + `MLDSA_POLYZ_UNPACK_19_POSTAMBLE_LENGTH = 4`;; + +let MLDSA_POLYZ_UNPACK_19_CORE_START = new_definition + `MLDSA_POLYZ_UNPACK_19_CORE_START = MLDSA_POLYZ_UNPACK_19_PREAMBLE_LENGTH`;; + +let MLDSA_POLYZ_UNPACK_19_CORE_END = new_definition + `MLDSA_POLYZ_UNPACK_19_CORE_END = + LENGTH mldsa_polyz_unpack_19_mc - MLDSA_POLYZ_UNPACK_19_POSTAMBLE_LENGTH`;; + +let LENGTH_SIMPLIFY_CONV_19 = + REWRITE_CONV[LENGTH_MLDSA_POLYZ_UNPACK_19_MC; + MLDSA_POLYZ_UNPACK_19_CORE_START; MLDSA_POLYZ_UNPACK_19_CORE_END; + MLDSA_POLYZ_UNPACK_19_PREAMBLE_LENGTH; + MLDSA_POLYZ_UNPACK_19_POSTAMBLE_LENGTH] THENC + NUM_REDUCE_CONV THENC REWRITE_CONV [ADD_0];; + +(* ------------------------------------------------------------------------- *) +(* D=20 instantiations for SIMD infrastructure *) +(* ------------------------------------------------------------------------- *) + +let BASE_SIMPS_D20 = mk_base_simps 20;; +let NUM_OF_WORDLIST_SPLIT_20_256 = mk_split_theorem 20 256 16;; +let READ_MEMORY_WBYTES_SPLIT_128_128_64 = prove + (`t < 2 EXP 320 + ==> (read (memory :> wbytes a) (s:armstate) = (word t : 320 word) <=> + read (memory :> bytes128 a) s = (word (t MOD 2 EXP 128) : int128) /\ + read (memory :> bytes128 (word_add a (word 16))) s = + (word ((t DIV 2 EXP 128) MOD 2 EXP 128) : int128) /\ + read (memory :> bytes64 (word_add a (word 32))) s = + (word (t DIV 2 EXP 256) : int64))`, + let split_16_24 = CONV_RULE (ONCE_DEPTH_CONV NUM_ADD_CONV THENC + DEPTH_CONV NUM_MULT_CONV) + (INST [`16`,`k:num`; `24`,`l:num`] READ_BYTES_SPLIT_ANY) in + let split_16_8 = CONV_RULE (ONCE_DEPTH_CONV NUM_ADD_CONV THENC + DEPTH_CONV NUM_MULT_CONV) + (INST [`16`,`k:num`; `8`,`l:num`] READ_BYTES_SPLIT_ANY) in + STRIP_TAC THEN + REWRITE_TAC[BYTES128_WBYTES; BYTES64_WBYTES; GSYM VAL_EQ; + VAL_READ_WBYTES; READ_COMPONENT_COMPOSE] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[split_16_24] THEN REWRITE_TAC[split_16_8] THEN + REWRITE_TAC[WORD_ADD_ASSOC_CONSTS] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[DIV_DIV; GSYM EXP_ADD] THEN CONV_TAC NUM_REDUCE_CONV THEN + IMP_REWRITE_TAC[VAL_WORD_EXACT] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN ASM_ARITH_TAC);; +let WORD_SUBWORD_NUM_OF_WORDLIST_CASES_D20 = mk_subword_cases 20 16;; + +(* ------------------------------------------------------------------------- *) +(* Core correctness theorem *) +(* ------------------------------------------------------------------------- *) + +let MLDSA_POLYZ_UNPACK_19_CORRECT = prove + (`!r b t l pc. + LENGTH l = 256 /\ + nonoverlapping (word pc,LENGTH mldsa_polyz_unpack_19_mc) (r,1024) /\ + nonoverlapping (b,640) (r,1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_polyz_unpack_19_mc /\ + read PC s = word (pc + MLDSA_POLYZ_UNPACK_19_CORE_START) /\ + C_ARGUMENTS [r; b; t] s /\ + read(memory :> bytes(t,64)) s = + num_of_wordlist mldsa_polyz_unpack_19_indices /\ + read(memory :> bytes(b,640)) s = num_of_wordlist l) + (\s. read PC s = word(pc + MLDSA_POLYZ_UNPACK_19_CORE_END) /\ + read(memory :> bytes(r,1024)) s = + num_of_wordlist (MAP zunpack19 l)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r,1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV_19 THEN + MAP_EVERY X_GEN_TAC [`r:int64`; `b:int64`; `t:int64`; + `l:(20 word) list`; `pc:num`] THEN + REWRITE_TAC[C_ARGUMENTS; MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; + NONOVERLAPPING_CLAUSES] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + + ENSURES_INIT_TAC "s0" THEN + + (*** Expand table precondition into 4 x bytes128 reads ***) + FIRST_X_ASSUM(MP_TAC o check (can (term_match [] + `read(memory :> bytes(t:int64,64)) s = x`) o concl)) THEN + REWRITE_TAC[mldsa_polyz_unpack_19_indices] THEN + REPLICATE_TAC 4 + (GEN_REWRITE_TAC (LAND_CONV o ONCE_DEPTH_CONV) + [GSYM NUM_OF_PAIR_WORDLIST]) THEN + REWRITE_TAC[pair_wordlist] THEN + CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(LAND_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + REWRITE_TAC[GSYM BYTES128_WBYTES] THEN + STRIP_TAC THEN + + (*** Split 256 20-bit coefficients into 16 chunks of 16 as 320-bit words ***) + UNDISCH_TAC `read(memory :> bytes(b,640)) s0 = num_of_wordlist(l:(20 word) list)` THEN + IMP_REWRITE_TAC [NUM_OF_WORDLIST_SPLIT_20_256] THEN + CONV_TAC (ONCE_DEPTH_CONV LIST_OF_SEQ_CONV) THEN + REWRITE_TAC [MAP; o_DEF] THEN + CONV_TAC(LAND_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + + (*** Split each 320-bit wbytes into bytes128 + bytes128 + bytes64 ***) + IMP_REWRITE_TAC [READ_MEMORY_WBYTES_SPLIT_128_128_64] THEN + MAP_EVERY (fun n -> SUBGOAL_THEN (subst[mk_small_numeral n,`k:num`] + `num_of_wordlist (SUB_LIST (16 * k,16) (l : (20 word) list)) < 2 EXP 320`) + (fun th -> REWRITE_TAC[th]) THENL [ + TRANS_TAC LTE_TRANS (subst[mk_small_numeral n,`k:num`] + `2 EXP (dimindex(:20) * LENGTH(SUB_LIST(16*k,16) (l : (20 word) list)))`) THEN + REWRITE_TAC[NUM_OF_WORDLIST_BOUND] THEN + REWRITE_TAC[LENGTH_SUB_LIST; DIMINDEX_CONV `dimindex (:20)`] THEN + ASM_SIMP_TAC [] THEN NUM_REDUCE_TAC; + ALL_TAC]) (0--15) THEN + REWRITE_TAC [WORD_ADD_ASSOC_CONSTS] THEN CONV_TAC (TOP_SWEEP_CONV NUM_ADD_CONV) THEN + STRIP_TAC THEN + + (*** Gather LENGTH assumptions for sublists ***) + MAP_EVERY (fun i -> SUBGOAL_THEN + (subst [mk_small_numeral (16 * i), `i: num`] + `LENGTH (SUB_LIST (i, 16) (l : (20 word) list)) = 16`) ASSUME_TAC + THENL [ASM_REWRITE_TAC [LENGTH_SUB_LIST] THEN NUM_REDUCE_TAC; ALL_TAC]) + (0 -- 15) THEN + + (*** Symbolic execution with per-step simplification ***) + MAP_UNTIL_TARGET_PC (fun n -> + ARM_STEPS_TAC MLDSA_POLYZ_UNPACK_19_EXEC [n] THEN + SIMD_SIMPLIFY_TAC (map GSYM BASE_SIMPS_D20) THEN + SIMP_ZUNPACK_TAC 20 ZUNPACK19_CORRECT) 1 THEN + + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (*** Fold output back to MAP zunpack19 l ***) + REPEAT (FIRST_X_ASSUM(MP_TAC o check + (can (term_match [] `read (memory :> bytes128 r) s0 = xxx`) o concl))) THEN + TRY (IMP_REWRITE_TAC WORD_SUBWORD_NUM_OF_WORDLIST_CASES_D20) THEN + UNDISCH_THEN `LENGTH (l : (20 word) list) = 256` + (fun th -> CONV_TAC (TOP_SWEEP_CONV (EL_SUB_LIST_CONV th)) THEN ASSUME_TAC th) THEN + REPEAT DISCH_TAC THEN + GEN_REWRITE_TAC (RAND_CONV o RAND_CONV o RAND_CONV) [GSYM LIST_OF_SEQ_EQ_SELF] THEN + ASM_REWRITE_TAC[LENGTH_MAP] THEN + CONV_TAC (TOP_SWEEP_CONV LIST_OF_SEQ_CONV) THEN + ASM_REWRITE_TAC [MAP] THEN + REPLICATE_TAC 2 (CONV_TAC (ONCE_REWRITE_CONV [GSYM NUM_OF_PAIR_WORDLIST])) THEN + REWRITE_TAC[pair_wordlist] THEN + CONV_TAC (ONCE_DEPTH_CONV BYTES_EQ_NUM_OF_WORDLIST_EXPAND_CONV) THEN + ASM_REWRITE_TAC[GSYM BYTES128_WBYTES]);; + +(* ------------------------------------------------------------------------- *) +(* Subroutine correctness *) +(* ------------------------------------------------------------------------- *) + +(* NOTE: This must be kept in sync with the CBMC specification + * in mldsa/src/native/aarch64/src/arith_native_aarch64.h *) + +let MLDSA_POLYZ_UNPACK_19_SUBROUTINE_CORRECT = prove + (`!r b t l pc returnaddress. + LENGTH l = 256 /\ + nonoverlapping (word pc,LENGTH mldsa_polyz_unpack_19_mc) (r,1024) /\ + nonoverlapping (b,640) (r,1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_polyz_unpack_19_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [r; b; t] s /\ + read(memory :> bytes(t,64)) s = + num_of_wordlist mldsa_polyz_unpack_19_indices /\ + read(memory :> bytes(b,640)) s = num_of_wordlist l) + (\s. read PC s = returnaddress /\ + read(memory :> bytes(r,1024)) s = + num_of_wordlist (MAP zunpack19 l) /\ + (!i. i < 256 ==> + --(&(2 EXP 19) - &1) <= ival(EL i (MAP zunpack19 l)) /\ + ival(EL i (MAP zunpack19 l)) <= &(2 EXP 19))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(r,1024)])`, + CONV_TAC LENGTH_SIMPLIFY_CONV_19 THEN + ARM_ADD_RETURN_NOSTACK_TAC MLDSA_POLYZ_UNPACK_19_EXEC + (CONV_RULE LENGTH_SIMPLIFY_CONV_19 MLDSA_POLYZ_UNPACK_19_CORRECT) THEN + REPEAT STRIP_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (ISPECL [`l:(20 word) list`; `i:num`] ZUNPACK19_MAP_BOUND)) THEN + ASM_REWRITE_TAC[] THEN SIMP_TAC[]);; diff --git a/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_consts.ml b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_consts.ml new file mode 100644 index 000000000..6817cbc33 --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_consts.ml @@ -0,0 +1,39 @@ +(* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* + * WARNING: This file is auto-generated from scripts/autogen + * in the mldsa-native repository. + * Do not modify it directly. + *) + +(* + * TBL index tables for polyz_unpack_{17,19} + * See autogen for details. + *) + +let mldsa_polyz_unpack_17_indices = (REWRITE_RULE[MAP] o define) + `mldsa_polyz_unpack_17_indices:byte list = MAP word [ + 0; 1; 2; 255; 2; 3; 4; 255; + 4; 5; 6; 255; 6; 7; 8; 255; + 9; 10; 11; 255; 11; 12; 13; 255; + 13; 14; 15; 255; 15; 16; 17; 255; + 2; 3; 4; 255; 4; 5; 6; 255; + 6; 7; 8; 255; 8; 9; 10; 255; + 11; 12; 13; 255; 13; 14; 15; 255; + 15; 16; 17; 255; 17; 18; 19; 255 +]`;; + +let mldsa_polyz_unpack_19_indices = (REWRITE_RULE[MAP] o define) + `mldsa_polyz_unpack_19_indices:byte list = MAP word [ + 0; 1; 2; 255; 2; 3; 4; 255; + 5; 6; 7; 255; 7; 8; 9; 255; + 10; 11; 12; 255; 12; 13; 14; 255; + 15; 16; 17; 255; 17; 18; 19; 255; + 4; 5; 6; 255; 6; 7; 8; 255; + 9; 10; 11; 255; 11; 12; 13; 255; + 14; 15; 16; 255; 16; 17; 18; 255; + 19; 20; 21; 255; 21; 22; 23; 255 +]`;; diff --git a/proofs/hol_light/common/mldsa_specs.ml b/proofs/hol_light/common/mldsa_specs.ml new file mode 100644 index 000000000..d91f3ae17 --- /dev/null +++ b/proofs/hol_light/common/mldsa_specs.ml @@ -0,0 +1,581 @@ +(* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + *) + +(* ========================================================================= *) +(* Architecture-independent specifications for ML-DSA *) +(* ========================================================================= *) + +(* ========================================================================= *) +(* decompose: Decompose r into high bits r1 and low bits r0 *) +(* *) +(* FIPS 204 Algorithm 36 (Decompose): *) +(* Input: r in Z_q *) +(* Output: integers r1, r0 *) +(* (1: r+ <- r mod q) [skipped: r already reduced]*) +(* 2: r0 <- r mod+/- (2*gamma2) [centered mod, see cmod] *) +(* 3: if r - r0 = q - 1 then *) +(* 4: r1 <- 0 *) +(* 5: r0 <- r0 - 1 *) +(* 6: else r1 <- (r - r0) / (2*gamma2) *) +(* 7: end if *) +(* 8: return (r1, r0) *) +(* *) +(* decompose32: gamma2 = 261888, 2*gamma2 = 523776, (q-1)/(2*gamma2) = 16. *) +(* decompose88: gamma2 = 95232, 2*gamma2 = 190464, (q-1)/(2*gamma2) = 44 *) +(* ========================================================================= *) + +(* --- Centered modular reduction (line 2: mod+/-) --- *) +(* cmod r m returns r mod m centered in (-m/2, m/2]. *) +(* The condition r MOD m * 2 <= m is equivalent to r MOD m <= m/2, *) +(* but avoids truncation from natural number division by 2. *) + +let cmod = new_definition + `cmod (r:num) (m:num) : int = + if r MOD m * 2 <= m then &(r MOD m) else &(r MOD m) - &m`;; + +(* --- decompose32: GAMMA2 = (Q-1)/32 = 261888 --- *) + +let decompose32 = new_definition + `decompose32 (r:num) : num # int = + let r0 = cmod r 523776 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int(&r - r0) DIV 523776, r0)`;; + +(* --- decompose88: GAMMA2 = (Q-1)/88 = 95232 --- *) + +let decompose88 = new_definition + `decompose88 (r:num) : num # int = + let r0 = cmod r 190464 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int(&r - r0) DIV 190464, r0)`;; + +(* --- Helper: num_of_int(&r - cmod r m) DIV m computes the highbits --- *) + +let CMOD_SUB = prove( + `!r m. ~(m = 0) ==> + num_of_int(&r - cmod r m) = + if r MOD m * 2 <= m then r DIV m * m + else (r DIV m + 1) * m`, + REPEAT STRIP_TAC THEN REWRITE_TAC[cmod] THEN + MP_TAC(SPECL [`r:num`; `m:num`] DIVISION) THEN ASM_REWRITE_TAC[] THEN + STRIP_TAC THEN + COND_CASES_TAC THEN REWRITE_TAC[] THENL + [SUBGOAL_THEN `r MOD m <= r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `&r - &(r MOD m) = &(r - r MOD m) : int` + (fun th -> REWRITE_TAC[th; NUM_OF_INT_OF_NUM]) THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB]; ALL_TAC] THEN + ASM_ARITH_TAC; + SUBGOAL_THEN `r MOD m <= r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `&r - (&(r MOD m) - &m) = &(r - r MOD m + m) : int` + (fun th -> REWRITE_TAC[th; NUM_OF_INT_OF_NUM]) THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN INT_ARITH_TAC; + ASM_ARITH_TAC]]);; + +let CMOD_HIGHBITS = prove( + `!r m. ~(m = 0) ==> + num_of_int(&r - cmod r m) DIV m = + (if r MOD m * 2 <= m then r DIV m else r DIV m + 1)`, + REPEAT STRIP_TAC THEN ASM_SIMP_TAC[CMOD_SUB] THEN + COND_CASES_TAC THEN REWRITE_TAC[MULT_SYM] THEN + ASM_SIMP_TAC[DIV_MULT]);; + +(* --- decompose32 lemmas --- *) + +(* Equivalence to MOD/DIV form, used in bound proofs *) +let DECOMPOSE32_EXPAND = prove( + `!r. decompose32 r = + let r0 = cmod r 523776 in + let h = if r MOD 523776 * 2 <= 523776 + then r DIV 523776 + else r DIV 523776 + 1 in + if h = 16 then (0, r0 - &1) + else (h, r0)`, + GEN_TAC THEN REWRITE_TAC[decompose32; LET_DEF; LET_END_DEF] THEN + MP_TAC(SPECL [`r:num`; `523776`] CMOD_HIGHBITS) THEN + ANTS_TAC THENL [ARITH_TAC; DISCH_TAC] THEN + MP_TAC(SPECL [`r:num`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THENL + [REWRITE_TAC[cmod] THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `r DIV 523776 = 16` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `&r - &(r MOD 523776) = &8380416 : int` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC; + SUBGOAL_THEN `~(&r - &(r MOD 523776) = &8380416 : int)` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + REWRITE_TAC[cmod] THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `r DIV 523776 + 1 = 16` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `&r - (&(r MOD 523776) - &523776) = &8380416 : int` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC; + SUBGOAL_THEN `~(&r - (&(r MOD 523776) - &523776) = &8380416 : int)` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]]);; + +let DECOMPOSE32_A1_BOUND = prove( + `!r. r < 8380417 ==> FST(decompose32 r) <= 15`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE32_EXPAND; cmod; LET_DEF; LET_END_DEF; FST] THEN + MP_TAC(SPECL [`r:num`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN + ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_ARITH_TAC);; + +let DECOMPOSE32_A0_BOUND = prove( + `!r. r < 8380417 ==> + -- &261888 <= SND(decompose32 r) /\ SND(decompose32 r) <= &261888`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE32_EXPAND; cmod; LET_DEF; LET_END_DEF] THEN + MP_TAC(SPECL [`r:num`; `523776`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THENL + [(* Case 1: MOD*2 <= 523776 *) + ASM_CASES_TAC `r DIV 523776 = 16` THEN ASM_REWRITE_TAC[SND] THENL + [(* 1a: wrap *) + SUBGOAL_THEN `r MOD 523776 = 0` SUBST1_TAC THENL + [ASM_ARITH_TAC; CONV_TAC INT_REDUCE_CONV]; + (* 1b: no wrap *) + MP_TAC(SPEC `r MOD 523776` INT_POS) THEN + ASM_REWRITE_TAC[INT_OF_NUM_LE] THEN ASM_ARITH_TAC]; + (* Case 2: MOD*2 > 523776 *) + ASM_CASES_TAC `r DIV 523776 + 1 = 16` THEN ASM_REWRITE_TAC[SND] THENL + [(* 2a: wrap *) + SUBGOAL_THEN `&261888 < &(r MOD 523776) : int /\ &(r MOD 523776) < &523776 : int` MP_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; INT_ARITH_TAC]; + (* 2b: no wrap *) + SUBGOAL_THEN `&261888 < &(r MOD 523776) : int /\ &(r MOD 523776) < &523776 : int` MP_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; INT_ARITH_TAC]]]);; + +let DECOMPOSE32_A1_MAP_BOUND = prove( + `!l. ALL (\x. x < 8380417) l + ==> ALL (\x. x <= 15) (MAP (FST o decompose32) l)`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ALL; MAP; o_THM] THEN + STRIP_TAC THEN CONJ_TAC THENL + [MATCH_MP_TAC DECOMPOSE32_A1_BOUND THEN ASM_REWRITE_TAC[]; + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; + +let DECOMPOSE32_A0_MAP_BOUND = prove( + `!l. ALL (\x. x < 8380417) l + ==> ALL (\x. -- &261888 <= x /\ x <= &261888) (MAP (SND o decompose32) l)`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ALL; MAP; o_THM] THEN + STRIP_TAC THEN CONJ_TAC THENL + [MATCH_MP_TAC DECOMPOSE32_A0_BOUND THEN ASM_REWRITE_TAC[]; + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; + +(* --- decompose88 lemmas --- *) + +(* Equivalence to MOD/DIV form, used in bound proofs *) +let DECOMPOSE88_EXPAND = prove( + `!r. decompose88 r = + let r0 = cmod r 190464 in + let h = if r MOD 190464 * 2 <= 190464 + then r DIV 190464 + else r DIV 190464 + 1 in + if h = 44 then (0, r0 - &1) + else (h, r0)`, + GEN_TAC THEN REWRITE_TAC[decompose88; LET_DEF; LET_END_DEF] THEN + MP_TAC(SPECL [`r:num`; `190464`] CMOD_HIGHBITS) THEN + ANTS_TAC THENL [ARITH_TAC; DISCH_TAC] THEN + MP_TAC(SPECL [`r:num`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THENL + [REWRITE_TAC[cmod] THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `r DIV 190464 = 44` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `&r - &(r MOD 190464) = &8380416 : int` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC; + SUBGOAL_THEN `~(&r - &(r MOD 190464) = &8380416 : int)` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + REWRITE_TAC[cmod] THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `r DIV 190464 + 1 = 44` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `&r - (&(r MOD 190464) - &190464) = &8380416 : int` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC; + SUBGOAL_THEN `~(&r - (&(r MOD 190464) - &190464) = &8380416 : int)` (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]]);; + +let DECOMPOSE88_A1_BOUND = prove( + `!r. r < 8380417 ==> FST(decompose88 r) <= 43`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE88_EXPAND; cmod; LET_DEF; LET_END_DEF; FST] THEN + MP_TAC(SPECL [`r:num`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN + ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_ARITH_TAC);; + +let DECOMPOSE88_A0_BOUND = prove( + `!r. r < 8380417 ==> + -- &95232 <= SND(decompose88 r) /\ SND(decompose88 r) <= &95232`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[DECOMPOSE88_EXPAND; cmod; LET_DEF; LET_END_DEF] THEN + MP_TAC(SPECL [`r:num`; `190464`] DIVISION) THEN + ANTS_TAC THENL [ARITH_TAC; STRIP_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THENL + [(* Case 1: MOD*2 <= 190464 *) + ASM_CASES_TAC `r DIV 190464 = 44` THEN ASM_REWRITE_TAC[SND] THENL + [(* 1a: wrap *) + SUBGOAL_THEN `r MOD 190464 = 0` SUBST1_TAC THENL + [ASM_ARITH_TAC; CONV_TAC INT_REDUCE_CONV]; + (* 1b: no wrap *) + MP_TAC(SPEC `r MOD 190464` INT_POS) THEN + ASM_REWRITE_TAC[INT_OF_NUM_LE] THEN ASM_ARITH_TAC]; + (* Case 2: MOD*2 > 190464 *) + ASM_CASES_TAC `r DIV 190464 + 1 = 44` THEN ASM_REWRITE_TAC[SND] THENL + [(* 2a: wrap *) + SUBGOAL_THEN `&95232 < &(r MOD 190464) : int /\ &(r MOD 190464) < &190464 : int` MP_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; INT_ARITH_TAC]; + (* 2b: no wrap *) + SUBGOAL_THEN `&95232 < &(r MOD 190464) : int /\ &(r MOD 190464) < &190464 : int` MP_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; INT_ARITH_TAC]]]);; + +let DECOMPOSE88_A1_MAP_BOUND = prove( + `!l. ALL (\x. x < 8380417) l + ==> ALL (\x. x <= 43) (MAP (FST o decompose88) l)`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ALL; MAP; o_THM] THEN + STRIP_TAC THEN CONJ_TAC THENL + [MATCH_MP_TAC DECOMPOSE88_A1_BOUND THEN ASM_REWRITE_TAC[]; + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; + +let DECOMPOSE88_A0_MAP_BOUND = prove( + `!l. ALL (\x. x < 8380417) l + ==> ALL (\x. -- &95232 <= x /\ x <= &95232) (MAP (SND o decompose88) l)`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ALL; MAP; o_THM] THEN + STRIP_TAC THEN CONJ_TAC THENL + [MATCH_MP_TAC DECOMPOSE88_A0_BOUND THEN ASM_REWRITE_TAC[]; + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]]);; + +(* ========================================================================= *) +(* zunpack: gamma1 - x unpacking for polyz *) +(* *) +(* zunpack_d maps a d-bit packed coefficient x in [0, 2^d - 1] to *) +(* gamma1 - x in [-(gamma1-1), gamma1], where gamma1 = 2^(d-1). *) +(* ========================================================================= *) + +(* --- zunpack17: GAMMA1 = 2^17, 18-bit packed coefficients --- *) + +let zunpack17 = new_definition + `zunpack17 (x:(18)word) : (32)word = + word_sub (word(2 EXP 17)) (word_zx x)`;; + +let ZUNPACK17_CORRECT = prove( + `!x:(18)word. + word_sub (word 131072 : 32 word) + (word_zx (x : 18 word) : 32 word) = zunpack17 x`, + REWRITE_TAC[zunpack17] THEN CONV_TAC NUM_REDUCE_CONV);; + +let ZUNPACK17_IVAL = prove( + `!x:(18)word. ival(zunpack17 x) = &(2 EXP 17) - &(val x)`, + GEN_TAC THEN REWRITE_TAC[zunpack17] THEN + SUBGOAL_THEN `word_zx (x:18 word) : 32 word = word(val x)` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:18 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN SIMP_TAC[MOD_LT]; + ALL_TAC] THEN + ONCE_REWRITE_TAC[WORD_IWORD] THEN + REWRITE_TAC[GSYM IWORD_INT_SUB] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:18 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV THENC NUM_REDUCE_CONV) THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN INT_ARITH_TAC);; + +let ZUNPACK17_BOUND = prove( + `!x:(18)word. --(&(2 EXP 17) - &1) <= ival(zunpack17 x) /\ + ival(zunpack17 x) <= &(2 EXP 17)`, + GEN_TAC THEN REWRITE_TAC[ZUNPACK17_IVAL] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:18 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV THENC NUM_REDUCE_CONV) THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN INT_ARITH_TAC);; + +let ZUNPACK17_MAP_BOUND = prove( + `!l:(18 word) list. !i. i < LENGTH l ==> + --(&(2 EXP 17) - &1) <= ival(EL i (MAP zunpack17 l)) /\ + ival(EL i (MAP zunpack17 l)) <= &(2 EXP 17)`, + REPEAT STRIP_TAC THEN ASM_SIMP_TAC[EL_MAP] THEN + REWRITE_TAC[ZUNPACK17_BOUND]);; + +(* --- zunpack19: GAMMA1 = 2^19, 20-bit packed coefficients --- *) + +let zunpack19 = new_definition + `zunpack19 (x:(20)word) : (32)word = + word_sub (word(2 EXP 19)) (word_zx x)`;; + +let ZUNPACK19_CORRECT = prove( + `!x:(20)word. + word_sub (word 524288 : 32 word) + (word_zx (x : 20 word) : 32 word) = zunpack19 x`, + REWRITE_TAC[zunpack19] THEN CONV_TAC NUM_REDUCE_CONV);; + +let ZUNPACK19_IVAL = prove( + `!x:(20)word. ival(zunpack19 x) = &(2 EXP 19) - &(val x)`, + GEN_TAC THEN REWRITE_TAC[zunpack19] THEN + SUBGOAL_THEN `word_zx (x:20 word) : 32 word = word(val x)` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_ZX_GEN; VAL_WORD] THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:20 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV) THEN SIMP_TAC[MOD_LT]; + ALL_TAC] THEN + ONCE_REWRITE_TAC[WORD_IWORD] THEN + REWRITE_TAC[GSYM IWORD_INT_SUB] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC IVAL_IWORD THEN REWRITE_TAC[DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:20 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV THENC NUM_REDUCE_CONV) THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN INT_ARITH_TAC);; + +let ZUNPACK19_BOUND = prove( + `!x:(20)word. --(&(2 EXP 19) - &1) <= ival(zunpack19 x) /\ + ival(zunpack19 x) <= &(2 EXP 19)`, + GEN_TAC THEN REWRITE_TAC[ZUNPACK19_IVAL] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(ISPEC `x:20 word` VAL_BOUND) THEN + CONV_TAC(DEPTH_CONV DIMINDEX_CONV THENC NUM_REDUCE_CONV) THEN + REWRITE_TAC[GSYM INT_OF_NUM_LT] THEN INT_ARITH_TAC);; + +let ZUNPACK19_MAP_BOUND = prove( + `!l:(20 word) list. !i. i < LENGTH l ==> + --(&(2 EXP 19) - &1) <= ival(EL i (MAP zunpack19 l)) /\ + ival(EL i (MAP zunpack19 l)) <= &(2 EXP 19)`, + REPEAT STRIP_TAC THEN ASM_SIMP_TAC[EL_MAP] THEN + REWRITE_TAC[ZUNPACK19_BOUND]);; + +(* ========================================================================= *) +(* Helper lemmas: list operations *) +(* ========================================================================= *) + +let EL_SUB_LIST = prove( + `!l:'a list. !i k n. i < n /\ k + n <= LENGTH l + ==> EL i (SUB_LIST (k, n) l) = EL (k + i) l`, + LIST_INDUCT_TAC THENL [ + REWRITE_TAC[LENGTH; LE; ADD_EQ_0] THEN ARITH_TAC; + REWRITE_TAC[LENGTH] THEN REPEAT GEN_TAC THEN + STRUCT_CASES_TAC (SPEC `k:num` num_CASES) THEN + STRUCT_CASES_TAC (SPEC `n:num` num_CASES) THEN + REWRITE_TAC[LT; SUB_LIST_CLAUSES; ADD_CLAUSES] THENL [ + STRUCT_CASES_TAC (SPEC `i:num` num_CASES) THEN + REWRITE_TAC[EL; HD; TL; ADD_CLAUSES] THEN STRIP_TAC THEN + FIRST_X_ASSUM (MP_TAC o SPECL [`n:num`; `0`; `n':num`]) THEN + REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN MATCH_MP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[EL; TL] THEN STRIP_TAC THEN + FIRST_X_ASSUM (MP_TAC o SPECL [`i:num`; `n':num`; `SUC n''`]) THEN + ASM_REWRITE_TAC[LT_SUC] THEN DISCH_THEN MATCH_MP_TAC THEN ASM_ARITH_TAC]]);; + +let EL_SUB_LIST_CONV len_thm tm = + let i_tm,sublist_tm = dest_comb tm in + let el_const,i = dest_comb i_tm in + let sublist_pair,ls = dest_comb sublist_tm in + let sublist_const,pair_tm = dest_comb sublist_pair in + let base,len = dest_pair pair_tm in + let i_num = dest_numeral i and + len_num = dest_numeral len in + if i_num >= len_num then failwith "EL_SUB_LIST_CONV: index out of bounds" else + let th1 = ISPECL [ls; i; base; len] EL_SUB_LIST in + let th2 = REWRITE_RULE[len_thm] th1 in + let th3 = MP th2 (EQT_ELIM(NUM_REDUCE_CONV (fst(dest_imp(concl th2))))) in + CONV_RULE (RAND_CONV (LAND_CONV NUM_ADD_CONV)) th3;; + +let LENGTH_SUB_LIST_0 = prove + (`!n (l:'a list). n <= LENGTH l ==> LENGTH (SUB_LIST (0, n) l) = n`, + REPEAT STRIP_TAC THEN REWRITE_TAC[LENGTH_SUB_LIST; SUB_0] THEN ASM_ARITH_TAC);; + +let SUB_LIST_SUB_LIST_0 = prove( + `!k n m (l:'a list). k + n <= m /\ m <= LENGTH l + ==> SUB_LIST (k, n) (SUB_LIST (0, m) l) = SUB_LIST (k, n) l`, + REPEAT STRIP_TAC THEN REWRITE_TAC[LIST_EQ; LENGTH_SUB_LIST; SUB_0] THEN + CONJ_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN REPEAT STRIP_TAC THEN + SUBGOAL_THEN `n' < n` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC (ISPECL [`SUB_LIST (0,m) l:'a list`; `n':num`; `k:num`; `n:num`] EL_SUB_LIST) THEN + MP_TAC (ISPECL [`l:'a list`; `n':num`; `k:num`; `n:num`] EL_SUB_LIST) THEN + ASM_REWRITE_TAC[LENGTH_SUB_LIST; SUB_0] THEN + SUBGOAL_THEN `k + n <= LENGTH (l:'a list)` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `k + n <= MIN m (LENGTH (l:'a list))` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[] THEN REPEAT DISCH_TAC THEN + MP_TAC (ISPECL [`l:'a list`; `k + n':num`; `0`; `m:num`] EL_SUB_LIST) THEN + ASM_REWRITE_TAC[ADD_CLAUSES] THEN DISCH_THEN MATCH_MP_TAC THEN ASM_ARITH_TAC);; + +let SUB_LIST_SPLIT_EQ = prove + (`!n r (l:'a list). n + r = LENGTH l + ==> APPEND (SUB_LIST (0, n) l) (SUB_LIST (n, r) l) = l`, + REPEAT STRIP_TAC THEN + MP_TAC (ISPECL [`l:'a list`; `n:num`] SUB_LIST_TOPSPLIT) THEN + FIRST_X_ASSUM (SUBST1_TAC o SYM) THEN REWRITE_TAC[ADD_SUB2]);; + +let APPEND_ITLIST_APPEND_NIL = prove + (`!(l:('a list) list) (x:'a list). APPEND (ITLIST APPEND l []) x = ITLIST APPEND l x`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ITLIST; APPEND] THEN + GEN_TAC THEN REWRITE_TAC[GSYM APPEND_ASSOC] THEN ASM_REWRITE_TAC[]);; + +let LIST_OF_SEQ_EQ = prove + (`!(f:num->'a) g n. (!i. i < n ==> f i = g i) ==> list_of_seq f n = list_of_seq g n`, + GEN_TAC THEN GEN_TAC THEN INDUCT_TAC THEN REWRITE_TAC[list_of_seq] THEN + DISCH_TAC THEN BINOP_TAC THENL [ + FIRST_X_ASSUM MATCH_MP_TAC THEN GEN_TAC THEN DISCH_TAC THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_ARITH_TAC; + REWRITE_TAC[CONS_11] THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ARITH_TAC + ]);; + +let SUBLIST_PARTITION = prove + (`!r s (l:'a list). LENGTH l = r * s ==> + l = ITLIST APPEND (list_of_seq (\i. SUB_LIST (r * i, r) l) s) []`, + GEN_TAC THEN INDUCT_TAC THENL [ + REWRITE_TAC[MULT_CLAUSES; list_of_seq; ITLIST; LENGTH_EQ_NIL]; + REWRITE_TAC[list_of_seq; ITLIST_EXTRA; APPEND_NIL] THEN + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN + `SUB_LIST (0, r * s) l = + ITLIST APPEND (list_of_seq (\i. SUB_LIST (r * i, r) (SUB_LIST (0, r * s) l)) s) []:'a list` + ASSUME_TAC THENL [ + FIRST_X_ASSUM MATCH_MP_TAC THEN + MATCH_MP_TAC LENGTH_SUB_LIST_0 THEN ASM_ARITH_TAC; + ALL_TAC + ] THEN + SUBGOAL_THEN + `list_of_seq (\i. SUB_LIST (r * i, r) (SUB_LIST (0, r * s) l):'a list) s = + list_of_seq (\i. SUB_LIST (r * i, r) l) s` + ASSUME_TAC THENL [ + MATCH_MP_TAC LIST_OF_SEQ_EQ THEN REPEAT STRIP_TAC THEN REWRITE_TAC[] THEN + MATCH_MP_TAC SUB_LIST_SUB_LIST_0 THEN CONJ_TAC THENL [ + REWRITE_TAC[ARITH_RULE `r * i + r = r * (i + 1)`] THEN + REWRITE_TAC[LE_MULT_LCANCEL] THEN ASM_ARITH_TAC; + ASM_ARITH_TAC + ]; + ALL_TAC + ] THEN + SUBGOAL_THEN + `APPEND (SUB_LIST (0, r * s) l) (SUB_LIST (r * s, r) l) = l:'a list` + ASSUME_TAC THENL [ + MATCH_MP_TAC SUB_LIST_SPLIT_EQ THEN ASM_REWRITE_TAC[MULT_SUC] THEN ARITH_TAC; + ALL_TAC + ] THEN + SUBGOAL_THEN + `SUB_LIST (0, r * s) l = + ITLIST APPEND (list_of_seq (\i. SUB_LIST (r * i, r) l) s) []:'a list` + ASSUME_TAC THENL [ASM_MESON_TAC[]; ALL_TAC] THEN + UNDISCH_TAC `APPEND (SUB_LIST (0,r * s) l) (SUB_LIST (r * s,r) l) = l:'a list` THEN + UNDISCH_TAC `SUB_LIST (0,r * s) l = ITLIST APPEND (list_of_seq (\i. SUB_LIST (r * i,r) l) s) []:'a list` THEN + SIMP_TAC[APPEND_ITLIST_APPEND_NIL] + ]);; + +(* ========================================================================= *) +(* Helper lemmas: word arithmetic *) +(* ========================================================================= *) + +let VAL_WORD_EXACT = prove( + `!n. n < 2 EXP dimindex(:N) ==> val(word n : N word) = n`, + REWRITE_TAC[VAL_WORD] THEN SIMP_TAC[MOD_LT]);; + +let WORD_PACKED_EQ = prove( + `!(x:N word) (y:N word). + dimindex(:N) = l * k /\ 0 < l /\ l <= dimindex(:M) + ==> (x = y <=> + !i. i < k + ==> word_subword x (l*i, l) : (M) word = + word_subword y (l*i, l))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN EQ_TAC THENL + [DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[]; + DISCH_TAC THEN + GEN_REWRITE_TAC I [WORD_EQ_BITS_ALT] THEN + X_GEN_TAC `j:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(MP_TAC o SPEC `j DIV l`) THEN + ANTS_TAC THENL + [UNDISCH_TAC `j < dimindex(:N)` THEN ASM_REWRITE_TAC[] THEN + ASM_SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `0 < l ==> ~(l = 0)`; MULT_SYM]; + DISCH_THEN(fun th -> + MP_TAC(AP_TERM `\(w:M word). bit (j MOD l) w` th)) THEN + REWRITE_TAC[BIT_WORD_SUBWORD] THEN + SUBGOAL_THEN `j MOD l < MIN l (dimindex(:M))` + (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[ARITH_RULE `l <= m ==> MIN l m = l`; + MOD_LT_EQ; ARITH_RULE `0 < l ==> ~(l = 0)`]; + ASM_SIMP_TAC[DIVISION_SIMP; ARITH_RULE `0 < l ==> ~(l = 0)`]]]]);; + +let WORD_SUBWORD_NUM_OF_WORDLIST = prove + (`!(ls:(L word)list) k. + dimindex(:KL) = dimindex(:L) * LENGTH ls /\ + k < LENGTH ls + ==> word_subword (word (num_of_wordlist ls) : KL word) (dimindex(:L)*k, dimindex(:L)) : L word = EL k ls`, + REPEAT STRIP_TAC THEN REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_SUBWORD] THEN + REWRITE_TAC[ARITH_RULE `MIN n n = n`] THEN + SUBGOAL_THEN `val (word (num_of_wordlist (ls:(L word)list)) : KL word) = num_of_wordlist ls` SUBST1_TAC THENL + [W(MP_TAC o PART_MATCH (lhand o rand) VAL_WORD_EQ o lhand o snd) THEN + ANTS_TAC THENL + [TRANS_TAC LTE_TRANS `2 EXP (dimindex(:L) * LENGTH (ls:(L word)list))` THEN + REWRITE_TAC[NUM_OF_WORDLIST_BOUND; LE_EXP; LE_REFL] THEN ASM_ARITH_TAC; + SIMP_TAC[]]; + MP_TAC(ISPECL [`ls:(L word)list`; `k:num`] NUM_OF_WORDLIST_EL) THEN + ASM_REWRITE_TAC[]]);; + +let NUM_OF_WORDLIST_FLATTEN = prove + (`!(ll:((N word) list) list) k. + ALL (\l. LENGTH l = k) ll /\ + dimindex(:N) * k = dimindex(:M) + ==> num_of_wordlist (ITLIST APPEND ll []) = + num_of_wordlist (MAP ((word:num->M word) o num_of_wordlist) ll)`, + LIST_INDUCT_TAC THEN REWRITE_TAC[ITLIST; MAP; num_of_wordlist; ALL] THEN + X_GEN_TAC `k:num` THEN STRIP_TAC THEN + FIRST_X_ASSUM(MP_TAC o SPEC `k:num`) THEN + ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + REWRITE_TAC[NUM_OF_WORDLIST_APPEND; num_of_wordlist; o_THM] THEN + ASM_REWRITE_TAC[] THEN + AP_THM_TAC THEN AP_TERM_TAC THEN + IMP_REWRITE_TAC [VAL_WORD_EXACT] THEN + TRANS_TAC LTE_TRANS `2 EXP (dimindex(:N) * LENGTH(h:(N word)list))` THEN + REWRITE_TAC[NUM_OF_WORDLIST_BOUND_LENGTH] THEN + ASM_REWRITE_TAC[LE_REFL]);; + +(* ========================================================================= *) +(* Helper lemmas: byte splitting *) +(* ========================================================================= *) + +let NUM_BIT_DECOMPOSE_UNIQ = prove( + `!a b t k. a < 2 EXP k + ==> (a + 2 EXP k * b = t <=> (a = t MOD 2 EXP k /\ b = t DIV 2 EXP k))`, + REPEAT STRIP_TAC THEN EQ_TAC THENL [ + DISCH_THEN (SUBST1_TAC o SYM) THEN + SIMP_TAC[MOD_MULT_ADD; DIV_MULT_ADD; EXP_EQ_0; ARITH_EQ] THEN + ASM_SIMP_TAC[MOD_LT; DIV_LT; ADD_CLAUSES]; + STRIP_TAC THEN + MP_TAC (SPECL [`t:num`; `2 EXP k`] DIVISION) THEN + SIMP_TAC[EXP_EQ_0; ARITH_EQ] THEN ASM_REWRITE_TAC[] THEN ARITH_TAC]);; + +let READ_BYTES_SPLIT_ANY = prove( + `read (bytes(a : int64,k+l)) s = t <=> + read (bytes(a,k)) s = t MOD 2 EXP (8*k) /\ + read (bytes(word_add a (word k), l)) s = t DIV 2 EXP (8*k)`, + let bound = prove(`read (bytes (a : int64,k)) s < 2 EXP (8*k)`, + REWRITE_TAC[READ_BYTES_BOUND]) in + REWRITE_TAC[GSYM VAL_EQ; VAL_READ_WBYTES; READ_COMPONENT_COMPOSE] THEN + REWRITE_TAC[READ_BYTES_COMBINE] THEN + REWRITE_TAC[MATCH_MP NUM_BIT_DECOMPOSE_UNIQ bound]);; + +(* ========================================================================= *) +(* Helper utilities: word subterm search and binary tree conversion *) +(* ========================================================================= *) + +let is_word_type_n n ty = + is_type ty && + let name, args = dest_type ty in + name = "word" && length args = 1 && + Num.int_of_num (dest_finty (hd args)) = n;; + +let rec find_word_subterm_n n tm = + if is_word_type_n n (type_of tm) then Some tm + else if is_comb tm then + match find_word_subterm_n n (rator tm) with + | Some t -> Some t + | None -> find_word_subterm_n n (rand tm) + else if is_abs tm then find_word_subterm_n n (body tm) + else None;; + +let BINOP_CONV_N n cv = + let rec go depth i tm = + if depth <= 0 then cv i tm + else + let half = 1 lsl (depth - 1) in + COMB2_CONV (RAND_CONV (go (depth-1) (half + i))) (go (depth-1) i) tm in + go n 0;; diff --git a/scripts/autogen b/scripts/autogen index b341f6c82..414bd72ea 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -1123,6 +1123,30 @@ def gen_aarch64_polyz_unpack_table(): update_file("dev/aarch64_clean/src/polyz_unpack_table.c", "\n".join(gen())) +def gen_hol_light_polyz_unpack_table(): + def gen(): + yield from gen_hol_light_header() + yield "(*" + yield " * TBL index tables for polyz_unpack_{17,19}" + yield " * See autogen for details." + yield " *)" + yield "" + for gamma1_bits in [17, 19]: + bit_width = gamma1_bits + 1 + name = f"mldsa_polyz_unpack_{gamma1_bits}_indices" + indices = list(gen_aarch64_polyz_unpack_indices(bit_width)) + yield f"let {name} = (REWRITE_RULE[MAP] o define)" + yield f" `{name}:byte list = MAP word [" + yield from print_hol_light_array(indices, as_int=False, pad=3) + yield "]`;;" + yield "" + + update_file( + "proofs/hol_light/aarch64/proofs/mldsa_polyz_unpack_consts.ml", + "\n".join(gen()), + ) + + def gen_avx2_rej_uniform_table_rows(): # The index into the lookup table is an 8-bit bitmap, i.e. a number 0..255. # Conceptually, the table entry at index i is a vector of 8 16-bit values, of @@ -2276,6 +2300,41 @@ def gen_hol_light_asm(): f"-Imldsa/src/native/aarch64/src {aarch64_flags}", "aarch64", ), + ( + "poly_chknorm_asm.S", + "mldsa_poly_chknorm.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), + ( + "poly_decompose_32_asm.S", + "mldsa_poly_decompose_32.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), + ( + "poly_decompose_88_asm.S", + "mldsa_poly_decompose_88.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), + ( + "polyz_unpack_17_asm.S", + "mldsa_polyz_unpack_17.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), + ( + "polyz_unpack_19_asm.S", + "mldsa_polyz_unpack_19.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), ] x86_64_flags = "-mavx2 -mbmi2 -msse4 -fcf-protection=full" @@ -3309,6 +3368,7 @@ def _main(): gen_aarch64_rej_uniform_table() gen_aarch64_rej_uniform_eta_table() gen_aarch64_polyz_unpack_table() + gen_hol_light_polyz_unpack_table() gen_avx2_hol_light_zeta_file() gen_avx2_zeta_file() gen_avx2_rej_uniform_table()