refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417
refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417prishajain1 wants to merge 1 commit into
Conversation
4696256 to
11ddfef
Compare
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details. |
f58fb9e to
b53d7d2
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
entrpn
left a comment
There was a problem hiding this comment.
@prishajain1 did you update the readme with the best values based on remat?
…ng, and weight loading improvements
b53d7d2 to
c7e492a
Compare
Updated |
Overview
This PR refactors the Flux model architecture in MaxDiffusion to support scanned blocks (nn.scan) for double and single blocks, implements configurable gradient checkpointing (rematerialization) policies, and updates the weights loader to support loading pretrained checkpoints under the scanned format.
Key Changes
MlpAndOutputBlockwrapper) to eliminate redundant recomputation of attention and projection outputs.jnp.splitacross Flux transformer blocks for cleaner layout constraints.nn.scanto optimize compiler tracing and step execution speed on TPUs.FLUX_OPTIMIZEDtoGradientCheckpointTypeto allow configuring block-specific rematerialization policies dynamically via configuration files instead of being hardcoded.util.py) to slice, group, and stack PyTorch checkpoint weights along axis 0 to match the expected format ofnn.scanlayers.