Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 92 additions & 79 deletions jaxite/jaxite_lib/polymul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,39 @@ def _i32_matmul_unreduced_CGGI(lhs, rhs):
"""
lax = jax.lax
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1]
lhs_i8 = jnp.broadcast_to(lhs, (2, *lhs.shape)).reshape((4, m//2, k))
lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8
lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift)
lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape))
lhs_i8 = lhs_i8.reshape((2 * m, k))

out_shift_base = lax.mul(
lax.broadcasted_iota(jnp.int32, (4, m//2, n), dimension=0), 8
)
acc = jnp.zeros((m//2, n), dtype=jnp.int32)
for rhs_shift in range(0, 32, 8):
# TODO(b/201562458): Don't multiply lhs rows with large shift.
rhs_i8 = lax.shift_right_logical(
rhs, jnp.broadcast_to(rhs_shift, rhs.shape)
)
m_actual = m // 2
lhs_effective = lhs[:m_actual]

lhs_parts = [
(lax.bitwise_and(lax.shift_right_logical(lhs_effective, s), 0xFF)).astype(
jnp.bfloat16
)
for s in [0, 8, 16, 24]
]
lhs_stacked = jnp.stack(lhs_parts, axis=0).reshape(-1, k)

acc = jnp.zeros((m_actual, n), dtype=jnp.int64)
for rhs_shift_amount in range(0, 32, 8):
rhs_i8 = lax.shift_right_logical(rhs, rhs_shift_amount)
rhs_i8 = lax.bitwise_and(rhs_i8, jnp.broadcast_to(0xFF, rhs_i8.shape))
# TODO(b/201562458): Use int8 matmuls once properly supported
raw_out = lax.dot(
lhs_i8.astype(jnp.bfloat16),
rhs_i8.astype(jnp.bfloat16),
preferred_element_type=jnp.float32,
).astype(jnp.int32).reshape((4, m//2, n))
raw_out = jnp.left_shift(raw_out, out_shift_base + rhs_shift)
acc += raw_out[0] + raw_out[1] + raw_out[2] + raw_out[3]
return acc

raw_products_stacked = lax.dot(
lhs_stacked, rhs_i8.astype(jnp.bfloat16), preferred_element_type=jnp.float32
).astype(jnp.int64) # (4 * m_actual, n)

# Unstack to (4, m_actual, n)
raw_products_unstacked = raw_products_stacked.reshape(4, m_actual, n)

# Combine with corresponding lhs_byte_shift_amount
for lhs_byte_shift_idx in range(4):
lhs_byte_shift_amount = lhs_byte_shift_idx * 8
acc += jnp.left_shift(
raw_products_unstacked[lhs_byte_shift_idx],
lhs_byte_shift_amount + rhs_shift_amount,
)

return acc.astype(jnp.int32)


def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
Expand All @@ -121,87 +129,92 @@ def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
# n is the degree of the RLWE polynomials.
b, n = poly_vec1.shape
# m is the number of polynomials in the RLWE dimension (e.g., 3)
b2, m, n2 = poly_mat2.shape
b2, real_m, n2 = poly_mat2.shape
assert b == b2 and n == n2

# We must pad m to 8 because the TPU register sublane has size 8, and more
# importantly, many of the pallas instructions like pltpu.roll will fail
# if the sublane size is not a multiple of 8. This further adds the assumption
# that the value of m is < 8. We are unlikely to need m > 8 for the
# foreseeable future, but if we did, we would need to round up to the next
# multiple of 8.
real_m = m
m = 8
poly_mat2 = jnp.pad(
poly_mat2,
((0, 0), (0, (m // 2) - real_m), (0, 0)),
mode="constant",
constant_values=(0,),
)
poly_mat2 = jnp.concatenate((poly_mat2, poly_mat2), axis=(1))
# We must pad real_m up to 8 because the TPU register sublane has size 8,
# and pallas instructions like pltpu.roll may fail if the sublane size is
# not a multiple of 8.
m_padded = 8
if n % 128 != 0:
raise ValueError(f"Input size {n} is not a multiple of 128")
dtype = poly_vec1.dtype

def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref):
chunk = jnp.broadcast_to(vec_ref[...], (128, n))
# The first roll prepares the initial 128xN block for the toeplitz matrix.
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0

toeplitz_chunks = []

base_local_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, 1), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
# `col_indices` are 0 to N-1, broadcasted to (1, n).
col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(1, n), dimension=1
)
toeplitz_chunks = []
for _ in range(0, n, 128):
toeplitz_chunks.append(
jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk)

# Initialize the mask for the first chunk (row_start_idx = 0)
initial_condition_val = (base_local_row_indices > col_indices).astype(dtype)
local_negacyclic_sign_mask = 1 - 2 * initial_condition_val

for row_start_idx_offset in range(0, n, 128):

current_global_row_offset = jnp.full(
(128, 1), row_start_idx_offset, dtype=jnp.int32
)
# Because the vector registers are aligned to size 128, this roll
# operation lowers to telling the TPU to refer to a different register,
# rather than actually applying any rolling operation. Hence, the op
# produces no hardware instructions.
adjusted_col_indices = col_indices - current_global_row_offset

# Generate the mask using jnp.where for explicit conditional selection.
condition = base_local_row_indices > adjusted_col_indices
local_negacyclic_sign_mask = jnp.where(condition, -1, 1).astype(dtype)

# Apply the negacyclic sign flip for the current 128xN chunk.
toeplitz_chunks.append(chunk * local_negacyclic_sign_mask)

# This roll operation is optimized on TPU.
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128
vec_toeplitz = jax.lax.concatenate(toeplitz_chunks, dimension=0)

assert vec_toeplitz.shape == (n, n)
result = _i32_matmul_unreduced_CGGI(mat_ref[...], vec_toeplitz)
assert result.shape == (m // 2, n), result.shape
out_ref[...] = result

# Pad mat_ref to m_padded rows for the CGGI matmul
padded_mat_ref = jnp.pad(
mat_ref[...],
((0, m_padded - real_m), (0, 0)),
mode="constant",
constant_values=(0,),
)
# Duplicate rows for CGGI trick
padded_mat_ref = jnp.concatenate((padded_mat_ref, padded_mat_ref), axis=(0))

result = _i32_matmul_unreduced_CGGI(padded_mat_ref, vec_toeplitz)
assert result.shape == (m_padded, n), result.shape
out_ref[...] = result[:real_m] # Store only the unpadded result

def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref):
for b in range(vec_ref.shape[0]):
# Process all batches in a single kernel launch to reduce overhead.
for i in range(b):
vec_mat_polymul_kernel_single_batch(
vec_ref.at[b], mat_ref.at[b], out_ref.at[b]
vec_ref.at[i, 0], mat_ref.at[i], out_ref.at[i]
)

block_b = 2
steps_b, rem_b = divmod(b, block_b)
if rem_b:
raise ValueError(f"b={b} is not a multiple of block_b={block_b}")

return jnp.sum(
pl.pallas_call(
vec_mat_polymul_kernel,
in_specs=(
pl.BlockSpec((block_b, 1, n), lambda b: (b, 0, 0)),
pl.BlockSpec((block_b, m, n), lambda b: (b, 0, 0)),
),
out_specs=pl.BlockSpec((block_b, m // 2, n), lambda b: (b, 0, 0)),
out_shape=jax.ShapeDtypeStruct((b, m // 2, n), jnp.int32),
grid=(steps_b,),
out_shape=jax.ShapeDtypeStruct((b, real_m, n), jnp.int32),
grid=(1,), # Single kernel launch for all batches.
in_specs=[
pl.BlockSpec((b, 1, n), lambda _: (0, 0, 0)),
pl.BlockSpec((b, real_m, n), lambda _: (0, 0, 0)),
],
out_specs=pl.BlockSpec((b, real_m, n), lambda _: (0, 0, 0)),
compiler_params=pltpu.CompilerParams(
# Set the vem limit to 32 MiB, it could be up to 128 MiB.
vmem_limit_bytes=int(2**10 * 10**15)
vmem_limit_bytes=64 * 1024 * 1024
),
)(
poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)
).reshape(
b, m // 2, n
),
axis=(0,),
).astype(jnp.uint32)[:real_m]
)(poly_vec1[:, None], poly_mat2),
axis=0,
).astype(jnp.uint32)


@jax.named_call
Expand Down