Skip to content

perf: optimize LTX2 inference latency and implement granular TPU profiling#389

Open
mbohlool wants to merge 1 commit intomainfrom
mehdy_perf
Open

perf: optimize LTX2 inference latency and implement granular TPU profiling#389
mbohlool wants to merge 1 commit intomainfrom
mehdy_perf

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

@mbohlool mbohlool commented Apr 23, 2026

Description:
This PR introduces better timing and profiling capabilities to the LTX2 generation pipeline to help identify performance bottlenecks.

Key Changes:

Detailed Timing: Added time.perf_counter() blocks and jax.block_until_ready() calls across the pipeline to accurately measure text encoding, connector passes, denoising steps, VAE decoding, and post-processing.

Multi-Pass Execution: Updated generate_ltx2.py to support a three-stage execution flow:

Warmup Pass: For JIT compilation.

Generation Pass: For actual output and standard timing.

Profiling Pass: (Optional) Captured via max_utils.Profiler for a subset of steps.

Enhanced Logging: Added a summary table for Load, Compile, and Inference times.

Config Updates: Added skip_first_n_steps_for_profiler and profiler_steps to the LTX2 configuration.

Memory Management: Explicitly deletes large tensors (out, videos, audios) before the profiling run to prevent OOM.

@mbohlool mbohlool requested a review from entrpn as a code owner April 23, 2026 00:20
@github-actions
Copy link
Copy Markdown

@mbohlool mbohlool force-pushed the mehdy_perf branch 2 times, most recently from c2eae2f to 6bd35bf Compare April 23, 2026 00:51
@Perseus14
Copy link
Copy Markdown
Collaborator

@mbohlool Could you add a table with the latency gain (single video and amortized throughput) of this change with the baseline (main)?

Thanks!

@mbohlool
Copy link
Copy Markdown
Collaborator Author

mbohlool commented May 1, 2026

@Perseus14 change the PR to focus only on the timing and profiling part. I explored the performance tweaking later. PTAL.

spec = NamedSharding(self.mesh, P(*activation_axes))
video_embeds_sharded = jax.device_put(video_embeds, spec)
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
audio_embeds_sharded = audio_embeds
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prishajain1 Could you check whether this will cause issues?

f" Load (checkpoint): {load_time:>7.1f}s\n"
f" Compile: {compile_time:>7.1f}s\n"
f" {'─' * 40}\n"
f" Inference: {generation_time:>7.1f}s\n"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to print a component wise split here for quick analysis now that we are timing all the components?

Elisa has done something like this for the WAN pipelines here

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants