Skip to content

Question about Stage 2 loss implementation #4

Description

@xzy-ustc

Hi, thank you for releasing this project.

I am a bit confused about the Stage 2 training objective and would appreciate your clarification.

From my understanding of the paper, the Stage 2 / Diffusion-FT loss is described as an image-space reconstruction loss, using a combination of L1 and L2 losses between the predicted image and the target image.

However, in the current implementation, Stage 2 appears to still use the diffusion noise prediction MSE loss:

# src/diffvs/train_stage2_diffusion_ft.py

pred = unet(
    model_input,
    timesteps,
    encoder_hidden_states=marker_context,
    return_dict=False,
)[0]

loss = F.mse_loss(pred.float(), noise.float(), reduction="mean")

So the current Stage 2 objective seems to be:

MSE(predicted_noise, true_noise)

rather than:

L1(predicted_image, target_image) + L2(predicted_image, target_image)

I also noticed that the main difference from Stage 1 is that Stage 2 uses a fixed timestep:

timesteps = torch.full(
    (target_latents.shape[0],),
    int(args.single_step_timestep),
    ...
)

while Stage 1 samples random timesteps.

Could you please clarify this question?

Specifically:

  1. Is the released Stage 2 implementation intended to use noise prediction MSE instead of the image-space L1 + L2 loss described in the paper?
  2. If the paper loss should be used, should the implementation first reconstruct pred_x0 from the predicted noise, decode it with the VAE, and then compute image-space L1/L2 against the target image?
  3. Is this repository implementing a simplified or modified version of the Stage 2 objective?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions