From f012802437b21af6bc4cb88927e521f3f78455fc Mon Sep 17 00:00:00 2001 From: Shruthi Gorantala Date: Thu, 23 Oct 2025 23:35:26 -0700 Subject: [PATCH] Improvements to the vector_polymul_kernel PiperOrigin-RevId: 823374879 --- jaxite/jaxite_lib/polymul_kernel.py | 171 +++++++++++++++------------- 1 file changed, 92 insertions(+), 79 deletions(-) diff --git a/jaxite/jaxite_lib/polymul_kernel.py b/jaxite/jaxite_lib/polymul_kernel.py index 53d69da..44f4fa4 100644 --- a/jaxite/jaxite_lib/polymul_kernel.py +++ b/jaxite/jaxite_lib/polymul_kernel.py @@ -88,31 +88,39 @@ def _i32_matmul_unreduced_CGGI(lhs, rhs): """ lax = jax.lax m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1] - lhs_i8 = jnp.broadcast_to(lhs, (2, *lhs.shape)).reshape((4, m//2, k)) - lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8 - lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift) - lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape)) - lhs_i8 = lhs_i8.reshape((2 * m, k)) - out_shift_base = lax.mul( - lax.broadcasted_iota(jnp.int32, (4, m//2, n), dimension=0), 8 - ) - acc = jnp.zeros((m//2, n), dtype=jnp.int32) - for rhs_shift in range(0, 32, 8): - # TODO(b/201562458): Don't multiply lhs rows with large shift. - rhs_i8 = lax.shift_right_logical( - rhs, jnp.broadcast_to(rhs_shift, rhs.shape) - ) + m_actual = m // 2 + lhs_effective = lhs[:m_actual] + + lhs_parts = [ + (lax.bitwise_and(lax.shift_right_logical(lhs_effective, s), 0xFF)).astype( + jnp.bfloat16 + ) + for s in [0, 8, 16, 24] + ] + lhs_stacked = jnp.stack(lhs_parts, axis=0).reshape(-1, k) + + acc = jnp.zeros((m_actual, n), dtype=jnp.int64) + for rhs_shift_amount in range(0, 32, 8): + rhs_i8 = lax.shift_right_logical(rhs, rhs_shift_amount) rhs_i8 = lax.bitwise_and(rhs_i8, jnp.broadcast_to(0xFF, rhs_i8.shape)) - # TODO(b/201562458): Use int8 matmuls once properly supported - raw_out = lax.dot( - lhs_i8.astype(jnp.bfloat16), - rhs_i8.astype(jnp.bfloat16), - preferred_element_type=jnp.float32, - ).astype(jnp.int32).reshape((4, m//2, n)) - raw_out = jnp.left_shift(raw_out, out_shift_base + rhs_shift) - acc += raw_out[0] + raw_out[1] + raw_out[2] + raw_out[3] - return acc + + raw_products_stacked = lax.dot( + lhs_stacked, rhs_i8.astype(jnp.bfloat16), preferred_element_type=jnp.float32 + ).astype(jnp.int64) # (4 * m_actual, n) + + # Unstack to (4, m_actual, n) + raw_products_unstacked = raw_products_stacked.reshape(4, m_actual, n) + + # Combine with corresponding lhs_byte_shift_amount + for lhs_byte_shift_idx in range(4): + lhs_byte_shift_amount = lhs_byte_shift_idx * 8 + acc += jnp.left_shift( + raw_products_unstacked[lhs_byte_shift_idx], + lhs_byte_shift_amount + rhs_shift_amount, + ) + + return acc.astype(jnp.int32) def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray): @@ -121,87 +129,92 @@ def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray): # n is the degree of the RLWE polynomials. b, n = poly_vec1.shape # m is the number of polynomials in the RLWE dimension (e.g., 3) - b2, m, n2 = poly_mat2.shape + b2, real_m, n2 = poly_mat2.shape assert b == b2 and n == n2 - - # We must pad m to 8 because the TPU register sublane has size 8, and more - # importantly, many of the pallas instructions like pltpu.roll will fail - # if the sublane size is not a multiple of 8. This further adds the assumption - # that the value of m is < 8. We are unlikely to need m > 8 for the - # foreseeable future, but if we did, we would need to round up to the next - # multiple of 8. - real_m = m - m = 8 - poly_mat2 = jnp.pad( - poly_mat2, - ((0, 0), (0, (m // 2) - real_m), (0, 0)), - mode="constant", - constant_values=(0,), - ) - poly_mat2 = jnp.concatenate((poly_mat2, poly_mat2), axis=(1)) + # We must pad real_m up to 8 because the TPU register sublane has size 8, + # and pallas instructions like pltpu.roll may fail if the sublane size is + # not a multiple of 8. + m_padded = 8 if n % 128 != 0: raise ValueError(f"Input size {n} is not a multiple of 128") dtype = poly_vec1.dtype def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref): chunk = jnp.broadcast_to(vec_ref[...], (128, n)) + # The first roll prepares the initial 128xN block for the toeplitz matrix. chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) - chunk_row_indices = jax.lax.broadcasted_iota( - dtype=jnp.int32, shape=(128, n), dimension=0 + + toeplitz_chunks = [] + + base_local_row_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, 1), dimension=0 ) - chunk_col_indices = jax.lax.broadcasted_iota( - dtype=jnp.int32, shape=(128, n), dimension=1 + # `col_indices` are 0 to N-1, broadcasted to (1, n). + col_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(1, n), dimension=1 ) - toeplitz_chunks = [] - for _ in range(0, n, 128): - toeplitz_chunks.append( - jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk) + + # Initialize the mask for the first chunk (row_start_idx = 0) + initial_condition_val = (base_local_row_indices > col_indices).astype(dtype) + local_negacyclic_sign_mask = 1 - 2 * initial_condition_val + + for row_start_idx_offset in range(0, n, 128): + + current_global_row_offset = jnp.full( + (128, 1), row_start_idx_offset, dtype=jnp.int32 ) - # Because the vector registers are aligned to size 128, this roll - # operation lowers to telling the TPU to refer to a different register, - # rather than actually applying any rolling operation. Hence, the op - # produces no hardware instructions. + adjusted_col_indices = col_indices - current_global_row_offset + + # Generate the mask using jnp.where for explicit conditional selection. + condition = base_local_row_indices > adjusted_col_indices + local_negacyclic_sign_mask = jnp.where(condition, -1, 1).astype(dtype) + + # Apply the negacyclic sign flip for the current 128xN chunk. + toeplitz_chunks.append(chunk * local_negacyclic_sign_mask) + + # This roll operation is optimized on TPU. chunk = pltpu.roll(chunk, 128, 1) - chunk_row_indices = chunk_row_indices + 128 vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0) assert vec_toeplitz.shape == (n, n) - result = _i32_matmul_unreduced_CGGI(mat_ref[...], vec_toeplitz) - assert result.shape == (m // 2, n), result.shape - out_ref[...] = result + + # Pad mat_ref to m_padded rows for the CGGI matmul + padded_mat_ref = jnp.pad( + mat_ref[...], + ((0, m_padded - real_m), (0, 0)), + mode="constant", + constant_values=(0,), + ) + # Duplicate rows for CGGI trick + padded_mat_ref = jnp.concatenate((padded_mat_ref, padded_mat_ref), axis=(0)) + + result = _i32_matmul_unreduced_CGGI(padded_mat_ref, vec_toeplitz) + assert result.shape == (m_padded, n), result.shape + out_ref[...] = result[:real_m] # Store only the unpadded result def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref): - for b in range(vec_ref.shape[0]): + # Process all batches in a single kernel launch to reduce overhead. + for i in range(b): vec_mat_polymul_kernel_single_batch( - vec_ref.at[b], mat_ref.at[b], out_ref.at[b] + vec_ref.at[i, 0], mat_ref.at[i], out_ref.at[i] ) - block_b = 2 - steps_b, rem_b = divmod(b, block_b) - if rem_b: - raise ValueError(f"b={b} is not a multiple of block_b={block_b}") - return jnp.sum( pl.pallas_call( vec_mat_polymul_kernel, - in_specs=( - pl.BlockSpec((block_b, 1, n), lambda b: (b, 0, 0)), - pl.BlockSpec((block_b, m, n), lambda b: (b, 0, 0)), - ), - out_specs=pl.BlockSpec((block_b, m // 2, n), lambda b: (b, 0, 0)), - out_shape=jax.ShapeDtypeStruct((b, m // 2, n), jnp.int32), - grid=(steps_b,), + out_shape=jax.ShapeDtypeStruct((b, real_m, n), jnp.int32), + grid=(1,), # Single kernel launch for all batches. + in_specs=[ + pl.BlockSpec((b, 1, n), lambda _: (0, 0, 0)), + pl.BlockSpec((b, real_m, n), lambda _: (0, 0, 0)), + ], + out_specs=pl.BlockSpec((b, real_m, n), lambda _: (0, 0, 0)), compiler_params=pltpu.CompilerParams( - # Set the vem limit to 32 MiB, it could be up to 128 MiB. - vmem_limit_bytes=int(2**10 * 10**15) + vmem_limit_bytes=64 * 1024 * 1024 ), - )( - poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32) - ).reshape( - b, m // 2, n - ), - axis=(0,), - ).astype(jnp.uint32)[:real_m] + )(poly_vec1[:, None], poly_mat2), + axis=0, + ).astype(jnp.uint32) @jax.named_call