diff --git a/exp/README.md b/exp/README.md new file mode 100644 index 000000000000..b6d337e914d8 --- /dev/null +++ b/exp/README.md @@ -0,0 +1,149 @@ +# Wan-AI/Wan2.2-I2V-A14B-Diffusers Recipe + +1. Export the environment of GCP project +* Fill the PROJECT_ID and TPU_NAME +``` +### 1. export env of gcp ### + +export PROJECT_ID= +export TPU_NAME= +export ZONE= +export ACCELERATOR_TYPE=v6e-16 +export RUNTIME_VERSION=v2-alpha-tpuv6e +``` + +2. Create the v6e-16 tpu vms on GCP +``` +gcloud compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --project=${PROJECT_ID} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=${RUNTIME_VERSION} +``` + +3. Prepare the python env on each tpu vms +``` +### 3. prepare env on each host ### + +run() +{ + local command=$1 + local worker=${2:-all} + gcloud compute tpus tpu-vm ssh --zone "${ZONE}" "${ACCOUNT}@${TPU_NAME}" --project "${PROJECT_ID}" --worker=${worker} --command="$command" +} + +BRANCH_NAME=wan2.2-main + +SETUP_COMMAND="\ +set -x && \ +curl -LsSf https://astral.sh/uv/install.sh | sh && \ +source ~/.local/bin/env && \ +uv venv -p 3.12 && \ +source .venv/bin/activate && \ +git clone -b ${BRANCH_NAME} https://github.com/yuyanpeng-google/diffusers.git || true && \ +cd diffusers && \ +uv pip install -e . && \ +uv pip install transformers accelerate && \ +uv pip install torch --index-url https://download.pytorch.org/whl/cpu && \ +uv pip install -U jax[tpu] && \ +uv pip install torchax && \ +uv pip install flax && \ +uv pip install ftfy imageio imageio-ffmpeg && \ +true +" + +run "${SETUP_COMMAND}" +``` + +4. Run wan2.2 pipeline to generate the videos +``` +### 4. run wan2.2 pipeline ### + +run() +{ + local command=$1 + local worker=${2:-all} + gcloud compute tpus tpu-vm ssh --zone "${ZONE}" "${ACCOUNT}@${TPU_NAME}" --project "${PROJECT_ID}" --worker=${worker} --command="$command" +} + +BRANCH_NAME=wan2.2-main + +RUN_COMMAND="\ +set -x && \ +source .venv/bin/activate && \ +killall -9 python || true && \ +sleep 10 && \ +export JAX_COMPILATION_CACHE_DIR="/dev/shm/jax_cache" && \ +export JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES=-1 && \ +export JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=0 && \ +export JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES='xla_gpu_per_fusion_autotune_cache_dir' && \ +export HF_HUB_CACHE=/dev/shm/hf_cache && \ +cd diffusers && \ +git fetch && git reset --hard origin/${BRANCH_NAME} && \ +cd exp && \ +nohup python wan2p2_benchmark.py > $(date +%Y-%m-%d_%H-%M-%S).log 2>&1 & +true +" +run "${RUN_COMMAND}" +``` + +5. See the results in stdout +``` +... +output video done. 20251029_093753.mp4 +Warmup and output video: 1961.571311s +... +Benchmark: 103.959559s +Done +``` +Notice that the first time warmup need to compile the graph which is time consuming. + +6. Use scp download generated videos +``` +VIDEO_NAME=20251029_093753.mp4 # from the 5 stdout + +gcloud compute tpus tpu-vm scp --zone "${ZONE}" "${TPU_NAME}:~/diffusers/exp/${VIDEO_NAME}" . --project "${PROJECT_ID}" --worker=0 +``` + + +# Install + +Install dependencies, setup virtual env first if required. + +Test use python 3.12 + +```sh +# install uv, python 3.12 and activate +curl -LsSf https://astral.sh/uv/install.sh | sh && \ +source ~/.local/bin/env && \ +uv venv -p 3.12 && \ +source .venv/bin/activate && \ +``` + +```sh +# install dependency +# pwd=. +uv pip install -e . && \ +uv pip install transformers accelerate && \ +uv pip install torch --index-url https://download.pytorch.org/whl/cpu && \ +uv pip install -U jax[tpu] && \ +uv pip install torchax && \ +uv pip install flax && \ +uv pip install ftfy imageio imageio-ffmpeg +``` + +To run: + +```sh +# cwd=exp +python wan2p2_benchmark.py +``` + +### Result + +``` +# python wan2p2_benchmark.py +Benchmark: 103.959559s +Done +``` + diff --git a/exp/benchmark_splash_attention_kernel.py b/exp/benchmark_splash_attention_kernel.py new file mode 100644 index 000000000000..21984a3c40ee --- /dev/null +++ b/exp/benchmark_splash_attention_kernel.py @@ -0,0 +1,240 @@ +import functools +from jax.experimental.pallas.ops.tpu import splash_attention +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jax.sharding import Mesh +from jax.experimental import mesh_utils + +import jax +import jax.numpy as jnp +import math +import time + +# import ringattention_pallas_tpu_splash +import custom_splash_attention + + +# Copy from wan_tx_splash_attn.py +@functools.partial( + jax.jit, + static_argnames=("mesh", "bqsize", "bkvsize", "bkvcomputesize", "bkvcomputesinize"), +) +def _tpu_splash_attention( + query, + key, + value, + mesh, + bqsize, + bkvsize, + bkvcomputesize, + bkvcomputesinize, + scale=None, + is_causal=False, + window_size=None, +): + num_heads = query.shape[1] + + # The function that will be sharded across devices. + def _attention_on_slices(q, k, v): + + # Scale the query tensor. This happens on each device with its slice of data. + scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale + q = q * scale_factor + + def pad_to_multiple2(x, multiple, axis): + # For try pad outside + return x, x.shape[axis] + + # Helper to pad to next multiple + def pad_to_multiple(x, multiple, axis): + seq_len = x.shape[axis] + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x, seq_len + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width), seq_len + + # This function operates on a single item from the batch. + def kernel_3d(q_3d, k_3d, v_3d): + q_seq_len = q_3d.shape[1] + kv_seq_len = k_3d.shape[1] + num_heads_on_device = q_3d.shape[0] + + # Pad q, k, v to next multiple of BQSIZE/BKVSIZE + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, bqsize, axis=1) + k_3d_padded, k_orig_len = pad_to_multiple(k_3d, bkvsize, axis=1) + v_3d_padded, v_orig_len = pad_to_multiple(v_3d, bkvsize, axis=1) + + padded_q_seq_len = q_3d_padded.shape[1] + padded_kv_seq_len = k_3d_padded.shape[1] + + block_sizes = splash_attention.BlockSizes( + block_q=min(bqsize, padded_q_seq_len), + block_kv=min(bkvsize, padded_kv_seq_len), + block_kv_compute=min(bkvcomputesize, padded_kv_seq_len), + ) + splash_kernel = custom_splash_attention.make_splash_mha( + block_sizes=block_sizes, bkv_compute_in=bkvcomputesinize + ) + out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded) + # Remove padding if any + out = jnp.swapaxes(out, 1, 2) + return out[:, :q_orig_len, ...] + + # Map the kernel over the batch dimension. + vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0) + return vmapped_kernel(q, k, v) + + # Determine the partitioning spec based on the number of heads. + if num_heads < mesh.size: + # Replicated case for VAE. All devices get the full tensor. + q_partition_spec = P() + kv_partition_spec = P() + else: + # Sharded case for Transformer. Split along the heads axis. + # Attn1 self attention, key length is long. + if key.shape[2] > 10000: + q_partition_spec = P("dp", "axis", "sp", None) + kv_partition_spec = P("dp", "axis", None, None) + else: + # Attn2 which is cross attention, kv sequence is shorter. All gather the key value cost less. + q_partition_spec = P("dp", None, ("axis", "sp"), None) + kv_partition_spec = P("dp", None, None, None) + + # ALWAYS use shard_map. The partition_spec will control the behavior. + sharded_fn = shard_map( + _attention_on_slices, + mesh=mesh, + in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), + out_specs=q_partition_spec, + check_rep=False, + ) + out = sharded_fn(query, key, value) + out = jax.lax.with_sharding_constraint(out, P("dp", None, ("axis", "sp"), None)) + return out + + +def main(): + query = jnp.ones((1, 40, 75600, 128)) + key = jnp.ones((1, 40, 75600, 128)) + value = jnp.ones((1, 40, 75600, 128)) + + bqsizes = (1512,) + + # bqsizes = (600, 630, 675, 700, 720, 756, 840, 900, 945, 1008, 1050, 1080, 1200, 1260, 1350, 1400, 1512, 1575, 1680, 1800, 1890, 2100, 2160, 2520, 2700, 2800, 3024, 3150, 3600, 3780, 4200) + bqsizes = range(2560, 4096, 256) + bkvsizes = range(2560, 4096, 256) + bkvcomputesizes = range(256, 4096, 256) + # bkvcomputesinizes = range(64, 4096, 64) + bkvcomputesinizes = range(256, 4096, 256) + + # bqsizes = list(range(512, 4096, 128)) + # bkvsizes = (3072,) + # bkvcomputesizes = (1024,) + + # BQSIZE = 2816 # 2240 # 3024 #2520 + # BKVSIZE = 3840 + # BKVCOMPUTESIZE = 256 + + # bqsizes = (512,) + # bkvsizes = (2048,) + # bkvcomputesizes = (256,) + + tp_dim = jax.device_count() + dp_dim = 1 + sp_dim = 1 + print("sp, bqsize, bkvsize, bkvcomputesize, time (s), padded_key_size") + while tp_dim >= 1: + mesh_devices = mesh_utils.create_device_mesh( + (tp_dim, dp_dim, sp_dim), allow_split_physical_axes=True + ) + mesh = Mesh(mesh_devices, ("axis", "dp", "sp")) + + query = jax.device_put( + query, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) + ) + key = jax.device_put( + key, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) + ) + value = jax.device_put( + value, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) + ) + with mesh: + for bqsize in bqsizes: + for bkvsize in bkvsizes: + for bkvcomputesize in bkvcomputesizes: + for bkvcomputesinize in bkvcomputesinizes: + if ( + bkvsize < bkvcomputesize + or bkvsize % bkvcomputesize != 0 + ): + continue + + if ( + bkvcomputesize < bkvcomputesinize + or bkvcomputesize % bkvcomputesinize != 0 + ): + continue + + try: + # pad key value + def pad_to_multiple(x, multiple, axis): + # Pad in kernel + return x + seq_len = x.shape[axis] + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width) + + padded_query = pad_to_multiple(query, bqsize, axis=2) + padded_key = pad_to_multiple(key, bkvsize, axis=2) + padded_value = pad_to_multiple(value, bkvsize, axis=2) + + jax.block_until_ready( + _tpu_splash_attention( + padded_query, + padded_key, + padded_value, + mesh, + bqsize, + bkvsize, + bkvcomputesize, + bkvcomputesinize, + ) + ) + + start = time.perf_counter() + jax.block_until_ready( + _tpu_splash_attention( + padded_query, + padded_key, + padded_value, + mesh, + bqsize, + bkvsize, + bkvcomputesize, + bkvcomputesinize, + ) + ) + end = time.perf_counter() + print( + f"{sp_dim=}, {bqsize}, {bkvsize}, {bkvcomputesize}, {bkvcomputesinize}, {end - start}, {padded_key.shape[2]}" + ) + except KeyboardInterrupt: + raise + except Exception: + # raise + continue + break + # smaller sp_dim better + tp_dim //= 2 + sp_dim *= 2 + + +if __name__ == "__main__": + main() diff --git a/exp/custom_splash_attention.py b/exp/custom_splash_attention.py new file mode 100644 index 000000000000..c3fad1e42865 --- /dev/null +++ b/exp/custom_splash_attention.py @@ -0,0 +1,790 @@ +import functools +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P +import jax +import math +import jax.numpy as jnp +import numpy as np +import dataclasses +import enum +from typing import Any +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax import lax + +partial = functools.partial +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) +NUM_LANES = 128 +NUM_SUBLANES = 8 +NN_DIM_NUMBERS = (((1,), (0,)), ((), ())) +NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) + +_LOG2_E = 1.44269504 +_LOG2_E_INV = 1 / _LOG2_E + + +class _QKVLayout(enum.IntEnum): + HEAD_DIM_MINOR = enum.auto() + SEQ_MINOR = enum.auto() + + +def _from_head_minor(vals: tuple[Any, ...], layout: _QKVLayout): + if layout == _QKVLayout.HEAD_DIM_MINOR: + return vals + return (*vals[:-2], vals[-1], vals[-2]) + + +def exp2(x: jax.Array) -> jax.Array: + return jnp.power(2.0, x) + + +@dataclasses.dataclass(frozen=True, slots=True) +class _BlockSizes: + block_q: int + block_kv: int + block_kv_compute: int | None = None + q_layout: _QKVLayout = _QKVLayout.HEAD_DIM_MINOR + k_layout: _QKVLayout = _QKVLayout.HEAD_DIM_MINOR + v_layout: _QKVLayout = _QKVLayout.HEAD_DIM_MINOR + + def __post_init__(self): + if self.block_kv_compute is None: + object.__setattr__(self, "block_kv_compute", self.block_kv) + + +def _flash_attention_kernel( + q_ref, + k_ref, + v_ref, + m_scratch_ref, + l_scratch_ref, + o_scratch_ref, + o_ref, + *, + mask_value: float, + grid_width: int, + bq: int, + bkv: int, + bkv_compute: int, + bkv_compute_in: int, + head_dim_v: int, +): + float32 = jnp.float32 + head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) + if rem != 0: + raise NotImplementedError( + f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}" + ) + # head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES) + # if rem != 0: + # raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_LANES}") + + h, i, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) + + @pl.when(j == 0) + def init(): + o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + + ### + + # # with jax.named_scope("qk"): + # q = q_ref[...] + # k = k_ref[...] + + # qk_all = lax.dot_general(q, k, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk_all.shape == (bq, bkv) + + # step = bkv_compute + # assert step % NUM_LANES == 0 + # assert bkv % step == 0 + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for i in range(0, bkv, step): + # qk = qk_all[:,i:i+step] + # # qk = lax.dot_general(k[i:i+step], q, NT_DIM_NUMBERS, preferred_element_type=float32) + # # assert qk.shape == (step, bq) + # # with jax.named_scope("qk"): + # assert m_prev.shape == (bq, NUM_LANES) + # assert l_prev.shape == (bq, NUM_LANES) + + # # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=1)[:, None] + # assert m_curr.shape == (bq, 1) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (bq, NUM_LANES) + + # bkv_repeats, rem = divmod(bkv_compute, NUM_LANES) + # if rem != 0: + # raise NotImplementedError( + # f"{bkv_compute=} should be a multiple of {NUM_LANES}" + # ) + + # s_curr = jnp.exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1)) + # # assert s_curr.shape == (bq, bkv_compute) + # # # with jax.named_scope("qk_exp"): + # # s_diff = qk - m_next[:,0:1] + # # s_curr = jnp.exp(s_diff) + # assert s_curr.shape == (bq, step) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=1, keepdims=True) + # assert l_curr.shape == (bq, 1) + + # # with jax.named_scope("qk_alpha"): + # m_diff = m_prev - m_next + # alpha = jnp.exp(m_diff) + + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # # with jax.named_scope("qkv"): + # v = v_ref[i:i+step].astype(float32) + # sv_dims = (((1,), (0,)), ((), ())) + # o_curr = lax.dot_general(s_curr, v, sv_dims) + # # alpha_o = alpha[:, 0:1] + # alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1) + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + # with jax.named_scope("qk"): + # q = q_ref[...] + # k = k_ref[...] + # qk_all = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk_all.shape == (bkv, bq) + # # qk_all = lax.dot_general(q, k, NT_DIM_NUMBERS, preferred_element_type=float32) + # # assert qk_all.shape == (bq, bkv) + + # step = bkv_compute + # assert step % NUM_SUBLANES == 0 + # assert bkv % step == 0 + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for i in range(0, bkv, step): + # qk = qk_all[i:i+step] + # # qk = lax.dot_general(k[i:i+step], q, NT_DIM_NUMBERS, preferred_element_type=float32) + # # assert qk.shape == (step, bq) + # # with jax.named_scope("qk"): + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # assert m_curr.shape == (1, bq) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_diff = qk - m_next[0:1] + # s_curr = jnp.exp(s_diff) + # assert s_curr.shape == (step, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # m_diff = m_prev - m_next + # alpha = jnp.exp(m_diff) + + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # # with jax.named_scope("qkv"): + # v = v_ref[i:i+step].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) # (head_dim, bk) @ (bk, bq) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for kv_compute_index in range(0, (bkv // bkv_compute)): + # # with jax.named_scope("qk"): + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + + # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # assert m_curr.shape == (1, bq) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_curr = jnp.exp(qk - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # with jax.named_scope("qkv"): + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + # assert bkv % bkv_compute == 0 + # qk_next = None + # m_curr_next = None + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for kv_compute_index in range(0, (bkv // bkv_compute) + 1): + # # nonlocal qk_pre + # # with jax.named_scope("qk"): + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + # if kv_compute_index < (bkv // bkv_compute): + # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # assert m_curr.shape == (1, bq) + + # m_curr, m_curr_next = m_curr_next, m_curr + # qk_next, qk = qk, qk_next + # if kv_compute_index == 0: + # continue + + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_curr = jnp.exp(qk - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # slice_k = pl.ds((kv_compute_index - 1) * bkv_compute, bkv_compute) + + # with jax.named_scope("qkv"): + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + # assert bkv % bkv_compute == 0 + # qk_next = None + # # def body(kv_compute_index, _): + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for kv_compute_index in range(0, (bkv // bkv_compute) + 1): + # # nonlocal qk_pre + # # with jax.named_scope("qk"): + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + # if kv_compute_index < (bkv // bkv_compute): + # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # qk_next, qk = qk, qk_next + # if kv_compute_index == 0: + # continue + + # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # assert m_curr.shape == (1, bq) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_curr = jnp.exp(qk - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # slice_k = pl.ds((kv_compute_index - 1) * bkv_compute, bkv_compute) + + # with jax.named_scope("qkv"): + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + # assert bkv % bkv_compute == 0 + # qk_next = None + # s_curr_next = None + # alpha_next = None + # # def body(kv_compute_index, _): + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # for kv_compute_index in range(0, (bkv // bkv_compute) + 2): + # # nonlocal qk_pre + # # with jax.named_scope("qk"): + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + # if kv_compute_index < (bkv // bkv_compute): + # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # qk_next, qk = qk, qk_next + # if kv_compute_index == 0: + # continue + # if kv_compute_index < (bkv // bkv_compute) + 1: + # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # assert m_curr.shape == (1, bq) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_curr = jnp.exp(qk - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_prev, l_prev = m_next, l_next + + # s_curr, s_curr_next = s_curr_next, s_curr + # alpha, alpha_next = alpha_next, alpha + # if kv_compute_index == 1: + # continue + + # slice_k = pl.ds((kv_compute_index - 2) * bkv_compute, bkv_compute) + + # with jax.named_scope("qkv"): + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + ### + + def body(kv_compute_index, _): + + # # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_max"): + # m_curr_list = [] + # s_curr_list = [] + # step = bkv_compute_in + # assert qk.shape[0] % step == 0 + # for i in range(0, qk.shape[0], step): + # m_curr = qk[i:i+step].max(axis=0)[None, :] + # # m_curr = qk[0:1] + # assert m_curr.shape == (1, bq) + # m_curr_list.append(m_curr) + + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # s_curr = jnp.exp(qk[i:i+step] - m_curr[0:1]) + # # assert s_curr.shape == (bkv_compute, bq) + # s_curr_list.append(s_curr) + + # m_curr = jnp.concatenate(m_curr_list, axis=0) + # m_curr = jnp.exp(m_curr - m_next[0:1]) + + # for i in range(len(s_curr_list)): + # s_curr_list[i] = s_curr_list[i] * m_curr[i:i+1] + + # s_curr = jnp.concatenate(s_curr_list, axis=0) + # assert s_curr.shape == (bkv_compute, bq) + + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + ### + + # # with jax.named_scope("qk"): + # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # assert m_prev.shape == (NUM_SUBLANES, bq) + # assert l_prev.shape == (NUM_SUBLANES, bq) + + # q = q_ref[...] + # k = k_ref[slice_k, :] + # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + # assert qk.shape == (bkv_compute, bq) + + # # with jax.named_scope("softmax"): + # # with jax.named_scope("qk_max"): + # m_curr = qk.max(axis=0)[None, :] + # # m_curr = qk[-1:, :] + # # m_curr = qk[0:1, :] + # # m_ub = jnp.zeros((1,bq), dtype=float32) + # assert m_curr.shape == (1, bq) + # # with jax.named_scope("qk_maximum"): + # m_next = jnp.maximum(m_prev, m_curr) + # assert m_next.shape == (NUM_SUBLANES, bq) + + # # with jax.named_scope("qk_exp"): + # s_curr = jnp.exp(qk - m_next[0:1]) + # # s_curr = jnp.exp(qk - m_prev[0:1]) + # # s_curr = s_curr * jnp.exp(m_prev - m_next)[0:1] + # assert s_curr.shape == (bkv_compute, bq) + + # # with jax.named_scope("qk_sum"): + # l_curr = s_curr.sum(axis=0, keepdims=True) + # assert l_curr.shape == (1, bq) + + # # with jax.named_scope("qk_alpha"): + # alpha = jnp.exp(m_prev - m_next) + # l_next = l_curr + alpha * l_prev + # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + # # with jax.named_scope("qkv"): + # v = v_ref[slice_k, :].astype(float32) + # sv_dims = (((0,), (0,)), ((), ())) + # o_curr = lax.dot_general(v, s_curr, sv_dims) + # alpha_o = alpha[0:1, ...] + # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr + + ### + + # # with jax.named_scope("qk"): + slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + assert m_prev.shape == (NUM_SUBLANES, bq) + assert l_prev.shape == (NUM_SUBLANES, bq) + + q = q_ref[...] + k = k_ref[slice_k, :] + qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + assert qk.shape == (bkv_compute, bq) + + # with jax.named_scope("softmax_qkv"): + o_prev = o_scratch_ref[:] + + v = v_ref[slice_k, :].astype(float32) + step = bkv_compute_in + assert qk.shape[0] % step == 0 + for i in range(0, qk.shape[0], step): + m_curr = qk[i : i + step].max(axis=0)[None, :] + assert m_curr.shape == (1, bq) + + m_next = jnp.maximum(m_prev, m_curr) + assert m_next.shape == (NUM_SUBLANES, bq) + + # the exp two ops: vmul and vpow. Fuse the vmul outside of kernel. + s_curr = exp2(qk[i : i + step] - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + l_curr = s_curr.sum(axis=0, keepdims=True) + assert l_curr.shape == (1, bq) + + alpha = jnp.exp2(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general(v[i : i + step], s_curr, sv_dims) + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev = m_next + l_prev = l_next + + m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + o_scratch_ref[:] = o_prev + + ### + + lax.fori_loop(0, (bkv // bkv_compute), body, None, unroll=True) + + @pl.when(j == grid_width - 1) + def end(): + l = l_scratch_ref[...] + l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=0) + # l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=1) + o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + + +def __splash_attention_forward( + q: jax.Array, + k: jax.Array, + v: jax.Array, + block_sizes: _BlockSizes, + bkv_compute_in: int, + interpret: bool = False, +): + num_q_heads, q_seq_len, head_dim_qk = q.shape + head_dim_v = v.shape[-1] + bq, bkv = block_sizes.block_q, block_sizes.block_kv + bkv_compute = block_sizes.block_kv_compute + num_kv_heads = k.shape[0] + kv_seq_len = k.shape[1] + q_heads_per_kv_head = num_q_heads // num_kv_heads + + def q_index_map(h, i, j, *_): + return (h, i, 0) + + def out_index_map(h, i, j, *_): + return h, 0, i + # return h, i, 0 + + def k_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + def v_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + in_specs = [ + pl.BlockSpec((None, bq, head_dim_qk), q_index_map), + pl.BlockSpec((None, bkv, head_dim_qk), k_index_map), + pl.BlockSpec((None, bkv, head_dim_v), v_index_map), + ] + out_shapes = [ + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((head_dim_v, bq), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, head_dim_v, q_seq_len), q.dtype), + ] + out_specs = [ + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)), + pl.BlockSpec((None, head_dim_v, bq), out_index_map), + ] + # in_specs = [ + # pl.BlockSpec((None, bq, head_dim_qk), q_index_map), + # pl.BlockSpec((None, bkv, head_dim_qk), k_index_map), + # pl.BlockSpec((None, bkv, head_dim_v), v_index_map), + # ] + # out_shapes = [ + # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), + # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), + # jax.ShapeDtypeStruct((bq, head_dim_v), jnp.float32), + # jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), q.dtype), + # ] + # out_specs = [ + # pl.BlockSpec((bq, NUM_LANES), lambda *_: (0, 0)), + # pl.BlockSpec((bq, NUM_LANES), lambda *_: (0, 0)), + # pl.BlockSpec((bq, head_dim_v), lambda *_: (0, 0)), + # pl.BlockSpec((None, bq, head_dim_v), out_index_map), + # ] + grid_width = kv_seq_len // bkv + grid = (num_q_heads, q_seq_len // bq, grid_width) + + all_out = pl.pallas_call( + partial( + _flash_attention_kernel, + mask_value=DEFAULT_MASK_VALUE, + grid_width=grid_width, + bq=bq, + bkv=bkv, + bkv_compute=bkv_compute, + bkv_compute_in=bkv_compute_in, + head_dim_v=head_dim_v, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, + ), + # compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")), + out_shape=out_shapes, + interpret=interpret, + # debug=True, + )(q, k, v) + return all_out[-1] + + +def make_splash_mha( + block_sizes: _BlockSizes, + bkv_compute_in: int, + interpret: bool = False, +): + def _splash_attention(q: jax.Array, k: jax.Array, v: jax.Array): + return __splash_attention_forward( + q, k, v, block_sizes, bkv_compute_in, interpret + ) + + return _splash_attention + + +BQSIZE = BKVSIZE = BKVCOMPUTESIZE = 1024 +BKVCOMPUTESIZE = 256 + +sharded_fn = None + + +def _tpu_splash_attention(query, key, value, mesh): + global sharded_fn + num_heads = query.shape[1] + + def _attention_on_slices(q, k, v): + scale_factor = 1.0 / math.sqrt(q.shape[-1]) + q = q * scale_factor + + def pad_to_multiple(x, multiple, axis): + seq_len = x.shape[axis] + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x, seq_len + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width), seq_len + + def kernel_3d(q_3d, k_3d, v_3d): + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) + k_3d_padded, k_orig_len = pad_to_multiple(k_3d, BKVSIZE, axis=1) + v_3d_padded, v_orig_len = pad_to_multiple(v_3d, BKVSIZE, axis=1) + padded_q_seq_len = q_3d_padded.shape[1] + padded_kv_seq_len = k_3d_padded.shape[1] + + block_sizes = _BlockSizes( + block_q=min(BQSIZE, padded_q_seq_len), + block_kv=min(BKVSIZE, padded_kv_seq_len), + block_kv_compute=min(BKVCOMPUTESIZE, padded_kv_seq_len), + ) + splash_kernel = make_splash_mha(block_sizes=block_sizes) + out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded) + out = jnp.swapaxes(out, 1, 2) + return out[:, :q_orig_len, ...] + + vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0) + return vmapped_kernel(q, k, v) + + if sharded_fn is None: + q_partition_spec = P("dp", "axis", None, None) + kv_partition_spec = P("dp", "axis", None, None) + sharded_fn = jax.jit( + shard_map( + _attention_on_slices, + mesh=mesh, + in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), + out_specs=q_partition_spec, + check_rep=False, + ) + ) + out = sharded_fn(query, key, value) + return jax.lax.with_sharding_constraint(out, P("dp", None, "axis", None)) + + +if __name__ == "__main__": + # import os + # os.environ["LIBTPU_INIT_ARGS"] = "--xla_enable_transpose_trace" + + shape = (1, 40, 75600, 128) + q = jnp.arange(np.prod(shape), dtype=jnp.bfloat16).reshape(*shape) + k = jnp.arange(np.prod(shape), dtype=jnp.bfloat16).reshape(*shape) + v = jnp.arange(np.prod(shape), dtype=jnp.bfloat16).reshape(*shape) + + mesh = jax.make_mesh((len(jax.devices()), 1, 1), ("axis", "dp", "sp")) + q = jax.device_put(q, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))) + k = jax.device_put(k, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))) + v = jax.device_put(v, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))) + + with mesh: + output = _tpu_splash_attention(q, k, v, mesh) + output.block_until_ready() + + with mesh: + with jax.profiler.trace("/dev/shm/tensorboard"): + output = _tpu_splash_attention(q, k, v, mesh) + output.block_until_ready() + + import time + + with mesh: + num_time = 50 + start_time = time.time() + for _ in range(num_time): + output = _tpu_splash_attention(q, k, v, mesh) + output.block_until_ready() + end_time = time.time() + print(f"{(end_time-start_time)/num_time}") diff --git a/exp/pose.png b/exp/pose.png new file mode 100644 index 000000000000..05f73d3ffc92 Binary files /dev/null and b/exp/pose.png differ diff --git a/exp/wan2p2_benchmark.py b/exp/wan2p2_benchmark.py new file mode 100644 index 000000000000..ba6b3d7c7c81 --- /dev/null +++ b/exp/wan2p2_benchmark.py @@ -0,0 +1,743 @@ +import argparse +from datetime import datetime +import functools +import math +import re +import time +from contextlib import contextmanager + +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from jax.sharding import Mesh +from jax.experimental import mesh_utils +from jax.experimental.pallas.ops.tpu import splash_attention + +import torch +import numpy as np +from diffusers import WanImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from diffusers.models.autoencoders import vae as diffusers_vae +from diffusers.models import modeling_outputs as diffusers_modeling_outputs + +from transformers import modeling_outputs + +import torchax +from torchax.ops import jaten +from torchax.ops import jtorch +from torchax.ops import ops_registry + +# Local file +import custom_splash_attention + + +SIZE_CONFIGS = { + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + # '704*1280': (704, 1280), + # '1280*704': (1280, 704), + # '1024*704': (1024, 704), + # '704*1024': (704, 1024), +} + +MAX_AREA_CONFIGS = { + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, + # '704*1280': 704 * 1280, + # '1280*704': 1280 * 704, + # '1024*704': 1024 * 704, + # '704*1024': 704 * 1024, +} + +SUPPORTED_SIZES = { + "t2v-A14B": ("720*1280", "1280*720", "480*832", "832*480"), + "i2v-A14B": ("720*1280", "1280*720", "480*832", "832*480"), + "ti2v-5B": ("704*1280", "1280*704"), + "s2v-14B": ( + "720*1280", + "1280*720", + "480*832", + "832*480", + "1024*704", + "704*1024", + "704*1280", + "1280*704", + ), + "animate-14B": ("720*1280", "1280*720"), +} + +DEFAULT_PROMPT = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +DEFAULT_NEG_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +DEFAULT_IMAGE_PATH = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG" +DEFAULT_PROFILE_OUT_PATH = "/tmp/wan_prof" + + +# fmt: off +TEXT_ENCODER_SHARDINGS = { +'shared.weight': (('dp','tp'),), # (torch.Size([256384, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.SelfAttention.q.weight': (('dp','tp'),), # (torch.Size([4096, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.SelfAttention.k.weight': (('dp','tp'),), # (torch.Size([4096, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.SelfAttention.v.weight': (('dp','tp'),), # (torch.Size([4096, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.SelfAttention.o.weight': (None, ('dp','tp'),), # (torch.Size([4096, 4096]), torch.bfloat16) +# 'encoder.block.*.layer.*.SelfAttention.relative_attention_bias.weight': (), # (torch.Size([32, 64]), torch.bfloat16) +# 'encoder.block.*.layer.*.layer_norm.weight': (), # (torch.Size([4096]), torch.bfloat16) +'encoder.block.*.layer.*.DenseReluDense.wi_0.weight': (('dp','tp'),), # (torch.Size([10240, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.DenseReluDense.wi_1.weight': (('dp','tp'),), # (torch.Size([10240, 4096]), torch.bfloat16) +'encoder.block.*.layer.*.DenseReluDense.wo.weight': (None, ('dp','tp'),), # (torch.Size([4096, 10240]), torch.bfloat16) +# 'encoder.final_layer_norm.weight': (), # (torch.Size([4096]), torch.bfloat16) +} + +TRANSFORMER_SHARDINGS = { +# 'scale_shift_table': (), # (torch.Size([1, 2, 5120]), torch.float32) +# 'patch_embedding.weight': (), # (torch.Size([5120, 36, 1, 2, 2]), torch.bfloat16) +# 'patch_embedding.bias': (), # (torch.Size([5120]), torch.bfloat16) +'condition_embedder.time_embedder.linear_1.weight': ('tp',), # (torch.Size([5120, 256]), torch.float32) +'condition_embedder.time_embedder.linear_1.bias': ('tp',), # (torch.Size([5120]), torch.float32) +'condition_embedder.time_embedder.linear_2.weight': (None, 'tp',), # (torch.Size([5120, 5120]), torch.float32) +# 'condition_embedder.time_embedder.linear_2.bias': (), # (torch.Size([5120]), torch.float32) +# 'condition_embedder.time_proj.weight': (), # (torch.Size([30720, 5120]), torch.bfloat16) +# 'condition_embedder.time_proj.bias': (), # (torch.Size([30720]), torch.bfloat16) +'condition_embedder.text_embedder.linear_1.weight': ('tp',), # (torch.Size([5120, 4096]), torch.bfloat16) +'condition_embedder.text_embedder.linear_1.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'condition_embedder.text_embedder.linear_2.weight': (None, 'tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +# 'condition_embedder.text_embedder.linear_2.bias': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.scale_shift_table': (), # (torch.Size([1, 6, 5120]), torch.float32) +'blocks.*.attn1.to_q.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn1.to_q.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn1.to_k.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn1.to_k.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn1.to_v.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn1.to_v.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn1.to_out.*.weight': (None, 'tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +# 'blocks.*.attn1.to_out.*.bias': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.attn1.norm_q.weight': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.attn1.norm_k.weight': (), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn2.to_q.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn2.to_q.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn2.to_k.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn2.to_k.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn2.to_v.weight': ('tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +'blocks.*.attn2.to_v.bias': ('tp',), # (torch.Size([5120]), torch.bfloat16) +'blocks.*.attn2.to_out.*.weight': (None, 'tp',), # (torch.Size([5120, 5120]), torch.bfloat16) +# 'blocks.*.attn2.to_out.*.bias': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.attn2.norm_q.weight': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.attn2.norm_k.weight': (), # (torch.Size([5120]), torch.bfloat16) +# 'blocks.*.norm2.weight': (), # (torch.Size([5120]), torch.float32) +# 'blocks.*.norm2.bias': (), # (torch.Size([5120]), torch.float32) +'blocks.*.ffn.net.*.proj.weight': ('tp',), # (torch.Size([13824, 5120]), torch.bfloat16) +'blocks.*.ffn.net.*.proj.bias': ('tp',), # (torch.Size([13824]), torch.bfloat16) +'blocks.*.ffn.net.*.weight': (None, 'tp',), # (torch.Size([5120, 13824]), torch.bfloat16) +# 'blocks.*.ffn.net.*.bias': (), # (torch.Size([5120]), torch.bfloat16) +# 'proj_out.weight': (), # (torch.Size([64, 5120]), torch.bfloat16) +# 'proj_out.bias': (), # (torch.Size([64]), torch.bfloat16) +# 'rope.freqs_cos': (), # (torch.Size([1024, 128]), torch.float32) +# 'rope.freqs_sin': (), # (torch.Size([1024, 128]), torch.float32) +} + +VAE_SHARDINGS = { +'encoder.conv_in.weight': (('dp','tp'),), # (torch.Size([96, 3, 3, 3, 3]), torch.bfloat16) +'encoder.conv_in.bias': (('dp','tp'),), # (torch.Size([96]), torch.bfloat16) +# 'encoder.down_blocks.*.norm1.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'encoder.down_blocks.*.conv1.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'encoder.down_blocks.*.conv1.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'encoder.down_blocks.*.norm2.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'encoder.down_blocks.*.conv2.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'encoder.down_blocks.*.conv2.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +'encoder.down_blocks.*.resample.*.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3]), torch.bfloat16) +'encoder.down_blocks.*.resample.*.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +'encoder.down_blocks.*.conv_shortcut.weight': (('dp','tp'),), # (torch.Size([384, 192, 1, 1, 1]), torch.bfloat16) +'encoder.down_blocks.*.conv_shortcut.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +'encoder.down_blocks.*.time_conv.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 1, 1]), torch.bfloat16) +'encoder.down_blocks.*.time_conv.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'encoder.mid_block.attentions.*.norm.gamma': (), # (torch.Size([384, 1, 1]), torch.bfloat16) +'encoder.mid_block.attentions.*.to_qkv.weight': (('dp','tp'),), # (torch.Size([1152, 384, 1, 1]), torch.bfloat16) +'encoder.mid_block.attentions.*.to_qkv.bias': (('dp','tp'),), # (torch.Size([1152]), torch.bfloat16) +'encoder.mid_block.attentions.*.proj.weight': (None, ('dp','tp'),), # (torch.Size([384, 384, 1, 1]), torch.bfloat16) +# 'encoder.mid_block.attentions.*.proj.bias': (), # (torch.Size([384]), torch.bfloat16) +# 'encoder.mid_block.resnets.*.norm1.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'encoder.mid_block.resnets.*.conv1.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'encoder.mid_block.resnets.*.conv1.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'encoder.mid_block.resnets.*.norm2.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'encoder.mid_block.resnets.*.conv2.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'encoder.mid_block.resnets.*.conv2.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'encoder.norm_out.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'encoder.conv_out.weight': (None, ('dp','tp'),), # (torch.Size([32, 384, 3, 3, 3]), torch.bfloat16) +# 'encoder.conv_out.bias': (), # (torch.Size([32]), torch.bfloat16) +# 'quant_conv.weight': (), # (torch.Size([32, 32, 1, 1, 1]), torch.bfloat16) +# 'quant_conv.bias': (), # (torch.Size([32]), torch.bfloat16) +# 'post_quant_conv.weight': (), # (torch.Size([16, 16, 1, 1, 1]), torch.bfloat16) +# 'post_quant_conv.bias': (), # (torch.Size([16]), torch.bfloat16) +'decoder.conv_in.weight': (('dp','tp'),), # (torch.Size([384, 16, 3, 3, 3]), torch.bfloat16) +'decoder.conv_in.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'decoder.mid_block.attentions.*.norm.gamma': (), # (torch.Size([384, 1, 1]), torch.bfloat16) +'decoder.mid_block.attentions.*.to_qkv.weight': (('dp','tp'),), # (torch.Size([1152, 384, 1, 1]), torch.bfloat16) +'decoder.mid_block.attentions.*.to_qkv.bias': (('dp','tp'),), # (torch.Size([1152]), torch.bfloat16) +'decoder.mid_block.attentions.*.proj.weight': (None, ('dp','tp'),), # (torch.Size([384, 384, 1, 1]), torch.bfloat16) +# 'decoder.mid_block.attentions.*.proj.bias': (), # (torch.Size([384]), torch.bfloat16) +# 'decoder.mid_block.resnets.*.norm1.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'decoder.mid_block.resnets.*.conv1.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'decoder.mid_block.resnets.*.conv1.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'decoder.mid_block.resnets.*.norm2.gamma': (), # (torch.Size([384, 1, 1, 1]), torch.bfloat16) +'decoder.mid_block.resnets.*.conv2.weight': (('dp','tp'),), # (torch.Size([384, 384, 3, 3, 3]), torch.bfloat16) +'decoder.mid_block.resnets.*.conv2.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'decoder.up_blocks.*.resnets.*.norm1.gamma': (), # (torch.Size([96, 1, 1, 1]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv1.weight': (('dp','tp'),), # (torch.Size([96, 96, 3, 3, 3]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv1.bias': (('dp','tp'),), # (torch.Size([96]), torch.bfloat16) +# 'decoder.up_blocks.*.resnets.*.norm2.gamma': (), # (torch.Size([96, 1, 1, 1]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv2.weight': (('dp','tp'),), # (torch.Size([96, 96, 3, 3, 3]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv2.bias': (('dp','tp'),), # (torch.Size([96]), torch.bfloat16) +'decoder.up_blocks.*.upsamplers.*.resample.*.weight': (('dp','tp'),), # (torch.Size([96, 192, 3, 3]), torch.bfloat16) +'decoder.up_blocks.*.upsamplers.*.resample.*.bias': (('dp','tp'),), # (torch.Size([96]), torch.bfloat16) +'decoder.up_blocks.*.upsamplers.*.time_conv.weight': (('dp','tp'),), # (torch.Size([768, 384, 3, 1, 1]), torch.bfloat16) +'decoder.up_blocks.*.upsamplers.*.time_conv.bias': (('dp','tp'),), # (torch.Size([768]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv_shortcut.weight': (('dp','tp'),), # (torch.Size([384, 192, 1, 1, 1]), torch.bfloat16) +'decoder.up_blocks.*.resnets.*.conv_shortcut.bias': (('dp','tp'),), # (torch.Size([384]), torch.bfloat16) +# 'decoder.norm_out.gamma': (), # (torch.Size([96, 1, 1, 1]), torch.bfloat16) +'decoder.conv_out.weight': (None, ('dp','tp')), # (torch.Size([3, 96, 3, 3, 3]), torch.bfloat16) +# 'decoder.conv_out.bias': (), # (torch.Size([3]), torch.bfloat16) +} +# fmt: on + +BQSIZE = 3328 +BKVSIZE = 2816 +BKVCOMPUTESIZE = 256 +BKVCOMPUTEINSIZE = 256 + + +@contextmanager +def perf_time(name: str): + print(f"{name} start") + start = time.perf_counter() + yield + end = time.perf_counter() + print(f"{name}: {end - start: .6f}s") + + +def _print_weights(module): + def make_key(name): + return re.sub(r"\.\d+\.", ".*.", name) + + all_buffers = dict(module.named_parameters()) + all_buffers.update(module.named_buffers()) + result = {} + for k, v in all_buffers.items(): + result[make_key(k)] = (v.shape, v.dtype) + print("{") + for k, v in result.items(): + print(f"'{k}': (), # {v}") + print("}") + + +def _torch_conv2d( + input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, *, env +): + jinput, jweight, jbias = env.t2j_iso((input, weight, bias)) + res = jaten._aten_conv2d(jinput, jweight, jbias, stride, padding, dilation, groups) + return env.j2t_iso(res) + + +def _overide_op_definition(env, op_to_override, op_impl): + # Workaround for the function lack is_view_op argument + # env.override_op_definition(op_to_override, op_impl) + env._ops[op_to_override] = ops_registry.Operator( + op_to_override, + op_impl, + is_jax_function=False, + is_user_defined=True, + needs_env=False, + is_view_op=False, + ) + + +def _shard_weight_dict(weight_dict, sharding_dict, mesh): + result = {} + for k, v in weight_dict.items(): + if isinstance(v, torch.Tensor): + v = v.to("jax") + for target, sharding in sharding_dict.items(): + if re.fullmatch(target, k) is not None: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P(*sharding))) + break + else: + # replicate + v.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + + result[k] = v + return result + + +def _move_module(env, module): + with jax.default_device("cpu"): + state_dict = module.state_dict() + state_dict = env.to_xla(state_dict) + module.load_state_dict(state_dict, assign=True) + + +### Flash Attention + + +# Helper to pad to next multiple +def pad_to_multiple(x, multiple, axis): + seq_len = x.shape[axis] + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x, seq_len + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width), seq_len + + +def _tpu_custom_attention(query, key, value, mesh, scale=None): + # The function that will be sharded across devices. + def _attention_on_slices(q, k, v): + # Scale the query tensor. This happens on each device with its slice of data. + scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale + # fuse the ops of exp in softmax here + _LOG2_E = 1.44269504 + q = q * scale_factor * _LOG2_E + + # This function operates on a single item from the batch. + def kernel_3d(q_3d, k_3d, v_3d): + q_seq_len = q_3d.shape[1] + kv_seq_len = k_3d.shape[1] + num_heads_on_device = q_3d.shape[0] + + # self attention + if k_3d.shape[1] > 10000: + # Pad q, k, v to next multiple of BQSIZE/BKVSIZE + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) + k_3d_padded, k_orig_len = pad_to_multiple(k_3d, BKVSIZE, axis=1) + v_3d_padded, v_orig_len = pad_to_multiple(v_3d, BKVSIZE, axis=1) + else: + # do not padding on kv in cross attention. kv length is 512 + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) + k_3d_padded, k_orig_len = k_3d, k_3d.shape[1] + v_3d_padded, v_orig_len = v_3d, v_3d.shape[1] + + padded_q_seq_len = q_3d_padded.shape[1] + padded_kv_seq_len = k_3d_padded.shape[1] + + block_sizes = splash_attention.BlockSizes( + block_q=min(BQSIZE, padded_q_seq_len), + block_kv=min(BKVSIZE, padded_kv_seq_len), + block_kv_compute=min(BKVCOMPUTESIZE, padded_kv_seq_len), + ) + splash_kernel = custom_splash_attention.make_splash_mha( + block_sizes=block_sizes, bkv_compute_in=BKVCOMPUTEINSIZE + ) + out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded).astype( + q_3d_padded.dtype + ) + # Remove padding if any + out = jnp.swapaxes(out, 1, 2) + return out[:, :q_orig_len, ...] + + # Map the kernel over the batch dimension. + vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0) + return vmapped_kernel(q, k, v) + + print(f"[DEBUG] {query.shape=}, {key.shape=}") + if key.shape[0] > 1: + dp_mesh_key = "dp" + remain_mesh_key = ("tp",) + else: + dp_mesh_key = None + remain_mesh_key = ("dp", "tp") + print(f"[DEBUG] {dp_mesh_key=}, {remain_mesh_key=}") + remain_devices_prod = 1 + for d in remain_mesh_key: + remain_devices_prod *= mesh.axis_sizes[mesh.axis_names.index(d)] + + q_num_head = query.shape[1] + q_seq_len = query.shape[2] + kv_num_head = key.shape[1] + kv_seq_len = key.shape[2] + # Sharded case for Transformer. Split along the heads axis. + # Attn1 self attention, key length is long. + if ( + kv_seq_len > 10000 + and kv_num_head % remain_devices_prod == 0 + and q_num_head % remain_devices_prod == 0 + ): + print("[DEBUG] cp") + q_partition_spec = P(dp_mesh_key, remain_mesh_key, None, None) + kv_partition_spec = P(dp_mesh_key, remain_mesh_key, None, None) + else: + print("[DEBUG] sp") + if q_seq_len % remain_devices_prod != 0: + print( + f"[DEBUG] padding query for sp to be divided by {remain_devices_prod}" + ) + query, _ = pad_to_multiple(query, remain_devices_prod, axis=2) + + # Attn2 which is cross attention, kv sequence is shorter. All gather the key value cost less. + q_partition_spec = P(dp_mesh_key, None, remain_mesh_key, None) + kv_partition_spec = P(dp_mesh_key, None, None, None) + + # ALWAYS use shard_map. The partition_spec will control the behavior. + sharded_fn = jax.shard_map( + _attention_on_slices, + mesh=mesh, + in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), + out_specs=q_partition_spec, + check_vma=False, + ) + query = jax.lax.with_sharding_constraint( + query, P(dp_mesh_key, None, remain_mesh_key, None) + ) + key = jax.lax.with_sharding_constraint( + key, P(dp_mesh_key, None, remain_mesh_key, None) + ) + value = jax.lax.with_sharding_constraint( + value, P(dp_mesh_key, None, remain_mesh_key, None) + ) + out = sharded_fn(query, key, value) + # Remove the potential padding for sp + out = out[:, :, :q_seq_len, :] + out = jax.lax.with_sharding_constraint( + out, P(dp_mesh_key, None, remain_mesh_key, None) + ) + return out + + +def _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + *, + env, + mesh, +) -> torch.Tensor: + # if env.config.use_tpu_splash_attention: + if True: + assert attn_mask is None + assert dropout_p == 0.0 + assert is_causal is False + assert enable_gqa is False + assert scale is None + jquery, jkey, jvalue = env.t2j_iso((query, key, value)) + res = _tpu_custom_attention( + jquery, + jkey, + jvalue, + mesh, + scale=scale, + ) + return env.j2t_iso(res) + + return jtorch._sdpa_reference( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa + ) + + +# register non-jax type +def _flatten_model_output(obj): + return obj.to_tuple(), type(obj) + + +def _unflatten_model_output(aux, children): + return aux(*children) + + +# For text_embedding +jax.tree_util.register_pytree_node( + modeling_outputs.BaseModelOutputWithPastAndCrossAttentions, + _flatten_model_output, + _unflatten_model_output, +) + +# For vae decode +jax.tree_util.register_pytree_node( + diffusers_vae.DecoderOutput, + _flatten_model_output, + _unflatten_model_output, +) + +# For vae encode +jax.tree_util.register_pytree_node( + diffusers_modeling_outputs.AutoencoderKLOutput, + _flatten_model_output, + _unflatten_model_output, +) + + +def _flatten_diagonal_gaussian_distribution( + obj: diffusers_vae.DiagonalGaussianDistribution, +): + return ( + obj.parameters, + obj.mean, + obj.logvar, + obj.deterministic, + obj.std, + obj.var, + ), None + + +def _unflatten_diagonal_gaussian_distribution( + aux, children +) -> diffusers_vae.DiagonalGaussianDistribution: + obj = object.__new__(diffusers_vae.DiagonalGaussianDistribution) + obj.parameters = children[0] + obj.mean = children[1] + obj.logvar = children[2] + obj.deterministic = children[3] + obj.std = children[4] + obj.var = children[5] + return obj + + +jax.tree_util.register_pytree_node( + diffusers_vae.DiagonalGaussianDistribution, + _flatten_diagonal_gaussian_distribution, + _unflatten_diagonal_gaussian_distribution, +) + + +class Args(argparse.Namespace): + size: str + frame_num: int + prompt: str + base_seed: int + image: str + sample_steps: int + print_weights: bool + profile: str + profile_output_path: str + + +def parse_args(): + # Copy args and modify from wan2.2 repo + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--size", + type=str, + default="720*1280", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.", + ) + parser.add_argument( + "--frame_num", + type=int, + default=81, + help="How many frames of video are generated. The number should be 4n+1", + ) + parser.add_argument( + "--prompt", + type=str, + default=DEFAULT_PROMPT, + help="The prompt to generate the video from.", + ) + parser.add_argument( + "--base_seed", + type=int, + default=0, + help="The seed to use for generating the video. Need to specify for multi-host sync.", + ) + parser.add_argument( + "--image", + type=str, + default=DEFAULT_IMAGE_PATH, + help="The image to generate the video from.", + ) + parser.add_argument( + "--sample_steps", type=int, default=40, help="The sampling steps." + ) + parser.add_argument( + "--print_weights", action="store_true", help="print weights in models" + ) + parser.add_argument( + "--profile", + type=str, + default="no", + choices=["no", "dit", "all"], + help="no for no profile, dit for dit only 3 steps, all including vae", + ) + parser.add_argument( + "--profile_output_path", + type=str, + default=DEFAULT_PROFILE_OUT_PATH, + help="path to save profile output", + ) + parser.add_argument( + "--dp", + type=int, + default=2, + help="Data parallelism for positive prompt and negative prompt.", + ) + return parser.parse_args(namespace=Args()) + + +def main(args: Args): + torch.set_default_dtype(torch.bfloat16) + + model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" + dtype = torch.bfloat16 + + with perf_time("load pipe"): + pipe = WanImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype) + + # print weights map for fill sharding + if args.print_weights: + print("text_encoder_shardings = ", end="") + _print_weights(pipe.text_encoder) + print() + print("transformer_shardings = ", end="") + _print_weights(pipe.transformer) + print() + print("vae_shardings = ", end="") + _print_weights(pipe.vae) + print() + + # enable torchax wrap jax array into torch array + torchax.enable_globally() + env = torchax.default_env() + assert isinstance(env, torchax.tensor.Environment) + + # mesh = jax.make_mesh((len(jax.devices()),), ("tp",)) + dp_dim = args.dp + assert len(jax.devices()) % dp_dim == 0 + tp_dim = len(jax.devices()) // dp_dim + mesh_devices = mesh_utils.create_device_mesh( + (dp_dim, tp_dim), allow_split_physical_axes=True + ) + mesh = Mesh(mesh_devices, ("dp", "tp")) + print(f"{mesh=}") + + # Workaround override function to use tpu. Better handle it in torchax + _overide_op_definition( + env, torch.nn.functional.conv2d, functools.partial(_torch_conv2d, env=env) + ) + _overide_op_definition( + env, + torch.nn.functional.scaled_dot_product_attention, + functools.partial(_scaled_dot_product_attention, env=env, mesh=mesh), + ) + + # Put weights into tpu + + with perf_time("Move model to tpu"): + with perf_time(" Move text encoder"): + _move_module(env, pipe.text_encoder) + pipe.text_encoder = torchax.compile(pipe.text_encoder) + pipe.text_encoder.params = _shard_weight_dict( + pipe.text_encoder.params, TEXT_ENCODER_SHARDINGS, mesh + ) + pipe.text_encoder.buffers = _shard_weight_dict( + pipe.text_encoder.buffers, TEXT_ENCODER_SHARDINGS, mesh + ) + + transformer_options = torchax.CompileOptions( + jax_jit_kwargs={"static_argnames": ("return_dict",)} + ) + with perf_time(" Move transformer"): + _move_module(env, pipe.transformer) + pipe.transformer = torchax.compile(pipe.transformer, transformer_options) + pipe.transformer.params = _shard_weight_dict( + pipe.transformer.params, TRANSFORMER_SHARDINGS, mesh + ) + pipe.transformer.buffers = _shard_weight_dict( + pipe.transformer.buffers, TRANSFORMER_SHARDINGS, mesh + ) + + with perf_time(" Move transformer2"): + _move_module(env, pipe.transformer_2) + pipe.transformer_2 = torchax.compile( + pipe.transformer_2, transformer_options + ) + pipe.transformer_2.params = _shard_weight_dict( + pipe.transformer_2.params, TRANSFORMER_SHARDINGS, mesh + ) + pipe.transformer_2.buffers = _shard_weight_dict( + pipe.transformer_2.buffers, TRANSFORMER_SHARDINGS, mesh + ) + + vae_options = torchax.CompileOptions( + methods_to_compile=["encode", "decode"], + jax_jit_kwargs={"static_argnames": ("return_dict",)}, + ) + with perf_time(" Move vae"): + _move_module(env, pipe.vae) + pipe.vae = torchax.compile(pipe.vae, vae_options) + pipe.vae.params = _shard_weight_dict(pipe.vae.params, VAE_SHARDINGS, mesh) + pipe.vae.buffers = _shard_weight_dict(pipe.vae.buffers, VAE_SHARDINGS, mesh) + + image = load_image(args.image) + max_area = MAX_AREA_CONFIGS[args.size] + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + prompt = args.prompt + negative_prompt = DEFAULT_NEG_PROMPT + generator = torch.Generator().manual_seed(args.base_seed) + with mesh: + with perf_time("Warmup and output video"): + output = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=args.frame_num, + guidance_scale=3.5, + num_inference_steps=args.sample_steps, + generator=generator, + ).frames[0] + current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S") + file_name = f"{current_datetime}.mp4" + export_to_video(output, file_name, fps=16) + print(f"output video done. {file_name}") + + if args.profile != "no": + with perf_time("Profile"): + if args.profile == "dit": + output_type = "latent" + else: + output_type = "np" + with jax.profiler.trace(args.profile_output_path): + output = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=args.frame_num, + guidance_scale=3.5, + num_inference_steps=3, + generator=generator, + output_type=output_type, + ).frames[0] + + with perf_time("Benchmark"): + output = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=args.frame_num, + guidance_scale=3.5, + num_inference_steps=args.sample_steps, + generator=generator, + ).frames[0] + + print("Done") + + +if __name__ == "__main__": + args = parse_args() + print(args) + main(args) diff --git a/exp/wan_i2v_input.JPG b/exp/wan_i2v_input.JPG new file mode 100644 index 000000000000..6f6dcced590c Binary files /dev/null and b/exp/wan_i2v_input.JPG differ diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f95c4cf37475..bc0d123e4e55 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -27,6 +27,9 @@ from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution +import jax +from torchax import interop + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1163,6 +1166,10 @@ def _encode(self, x: torch.Tensor): feat_idx=self._enc_conv_idx, ) out = torch.cat([out, out_], 2) + # Prevent jit optmization run multi-step loops simultaneous and cause OOM. + # Add the dependency next x to current out + x, out = jax.lax.optimization_barrier(interop.jax_view((x, out))) + x, out = interop.torch_view((x, out)) enc = self.quant_conv(out) self.clear_cache() @@ -1184,6 +1191,7 @@ def encode( The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ + print(f"[DEBUG] {x.shape=}, {x.dtype=}") if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) @@ -1214,6 +1222,10 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) + # Prevent jit optmization run multi-step loops simultaneous and cause OOM. + # Add the dependency next x to current out + x, out = jax.lax.optimization_barrier(interop.jax_view((x, out))) + x, out = interop.torch_view((x, out)) if self.config.patch_size is not None: out = unpatchify(out, patch_size=self.config.patch_size) @@ -1241,6 +1253,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ + print(f"[DEBUG] {z.shape=}, {z.dtype=}") if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..9ab3194ea033 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -35,6 +35,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +import jax +from torchax import interop +from jax.sharding import PartitionSpec as P +mark_sharding = interop.torch_view(jax.lax.with_sharding_constraint) def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): # encoder_hidden_states is only passed for cross-attention @@ -623,6 +627,9 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + hidden_states = mark_sharding(hidden_states, P("dp")) + encoder_hidden_states = mark_sharding(encoder_hidden_states, P("dp")) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b7fd0b05980f..721492da7caa 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -724,6 +724,7 @@ def __call__( else: boundary_timestep = None + print("[WARNING] cache_context is not support") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -752,27 +753,51 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - with current_model.cache_context("cond"): - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + # with current_model.cache_context("cond"): + # noise_pred = current_model( + # hidden_states=latent_model_input, + # timestep=timestep, + # encoder_hidden_states=prompt_embeds, + # encoder_hidden_states_image=image_embeds, + # attention_kwargs=attention_kwargs, + # return_dict=False, + # )[0] + + # if self.do_classifier_free_guidance: + # with current_model.cache_context("uncond"): + # noise_uncond = current_model( + # hidden_states=latent_model_input, + # timestep=timestep, + # encoder_hidden_states=negative_prompt_embeds, + # encoder_hidden_states_image=image_embeds, + # attention_kwargs=attention_kwargs, + # return_dict=False, + # )[0] + # noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) if self.do_classifier_free_guidance: - with current_model.cache_context("uncond"): - noise_uncond = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + batch_latent_model_input = torch.cat([latent_model_input, latent_model_input]) + batch_timestep = torch.cat([timestep, timestep]) + batch_encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds]) + batch_encoder_hidden_states_image = torch.cat([image_embeds, image_embeds]) if image_embeds is not None else None + else: + batch_latent_model_input = latent_model_input + batch_timestep = timestep + batch_encoder_hidden_states = prompt_embeds + batch_encoder_hidden_states_image = image_embeds + + batch_noise = current_model( + hidden_states=batch_latent_model_input, + timestep=batch_timestep, + encoder_hidden_states=batch_encoder_hidden_states, + encoder_hidden_states_image=batch_encoder_hidden_states_image, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # return is tuple + noise_pred = batch_noise[0:1] + if self.do_classifier_free_guidance: + noise_uncond = batch_noise[1:2] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]