Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1610,23 +1610,18 @@ class StableDiffusionGGML {
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
void* step_callback_data,
bool is_noisy) {
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
int channels = get_latent_channel();
auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels)
: sd::ops::slice(latents, 2, 0, channels)
: latents;
if (preview_mode == PREVIEW_PROJ) {
sd::Tensor<float> _latents = latents;
int patch_sz = 1;
const float(*latent_rgb_proj)[3] = nullptr;
float* latent_rgb_bias = nullptr;
bool is_video = preview_latent_tensor_is_video(latents);
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
if (version == VERSION_LTXAV) {
if (is_video) {
_latents = sd::ops::slice(_latents, 3, 0, 128);
} else {
_latents = sd::ops::slice(_latents, 2, 0, 128);
}
dim = 128;
}

if (dim == 128) {
if (channels == 128) {
if (sd_version_uses_flux2_vae(version)) {
latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias;
Expand All @@ -1638,15 +1633,15 @@ class StableDiffusionGGML {
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 48) {
} else if (channels == 48) {
if (sd_version_is_wan(version)) {
latent_rgb_proj = wan_22_latent_rgb_proj;
latent_rgb_bias = wan_22_latent_rgb_bias;
} else {
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 16) {
} else if (channels == 16) {
if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias;
Expand All @@ -1660,7 +1655,7 @@ class StableDiffusionGGML {
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 4) {
} else if (channels == 4) {
if (sd_version_is_sdxl(version)) {
latent_rgb_proj = sdxl_latent_rgb_proj;
latent_rgb_bias = sdxl_latent_rgb_bias;
Expand All @@ -1671,8 +1666,8 @@ class StableDiffusionGGML {
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim != 3) {
LOG_WARN("No latent to RGB projection known for this model");
} else if (channels != 3) {
LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim);
return;
}

Expand All @@ -1697,14 +1692,13 @@ class StableDiffusionGGML {
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
sd::Tensor<float> vae_latents;
sd::Tensor<float> decoded;
bool is_video = preview_latent_tensor_is_video(latents);
if (preview_vae) {
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
vae_latents = preview_vae->diffusion_to_vae_latents(_latents);
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
} else {
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
vae_latents = first_stage_model->diffusion_to_vae_latents(_latents);
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
}
if (decoded.empty()) {
Expand Down
Loading