From 7c7bc4a8bca7fd2e2ce22c26bd3195e2332b75d2 Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 22 May 2026 00:15:09 +0800 Subject: [PATCH] feat: stream LTX VAE temporal tile decoding --- src/ltx_vae.hpp | 155 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 145 insertions(+), 10 deletions(-) diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index 2f4086a8c..b7a462fc5 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -1158,6 +1158,27 @@ namespace LTXVAE { return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); } + ggml_tensor* decode_tiled_chunk(GGMLRunnerContext* ctx, + ggml_tensor* z, + ggml_tensor* timestep, + std::vector& feat_map, + int chunk_idx, + int temporal_tile_overlap, + int& feat_idx) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto latents = processor->un_normalize(ctx, z); + + feat_idx = 0; + int chunk_overlap = temporal_tile_overlap; // modified by forward_tiled_frame temporal inflation + auto out_chunk = decoder->forward_tiled_frame(ctx, latents, timestep, + feat_map, feat_idx, chunk_idx, chunk_overlap); + if (chunk_overlap > 0) { + out_chunk = ggml_ext_slice(ctx->ggml_ctx, out_chunk, 2, 0, out_chunk->ne[2] - chunk_overlap); + } + return WAN::WanVAE::unpatchify(ctx->ggml_ctx, out_chunk, patch_size, 1); + } + ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) { GGML_ASSERT(!decode_only); @@ -1296,6 +1317,41 @@ struct LTXVideoVAE : public VAE { vae.get_param_tensors(tensors, prefix); } + struct TemporalTilePlan { + int frames = 1; + int overlap = 0; + int stride = 1; + int num_tiles = 1; + }; + + TemporalTilePlan resolve_temporal_tile_plan(int64_t total_frames) const { + TemporalTilePlan plan; + plan.frames = std::max(1, temporal_tile_frames); + plan.overlap = std::max(0, temporal_tile_overlap); + + if (plan.overlap >= plan.frames) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to temporal_tile_frames (%d), adjusting values to avoid empty decode windows", + plan.overlap, + plan.frames); + plan.overlap = plan.frames - 1; + } + if (total_frames > 1 && plan.overlap >= total_frames) { + LOG_WARN("temporal_tile_overlap (%d) is greater than or equal to total latent frames (%lld), adjusting values to decode at least one tile", + plan.overlap, + (long long)total_frames); + plan.overlap = static_cast(total_frames - 1); + } + + plan.stride = std::max(1, plan.frames - plan.overlap); + int64_t tiled_frames = std::max(1, total_frames - plan.overlap); + plan.num_tiles = total_frames > 0 ? static_cast((tiled_frames + plan.stride - 1) / plan.stride) : 0; + return plan; + } + + std::string temporal_feat_cache_name(size_t feat_idx) const { + return "ltx_vae_temporal_feat:" + std::to_string(feat_idx); + } + ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { ggml_cgraph* gf = new_graph_custom(20480); ggml_tensor* z = make_input(z_tensor); @@ -1306,21 +1362,97 @@ struct LTXVideoVAE : public VAE { auto runner_ctx = get_context(); ggml_tensor* out; - bool use_tiled = decode_graph && temporal_tiling_enabled && - z_tensor.dim() == 5 && z_tensor.shape()[2] > 1; - if (use_tiled) { - LOG_DEBUG("Using LTX VAE temporal tiling params: temporal_tile_frames=%d, temporal_tile_overlap=%d", - temporal_tile_frames, - temporal_tile_overlap); - out = vae.decode_tiled(&runner_ctx, z, timestep, temporal_tile_frames, temporal_tile_overlap); - } else { - out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); - } + out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); ggml_build_forward_expand(gf, out); return gf; } + ggml_cgraph* build_temporal_tile_graph(const sd::Tensor& z_chunk_tensor, + int chunk_idx, + int chunk_overlap) { + ggml_cgraph* gf = new_graph_custom(20480); + ggml_tensor* z = make_input(z_chunk_tensor); + ggml_tensor* timestep = nullptr; + if (timestep_conditioning) { + timestep = make_input(decode_timestep_tensor); + } + + std::vector feat_map(128, nullptr); + for (size_t feat_idx = 0; feat_idx < feat_map.size(); ++feat_idx) { + feat_map[feat_idx] = get_cache_tensor_by_name(temporal_feat_cache_name(feat_idx)); + } + + auto runner_ctx = get_context(); + int feat_count = 0; + ggml_tensor* out = vae.decode_tiled_chunk(&runner_ctx, + z, + timestep, + feat_map, + chunk_idx, + chunk_overlap, + feat_count); + + for (int feat_idx = 0; feat_idx < feat_count && feat_idx < static_cast(feat_map.size()); ++feat_idx) { + ggml_tensor* feat_cache = feat_map[static_cast(feat_idx)]; + if (feat_cache != nullptr) { + cache(temporal_feat_cache_name(static_cast(feat_idx)), feat_cache); + ggml_build_forward_expand(gf, feat_cache); + } + } + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor decode_temporal_tiled_streaming(const int n_threads, + const sd::Tensor& input, + size_t expected_dim) { + const int64_t total_frames = input.shape()[2]; + TemporalTilePlan plan = resolve_temporal_tile_plan(total_frames); + + LOG_DEBUG("Using streaming temporal tiling: temporal_tile_frames=%d, temporal_tile_overlap=%d, total latent frames=%lld, resulting in %d tiles", + plan.frames, + plan.overlap, + (long long)total_frames, + plan.num_tiles); + + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + + sd::Tensor output; + for (int64_t start = 0; start < total_frames - plan.overlap; start += plan.stride) { + const int64_t end = std::min(total_frames, start + plan.frames); + const int chunk_overlap = end < total_frames ? plan.overlap : 0; + auto z_chunk = sd::ops::slice(input, 2, start, end); + + LOG_DEBUG("LTX VAE temporal tile %lld/%d: latent frames [%lld, %lld), overlap=%d", + (long long)(start / plan.stride + 1), + plan.num_tiles, + (long long)start, + (long long)end, + chunk_overlap); + + auto get_graph = [&]() -> ggml_cgraph* { + return build_temporal_tile_graph(z_chunk, + static_cast(start), + chunk_overlap); + }; + auto chunk = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, true), + expected_dim); + if (chunk.empty()) { + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + return {}; + } + output = output.empty() ? std::move(chunk) : sd::ops::concat(output, chunk, 2); + } + + free_cache_ctx_and_buffer(); + cache_tensor_map.clear(); + return output; + } + ggml_cgraph* build_latent_statistics_graph(const sd::Tensor& z_tensor, bool normalize) { ggml_cgraph* gf = new_graph_custom(1024); ggml_tensor* z = make_input(z_tensor); @@ -1356,6 +1488,9 @@ struct LTXVideoVAE : public VAE { input = sd::ops::slice(input, 2, 0, cropped_t); } } + if (decode_graph && temporal_tiling_enabled && input.dim() == 5 && input.shape()[2] > 1) { + return decode_temporal_tiled_streaming(n_threads, input, expected_dim); + } auto get_graph = [&]() -> ggml_cgraph* { return build_graph(input, decode_graph); };