From f75bc9d5522ded44e6991da1ba25009073d1fc1e Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 May 2026 00:17:54 +0800 Subject: [PATCH] feat: add LTX rational latent upscaler --- src/ltx_latent_upscaler.hpp | 234 ++++++++++++++++++++++++++++++------ src/stable-diffusion.cpp | 2 +- 2 files changed, 198 insertions(+), 38 deletions(-) diff --git a/src/ltx_latent_upscaler.hpp b/src/ltx_latent_upscaler.hpp index 93254454d..1cdc02282 100644 --- a/src/ltx_latent_upscaler.hpp +++ b/src/ltx_latent_upscaler.hpp @@ -6,8 +6,10 @@ #include #include #include +#include #include #include +#include #include "common_dit.hpp" #include "ggml_extend.hpp" @@ -26,6 +28,9 @@ namespace LTXVUpsampler { bool spatial_upsample = true; bool temporal_upsample = false; bool rational_resampler = false; + float spatial_scale = 2.f; + int spatial_up_num = 2; + int spatial_down_den = 1; }; static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, @@ -33,14 +38,21 @@ namespace LTXVUpsampler { return tensor_storage_map.find(name) != tensor_storage_map.end(); } - static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map, - const std::string& name, - int64_t fallback) { + static inline int64_t get_tensor_ne(const String2TensorStorage& tensor_storage_map, + const std::string& name, + int axis, + int64_t fallback) { auto it = tensor_storage_map.find(name); - if (it == tensor_storage_map.end()) { + if (it == tensor_storage_map.end() || axis < 0 || axis >= GGML_MAX_DIMS) { return fallback; } - return it->second.ne[0]; + return it->second.ne[axis]; + } + + static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map, + const std::string& name, + int64_t fallback) { + return get_tensor_ne(tensor_storage_map, name, 0, fallback); } static inline int count_module_blocks(const String2TensorStorage& tensor_storage_map, @@ -71,8 +83,32 @@ namespace LTXVUpsampler { if (detected_blocks > 0) { config.num_blocks_per_stage = detected_blocks; } - config.spatial_upsample = has_tensor(tensor_storage_map, "upsampler.0.weight"); - config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight"); + config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight"); + config.spatial_upsample = config.rational_resampler || has_tensor(tensor_storage_map, "upsampler.0.weight"); + config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight"); + if (config.rational_resampler) { + int64_t out_channels = get_tensor_ne(tensor_storage_map, + "upsampler.conv.weight", + 3, + config.mid_channels * 9); + if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) { + int64_t ratio = out_channels / config.mid_channels; + int num = static_cast(std::round(std::sqrt(static_cast(ratio)))); + if (num > 0 && static_cast(num) * num == ratio) { + config.spatial_up_num = num; + } + } + if (config.spatial_up_num == 3) { + config.spatial_down_den = 2; + config.spatial_scale = 1.5f; + } else if (config.spatial_up_num == 4) { + config.spatial_down_den = 1; + config.spatial_scale = 4.f; + } else { + config.spatial_down_den = 1; + config.spatial_scale = static_cast(config.spatial_up_num); + } + } return config; } @@ -160,16 +196,111 @@ namespace LTXVUpsampler { : upscale_factor(upscale_factor) {} ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { - GGML_ASSERT(upscale_factor == 2); + GGML_ASSERT(upscale_factor > 0); int64_t h = x->ne[1]; int64_t w = x->ne[0]; - // x: [b*f, c*4, h, w] -> [b*f, c, h*2, w*2] - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*4] - x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*4] + GGML_ASSERT(x->ne[2] % (upscale_factor * upscale_factor) == 0); + // x: [b*f, c*p1*p2, h, w] -> [b*f, c, h*p1, w*p2] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*p1*p2] + x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*p1*p2] return DiT::unpatchify(ctx->ggml_ctx, x, h, w, upscale_factor, upscale_factor, true); } }; + class BlurDownsample : public GGMLBlock { + protected: + int64_t channels; + int stride; + ggml_tensor* kernel = nullptr; + std::vector kernel_data; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + SD_UNUSED(tensor_storage_map); + if (stride == 1) { + return; + } + kernel = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 5, 5, 1, channels); + std::string name = prefix + "kernel"; + ggml_set_name(kernel, name.c_str()); + + static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f}; + kernel_data.resize(static_cast(5 * 5 * channels)); + for (int64_t c = 0; c < channels; ++c) { + for (int y = 0; y < 5; ++y) { + for (int x = 0; x < 5; ++x) { + kernel_data[static_cast(x + 5 * (y + 5 * c))] = + binomial[y] * binomial[x] / 256.f; + } + } + } + } + + public: + BlurDownsample(int64_t channels, int stride) + : channels(channels), + stride(stride) { + GGML_ASSERT(stride >= 1); + } + + void load_fixed_tensors() { + if (kernel == nullptr || kernel_data.empty()) { + return; + } + ggml_backend_tensor_set(kernel, kernel_data.data(), 0, kernel_data.size() * sizeof(float)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + if (stride == 1) { + return x; + } + GGML_ASSERT(kernel != nullptr); + GGML_ASSERT(x->ne[2] == channels); + if (ctx->conv2d_direct_enabled) { + return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1); + } + return ggml_conv_2d_dw(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1); + } + }; + + class SpatialRationalResampler : public GGMLBlock { + protected: + int64_t mid_channels; + int num; + int den; + + public: + SpatialRationalResampler(int64_t mid_channels, int num, int den) + : mid_channels(mid_channels), + num(num), + den(den) { + GGML_ASSERT(num >= 1); + GGML_ASSERT(den >= 1); + blocks["conv"] = std::shared_ptr(new Conv2d(mid_channels, num * num * mid_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["pixel_shuffle"] = std::shared_ptr(new PixelShuffleND(num)); + blocks["blur_down"] = std::shared_ptr(new BlurDownsample(mid_channels, den)); + } + + void load_fixed_tensors() { + auto blur_down = std::dynamic_pointer_cast(blocks["blur_down"]); + blur_down->load_fixed_tensors(); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["pixel_shuffle"]); + auto blur_down = std::dynamic_pointer_cast(blocks["blur_down"]); + + // rearrange(x, "b c f h w -> (b f) c h w") + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + x = conv->forward(ctx, x); + x = pixel_shuffle->forward(ctx, x); + x = blur_down->forward(ctx, x); + return ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + } + }; + class LatentUpsampler : public GGMLBlock { public: LatentUpsamplerConfig config; @@ -179,7 +310,6 @@ namespace LTXVUpsampler { GGML_ASSERT(this->config.dims == 3); GGML_ASSERT(this->config.spatial_upsample); GGML_ASSERT(!this->config.temporal_upsample); - GGML_ASSERT(!this->config.rational_resampler); blocks["initial_conv"] = std::shared_ptr(new Conv3d(this->config.in_channels, this->config.mid_channels, @@ -190,12 +320,18 @@ namespace LTXVUpsampler { for (int i = 0; i < this->config.num_blocks_per_stage; ++i) { blocks["res_blocks." + std::to_string(i)] = std::shared_ptr(new ResBlock(this->config.mid_channels, this->config.dims)); } - blocks["upsampler.0"] = std::shared_ptr(new Conv2d(this->config.mid_channels, - 4 * this->config.mid_channels, - {3, 3}, - {1, 1}, - {1, 1})); - blocks["upsampler.1"] = std::shared_ptr(new PixelShuffleND(2)); + if (this->config.rational_resampler) { + blocks["upsampler"] = std::shared_ptr(new SpatialRationalResampler(this->config.mid_channels, + this->config.spatial_up_num, + this->config.spatial_down_den)); + } else { + blocks["upsampler.0"] = std::shared_ptr(new Conv2d(this->config.mid_channels, + 4 * this->config.mid_channels, + {3, 3}, + {1, 1}, + {1, 1})); + blocks["upsampler.1"] = std::shared_ptr(new PixelShuffleND(2)); + } for (int i = 0; i < this->config.num_blocks_per_stage; ++i) { blocks["post_upsample_res_blocks." + std::to_string(i)] = std::shared_ptr(new ResBlock(this->config.mid_channels, this->config.dims)); } @@ -207,13 +343,11 @@ namespace LTXVUpsampler { } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { - // x: [b*c, f, h, w] - // return: [b*c, f, h*2, w*2] - auto initial_conv = std::dynamic_pointer_cast(blocks["initial_conv"]); - auto initial_norm = std::dynamic_pointer_cast(blocks["initial_norm"]); - auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); - auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); - auto final_conv = std::dynamic_pointer_cast(blocks["final_conv"]); + // x: [b, c, f, h, w] + // return: [b, c, f, scaled_h, scaled_w] + auto initial_conv = std::dynamic_pointer_cast(blocks["initial_conv"]); + auto initial_norm = std::dynamic_pointer_cast(blocks["initial_norm"]); + auto final_conv = std::dynamic_pointer_cast(blocks["final_conv"]); x = initial_conv->forward(ctx, x); x = initial_norm->forward(ctx, x); @@ -226,11 +360,19 @@ namespace LTXVUpsampler { sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.res_blocks." + std::to_string(i), "x"); } - // rearrange(x, "b c f h w -> (b f) c h w"), - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w] - x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w] - x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2] - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w] + if (config.rational_resampler) { + auto upsampler = std::dynamic_pointer_cast(blocks["upsampler"]); + x = upsampler->forward(ctx, x); + } else { + auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); + + // rearrange(x, "b c f h w -> (b f) c h w"), + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w] + x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w] + x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w] + } sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.spatial_up", "x"); for (int i = 0; i < config.num_blocks_per_stage; ++i) { @@ -243,6 +385,14 @@ namespace LTXVUpsampler { sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x"); return x; } + + void load_fixed_tensors() { + if (!config.rational_resampler) { + return; + } + auto upsampler = std::dynamic_pointer_cast(blocks["upsampler"]); + upsampler->load_fixed_tensors(); + } }; struct LatentUpsamplerRunner : public GGMLRunner { @@ -265,20 +415,23 @@ namespace LTXVUpsampler { } const auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + bool has_regular_spatial = has_tensor(tensor_storage_map, "upsampler.0.weight"); + bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight"); if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") || - !has_tensor(tensor_storage_map, "upsampler.0.weight")) { + (!has_regular_spatial && !has_rational_spatial)) { LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors"); return false; } LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map); if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample || - config.rational_resampler) { - LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d", + config.spatial_up_num < 1 || config.spatial_down_den < 1) { + LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f", config.dims, config.spatial_upsample, config.temporal_upsample, - config.rational_resampler); + config.rational_resampler, + config.spatial_scale); return false; } @@ -291,15 +444,22 @@ namespace LTXVUpsampler { std::map tensors; model->get_param_tensors(tensors); - if (!model_loader.load_tensors(tensors, {}, n_threads)) { + std::set ignore_tensors; + if (config.rational_resampler) { + ignore_tensors.insert("upsampler.blur_down.kernel"); + } + if (!model_loader.load_tensors(tensors, ignore_tensors, n_threads)) { LOG_ERROR("load LTX latent upsampler tensors failed"); return false; } + model->load_fixed_tensors(); - LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d", + LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, rational=%d", config.in_channels, config.mid_channels, - config.num_blocks_per_stage); + config.num_blocks_per_stage, + config.spatial_scale, + config.rational_resampler); return true; } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index eb6845b46..705f9f12a 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -4798,7 +4798,7 @@ static sd::Tensor upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx, audio_latent = unpack_ltxav_audio_latent(packed_latent, audio_length, latent_channels); } - LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> x2", + LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> model output", (int)video_latent.shape()[0], (int)video_latent.shape()[1], (int)video_latent.shape()[2],