diff --git a/examples/dreambooth/test_dreambooth_lora.py b/examples/dreambooth/test_dreambooth_lora.py index e950807d372d..b646d47f4171 100644 --- a/examples/dreambooth/test_dreambooth_lora.py +++ b/examples/dreambooth/test_dreambooth_lora.py @@ -377,3 +377,29 @@ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit # checkpoint-2 should have been deleted {"checkpoint-4", "checkpoint-6"}, ) + + def test_dreambooth_lora_sdxl_snr_gamma_with_prior_preservation(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 2 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --snr_gamma 5.0 + --with_prior_preservation + --class_data_dir docs/source/en/imgs + --class_prompt photo + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # Verify training completed and produced a valid LoRA weights file. + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ac8dd9243df6..e3efbd2a3163 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1853,6 +1853,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() + if args.with_prior_preservation: + # Apply the same SNR weighting to the prior loss for consistency. + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + prior_loss = prior_loss.mean(dim=list(range(1, len(prior_loss.shape)))) * mse_loss_weights + prior_loss = prior_loss.mean() + if args.with_prior_preservation: # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss