diff --git a/examples/cli/README.md b/examples/cli/README.md index 2b23d1198..388c16d49 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -107,6 +107,8 @@ Generation Options: --extra-sample-args extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma + --extra-tiling-args extra VAE tiling args, key=value list. LTX video VAE supports + temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1) -H, --height image height, in pixel space (default: 512) -W, --width image width, in pixel space (default: 512) --steps number of sample steps (default: 20) diff --git a/examples/common/common.cpp b/examples/common/common.cpp index f32d0c6ff..519e8aae6 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -835,6 +835,10 @@ ArgOptions SDGenerationParams::get_options() { "--extra-sample-args", "extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma", &extra_sample_args}, + {"", + "--extra-tiling-args", + "extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)", + &extra_tiling_args}, }; options.int_options = { @@ -1780,6 +1784,9 @@ bool SDGenerationParams::from_json_str( if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) { vae_tiling_params.rel_size_y = tiling_json["rel_size_y"]; } + if (tiling_json.contains("extra_tiling_args") && tiling_json["extra_tiling_args"].is_string()) { + extra_tiling_args = tiling_json["extra_tiling_args"].get(); + } } if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) { @@ -2002,6 +2009,8 @@ bool SDGenerationParams::initialize_cache_params() { } bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) { + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); + if (high_noise_sample_params.sample_steps <= 0) { high_noise_sample_params.sample_steps = -1; } @@ -2188,6 +2197,7 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() { sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str(); high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str(); + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); sd_pm_params_t pm_params = { @@ -2261,6 +2271,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); sample_params.extra_sample_args = extra_sample_args.empty() ? nullptr : extra_sample_args.c_str(); high_noise_sample_params.extra_sample_args = high_noise_extra_sample_args.empty() ? nullptr : high_noise_extra_sample_args.c_str(); + vae_tiling_params.extra_tiling_args = extra_tiling_args.empty() ? nullptr : extra_tiling_args.c_str(); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); params.loras = lora_vec.empty() ? nullptr : lora_vec.data(); @@ -2386,7 +2397,8 @@ std::string SDGenerationParams::to_string() const { << vae_tiling_params.tile_size_y << ", " << vae_tiling_params.target_overlap << ", " << vae_tiling_params.rel_size_x << ", " - << vae_tiling_params.rel_size_y << " },\n" + << vae_tiling_params.rel_size_y << ", " + << "\"" << extra_tiling_args << "\" },\n" << "}"; return oss.str(); } @@ -2565,14 +2577,18 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, }; } - if (gen_params.vae_tiling_params.enabled) { + if (gen_params.vae_tiling_params.enabled || + gen_params.vae_tiling_params.temporal_tiling || + !gen_params.extra_tiling_args.empty()) { root["vae_tiling"] = { {"enabled", gen_params.vae_tiling_params.enabled}, + {"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling}, {"tile_size_x", gen_params.vae_tiling_params.tile_size_x}, {"tile_size_y", gen_params.vae_tiling_params.tile_size_y}, {"target_overlap", gen_params.vae_tiling_params.target_overlap}, {"rel_size_x", gen_params.vae_tiling_params.rel_size_x}, {"rel_size_y", gen_params.vae_tiling_params.rel_size_y}, + {"extra_tiling_args", gen_params.extra_tiling_args}, }; } diff --git a/examples/common/common.h b/examples/common/common.h index d526ca3a5..ca367f7ee 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -189,7 +189,8 @@ struct SDGenerationParams { int video_frames = 1; int fps = 16; float vace_strength = 1.f; - sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; + std::string extra_tiling_args; std::string pm_id_images_dir; std::string pm_id_embed_path; diff --git a/examples/server/README.md b/examples/server/README.md index 108785622..82f1c4778 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -209,6 +209,8 @@ Default Generation Options: --extra-sample-args extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma + --extra-tiling-args extra VAE tiling args, key=value list. LTX video VAE supports + temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1) -H, --height image height, in pixel space (default: 512) -W, --width image width, in pixel space (default: 512) --steps number of sample steps (default: 20) @@ -264,6 +266,7 @@ Default Generation Options: --disable-auto-resize-ref-image disable auto resize of ref images --disable-image-metadata do not embed generation metadata on image files --vae-tiling process vae in tiles to reduce memory usage + --temporal-tiling enable temporal tiling for LTX video VAE decode --hires enable highres fix -s, --seed RNG seed (default: 42, use random seed for < 0) --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, diff --git a/examples/server/api.md b/examples/server/api.md index 9abb74078..adcec26ff 100644 --- a/examples/server/api.md +++ b/examples/server/api.md @@ -504,11 +504,13 @@ Shared default fields used by both `img_gen` and `vid_gen`: | `sample_params.guidance.slg.scale` | `number` | | `vae_tiling_params` | `object` | | `vae_tiling_params.enabled` | `boolean` | +| `vae_tiling_params.temporal_tiling` | `boolean` | | `vae_tiling_params.tile_size_x` | `integer` | | `vae_tiling_params.tile_size_y` | `integer` | | `vae_tiling_params.target_overlap` | `number` | | `vae_tiling_params.rel_size_x` | `number` | | `vae_tiling_params.rel_size_y` | `number` | +| `vae_tiling_params.extra_tiling_args` | `string` | | `cache_mode` | `string` | | `cache_option` | `string` | | `scm_mask` | `string` | @@ -516,6 +518,8 @@ Shared default fields used by both `img_gen` and `vid_gen`: | `output_format` | `string` | | `output_compression` | `integer` | +`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`. + `img_gen`-specific default fields: | Field | Type | @@ -692,11 +696,13 @@ Example: "vae_tiling_params": { "enabled": false, + "temporal_tiling": false, "tile_size_x": 0, "tile_size_y": 0, "target_overlap": 0.5, "rel_size_x": 0.0, - "rel_size_y": 0.0 + "rel_size_y": 0.0, + "extra_tiling_args": "" }, "cache_mode": "disabled", @@ -804,6 +810,14 @@ Other native fields: | `hires.custom_sigmas` | `array` | | `hires.upscale_tile_size` | `integer` | | `vae_tiling_params` | `object` | +| `vae_tiling_params.enabled` | `boolean` | +| `vae_tiling_params.temporal_tiling` | `boolean` | +| `vae_tiling_params.tile_size_x` | `integer` | +| `vae_tiling_params.tile_size_y` | `integer` | +| `vae_tiling_params.target_overlap` | `number` | +| `vae_tiling_params.rel_size_x` | `number` | +| `vae_tiling_params.rel_size_y` | `number` | +| `vae_tiling_params.extra_tiling_args` | `string` | | `cache_mode` | `string` | | `cache_option` | `string` | | `scm_mask` | `string` | @@ -1012,11 +1026,13 @@ Example: "vae_tiling_params": { "enabled": false, + "temporal_tiling": false, "tile_size_x": 0, "tile_size_y": 0, "target_overlap": 0.5, "rel_size_x": 0.0, - "rel_size_y": 0.0 + "rel_size_y": 0.0, + "extra_tiling_args": "" }, "cache_mode": "disabled", @@ -1134,6 +1150,14 @@ Other native fields: | Field | Type | | --- | --- | | `vae_tiling_params` | `object` | +| `vae_tiling_params.enabled` | `boolean` | +| `vae_tiling_params.temporal_tiling` | `boolean` | +| `vae_tiling_params.tile_size_x` | `integer` | +| `vae_tiling_params.tile_size_y` | `integer` | +| `vae_tiling_params.target_overlap` | `number` | +| `vae_tiling_params.rel_size_x` | `number` | +| `vae_tiling_params.rel_size_y` | `number` | +| `vae_tiling_params.extra_tiling_args` | `string` | | `cache_mode` | `string` | | `cache_option` | `string` | | `scm_mask` | `string` | diff --git a/examples/server/routes_sdcpp.cpp b/examples/server/routes_sdcpp.cpp index 265ff5739..c60e9da65 100644 --- a/examples/server/routes_sdcpp.cpp +++ b/examples/server/routes_sdcpp.cpp @@ -56,11 +56,13 @@ static const char* capability_sample_method_name(enum sample_method_t sample_met static json make_vae_tiling_json(const sd_tiling_params_t& params) { return { {"enabled", params.enabled}, + {"temporal_tiling", params.temporal_tiling}, {"tile_size_x", params.tile_size_x}, {"tile_size_y", params.tile_size_y}, {"target_overlap", params.target_overlap}, {"rel_size_x", params.rel_size_x}, {"rel_size_y", params.rel_size_y}, + {"extra_tiling_args", params.extra_tiling_args ? params.extra_tiling_args : ""}, }; } diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 3ae44addf..f8b2c2f59 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -160,6 +160,7 @@ typedef struct { float target_overlap; float rel_size_x; float rel_size_y; + const char* extra_tiling_args; } sd_tiling_params_t; typedef struct { diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index cc149fef2..28df9f1bf 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -3172,7 +3172,7 @@ class Conv2d_grouped : public UnaryBlock { void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override { this->prefix = prefix; enum ggml_type wtype = GGML_TYPE_F16; - params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels); + params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels); if (bias) { enum ggml_type wtype = GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index 751995860..2f4086a8c 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -1,6 +1,7 @@ #ifndef __SD_LTX_VAE_HPP__ #define __SD_LTX_VAE_HPP__ +#include #include #include #include @@ -143,16 +144,25 @@ namespace LTXVAE { std::vector& feat_map, int& feat_idx, int chunk_idx, - bool causal = true) { + bool causal = true, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); const int pad = causal ? (time_kernel_size - 1) : (time_kernel_size - 1) / 2; ggml_tensor* prev = (feat_idx < (int)feat_map.size()) ? feat_map[feat_idx] : nullptr; + GGML_ASSERT(x->ne[2] >= temporal_pad); + + int end_idx = x->ne[2] - temporal_pad; + int start_idx = std::max(end_idx - pad, 0); + // Save a contiguous copy of the last `pad` frames so the large `x` // tensor is not kept alive across iterations by a dangling view. - if (feat_idx < (int)feat_map.size() && pad > 0 && x->ne[2] >= pad) { - auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - pad, x->ne[2]); + if (feat_idx < (int)feat_map.size() && end_idx - start_idx > 0) { + GGML_ASSERT(start_idx >= 0); + GGML_ASSERT(end_idx > 0); + + auto slice = ggml_ext_slice(ctx->ggml_ctx, x, 2, start_idx, end_idx); feat_map[feat_idx] = ggml_cont(ctx->ggml_ctx, slice); } feat_idx++; @@ -284,7 +294,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); @@ -311,14 +322,14 @@ namespace LTXVAE { h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv1->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); h = norm2->forward(ctx, h); if (timestep_conditioning) { h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2); } h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal); + h = conv2->forward(ctx, h, feat_map, feat_idx, chunk_idx, causal, temporal_pad); return ggml_add(ctx->ggml_ctx, h, x); } @@ -367,7 +378,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { ggml_tensor* timestep_embed = nullptr; if (timestep_conditioning) { GGML_ASSERT(timestep != nullptr); @@ -376,7 +388,7 @@ namespace LTXVAE { } for (int i = 0; i < num_layers; i++) { auto resnet = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); - x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx); + x = resnet->forward(ctx, x, timestep_embed, causal, feat_map, feat_idx, chunk_idx, temporal_pad); } return x; } @@ -437,7 +449,8 @@ namespace LTXVAE { bool causal, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int temporal_pad = 0) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); bool drop_first = (chunk_idx == 0) && (factor_t > 1); @@ -453,7 +466,7 @@ namespace LTXVAE { x_in = res; } - x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal); + x = conv->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal, temporal_pad); x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, drop_first); if (residual) { x = ggml_add(ctx->ggml_ctx, x, x_in); @@ -986,7 +999,8 @@ namespace LTXVAE { ggml_tensor* timestep, std::vector& feat_map, int& feat_idx, - int chunk_idx) { + int chunk_idx, + int& temporal_pad) { auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); auto conv_norm_out = std::dynamic_pointer_cast(blocks["conv_norm_out"]); auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); @@ -998,7 +1012,7 @@ namespace LTXVAE { } // conv_in with feat_map for left temporal context - x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_in->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); // up_blocks int block_idx = 0; @@ -1006,12 +1020,13 @@ namespace LTXVAE { auto mid_block = std::dynamic_pointer_cast(blocks["up_blocks." + std::to_string(block_idx)]); if (mid_block) { x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); } else { auto upsample = std::dynamic_pointer_cast( blocks["up_blocks." + std::to_string(block_idx)]); x = upsample->forward(ctx, x, causal_decoder, - feat_map, feat_idx, chunk_idx); + feat_map, feat_idx, chunk_idx, temporal_pad); + temporal_pad *= upsample->factor_t; } block_idx++; } @@ -1028,7 +1043,7 @@ namespace LTXVAE { x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift); } x = ggml_silu_inplace(ctx->ggml_ctx, x); - x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder); + x = conv_out->forward(ctx, x, feat_map, feat_idx, chunk_idx, causal_decoder, temporal_pad); return x; } }; @@ -1084,7 +1099,9 @@ namespace LTXVAE { // tensors can be freed by GGML before the next iteration starts. ggml_tensor* decode_tiled(GGMLRunnerContext* ctx, ggml_tensor* z, - ggml_tensor* timestep) { + ggml_tensor* timestep, + int temporal_window_size = 1, + int temporal_tile_overlap = 0) { auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); auto latents = processor->un_normalize(ctx, z); @@ -1099,13 +1116,43 @@ namespace LTXVAE { // 128 slots is generous enough for any supported decoder configuration. std::vector feat_map(128, nullptr); + // Ensure window size is at least 1 + int window = std::max(1, temporal_window_size); + int overlap = std::max(0, temporal_tile_overlap); + + if (overlap >= window) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows", + overlap, window); + overlap = window - 1; + } + LOG_DEBUG("Using temporal tiling: temporal_tile_frames = %d, temporal_tile_overlap = %d, total frames = %d, resulting in %d tiles", + window, + overlap, + (int)T, + (T + window - overlap - 1) / (window - overlap)); ggml_tensor* out = nullptr; - for (int i = 0; i < (int)T; i++) { + for (int i = 0; i < (int)T - overlap; i += (window - overlap)) { int feat_idx = 0; - auto z_i = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, i + 1); - auto out_i = decoder->forward_tiled_frame(ctx, z_i, timestep, - feat_map, feat_idx, i); - out = (out == nullptr) ? out_i : ggml_concat(ctx->ggml_ctx, out, out_i, 2); + + // Calculate the end index for the current temporal chunk + int end_i = std::min((int)T, i + window); + if (end_i >= (int)T) { + overlap = 0; // avoid overlap issues in the last chunk + } + + int chunk_overlap = overlap; // modified by forward_tiled_frame temporal inflation + + auto z_chunk = ggml_ext_slice(ctx->ggml_ctx, latents, 2, i, end_i); + + auto out_chunk = decoder->forward_tiled_frame(ctx, z_chunk, timestep, + feat_map, feat_idx, i, chunk_overlap); + + // discard overlap frames if it's not the final chunk + if (overlap > 0 && end_i < (int)T) { + out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap); + } + + out = (out == nullptr) ? out_chunk : ggml_concat(ctx->ggml_ctx, out, out_chunk, 2); } return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); @@ -1140,8 +1187,13 @@ namespace LTXVAE { } // namespace LTXVAE struct LTXVideoVAE : public VAE { + static constexpr int DEFAULT_TEMPORAL_TILE_FRAMES = 4; + static constexpr int DEFAULT_TEMPORAL_TILE_OVERLAP = 1; + bool decode_only; bool temporal_tiling_enabled = false; + int temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; + int temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; int ltx_vae_version; bool timestep_conditioning; int patch_size; @@ -1178,6 +1230,68 @@ struct LTXVideoVAE : public VAE { temporal_tiling_enabled = enabled; } + static std::string trim_tiling_arg(std::string value) { + const char* whitespace = " \t\r\n"; + size_t begin = value.find_first_not_of(whitespace); + if (begin == std::string::npos) { + return ""; + } + size_t end = value.find_last_not_of(whitespace); + return value.substr(begin, end - begin + 1); + } + + static bool parse_tiling_int(const std::string& value, int& parsed) { + try { + size_t consumed = 0; + parsed = std::stoi(value, &consumed); + return trim_tiling_arg(value.substr(consumed)).empty(); + } catch (...) { + return false; + } + } + + void set_tiling_params(const sd_tiling_params_t& params) override { + temporal_tiling_enabled = params.temporal_tiling; + temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; + temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; + + const char* extra_tiling_args = params.extra_tiling_args; + if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') { + return; + } + + std::string raw(extra_tiling_args); + size_t start = 0; + for (size_t pos = 0; pos <= raw.size(); ++pos) { + if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') { + continue; + } + + std::string token = trim_tiling_arg(raw.substr(start, pos - start)); + if (!token.empty()) { + size_t eq = token.find('='); + if (eq == std::string::npos) { + LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str()); + } else { + std::string key = trim_tiling_arg(token.substr(0, eq)); + std::string value = trim_tiling_arg(token.substr(eq + 1)); + int parsed = 0; + if (!parse_tiling_int(value, parsed)) { + LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str()); + } else if (key == "temporal_tile_frames") { + temporal_tile_frames = std::max(1, parsed); + } else if (key == "temporal_tile_overlap") { + temporal_tile_overlap = std::max(0, parsed); + } else { + LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str()); + } + } + } + + start = pos + 1; + } + } + void get_param_tensors(std::map& tensors, const std::string prefix) override { vae.get_param_tensors(tensors, prefix); } @@ -1195,7 +1309,10 @@ struct LTXVideoVAE : public VAE { bool use_tiled = decode_graph && temporal_tiling_enabled && z_tensor.dim() == 5 && z_tensor.shape()[2] > 1; if (use_tiled) { - out = vae.decode_tiled(&runner_ctx, z, timestep); + LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d", + temporal_tile_frames, + temporal_tile_overlap); + out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap); } else { out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 8d6806228..eb6845b46 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -151,7 +151,7 @@ class StableDiffusionGGML { bool apply_lora_immediately = false; std::string taesd_path; - sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0}; + sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr}; bool offload_params_to_cpu = false; float max_vram = 0.f; bool use_pmid = false; @@ -2679,7 +2679,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->batch_count = 1; sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; - sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; sd_cache_params_init(&sd_img_gen_params->cache); sd_hires_params_init(&sd_img_gen_params->hires); } @@ -2708,7 +2708,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "increase_ref_index: %s\n" "control_strength: %.2f\n" "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" - "VAE tiling: %s (temporal=%s)\n" + "VAE tiling: %s (temporal=%s, extra_tiling_args=%s)\n" "hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), @@ -2728,6 +2728,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), BOOL_STR(sd_img_gen_params->vae_tiling_params.temporal_tiling), + SAFE_STR(sd_img_gen_params->vae_tiling_params.extra_tiling_args), BOOL_STR(sd_img_gen_params->hires.enabled), sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler), SAFE_STR(sd_img_gen_params->hires.model_path), @@ -2765,7 +2766,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->fps = 16; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; - sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; sd_vid_gen_params->hires.enabled = false; sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT; sd_vid_gen_params->hires.scale = 2.f; diff --git a/src/tae.hpp b/src/tae.hpp index e091c163d..62f14c9ee 100644 --- a/src/tae.hpp +++ b/src/tae.hpp @@ -265,7 +265,7 @@ class WideMemBlock : public GGMLBlock { public: WideMemBlock(int channels, int out_channels) : has_skip_conv(channels != out_channels) { - int groups = std::max(1, out_channels / 64); + int groups = std::max(1, out_channels / 64); blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1})); @@ -479,12 +479,12 @@ class TinyVideoDecoder : public UnaryBlock { int index = 3; for (int i = 0; i < num_layers; i++) { for (int j = 0; j < num_blocks; j++) { - auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); - mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); if (is_wide) { auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); h = block->forward(ctx, h, mem); - } else{ + } else { auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); h = block->forward(ctx, h, mem); } @@ -683,8 +683,8 @@ struct TinyImageAutoEncoder : public VAE { struct TinyVideoAutoEncoder : public VAE { TAEHV taehv; bool decode_only = false; - bool is_wide = false; - + bool is_wide = false; + TinyVideoAutoEncoder(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map, @@ -699,7 +699,7 @@ struct TinyVideoAutoEncoder : public VAE { break; } } - taehv = TAEHV(decoder_only, version, is_wide); + taehv = TAEHV(decoder_only, version, is_wide); scale_input = false; taehv.init(params_ctx, tensor_storage_map, prefix); } diff --git a/src/vae.hpp b/src/vae.hpp index d7e0fdee1..cc4cd967f 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -167,6 +167,7 @@ struct VAE : public GGMLRunner { int64_t t0 = ggml_time_ms(); sd::Tensor input = x; sd::Tensor output; + set_tiling_params(tiling_params); if (tiling_params.enabled) { const int scale_factor = get_scale_factor(); @@ -216,6 +217,9 @@ struct VAE : public GGMLRunner { virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; virtual void set_temporal_tiling_enabled(bool enabled) { SD_UNUSED(enabled); }; + virtual void set_tiling_params(const sd_tiling_params_t& params) { + set_temporal_tiling_enabled(params.temporal_tiling); + }; }; struct FakeVAE : public VAE {