diff --git a/jaxite/jaxite_lib/bootstrap.py b/jaxite/jaxite_lib/bootstrap.py index f7e678b..e0ffcf9 100644 --- a/jaxite/jaxite_lib/bootstrap.py +++ b/jaxite/jaxite_lib/bootstrap.py @@ -379,7 +379,7 @@ def external_product( message=output, ) - +@jax.named_call @functools.partial(jax.jit, static_argnums=(2, 3)) def jit_external_product( rgsw_ct: jnp.ndarray, @@ -537,55 +537,67 @@ def jit_blind_rotate( # Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier # (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger # bootstrapping key to reduce the number of external products required by 1/2. + unroll_factor = None if use_bmmp: num_loop_terms = (coefficient_index.shape[0] - 1) // 2 + if num_loop_terms % 8 == 0: + unroll_factor = 8 else: num_loop_terms = coefficient_index.shape[0] - 1 - def one_external_product(j, c_prime_accum): + def _unrolled_loop_body(k, c_prime_accum): if use_bmmp: # Doing this computation inside the external product loop improves cache # locality, resulting in reduced data copying. - power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1] - power2 = coefficient_index[2 * j] - power3 = coefficient_index[2 * j + 1] + power1 = coefficient_index[2 * k] + coefficient_index[2 * k + 1] + power2 = coefficient_index[2 * k] + power3 = coefficient_index[2 * k + 1] bmmp_factor = ( matrix_utils.scale_by_x_power_n_minus_1( # Rotation. - power1, bsk[3 * j], log_modulus=log_coefficient_modulus + power1, bsk[3 * k], log_modulus=log_coefficient_modulus ) + matrix_utils.scale_by_x_power_n_minus_1( - power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus + power2, bsk[3 * k + 1], log_modulus=log_coefficient_modulus ) + matrix_utils.scale_by_x_power_n_minus_1( - power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus + power3, bsk[3 * k + 2], log_modulus=log_coefficient_modulus ) ).astype(jnp.uint32) - return c_prime_accum + jit_external_product( + # The external product is equivalent to a CMUX where the `else` branch is + # zero. + c_prime_accum = c_prime_accum + jit_external_product( rgsw_ct=bmmp_factor, rlwe_ct=c_prime_accum, decomposition_params=decomposition_params, use_bat=False, ) # c'_mul = c' * X^{a_j^tilde} (for each entry in c') + return c_prime_accum else: - # where a_j^tilde = coefficient_index[j] #Disabled BMMP + # where a_j^tilde = coefficient_index[k] #Disabled BMMP c_prime_mul = matrix_utils.monomial_mul_list( c_prime_accum, - coefficient_index[j], + coefficient_index[k], log_coefficient_modulus, ).astype(jnp.uint32) # Update c_prime with the output of the CMUX operation, where either # `c_prime` or `c_prime * X^{a_j^tilde}` is chosen by `bsk` at index j. return jit_cmux( - control=bsk[j], + control=bsk[k], eq_zero=c_prime_accum, neq_zero=c_prime_mul, decomposition_params=decomposition_params, use_bat=use_bat, ) - return jax.lax.fori_loop(0, num_loop_terms, one_external_product, c_prime) + return jax.lax.fori_loop( + 0, + num_loop_terms, + _unrolled_loop_body, + c_prime, + unroll=unroll_factor, + ) def sample_extract(ciphertext: rlwe.RlweCiphertext) -> types.LweCiphertext: diff --git a/jaxite/jaxite_lib/polymul_kernel.py b/jaxite/jaxite_lib/polymul_kernel.py index 53d69da..878a586 100644 --- a/jaxite/jaxite_lib/polymul_kernel.py +++ b/jaxite/jaxite_lib/polymul_kernel.py @@ -82,125 +82,129 @@ def bat_matmul(lhs: jax.Array, y: jax.Array): def _i32_matmul_unreduced_CGGI(lhs, rhs): - """Modified from i32_matmul_unreduced to incorporate CGGI tricks for better + """Performs a 32-bit integer matrix multiplication with CGGI optimization. - efficiency. + This function is an optimized version of i32_matmul_unreduced, incorporating + tricks from the CGGI paper to improve efficiency. + + Args: + lhs: The left-hand side matrix of shape (m, k). + rhs: The right-hand side matrix of shape (k, n). + + Returns: + The result of the matrix multiplication as a jnp.ndarray. """ lax = jax.lax m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1] - lhs_i8 = jnp.broadcast_to(lhs, (2, *lhs.shape)).reshape((4, m//2, k)) - lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8 - lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift) - lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape)) - lhs_i8 = lhs_i8.reshape((2 * m, k)) + # Optimized byte extraction for the CGGI trick. This splits the 32-bit + # integers in the `lhs` matrix into their 8-bit components. + byte0 = lax.bitwise_and(lhs, 0xFF) + byte1 = lax.bitwise_and(lax.shift_right_logical(lhs, 8), 0xFF) + byte2 = lax.bitwise_and(lax.shift_right_logical(lhs, 16), 0xFF) + byte3 = lax.bitwise_and(lax.shift_right_logical(lhs, 24), 0xFF) + + # Concatenate the bytes to form a new matrix. + lhs_i8 = jnp.concatenate( + [byte0[: m // 2, :], byte1[m // 2 :, :], byte2[: m // 2, :], byte3[m // 2 :, :]], + axis=0, + ) + + # rhs is the Toeplitz matrix of the decomposed vector. The values are small + # (e.g. < 64) because of the decomposition base. We can cast to bfloat16 + # directly without splitting into bytes, reducing the number of matmuls by 4x. + # Perform the matrix multiplication. Use float32 for higher precision to + # avoid potential correctness issues with bfloat16 for intermediate sums, + # especially given the range of input values and dimensions in FHE. + # Perform the matrix multiplication using bfloat16 for inputs to leverage + # TPU's optimized bfloat16 hardware, while ensuring float32 accumulation + # for precision. The rhs values are small (< 64) and lhs_i8 values (0-255) + # are exactly representable in bfloat16. + raw_out = lax.dot( + lhs_i8.astype(jnp.bfloat16), + rhs.astype(jnp.bfloat16), + preferred_element_type=jnp.float32, # Ensures accumulation in float32 + ).astype(jnp.int32) + raw_out = raw_out.reshape((4, m // 2, n)) + + # Reconstruct the 32-bit integers from the 8-bit components. This is done by + # shifting the bytes back to their original positions and summing them up. out_shift_base = lax.mul( - lax.broadcasted_iota(jnp.int32, (4, m//2, n), dimension=0), 8 + lax.broadcasted_iota(jnp.int32, (4, m // 2, n), dimension=0), 8 ) - acc = jnp.zeros((m//2, n), dtype=jnp.int32) - for rhs_shift in range(0, 32, 8): - # TODO(b/201562458): Don't multiply lhs rows with large shift. - rhs_i8 = lax.shift_right_logical( - rhs, jnp.broadcast_to(rhs_shift, rhs.shape) - ) - rhs_i8 = lax.bitwise_and(rhs_i8, jnp.broadcast_to(0xFF, rhs_i8.shape)) - # TODO(b/201562458): Use int8 matmuls once properly supported - raw_out = lax.dot( - lhs_i8.astype(jnp.bfloat16), - rhs_i8.astype(jnp.bfloat16), - preferred_element_type=jnp.float32, - ).astype(jnp.int32).reshape((4, m//2, n)) - raw_out = jnp.left_shift(raw_out, out_shift_base + rhs_shift) - acc += raw_out[0] + raw_out[1] + raw_out[2] + raw_out[3] - return acc + + shifted_out = jnp.left_shift(raw_out, out_shift_base) + return shifted_out[0] + shifted_out[1] + shifted_out[2] + shifted_out[3] def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray): - # b is the product of the RLWE dimension (e.g., 3) and the number of - # decomposition levels in the decomposition parameters (e.g., 6). - # n is the degree of the RLWE polynomials. + """Computes the polynomial multiplication of a vector and a matrix. + + This function is optimized for TPU execution using Pallas. + + Args: + poly_vec1: A vector of polynomials. + poly_mat2: A matrix of polynomials. + + Returns: + The result of the polynomial multiplication. + """ b, n = poly_vec1.shape - # m is the number of polynomials in the RLWE dimension (e.g., 3) b2, m, n2 = poly_mat2.shape assert b == b2 and n == n2 - # We must pad m to 8 because the TPU register sublane has size 8, and more - # importantly, many of the pallas instructions like pltpu.roll will fail - # if the sublane size is not a multiple of 8. This further adds the assumption - # that the value of m is < 8. We are unlikely to need m > 8 for the - # foreseeable future, but if we did, we would need to round up to the next - # multiple of 8. + # We optimize m to be 2 * real_m to fully utilize the CGGI split trick without + # unnecessary padding. real_m = m - m = 8 + m = 2 * real_m + if m % 4 != 0: + m += 4 - (m % 4) + poly_mat2 = jnp.pad( poly_mat2, ((0, 0), (0, (m // 2) - real_m), (0, 0)), - mode="constant", + mode='constant', constant_values=(0,), ) poly_mat2 = jnp.concatenate((poly_mat2, poly_mat2), axis=(1)) if n % 128 != 0: - raise ValueError(f"Input size {n} is not a multiple of 128") - dtype = poly_vec1.dtype - - def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref): - chunk = jnp.broadcast_to(vec_ref[...], (128, n)) - chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) - chunk_row_indices = jax.lax.broadcasted_iota( - dtype=jnp.int32, shape=(128, n), dimension=0 - ) - chunk_col_indices = jax.lax.broadcasted_iota( - dtype=jnp.int32, shape=(128, n), dimension=1 - ) - toeplitz_chunks = [] - for _ in range(0, n, 128): - toeplitz_chunks.append( - jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk) - ) - # Because the vector registers are aligned to size 128, this roll - # operation lowers to telling the TPU to refer to a different register, - # rather than actually applying any rolling operation. Hence, the op - # produces no hardware instructions. - chunk = pltpu.roll(chunk, 128, 1) - chunk_row_indices = chunk_row_indices + 128 - vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0) - - assert vec_toeplitz.shape == (n, n) - result = _i32_matmul_unreduced_CGGI(mat_ref[...], vec_toeplitz) - assert result.shape == (m // 2, n), result.shape - out_ref[...] = result + raise ValueError(f'Input size {n} is not a multiple of 128') def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref): - for b in range(vec_ref.shape[0]): - vec_mat_polymul_kernel_single_batch( - vec_ref.at[b], mat_ref.at[b], out_ref.at[b] + """Pallas kernel for polynomial multiplication.""" + for b_i in range(vec_ref.shape[0]): + chunk = jnp.broadcast_to(vec_ref[b_i, ...], (128, n)) + chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) + chunk_row_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=0 ) + chunk_col_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=1 + ) + toeplitz_chunks = [] + for _ in range(0, n, 128): + toeplitz_chunks.append( + jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk) + ) + chunk = pltpu.roll(chunk, 128, 1) + chunk_row_indices = chunk_row_indices + 128 + vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0) - block_b = 2 - steps_b, rem_b = divmod(b, block_b) - if rem_b: - raise ValueError(f"b={b} is not a multiple of block_b={block_b}") + result = _i32_matmul_unreduced_CGGI(mat_ref[b_i, ...], vec_toeplitz) + out_ref[b_i, ...] = result return jnp.sum( pl.pallas_call( vec_mat_polymul_kernel, in_specs=( - pl.BlockSpec((block_b, 1, n), lambda b: (b, 0, 0)), - pl.BlockSpec((block_b, m, n), lambda b: (b, 0, 0)), + pl.BlockSpec((b, 1, n), lambda i: (i, 0, 0)), + pl.BlockSpec((b, m, n), lambda i: (i, 0, 0)), ), - out_specs=pl.BlockSpec((block_b, m // 2, n), lambda b: (b, 0, 0)), + out_specs=pl.BlockSpec((b, m // 2, n), lambda i: (i, 0, 0)), out_shape=jax.ShapeDtypeStruct((b, m // 2, n), jnp.int32), - grid=(steps_b,), - compiler_params=pltpu.CompilerParams( - # Set the vem limit to 32 MiB, it could be up to 128 MiB. - vmem_limit_bytes=int(2**10 * 10**15) - ), - )( - poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32) - ).reshape( - b, m // 2, n - ), - axis=(0,), + grid=(1,), + )(poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)), + axis=0, ).astype(jnp.uint32)[:real_m] diff --git a/jaxite/jaxite_lib/polymul_kernel_test.py b/jaxite/jaxite_lib/polymul_kernel_test.py index 6d6f95a..b300dcf 100644 --- a/jaxite/jaxite_lib/polymul_kernel_test.py +++ b/jaxite/jaxite_lib/polymul_kernel_test.py @@ -8,9 +8,9 @@ _SEEDS = list(range(3)) -def random(shape, dtype=np.int32): +def random(shape, dtype=np.int32, high=2**31 - 1): return jnp.array( - np.random.randint(low=0, high=2**31 - 1, size=shape, dtype=dtype) + np.random.randint(low=0, high=high, size=shape, dtype=dtype) ) @@ -26,7 +26,7 @@ def test_i32_matmul_vs_reference(self, seed: int): np.testing.assert_array_equal(expected, actual) def test_vector_matrix_vs_reference(self): - vector = random(shape=(18, 512)) + vector = random(shape=(18, 512), high=2**8 - 1).astype(jnp.uint32) matrix = random(shape=(18, 3, 512)) expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix) actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix) @@ -37,7 +37,7 @@ def test_vector_matrix_vs_reference(self): ) def test_many_seeds(self, seed: int): np.random.seed(seed) - vector = random(shape=(18, 512), dtype=jnp.uint32) + vector = random(shape=(18, 512), high=2**8 - 1).astype(jnp.uint32) matrix = random(shape=(18, 3, 512), dtype=jnp.uint32) expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix) actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix)