Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/dreambooth/test_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,29 @@ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_sdxl_snr_gamma_with_prior_preservation(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 2
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--snr_gamma 5.0
--with_prior_preservation
--class_data_dir docs/source/en/imgs
--class_prompt photo
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# Verify training completed and produced a valid LoRA weights file.
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
6 changes: 6 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

if args.with_prior_preservation:
# Apply the same SNR weighting to the prior loss for consistency.
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
prior_loss = prior_loss.mean(dim=list(range(1, len(prior_loss.shape)))) * mse_loss_weights
prior_loss = prior_loss.mean()

if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
Expand Down
Loading