Skip to content

[flux.2 LoRA] make lora training compatible with flux.2 klein kv#13325

Open
linoytsaban wants to merge 21 commits intohuggingface:mainfrom
linoytsaban:klein-kv-train
Open

[flux.2 LoRA] make lora training compatible with flux.2 klein kv#13325
linoytsaban wants to merge 21 commits intohuggingface:mainfrom
linoytsaban:klein-kv-train

Conversation

@linoytsaban
Copy link
Copy Markdown
Collaborator

@linoytsaban linoytsaban commented Mar 24, 2026

adds a flux.2 klein kv training script, with additional general changes that can be propagated to other lora scripts as well:

  • Default aspect ratio buckets - uses preset aspect ratio buckets to avoid needing to pass them manually.

  • --caption_dropout_rate - Randomly replaces prompts with empty strings at the given rate (default 0.05). Forces the model to learn from visual signal alone on some steps to improve robustness.

  • --shift_timesteps - resolution-adaptive timestep sampling. Samples from sigmoid distribution then warps with t' = (t·μ)/(1+(μ-1)·t) where μ scales with latent sequence length. higher resolution gets more high-noise training. default behaviour in popular trainers like ai-toolkit and khoya.

  • Cache keying fix - latent/prompt caches bug fixed.

  • Multiple conditions support - allows for multiple image conditions per example.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban linoytsaban marked this pull request as ready for review March 31, 2026 10:41
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 31, 2026

Style bot fixed some files and pushed the changes.

@linoytsaban linoytsaban requested a review from sayakpaul March 31, 2026 18:07
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting started on this. Could we see some results across different setups?

return images


def module_filter_fn(mod: torch.nn.Module, fqn: str):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for a different PR: we could move it to training_utils.py and name it module_filter_fn_torchao?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do that in a seprate refactor PR since this also persists in other lora scripts?

return batch


class BucketBatchSampler(BatchSampler):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably be moved to training_utils.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do that in a seprate refactor PR since this also persists in other lora scripts?

# train transformer_blocks and single_transformer_blocks
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
"to_qkv_mlp_proj",
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️

# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the fix needed to ensure the keying fix for precomputed latents?

sigma = sigma.unsqueeze(-1)
return sigma

def calculate_shift(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use it from the pipeline itself?

# When caption dropout triggers, we replace the real prompt embedding with this.
# Note: empty_prompt_embeds and empty_text_ids are computed above before the text encoder is freed.
if args.caption_dropout_rate > 0.0:
if empty_prompt_embeds is None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably not be the case since we're already doing:

if args.caption_dropout_rate > 0.0:
        logger.info("Pre-computing empty prompt embeddings for caption dropout...")
        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
            empty_prompt_embeds, empty_text_ids = compute_text_embeddings("", text_encoding_pipeline)

Comment on lines +1734 to +1740
# Clone when caption dropout is active to avoid mutating the cache.
if args.caption_dropout_rate > 0.0:
prompt_embeds = prompt_embeds_cache[cache_key].clone()
text_ids = text_ids_cache[cache_key].clone()
else:
prompt_embeds = prompt_embeds_cache[cache_key]
text_ids = text_ids_cache[cache_key]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

@github-actions github-actions bot added pipelines examples size/L PR with diff > 200 LOC labels Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples pipelines size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants