Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ vae_spatial: -1 # default to total_device * 2 // (dp)
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True
# Use jax.lax.scan for the diffusion loop (non-cache path only).
# Note: Enabling this will disable per-step profiling.
scan_diffusion_loop: False

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
Expand All @@ -61,21 +64,21 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 0
flash_min_seq_length: 4096
dropout: 0.0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.0
attention_sharding_uniform: True

flash_block_sizes: {
"block_q" : 512,
Expand Down Expand Up @@ -202,9 +205,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: 1
Expand Down Expand Up @@ -338,16 +341,20 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 480
width: 832
height: 720
width: 1280
num_frames: 81
guidance_scale: 5.0
flow_shift: 3.0
flow_shift: 5.0

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
use_cfg_cache: False

# Batch positive and negative prompts in text encoder to save compute.
use_batched_text_encoder: False

use_magcache: False
magcache_thresh: 0.12
magcache_K: 2
Expand All @@ -356,7 +363,7 @@ mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962,

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
num_inference_steps: 50
fps: 16
save_final_checkpoint: False

Expand Down
7 changes: 7 additions & 0 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ flow_shift: 3.0
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
use_cfg_cache: False

# Batch positive and negative prompts in text encoder to save compute.
use_batched_text_encoder: False

# Use jax.lax.scan for the diffusion loop (non-cache path only).
# Note: Enabling this will disable per-step profiling.
scan_diffusion_loop: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
Expand Down
39 changes: 24 additions & 15 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: 1
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -53,6 +53,9 @@ vae_spatial: 1
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True
# Use jax.lax.scan for the diffusion loop (non-cache path only).
# Note: Enabling this will disable per-step profiling.
scan_diffusion_loop: False

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
Expand All @@ -61,20 +64,21 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
dropout: 0.0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.0
attention_sharding_uniform: True

flash_block_sizes: {
"block_q" : 512,
Expand Down Expand Up @@ -159,7 +163,7 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
logical_axis_rules: [
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
Expand Down Expand Up @@ -190,9 +194,9 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: 1
Expand Down Expand Up @@ -304,17 +308,17 @@ prompt: "A cat and a dog baking a cake together in a kitchen. The cat is careful
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
do_classifier_free_guidance: True
height: 480
width: 832
height: 720
width: 1280
num_frames: 81
flow_shift: 3.0
flow_shift: 5.0

# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
# guidance scale factor for low noise transformer
guidance_scale_low: 3.0
guidance_scale_low: 3.0

# guidance scale factor for high noise transformer
guidance_scale_high: 4.0
guidance_scale_high: 4.0

# The timestep threshold. If `t` is at or above this value,
# the `high_noise_model` is considered as the required model.
Expand All @@ -323,14 +327,19 @@ boundary_ratio: 0.875

# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False

# Batch positive and negative prompts in text encoder to save compute.
use_batched_text_encoder: False


# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
# when predicted output change (based on accumulated latent/timestep drift) is small
use_sen_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 24
num_inference_steps: 40
fps: 16
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
27 changes: 20 additions & 7 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -52,6 +53,9 @@ replicate_vae: False
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True
# Use jax.lax.scan for the diffusion loop (non-cache path only).
# Note: Enabling this will disable per-step profiling.
scan_diffusion_loop: False

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
Expand All @@ -60,7 +64,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
Expand All @@ -69,7 +73,11 @@ dropout: 0.0
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {
Expand Down Expand Up @@ -184,13 +192,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down Expand Up @@ -306,6 +314,11 @@ flow_shift: 5.0

# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False

# Batch positive and negative prompts in text encoder to save compute.
use_batched_text_encoder: False


# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False
use_magcache: False
Expand Down
29 changes: 21 additions & 8 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -52,6 +53,9 @@ replicate_vae: False
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True
# Use jax.lax.scan for the diffusion loop (non-cache path only).
# Note: Enabling this will disable per-step profiling.
scan_diffusion_loop: False

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
Expand All @@ -60,7 +64,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
Expand All @@ -69,7 +73,11 @@ dropout: 0.0
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {
Expand Down Expand Up @@ -185,13 +193,13 @@ data_sharding: [['data', 'fsdp', 'context', 'tensor']]
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_data_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down Expand Up @@ -318,12 +326,17 @@ boundary_ratio: 0.875

# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False

# Batch positive and negative prompts in text encoder to save compute.
use_batched_text_encoder: False


# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50
num_inference_steps: 40
fps: 16
save_final_checkpoint: False

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
f"{'=' * 50}",
f" Load (checkpoint): {load_time:>7.1f}s",
f" Compile: {compile_time:>7.1f}s",
f" {'─' * 40}",
f" Inference: {generation_time:>7.1f}s",
]
if trace:
summary.extend([
f" {'─' * 40}",
f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s",
f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s",
f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s",
Expand Down
Loading
Loading