Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def external_product(
message=output,
)


@jax.named_call
@functools.partial(jax.jit, static_argnums=(2, 3))
def jit_external_product(
rgsw_ct: jnp.ndarray,
Expand Down Expand Up @@ -537,55 +537,67 @@ def jit_blind_rotate(
# Using the improved blind rotate from Bourse-Minelli-Minihold-Paillier
# (BMMP17: https://eprint.iacr.org/2017/1114), a trick uses a larger
# bootstrapping key to reduce the number of external products required by 1/2.
unroll_factor = None
if use_bmmp:
num_loop_terms = (coefficient_index.shape[0] - 1) // 2
if num_loop_terms % 8 == 0:
unroll_factor = 8
else:
num_loop_terms = coefficient_index.shape[0] - 1

def one_external_product(j, c_prime_accum):
def _unrolled_loop_body(k, c_prime_accum):
if use_bmmp:
# Doing this computation inside the external product loop improves cache
# locality, resulting in reduced data copying.
power1 = coefficient_index[2 * j] + coefficient_index[2 * j + 1]
power2 = coefficient_index[2 * j]
power3 = coefficient_index[2 * j + 1]
power1 = coefficient_index[2 * k] + coefficient_index[2 * k + 1]
power2 = coefficient_index[2 * k]
power3 = coefficient_index[2 * k + 1]
bmmp_factor = (
matrix_utils.scale_by_x_power_n_minus_1( # Rotation.
power1, bsk[3 * j], log_modulus=log_coefficient_modulus
power1, bsk[3 * k], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power2, bsk[3 * j + 1], log_modulus=log_coefficient_modulus
power2, bsk[3 * k + 1], log_modulus=log_coefficient_modulus
)
+ matrix_utils.scale_by_x_power_n_minus_1(
power3, bsk[3 * j + 2], log_modulus=log_coefficient_modulus
power3, bsk[3 * k + 2], log_modulus=log_coefficient_modulus
)
).astype(jnp.uint32)
return c_prime_accum + jit_external_product(
# The external product is equivalent to a CMUX where the `else` branch is
# zero.
c_prime_accum = c_prime_accum + jit_external_product(
rgsw_ct=bmmp_factor,
rlwe_ct=c_prime_accum,
decomposition_params=decomposition_params,
use_bat=False,
)
# c'_mul = c' * X^{a_j^tilde} (for each entry in c')
return c_prime_accum
else:
# where a_j^tilde = coefficient_index[j] #Disabled BMMP
# where a_j^tilde = coefficient_index[k] #Disabled BMMP
c_prime_mul = matrix_utils.monomial_mul_list(
c_prime_accum,
coefficient_index[j],
coefficient_index[k],
log_coefficient_modulus,
).astype(jnp.uint32)

# Update c_prime with the output of the CMUX operation, where either
# `c_prime` or `c_prime * X^{a_j^tilde}` is chosen by `bsk` at index j.
return jit_cmux(
control=bsk[j],
control=bsk[k],
eq_zero=c_prime_accum,
neq_zero=c_prime_mul,
decomposition_params=decomposition_params,
use_bat=use_bat,
)

return jax.lax.fori_loop(0, num_loop_terms, one_external_product, c_prime)
return jax.lax.fori_loop(
0,
num_loop_terms,
_unrolled_loop_body,
c_prime,
unroll=unroll_factor,
)


def sample_extract(ciphertext: rlwe.RlweCiphertext) -> types.LweCiphertext:
Expand Down
176 changes: 90 additions & 86 deletions jaxite/jaxite_lib/polymul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,125 +82,129 @@ def bat_matmul(lhs: jax.Array, y: jax.Array):


def _i32_matmul_unreduced_CGGI(lhs, rhs):
"""Modified from i32_matmul_unreduced to incorporate CGGI tricks for better
"""Performs a 32-bit integer matrix multiplication with CGGI optimization.

efficiency.
This function is an optimized version of i32_matmul_unreduced, incorporating
tricks from the CGGI paper to improve efficiency.

Args:
lhs: The left-hand side matrix of shape (m, k).
rhs: The right-hand side matrix of shape (k, n).

Returns:
The result of the matrix multiplication as a jnp.ndarray.
"""
lax = jax.lax
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1]
lhs_i8 = jnp.broadcast_to(lhs, (2, *lhs.shape)).reshape((4, m//2, k))
lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8
lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift)
lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape))
lhs_i8 = lhs_i8.reshape((2 * m, k))

# Optimized byte extraction for the CGGI trick. This splits the 32-bit
# integers in the `lhs` matrix into their 8-bit components.
byte0 = lax.bitwise_and(lhs, 0xFF)
byte1 = lax.bitwise_and(lax.shift_right_logical(lhs, 8), 0xFF)
byte2 = lax.bitwise_and(lax.shift_right_logical(lhs, 16), 0xFF)
byte3 = lax.bitwise_and(lax.shift_right_logical(lhs, 24), 0xFF)

# Concatenate the bytes to form a new matrix.
lhs_i8 = jnp.concatenate(
[byte0[: m // 2, :], byte1[m // 2 :, :], byte2[: m // 2, :], byte3[m // 2 :, :]],
axis=0,
)

# rhs is the Toeplitz matrix of the decomposed vector. The values are small
# (e.g. < 64) because of the decomposition base. We can cast to bfloat16
# directly without splitting into bytes, reducing the number of matmuls by 4x.
# Perform the matrix multiplication. Use float32 for higher precision to
# avoid potential correctness issues with bfloat16 for intermediate sums,
# especially given the range of input values and dimensions in FHE.
# Perform the matrix multiplication using bfloat16 for inputs to leverage
# TPU's optimized bfloat16 hardware, while ensuring float32 accumulation
# for precision. The rhs values are small (< 64) and lhs_i8 values (0-255)
# are exactly representable in bfloat16.
raw_out = lax.dot(
lhs_i8.astype(jnp.bfloat16),
rhs.astype(jnp.bfloat16),
preferred_element_type=jnp.float32, # Ensures accumulation in float32
).astype(jnp.int32)
raw_out = raw_out.reshape((4, m // 2, n))

# Reconstruct the 32-bit integers from the 8-bit components. This is done by
# shifting the bytes back to their original positions and summing them up.
out_shift_base = lax.mul(
lax.broadcasted_iota(jnp.int32, (4, m//2, n), dimension=0), 8
lax.broadcasted_iota(jnp.int32, (4, m // 2, n), dimension=0), 8
)
acc = jnp.zeros((m//2, n), dtype=jnp.int32)
for rhs_shift in range(0, 32, 8):
# TODO(b/201562458): Don't multiply lhs rows with large shift.
rhs_i8 = lax.shift_right_logical(
rhs, jnp.broadcast_to(rhs_shift, rhs.shape)
)
rhs_i8 = lax.bitwise_and(rhs_i8, jnp.broadcast_to(0xFF, rhs_i8.shape))
# TODO(b/201562458): Use int8 matmuls once properly supported
raw_out = lax.dot(
lhs_i8.astype(jnp.bfloat16),
rhs_i8.astype(jnp.bfloat16),
preferred_element_type=jnp.float32,
).astype(jnp.int32).reshape((4, m//2, n))
raw_out = jnp.left_shift(raw_out, out_shift_base + rhs_shift)
acc += raw_out[0] + raw_out[1] + raw_out[2] + raw_out[3]
return acc

shifted_out = jnp.left_shift(raw_out, out_shift_base)
return shifted_out[0] + shifted_out[1] + shifted_out[2] + shifted_out[3]


def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
# b is the product of the RLWE dimension (e.g., 3) and the number of
# decomposition levels in the decomposition parameters (e.g., 6).
# n is the degree of the RLWE polynomials.
"""Computes the polynomial multiplication of a vector and a matrix.

This function is optimized for TPU execution using Pallas.

Args:
poly_vec1: A vector of polynomials.
poly_mat2: A matrix of polynomials.

Returns:
The result of the polynomial multiplication.
"""
b, n = poly_vec1.shape
# m is the number of polynomials in the RLWE dimension (e.g., 3)
b2, m, n2 = poly_mat2.shape
assert b == b2 and n == n2

# We must pad m to 8 because the TPU register sublane has size 8, and more
# importantly, many of the pallas instructions like pltpu.roll will fail
# if the sublane size is not a multiple of 8. This further adds the assumption
# that the value of m is < 8. We are unlikely to need m > 8 for the
# foreseeable future, but if we did, we would need to round up to the next
# multiple of 8.
# We optimize m to be 2 * real_m to fully utilize the CGGI split trick without
# unnecessary padding.
real_m = m
m = 8
m = 2 * real_m
if m % 4 != 0:
m += 4 - (m % 4)

poly_mat2 = jnp.pad(
poly_mat2,
((0, 0), (0, (m // 2) - real_m), (0, 0)),
mode="constant",
mode='constant',
constant_values=(0,),
)
poly_mat2 = jnp.concatenate((poly_mat2, poly_mat2), axis=(1))
if n % 128 != 0:
raise ValueError(f"Input size {n} is not a multiple of 128")
dtype = poly_vec1.dtype

def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref):
chunk = jnp.broadcast_to(vec_ref[...], (128, n))
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
)
toeplitz_chunks = []
for _ in range(0, n, 128):
toeplitz_chunks.append(
jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk)
)
# Because the vector registers are aligned to size 128, this roll
# operation lowers to telling the TPU to refer to a different register,
# rather than actually applying any rolling operation. Hence, the op
# produces no hardware instructions.
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128
vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0)

assert vec_toeplitz.shape == (n, n)
result = _i32_matmul_unreduced_CGGI(mat_ref[...], vec_toeplitz)
assert result.shape == (m // 2, n), result.shape
out_ref[...] = result
raise ValueError(f'Input size {n} is not a multiple of 128')

def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref):
for b in range(vec_ref.shape[0]):
vec_mat_polymul_kernel_single_batch(
vec_ref.at[b], mat_ref.at[b], out_ref.at[b]
"""Pallas kernel for polynomial multiplication."""
for b_i in range(vec_ref.shape[0]):
chunk = jnp.broadcast_to(vec_ref[b_i, ...], (128, n))
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
)
toeplitz_chunks = []
for _ in range(0, n, 128):
toeplitz_chunks.append(
jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk)
)
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128
vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0)

block_b = 2
steps_b, rem_b = divmod(b, block_b)
if rem_b:
raise ValueError(f"b={b} is not a multiple of block_b={block_b}")
result = _i32_matmul_unreduced_CGGI(mat_ref[b_i, ...], vec_toeplitz)
out_ref[b_i, ...] = result

return jnp.sum(
pl.pallas_call(
vec_mat_polymul_kernel,
in_specs=(
pl.BlockSpec((block_b, 1, n), lambda b: (b, 0, 0)),
pl.BlockSpec((block_b, m, n), lambda b: (b, 0, 0)),
pl.BlockSpec((b, 1, n), lambda i: (i, 0, 0)),
pl.BlockSpec((b, m, n), lambda i: (i, 0, 0)),
),
out_specs=pl.BlockSpec((block_b, m // 2, n), lambda b: (b, 0, 0)),
out_specs=pl.BlockSpec((b, m // 2, n), lambda i: (i, 0, 0)),
out_shape=jax.ShapeDtypeStruct((b, m // 2, n), jnp.int32),
grid=(steps_b,),
compiler_params=pltpu.CompilerParams(
# Set the vem limit to 32 MiB, it could be up to 128 MiB.
vmem_limit_bytes=int(2**10 * 10**15)
),
)(
poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)
).reshape(
b, m // 2, n
),
axis=(0,),
grid=(1,),
)(poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)),
axis=0,
).astype(jnp.uint32)[:real_m]


Expand Down
8 changes: 4 additions & 4 deletions jaxite/jaxite_lib/polymul_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
_SEEDS = list(range(3))


def random(shape, dtype=np.int32):
def random(shape, dtype=np.int32, high=2**31 - 1):
return jnp.array(
np.random.randint(low=0, high=2**31 - 1, size=shape, dtype=dtype)
np.random.randint(low=0, high=high, size=shape, dtype=dtype)
)


Expand All @@ -26,7 +26,7 @@ def test_i32_matmul_vs_reference(self, seed: int):
np.testing.assert_array_equal(expected, actual)

def test_vector_matrix_vs_reference(self):
vector = random(shape=(18, 512))
vector = random(shape=(18, 512), high=2**8 - 1).astype(jnp.uint32)
matrix = random(shape=(18, 3, 512))
expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix)
actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix)
Expand All @@ -37,7 +37,7 @@ def test_vector_matrix_vs_reference(self):
)
def test_many_seeds(self, seed: int):
np.random.seed(seed)
vector = random(shape=(18, 512), dtype=jnp.uint32)
vector = random(shape=(18, 512), high=2**8 - 1).astype(jnp.uint32)
matrix = random(shape=(18, 3, 512), dtype=jnp.uint32)
expected = polymul_kernel.fallback_vector_matrix_polymul(vector, matrix)
actual = polymul_kernel.negacyclic_vector_matrix_polymul(vector, matrix)
Expand Down
Loading