From 058b22a6f688d3e3e6ff8b4424f0e2ff86c103bf Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Tue, 5 May 2026 21:52:12 +0000 Subject: [PATCH] feat: add optional batched text encoder and diffusion loop with torch.compile --- src/maxdiffusion/configs/base_wan_14b.yml | 29 +++++--- src/maxdiffusion/configs/base_wan_1_3b.yml | 7 ++ src/maxdiffusion/configs/base_wan_27b.yml | 39 +++++++---- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 27 ++++++-- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 29 +++++--- src/maxdiffusion/generate_wan.py | 2 +- .../pipelines/wan/wan_pipeline.py | 54 +++++++++++---- .../pipelines/wan/wan_pipeline_2_1.py | 67 +++++++++++++++++- .../pipelines/wan/wan_pipeline_2_2.py | 63 +++++++++++++++++ .../pipelines/wan/wan_pipeline_i2v_2p1.py | 69 ++++++++++++++++++- .../pipelines/wan/wan_pipeline_i2v_2p2.py | 55 ++++++++++++++- 11 files changed, 379 insertions(+), 62 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 7ffb659c8..c2c83c9f7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -53,6 +53,9 @@ vae_spatial: -1 # default to total_device * 2 // (dp) precision: "DEFAULT" # Use jax.lax.scan for transformer layers scan_layers: True +# Use jax.lax.scan for the diffusion loop (non-cache path only). +# Note: Enabling this will disable per-step profiling. +scan_diffusion_loop: False # if False state is not jitted and instead replicate is called. This is good for debugging on single host # It must be True for multi-host. @@ -61,21 +64,21 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom use_base2_exp: True use_experimental_scheduler: True -flash_min_seq_length: 0 +flash_min_seq_length: 4096 +dropout: 0.0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. # However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True +mask_padding_tokens: True # Maxdiffusion has 2 types of attention sharding strategies: # 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded # in cross attention q. -attention_sharding_uniform: True -dropout: 0.0 +attention_sharding_uniform: True flash_block_sizes: { "block_q" : 512, @@ -202,9 +205,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_data_parallelism: 1 dcn_fsdp_parallelism: 1 -dcn_context_parallelism: -1 +dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_tensor_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: 1 @@ -338,16 +341,20 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 guidance_scale: 5.0 -flow_shift: 3.0 +flow_shift: 5.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) # Skips the unconditional forward pass on ~35% of steps via residual compensation. # See: FasterCache (Lv et al. 2024), WAN 2.1 paper ยง4.4.2 use_cfg_cache: False + +# Batch positive and negative prompts in text encoder to save compute. +use_batched_text_encoder: False + use_magcache: False magcache_thresh: 0.12 magcache_K: 2 @@ -356,7 +363,7 @@ mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 +num_inference_steps: 50 fps: 16 save_final_checkpoint: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 9e59ba9ce..1fd384eb1 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -302,6 +302,13 @@ flow_shift: 3.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) use_cfg_cache: False +# Batch positive and negative prompts in text encoder to save compute. +use_batched_text_encoder: False + +# Use jax.lax.scan for the diffusion loop (non-cache path only). +# Note: Enabling this will disable per-step profiling. +scan_diffusion_loop: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index f80c15515..1ce67a3cf 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -44,7 +44,7 @@ activations_dtype: 'bfloat16' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False -vae_spatial: 1 +vae_spatial: -1 # default to total_device * 2 // (dp) # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" @@ -53,6 +53,9 @@ vae_spatial: 1 precision: "DEFAULT" # Use jax.lax.scan for transformer layers scan_layers: True +# Use jax.lax.scan for the diffusion loop (non-cache path only). +# Note: Enabling this will disable per-step profiling. +scan_diffusion_loop: False # if False state is not jitted and instead replicate is called. This is good for debugging on single host # It must be True for multi-host. @@ -61,20 +64,21 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 +dropout: 0.0 + # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. # However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True +mask_padding_tokens: True # Maxdiffusion has 2 types of attention sharding strategies: # 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) # 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded # in cross attention q. -attention_sharding_uniform: True -dropout: 0.0 +attention_sharding_uniform: True flash_block_sizes: { "block_q" : 512, @@ -159,7 +163,7 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor'] logical_axis_rules: [ ['batch', ['data', 'fsdp']], ['activation_batch', ['data', 'fsdp']], - ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_self_attn_heads', ['context', 'tensor']], ['activation_cross_attn_q_length', ['context', 'tensor']], ['activation_length', 'context'], ['activation_heads', 'tensor'], @@ -190,9 +194,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_data_parallelism: 1 dcn_fsdp_parallelism: 1 -dcn_context_parallelism: -1 +dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_tensor_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: 1 @@ -304,17 +308,17 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 -flow_shift: 3.0 +flow_shift: 5.0 # Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py # guidance scale factor for low noise transformer -guidance_scale_low: 3.0 +guidance_scale_low: 3.0 # guidance scale factor for high noise transformer -guidance_scale_high: 4.0 +guidance_scale_high: 4.0 # The timestep threshold. If `t` is at or above this value, # the `high_noise_model` is considered as the required model. @@ -323,14 +327,19 @@ boundary_ratio: 0.875 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False + +# Batch positive and negative prompts in text encoder to save compute. +use_batched_text_encoder: False + + # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) โ€” skip forward pass # when predicted output change (based on accumulated latent/timestep drift) is small use_sen_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 -fps: 24 +num_inference_steps: 40 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index b136c7a9e..214cf5ce4 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False +vae_spatial: -1 # default to total_device * 2 // (dp) # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" @@ -52,6 +53,9 @@ replicate_vae: False precision: "DEFAULT" # Use jax.lax.scan for transformer layers scan_layers: True +# Use jax.lax.scan for the diffusion loop (non-cache path only). +# Note: Enabling this will disable per-step profiling. +scan_diffusion_loop: False # if False state is not jitted and instead replicate is called. This is good for debugging on single host # It must be True for multi-host. @@ -60,7 +64,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 @@ -69,7 +73,11 @@ dropout: 0.0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. # However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. attention_sharding_uniform: True flash_block_sizes: { @@ -184,13 +192,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 -dcn_context_parallelism: 1 +dcn_data_parallelism: 1 +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_tensor_parallelism: 1 ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded -ici_context_parallelism: 1 +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -306,6 +314,11 @@ flow_shift: 5.0 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False + +# Batch positive and negative prompts in text encoder to save compute. +use_batched_text_encoder: False + + # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False use_magcache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 4af011879..d2eb451d4 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False +vae_spatial: -1 # default to total_device * 2 // (dp) # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" @@ -52,6 +53,9 @@ replicate_vae: False precision: "DEFAULT" # Use jax.lax.scan for transformer layers scan_layers: True +# Use jax.lax.scan for the diffusion loop (non-cache path only). +# Note: Enabling this will disable per-step profiling. +scan_diffusion_loop: False # if False state is not jitted and instead replicate is called. This is good for debugging on single host # It must be True for multi-host. @@ -60,7 +64,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 @@ -69,7 +73,11 @@ dropout: 0.0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. # However, when padding tokens are significant, this will lead to worse quality and should be set to True. -mask_padding_tokens: True +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. attention_sharding_uniform: True flash_block_sizes: { @@ -185,13 +193,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']] # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. -dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 -dcn_context_parallelism: 1 +dcn_data_parallelism: 1 +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_tensor_parallelism: 1 ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded -ici_context_parallelism: 1 +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -318,12 +326,17 @@ boundary_ratio: 0.875 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False + +# Batch positive and negative prompts in text encoder to save compute. +use_batched_text_encoder: False + + # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 50 +num_inference_steps: 40 fps: 16 save_final_checkpoint: False diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 2a3fae518..c0f71c84a 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -304,11 +304,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f"{'=' * 50}", f" Load (checkpoint): {load_time:>7.1f}s", f" Compile: {compile_time:>7.1f}s", - f" {'โ”€' * 40}", f" Inference: {generation_time:>7.1f}s", ] if trace: summary.extend([ + f" {'โ”€' * 40}", f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s", f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s", f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s", diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 50a82607b..608f7282d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -269,10 +269,13 @@ def __init__( @classmethod def load_text_encoder(cls, config: HyperParameters): + torch_dtype = getattr(torch, str(config.weights_dtype), torch.float32) text_encoder = UMT5EncoderModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", + torch_dtype=torch_dtype, ) + text_encoder = torch.compile(text_encoder) return text_encoder @classmethod @@ -501,24 +504,45 @@ def encode_prompt( negative_prompt_embeds: jax.Array = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32) - - if negative_prompt_embeds is None: - batch_size = len(prompt_embeds) - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, + batch_size = len(prompt) + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + use_batched_text_encoder = getattr(self.config, "use_batched_text_encoder", False) + if use_batched_text_encoder and prompt_embeds is None and negative_prompt_embeds is None: + # Batch both together + combined_prompts = prompt + negative_prompt + combined_embeds = self._get_t5_prompt_embeds( + prompt=combined_prompts, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, ) - negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32) + combined_embeds = jnp.array(combined_embeds.detach().float().numpy(), dtype=jnp.float32) + + # Split back + prompt_embeds = combined_embeds[: batch_size * num_videos_per_prompt] + negative_prompt_embeds = combined_embeds[batch_size * num_videos_per_prompt :] + + else: + # Fallback to separate encoding if one of them is already provided + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = jnp.array(prompt_embeds.detach().float().numpy(), dtype=jnp.float32) + + if negative_prompt_embeds is None: + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().float().numpy(), dtype=jnp.float32) return prompt_embeds, negative_prompt_embeds diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 8e859e8b8..e0a2f05e6 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler import numpy as np +import time from ... import max_utils @@ -114,6 +115,9 @@ def __call__( "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0." ) + trace = {} + t_cond_start = time.perf_counter() + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, @@ -128,6 +132,9 @@ def __call__( negative_prompt_embeds, vae_only, ) + latents.block_until_ready() + prompt_embeds.block_until_ready() + trace["conditioning"] = time.perf_counter() - t_cond_start graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) @@ -147,6 +154,7 @@ def __call__( config=self.config, ) + t_denoise_start = time.perf_counter() with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( graphdef=graphdef, @@ -157,7 +165,16 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) latents = self._denormalize_latents(latents) - return self._decode_latents_to_video(latents) + latents.block_until_ready() + trace["denoise_total"] = time.perf_counter() - t_denoise_start + + t_decode_start = time.perf_counter() + video = self._decode_latents_to_video(latents) + if hasattr(video, "block_until_ready"): + video.block_until_ready() + trace["vae_decode"] = time.perf_counter() - t_decode_start + + return video, trace def run_inference_2_1( @@ -261,6 +278,54 @@ def run_inference_2_1( profiler_steps = config.profiler_steps if config else 0 last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + + if scan_diffusion_loop and not use_magcache and not use_cfg_cache: + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + + scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) + + def scan_body(carry, t): + current_latents, current_scheduler_state = carry + + if do_cfg: + latents_doubled = jnp.concatenate([current_latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + else: + timestep = jnp.broadcast_to(t, bsz) + noise_pred, _ = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + current_latents, + timestep, + prompt_cond_embeds, + do_classifier_free_guidance=False, + guidance_scale=guidance_scale, + ) + + new_latents, new_scheduler_state = scheduler.step( + current_scheduler_state, noise_pred, t, current_latents, return_dict=False + ) + + return (new_latents, new_scheduler_state), None + + initial_carry = (latents, scheduler_state) + + final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps) + + final_latents, _ = final_carry + return final_latents + profiler = None for step in range(num_inference_steps): if config and max_utils.profiler_enabled(config) and step == first_profiling_step: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 95912e436..77331d66d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -151,6 +151,8 @@ def __call__( negative_prompt_embeds, vae_only, ) + latents.block_until_ready() + prompt_embeds.block_until_ready() trace["conditioning"] = time.perf_counter() - t_cond_start low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) @@ -191,6 +193,8 @@ def __call__( t_decode_start = time.perf_counter() video = self._decode_latents_to_video(latents) + if hasattr(video, "block_until_ready"): + video.block_until_ready() trace["vae_decode"] = time.perf_counter() - t_decode_start return video, trace @@ -471,6 +475,65 @@ def run_inference_2_2( profiler_steps = config.profiler_steps if config else 0 last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + + def high_noise_branch(ops): + model_latents_in, timestep_in = ops + return transformer_forward_pass( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + model_latents_in, + timestep_in, + prompt_embeds_combined, + do_classifier_free_guidance, + guidance_scale_high, + ) + + def low_noise_branch(ops): + model_latents_in, timestep_in = ops + return transformer_forward_pass( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + model_latents_in, + timestep_in, + prompt_embeds_combined, + do_classifier_free_guidance, + guidance_scale_low, + ) + + if scan_diffusion_loop: + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + + scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) + + def scan_body(carry, t): + current_latents, current_scheduler_state = carry + + if do_classifier_free_guidance: + model_latents = jnp.concatenate([current_latents] * 2) + else: + model_latents = current_latents + + timestep = jnp.broadcast_to(t, model_latents.shape[0]) + use_high_noise = jnp.greater_equal(t, boundary) + + noise_pred, latents_out = jax.lax.cond(use_high_noise, high_noise_branch, low_noise_branch, (model_latents, timestep)) + + new_latents, new_scheduler_state = scheduler.step( + current_scheduler_state, noise_pred, t, latents_out, return_dict=False + ) + + return (new_latents, new_scheduler_state), None + + initial_carry = (latents, scheduler_state) + + final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps) + + final_latents, _ = final_carry + return final_latents + profiler = None for step in range(num_inference_steps): if config and max_utils.profiler_enabled(config) and step == first_profiling_step: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 787f22957..0abe4fa5b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -26,6 +26,7 @@ from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler import numpy as np +import time from ... import max_utils @@ -180,6 +181,9 @@ def __call__( max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) + trace = {} + t_cond_start = time.perf_counter() + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( prompt, image, @@ -222,6 +226,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): last_image=last_image_tensor, num_videos_per_prompt=num_videos_per_prompt, ) + latents.block_until_ready() + condition.block_until_ready() + trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape @@ -257,6 +264,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): config=self.config, ) + t_denoise_start = time.perf_counter() with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( latents=latents, @@ -268,10 +276,19 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) latents = self._denormalize_latents(latents) + latents.block_until_ready() + trace["denoise_total"] = time.perf_counter() - t_denoise_start if output_type == "latent": - return latents - return self._decode_latents_to_video(latents) + return latents, trace + + t_decode_start = time.perf_counter() + video = self._decode_latents_to_video(latents) + if hasattr(video, "block_until_ready"): + video.block_until_ready() + trace["vae_decode"] = time.perf_counter() - t_decode_start + + return video, trace def run_inference_2_1_i2v( @@ -317,6 +334,54 @@ def run_inference_2_1_i2v( profiler_steps = config.profiler_steps if config else 0 last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + + if scan_diffusion_loop and not use_magcache: + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + + scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) + + def scan_body(carry, t): + current_latents, current_scheduler_state = carry + + latents_input = current_latents + if do_cfg: + latents_input = jnp.concatenate([current_latents, current_latents], axis=0) + + latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + + outputs = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latent_model_input, + timestep, + prompt_embeds_combined, + do_classifier_free_guidance=do_cfg, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds_combined, + skip_blocks=None, + cached_residual=None, + return_residual=False, + ) + noise_pred, _ = outputs + + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + new_latents, new_scheduler_state = scheduler.step( + current_scheduler_state, noise_pred, t, current_latents, return_dict=False + ) + + return (new_latents, new_scheduler_state), None + + initial_carry = (latents, scheduler_state) + + final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps) + + final_latents, _ = final_carry + return final_latents + profiler = None for step in range(num_inference_steps): if config and max_utils.profiler_enabled(config) and step == first_profiling_step: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index d8398f58f..f466ec574 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -24,6 +24,7 @@ import jax import jax.numpy as jnp import numpy as np +import time from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler from ... import max_utils @@ -202,6 +203,9 @@ def __call__( max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) + trace = {} + t_cond_start = time.perf_counter() + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( prompt, image, @@ -247,6 +251,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): latents=latents, last_image=last_image_tensor, ) + latents.block_until_ready() + condition.block_until_ready() + trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape @@ -283,6 +290,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): config=self.config, ) + t_denoise_start = time.perf_counter() with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( low_noise_graphdef=low_noise_graphdef, @@ -299,10 +307,19 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) latents = self._denormalize_latents(latents) + latents.block_until_ready() + trace["denoise_total"] = time.perf_counter() - t_denoise_start if output_type == "latent": - return latents - return self._decode_latents_to_video(latents) + return latents, trace + + t_decode_start = time.perf_counter() + video = self._decode_latents_to_video(latents) + if hasattr(video, "block_until_ready"): + video.block_until_ready() + trace["vae_decode"] = time.perf_counter() - t_decode_start + + return video, trace def run_inference_2_2_i2v( @@ -609,6 +626,40 @@ def low_noise_branch(operands): profiler_steps = config.profiler_steps if config else 0 last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False + + if scan_diffusion_loop: + timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32) + + scheduler_state = scheduler_state.replace(last_sample=jnp.zeros_like(latents), step_index=jnp.array(0, dtype=jnp.int32)) + + def scan_body(carry, t): + current_latents, current_scheduler_state = carry + + latents_input = current_latents + if do_classifier_free_guidance: + latents_input = jnp.concatenate([current_latents, current_latents], axis=0) + latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred, _ = jax.lax.cond( + use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + new_latents, new_scheduler_state = scheduler.step( + current_scheduler_state, noise_pred, t, current_latents, return_dict=False + ) + + return (new_latents, new_scheduler_state), None + + initial_carry = (latents, scheduler_state) + + final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps) + + final_latents, _ = final_carry + return final_latents + profiler = None for step in range(num_inference_steps): if config and max_utils.profiler_enabled(config) and step == first_profiling_step: