perf: optimize LTX2 inference latency and implement granular TPU profiling#389
perf: optimize LTX2 inference latency and implement granular TPU profiling#389
Conversation
c2eae2f to
6bd35bf
Compare
|
@mbohlool Could you add a table with the latency gain (single video and amortized throughput) of this change with the baseline (main)? Thanks! |
|
@Perseus14 change the PR to focus only on the timing and profiling part. I explored the performance tweaking later. PTAL. |
| spec = NamedSharding(self.mesh, P(*activation_axes)) | ||
| video_embeds_sharded = jax.device_put(video_embeds, spec) | ||
| audio_embeds_sharded = jax.device_put(audio_embeds, spec) | ||
| audio_embeds_sharded = audio_embeds |
There was a problem hiding this comment.
@prishajain1 Could you check whether this will cause issues?
| f" Load (checkpoint): {load_time:>7.1f}s\n" | ||
| f" Compile: {compile_time:>7.1f}s\n" | ||
| f" {'─' * 40}\n" | ||
| f" Inference: {generation_time:>7.1f}s\n" |
There was a problem hiding this comment.
Is it possible to print a component wise split here for quick analysis now that we are timing all the components?
Elisa has done something like this for the WAN pipelines here
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
Description:
This PR introduces better timing and profiling capabilities to the LTX2 generation pipeline to help identify performance bottlenecks.
Key Changes:
Detailed Timing: Added time.perf_counter() blocks and jax.block_until_ready() calls across the pipeline to accurately measure text encoding, connector passes, denoising steps, VAE decoding, and post-processing.
Multi-Pass Execution: Updated generate_ltx2.py to support a three-stage execution flow:
Warmup Pass: For JIT compilation.
Generation Pass: For actual output and standard timing.
Profiling Pass: (Optional) Captured via max_utils.Profiler for a subset of steps.
Enhanced Logging: Added a summary table for Load, Compile, and Inference times.
Config Updates: Added skip_first_n_steps_for_profiler and profiler_steps to the LTX2 configuration.
Memory Management: Explicitly deletes large tensors (out, videos, audios) before the profiling run to prevent OOM.