Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 197 additions & 37 deletions src/ltx_latent_upscaler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#include <cstdlib>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "common_dit.hpp"
#include "ggml_extend.hpp"
Expand All @@ -26,21 +28,31 @@ namespace LTXVUpsampler {
bool spatial_upsample = true;
bool temporal_upsample = false;
bool rational_resampler = false;
float spatial_scale = 2.f;
int spatial_up_num = 2;
int spatial_down_den = 1;
};

static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
const std::string& name) {
return tensor_storage_map.find(name) != tensor_storage_map.end();
}

static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int64_t fallback) {
static inline int64_t get_tensor_ne(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int axis,
int64_t fallback) {
auto it = tensor_storage_map.find(name);
if (it == tensor_storage_map.end()) {
if (it == tensor_storage_map.end() || axis < 0 || axis >= GGML_MAX_DIMS) {
return fallback;
}
return it->second.ne[0];
return it->second.ne[axis];
}

static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int64_t fallback) {
return get_tensor_ne(tensor_storage_map, name, 0, fallback);
}

static inline int count_module_blocks(const String2TensorStorage& tensor_storage_map,
Expand Down Expand Up @@ -71,8 +83,32 @@ namespace LTXVUpsampler {
if (detected_blocks > 0) {
config.num_blocks_per_stage = detected_blocks;
}
config.spatial_upsample = has_tensor(tensor_storage_map, "upsampler.0.weight");
config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight");
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
config.spatial_upsample = config.rational_resampler || has_tensor(tensor_storage_map, "upsampler.0.weight");
config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight");
if (config.rational_resampler) {
int64_t out_channels = get_tensor_ne(tensor_storage_map,
"upsampler.conv.weight",
3,
config.mid_channels * 9);
if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) {
int64_t ratio = out_channels / config.mid_channels;
int num = static_cast<int>(std::round(std::sqrt(static_cast<double>(ratio))));
if (num > 0 && static_cast<int64_t>(num) * num == ratio) {
config.spatial_up_num = num;
}
}
if (config.spatial_up_num == 3) {
config.spatial_down_den = 2;
config.spatial_scale = 1.5f;
} else if (config.spatial_up_num == 4) {
config.spatial_down_den = 1;
config.spatial_scale = 4.f;
} else {
config.spatial_down_den = 1;
config.spatial_scale = static_cast<float>(config.spatial_up_num);
}
}
return config;
}

Expand Down Expand Up @@ -160,16 +196,111 @@ namespace LTXVUpsampler {
: upscale_factor(upscale_factor) {}

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
GGML_ASSERT(upscale_factor == 2);
GGML_ASSERT(upscale_factor > 0);
int64_t h = x->ne[1];
int64_t w = x->ne[0];
// x: [b*f, c*4, h, w] -> [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*4]
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*4]
GGML_ASSERT(x->ne[2] % (upscale_factor * upscale_factor) == 0);
// x: [b*f, c*p1*p2, h, w] -> [b*f, c, h*p1, w*p2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [b*f, h, w, c*p1*p2]
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); // [b*f, h*w, c*p1*p2]
return DiT::unpatchify(ctx->ggml_ctx, x, h, w, upscale_factor, upscale_factor, true);
}
};

class BlurDownsample : public GGMLBlock {
protected:
int64_t channels;
int stride;
ggml_tensor* kernel = nullptr;
std::vector<float> kernel_data;

void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
SD_UNUSED(tensor_storage_map);
if (stride == 1) {
return;
}
kernel = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 5, 5, 1, channels);
std::string name = prefix + "kernel";
ggml_set_name(kernel, name.c_str());

static const float binomial[5] = {1.f, 4.f, 6.f, 4.f, 1.f};
kernel_data.resize(static_cast<size_t>(5 * 5 * channels));
for (int64_t c = 0; c < channels; ++c) {
for (int y = 0; y < 5; ++y) {
for (int x = 0; x < 5; ++x) {
kernel_data[static_cast<size_t>(x + 5 * (y + 5 * c))] =
binomial[y] * binomial[x] / 256.f;
}
}
}
}

public:
BlurDownsample(int64_t channels, int stride)
: channels(channels),
stride(stride) {
GGML_ASSERT(stride >= 1);
}

void load_fixed_tensors() {
if (kernel == nullptr || kernel_data.empty()) {
return;
}
ggml_backend_tensor_set(kernel, kernel_data.data(), 0, kernel_data.size() * sizeof(float));
}

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
if (stride == 1) {
return x;
}
GGML_ASSERT(kernel != nullptr);
GGML_ASSERT(x->ne[2] == channels);
if (ctx->conv2d_direct_enabled) {
return ggml_conv_2d_dw_direct(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1);
}
return ggml_conv_2d_dw(ctx->ggml_ctx, kernel, x, stride, stride, 2, 2, 1, 1);
}
};

class SpatialRationalResampler : public GGMLBlock {
protected:
int64_t mid_channels;
int num;
int den;

public:
SpatialRationalResampler(int64_t mid_channels, int num, int den)
: mid_channels(mid_channels),
num(num),
den(den) {
GGML_ASSERT(num >= 1);
GGML_ASSERT(den >= 1);
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(mid_channels, num * num * mid_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["pixel_shuffle"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(num));
blocks["blur_down"] = std::shared_ptr<GGMLBlock>(new BlurDownsample(mid_channels, den));
}

void load_fixed_tensors() {
auto blur_down = std::dynamic_pointer_cast<BlurDownsample>(blocks["blur_down"]);
blur_down->load_fixed_tensors();
}

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["pixel_shuffle"]);
auto blur_down = std::dynamic_pointer_cast<BlurDownsample>(blocks["blur_down"]);

// rearrange(x, "b c f h w -> (b f) c h w")
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
x = conv->forward(ctx, x);
x = pixel_shuffle->forward(ctx, x);
x = blur_down->forward(ctx, x);
return ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
}
};

class LatentUpsampler : public GGMLBlock {
public:
LatentUpsamplerConfig config;
Expand All @@ -179,7 +310,6 @@ namespace LTXVUpsampler {
GGML_ASSERT(this->config.dims == 3);
GGML_ASSERT(this->config.spatial_upsample);
GGML_ASSERT(!this->config.temporal_upsample);
GGML_ASSERT(!this->config.rational_resampler);

blocks["initial_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.in_channels,
this->config.mid_channels,
Expand All @@ -190,12 +320,18 @@ namespace LTXVUpsampler {
for (int i = 0; i < this->config.num_blocks_per_stage; ++i) {
blocks["res_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResBlock(this->config.mid_channels, this->config.dims));
}
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
4 * this->config.mid_channels,
{3, 3},
{1, 1},
{1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(2));
if (this->config.rational_resampler) {
blocks["upsampler"] = std::shared_ptr<GGMLBlock>(new SpatialRationalResampler(this->config.mid_channels,
this->config.spatial_up_num,
this->config.spatial_down_den));
} else {
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
4 * this->config.mid_channels,
{3, 3},
{1, 1},
{1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new PixelShuffleND(2));
}
for (int i = 0; i < this->config.num_blocks_per_stage; ++i) {
blocks["post_upsample_res_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResBlock(this->config.mid_channels, this->config.dims));
}
Expand All @@ -207,13 +343,11 @@ namespace LTXVUpsampler {
}

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
// x: [b*c, f, h, w]
// return: [b*c, f, h*2, w*2]
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
// x: [b, c, f, h, w]
// return: [b, c, f, scaled_h, scaled_w]
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);

x = initial_conv->forward(ctx, x);
x = initial_norm->forward(ctx, x);
Expand All @@ -226,11 +360,19 @@ namespace LTXVUpsampler {
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.res_blocks." + std::to_string(i), "x");
}

// rearrange(x, "b c f h w -> (b f) c h w"),
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w]
x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w]
x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w]
if (config.rational_resampler) {
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
x = upsampler->forward(ctx, x);
} else {
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);

// rearrange(x, "b c f h w -> (b f) c h w"),
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*f, c, h, w]
x = upsample_conv->forward(ctx, x); // [b*f, c*4, h, w]
x = pixel_shuffle->forward(ctx, x); // [b*f, c, h*2, w*2]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // [b*c, f, h, w]
}
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.spatial_up", "x");

for (int i = 0; i < config.num_blocks_per_stage; ++i) {
Expand All @@ -243,6 +385,14 @@ namespace LTXVUpsampler {
sd::ggml_graph_cut::mark_graph_cut(x, "ltx_latent_upsampler.final", "x");
return x;
}

void load_fixed_tensors() {
if (!config.rational_resampler) {
return;
}
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
upsampler->load_fixed_tensors();
}
};

struct LatentUpsamplerRunner : public GGMLRunner {
Expand All @@ -265,20 +415,23 @@ namespace LTXVUpsampler {
}

const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
bool has_regular_spatial = has_tensor(tensor_storage_map, "upsampler.0.weight");
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
!has_tensor(tensor_storage_map, "upsampler.0.weight")) {
(!has_regular_spatial && !has_rational_spatial)) {
LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors");
return false;
}

LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample ||
config.rational_resampler) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d",
config.spatial_up_num < 1 || config.spatial_down_den < 1) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f",
config.dims,
config.spatial_upsample,
config.temporal_upsample,
config.rational_resampler);
config.rational_resampler,
config.spatial_scale);
return false;
}

Expand All @@ -291,15 +444,22 @@ namespace LTXVUpsampler {

std::map<std::string, ggml_tensor*> tensors;
model->get_param_tensors(tensors);
if (!model_loader.load_tensors(tensors, {}, n_threads)) {
std::set<std::string> ignore_tensors;
if (config.rational_resampler) {
ignore_tensors.insert("upsampler.blur_down.kernel");
}
if (!model_loader.load_tensors(tensors, ignore_tensors, n_threads)) {
LOG_ERROR("load LTX latent upsampler tensors failed");
return false;
}
model->load_fixed_tensors();

LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d",
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, rational=%d",
config.in_channels,
config.mid_channels,
config.num_blocks_per_stage);
config.num_blocks_per_stage,
config.spatial_scale,
config.rational_resampler);
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4798,7 +4798,7 @@ static sd::Tensor<float> upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx,
audio_latent = unpack_ltxav_audio_latent(packed_latent, audio_length, latent_channels);
}

LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> x2",
LOG_INFO("LTX latent spatial upscale: latent %dx%dx%dx%d -> model output",
(int)video_latent.shape()[0],
(int)video_latent.shape()[1],
(int)video_latent.shape()[2],
Expand Down
Loading