feat(wan): Add text encoder batching and optional scan loop for diffusion#397
Open
feat(wan): Add text encoder batching and optional scan loop for diffusion#397
Conversation
54e608d to
4840b6f
Compare
4840b6f to
1968294
Compare
466f90e to
151df42
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
151df42 to
867ae29
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces two valuable performance optimizations for the WAN pipeline: batched text encoding and a jax.lax.scan-based diffusion loop. These changes improve compute efficiency and reduce Python loop overhead during inference. The implementation is clean and integrates well with the existing architecture.
🔍 General Feedback
- Optimization Consistency: The batched text encoder logic correctly handles the partitioning of embeddings back into positive and negative sets, ensuring compatibility with the existing API.
- Robustness: I've identified one potential unsafe access to the
configobject in the scan loop path which could lead to a crash ifconfigisNone. A simple fix has been suggested. - Performance: The use of
jax.lax.scanfor the non-cache path is a great addition for performance-sensitive workloads on TPU/GPU.
d237477 to
921290d
Compare
mbohlool
requested changes
May 5, 2026
921290d to
1e2f5c1
Compare
1e2f5c1 to
9f14475
Compare
Collaborator
Author
|
Done! PTAL @mbohlool |
9f14475 to
2f79061
Compare
2f79061 to
4945072
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces two main optimizations for the WAN pipelines (T2V 2.1/2.2 and I2V 2.1/2.2) to improve performance and resource utilization, and adds timing instrumentation:
jax.lax.scanfor the non-cache path of the diffusion process in all four main WAN pipelines. This avoids Python loop overhead while remaining compatible withscan_layers: trueat the layer level. For WAN 2.2 pipelines, it usesjax.lax.condto switch between the dual transformers at each step.tracedictionary) to all pipelines to support theTIMING SUMMARYprintout ingenerate_wan.py, providing visibility into Conditioning, Denoise Total, and VAE Decode times.Changes
maxdiffusion/pipelines/wan
[MODIFY] wan_pipeline.py
encode_promptto batch positive and negative prompts whenuse_batched_text_encoderis enabled in the config.[MODIFY] wan_pipeline_2_2.py, wan_pipeline_i2v_2p2.py
run_inferencemethods usingjax.lax.scanandjax.lax.cond.[MODIFY] wan_pipeline_2_1.py, wan_pipeline_i2v_2p1.py
jax.lax.scan(without needinglax.condas they use a single transformer).tracedictionary return from__call__to support timing summary.maxdiffusion/configs
[MODIFY] All 5 WAN config files (
base_wan_*.yml)use_batched_text_encoder: Falseby default.scan_diffusion_loop: Falseby default, with a warning that enabling it will disable per-step profiling.Generation Time
Environment & Configuration:
Command: https://paste.googleplex.com/6221970925551616