[flux.2 LoRA] make lora training compatible with flux.2 klein kv#13325
[flux.2 LoRA] make lora training compatible with flux.2 klein kv#13325linoytsaban wants to merge 21 commits intohuggingface:mainfrom
Conversation
…transformer is None (e.g. when initializing the pipeline as a text encoding pipeline)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Style bot fixed some files and pushed the changes. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for getting started on this. Could we see some results across different setups?
examples/dreambooth/train_dreambooth_lora_flux2_klein_kv_img2img.py
Outdated
Show resolved
Hide resolved
| return images | ||
|
|
||
|
|
||
| def module_filter_fn(mod: torch.nn.Module, fqn: str): |
There was a problem hiding this comment.
Maybe for a different PR: we could move it to training_utils.py and name it module_filter_fn_torchao?
There was a problem hiding this comment.
maybe do that in a seprate refactor PR since this also persists in other lora scripts?
| return batch | ||
|
|
||
|
|
||
| class BucketBatchSampler(BatchSampler): |
There was a problem hiding this comment.
Can probably be moved to training_utils.py?
There was a problem hiding this comment.
maybe do that in a seprate refactor PR since this also persists in other lora scripts?
| # train transformer_blocks and single_transformer_blocks | ||
| target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [ | ||
| "to_qkv_mlp_proj", | ||
| *[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)], |
| # if cache_latents is set to True, we encode images to latents and store them. | ||
| # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided | ||
| # we encode them in advance as well. | ||
| precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts |
There was a problem hiding this comment.
What was the fix needed to ensure the keying fix for precomputed latents?
| sigma = sigma.unsqueeze(-1) | ||
| return sigma | ||
|
|
||
| def calculate_shift( |
There was a problem hiding this comment.
Could use it from the pipeline itself?
| # When caption dropout triggers, we replace the real prompt embedding with this. | ||
| # Note: empty_prompt_embeds and empty_text_ids are computed above before the text encoder is freed. | ||
| if args.caption_dropout_rate > 0.0: | ||
| if empty_prompt_embeds is None: |
There was a problem hiding this comment.
This will probably not be the case since we're already doing:
if args.caption_dropout_rate > 0.0:
logger.info("Pre-computing empty prompt embeddings for caption dropout...")
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
empty_prompt_embeds, empty_text_ids = compute_text_embeddings("", text_encoding_pipeline)| # Clone when caption dropout is active to avoid mutating the cache. | ||
| if args.caption_dropout_rate > 0.0: | ||
| prompt_embeds = prompt_embeds_cache[cache_key].clone() | ||
| text_ids = text_ids_cache[cache_key].clone() | ||
| else: | ||
| prompt_embeds = prompt_embeds_cache[cache_key] | ||
| text_ids = text_ids_cache[cache_key] |
…mg.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
adds a flux.2 klein kv training script, with additional general changes that can be propagated to other lora scripts as well:
Default aspect ratio buckets - uses preset aspect ratio buckets to avoid needing to pass them manually.
--caption_dropout_rate- Randomly replaces prompts with empty strings at the given rate (default 0.05). Forces the model to learn from visual signal alone on some steps to improve robustness.--shift_timesteps- resolution-adaptive timestep sampling. Samples from sigmoid distribution then warps witht' = (t·μ)/(1+(μ-1)·t)where μ scales with latent sequence length. higher resolution gets more high-noise training. default behaviour in popular trainers like ai-toolkit and khoya.Cache keying fix - latent/prompt caches bug fixed.
Multiple conditions support - allows for multiple image conditions per example.