Skip to content

Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422

Open
Perseus14 wants to merge 1 commit into
mainfrom
ltx2-improvements
Open

Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422
Perseus14 wants to merge 1 commit into
mainfrom
ltx2-improvements

Conversation

@Perseus14

Copy link
Copy Markdown
Collaborator

Description

This PR is a comprehensive refactor and optimization sweep. It brings massive improvements to XLA compilation times, memory usage (HBM), and architectural hygiene by stripping out redundant compute, unifying duplicated logic, and optimizing JAX tracing.

🧹 Architectural Hygiene & Code Quality

  • Unified Transformer Block Application: Solved the "quadruple code duplication" hazard in LTX2VideoTransformer3DModel.__call__. The 4 separate block execution paths (scan vs. loop, perturbation vs. no-perturbation) have been consolidated using a new TransformerContext container and a single apply_block helper function.
  • Removed Dead Code in prepare_video_coords: Deleted the wasteful 5D latent_coords block in attention_ltx2.py that was computing an unused, wrongly-shaped tensor only to immediately overwrite it.
  • Simplified apply_split_rotary_emb: Cleaned up the convoluted reshape/broadcast logic for split RoPE. Removed the redundant expand_dims and squeeze operations, executing the rotation directly (first_x * cos - second_x * sin) to avoid allocating unnecessary intermediate 5D tensors.
  • Removed Redundant Guards: Dropped the unnecessary hasattr(self, "rope_type") check in LTX2Attention.
  • Missing PRNG Fallback: Added a missing max_logging.log warning when defaulting to a zero-seed jax.random.key(0) for noise generation.

⚡ XLA & JAX Compilation Optimizations

  • Dynamic Guidance Scales: Removed continuous hyperparameter floats (guidance_scale, stg_scale, audio_guidance_scale, etc.) from static_argnames in run_diffusion_loop(). Tweaking these generation scales will no longer trigger expensive 10-30 minute JAX recompilations!
  • Dynamic JAX Control Flow: Replaced the static Python if guidance_rescale > 0: check inside the compiled diffusion loop with jax.lax.cond. This enables the CFG rescaling logic to be fully dynamic, complementing the removal of the static scales and fixing formulation inconsistencies.
  • Standardized Scan Loop: Replaced nnx.scan with standard jax.lax.scan for the primary denoising timestep loop to ensure predictable compilation.
  • Fixed RuntimeProgramInputMismatch for scan_layers=False: Resolved an issue where XLA would fail during warmup compilation due to unrolled layer layout mismatches. Added explicit @jax.jit wrappers with jax.lax.with_sharding_constraint to enforce layout transpositions before crossing into run_diffusion_loop.

🧠 Memory (HBM) Optimizations

  • Direct Dtype Typecasting: Streamlined text encoder state extraction by evaluating target_dtype upfront and mapping it directly, avoiding a redundant double-casting pipeline that was passing through bfloat16.

@Perseus14 Perseus14 requested a review from entrpn as a code owner June 19, 2026 21:51
@github-actions

Copy link
Copy Markdown

@Perseus14 Perseus14 self-assigned this Jun 19, 2026
@Perseus14 Perseus14 requested a review from prishajain1 June 19, 2026 21:51
@Perseus14 Perseus14 changed the title LTX2.3 improvements and bug fixes Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations Jun 19, 2026
@Perseus14 Perseus14 force-pushed the ltx2-improvements branch 4 times, most recently from 3a13196 to 36e0f5d Compare June 20, 2026 09:59
@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request introduces a highly comprehensive set of optimizations and refactorings for the LTX2 model family, delivering significant enhancements to memory usage, model modularity, and JAX compilation behavior. By transforming static continuous hyperparameters into dynamic inputs, the sweep successfully eliminates expensive compilation timeouts while resolving underlying layout mismatches and unneeded computational overhead. The architectural quality of the codebase is substantially improved, particularly in the block consolidation and cleaner RoPE logic.

🔍 General Feedback

  • Exceptional Block Refactoring: The transition of LTX2VideoTransformer3DModel block application to a single unified apply_block helper and an immutable LTX2BlockContext is a masterclass in reducing hazard-prone code duplication.
  • RoPE & Coordinates Cleanup: Removing the duplicated and dead coordinate generation block in prepare_video_coords and simplifying split-RoPE dimension manipulation are excellent and highly effective hygiene improvements.
  • Robustness in Fallbacks & Dtypes: Incorporating proper warning logs for the noise generator fallback and upfront dtype casting for prompt embeddings ensures both runtime reliability and optimal HBM performance.

@@ -81,6 +81,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
std_text = jnp.std(noise_pred_text, axis=list(range(1, noise_pred_text.ndim)), keepdims=True)
std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 An epsilon of `1e-15` will underflow to `0.0` when the model is run in `float16` precision (the minimum subnormal of `float16` is `5.96e-8`). This underflow defeats the "Prevent division by zero" logic, resulting in potential `NaN` values. Using a standard deep learning epsilon like `1e-5` ensures robustness across both `float16` and `bfloat16`.
Suggested change
std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True)
std_cfg = jnp.maximum(std_cfg, 1e-5)

@@ -237,7 +237,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):

# Export videos
for i in range(len(videos)):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 If `config.model_name` is explicitly set to `None`/`null` in the configuration file, `getattr(config, "model_name", "ltx2")` will return `None`. Calling `.replace(".", "_")` on `None` will then raise an `AttributeError` at runtime. Adding a safe fallback `or "ltx2"` avoids potential crashes when handling explicitly null configuration entries.
Suggested change
for i in range(len(videos)):
model_name = getattr(config, "model_name", "ltx2") or "ltx2"
model_name_prefix = model_name.replace(".", "_")

Comment on lines 58 to 71
@@ -72,9 +69,6 @@ def __call__(
output_hidden_states=output_hidden_states,
)
return interop.jax_view(output)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using `unittest.mock.patch` in production model modules is a severe anti-pattern. `unittest.mock` is designed for testing and carries significant runtime and introspection overhead, which degrades prompt embedding generation performance in production training and inference runs. Reverting to a lightweight standard `try...finally` block completely eliminates this overhead.
Suggested change
import transformers.masking_utils
orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
try:
with default_env():
input_ids = interop.torch_view(input_ids)
attention_mask = interop.torch_view(attention_mask)
output = self.functional_call(
self._forward_inner,
params=self.params,
buffers=self.buffers,
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
)
return interop.jax_view(output)
finally:
transformers.masking_utils.sliding_window_overlay = orig_sliding_window_overlay

Comment on lines +1561 to +1572
@jax.jit
def enforce_layout_act(x):
return jax.lax.with_sharding_constraint(x, activation_axes)

@jax.jit
def enforce_layout_act_audio(x):
return jax.lax.with_sharding_constraint(x, activation_axes_audio)

latents_jax = enforce_layout_act(latents_jax)
audio_latents_jax = enforce_layout_act_audio(audio_latents_jax)
video_embeds_sharded = enforce_layout_act(video_embeds_sharded)
audio_embeds_sharded = enforce_layout_act_audio(audio_embeds_sharded)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 JAX Performance Bottleneck: Inner @jax.jit Compilation

Defining @jax.jit decorated functions inside the pipeline's runtime __call__ method creates new Python function objects on every invocation. Since JAX's compilation cache relies on the function object's identity, this will trigger an expensive recompilation on every single pipeline call, leading to huge latency and CPU/TPU compilation overhead during inference.

To prevent recompilation, define a static, cached helper at the module-level:

from functools import partial

@partial(jax.jit, static_argnums=(1,))
def _enforce_layout(x, axes):
  return jax.lax.with_sharding_constraint(x, axes)

And then use it directly inside the __call__ method.

Suggested change
@jax.jit
def enforce_layout_act(x):
return jax.lax.with_sharding_constraint(x, activation_axes)
@jax.jit
def enforce_layout_act_audio(x):
return jax.lax.with_sharding_constraint(x, activation_axes_audio)
latents_jax = enforce_layout_act(latents_jax)
audio_latents_jax = enforce_layout_act_audio(audio_latents_jax)
video_embeds_sharded = enforce_layout_act(video_embeds_sharded)
audio_embeds_sharded = enforce_layout_act_audio(audio_embeds_sharded)
latents_jax = _enforce_layout(latents_jax, activation_axes)
audio_latents_jax = _enforce_layout(audio_latents_jax, activation_axes_audio)
video_embeds_sharded = _enforce_layout(video_embeds_sharded, activation_axes)
audio_embeds_sharded = _enforce_layout(audio_embeds_sharded, activation_axes_audio)

import transformers.masking_utils

_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
from unittest import mock

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Clean Code: Avoid using unittest.mock in production code

Using test-suite frameworks (unittest.mock.patch) in production runtime code is considered an anti-pattern. It introduces performance overhead and can fail in strict production/minimal packaging environments where tests are excluded.

Instead, define a lightweight and standard custom context manager using contextlib under the imports:

Suggested change
from unittest import mock
import contextlib
import transformers.masking_utils
@contextlib.contextmanager
def patch_sliding_window_overlay():
orig = transformers.masking_utils.sliding_window_overlay
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
try:
yield
finally:
transformers.masking_utils.sliding_window_overlay = orig

# Dynamically patch transformers.masking_utils only during the duration of this call
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
try:
with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Clean Code: Use the custom context manager

Apply our custom context manager here to execute the forward pass under the monkeypatched sliding window overlay safely.

Suggested change
with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay):
with patch_sliding_window_overlay():

Comment on lines +1560 to +1570

@jax.jit
def enforce_layout_act(x):
return jax.lax.with_sharding_constraint(x, activation_axes)

@jax.jit
def enforce_layout_act_audio(x):
return jax.lax.with_sharding_constraint(x, activation_axes_audio)

latents_jax = enforce_layout_act(latents_jax)
audio_latents_jax = enforce_layout_act_audio(audio_latents_jax)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 **Performance/Compilation Bottleneck:** Defining `@jax.jit` wrappers nested inside the pipeline's `__call__` method will recreate the JIT-compiled function on every pipeline execution. Since JAX JIT caches binaries using the Python function object's ID, a newly created function on every invocation will bypass the compiler cache, forcing JAX to re-compile `enforce_layout_act` and `enforce_layout_act_audio` on every single generation.

To avoid this, define a single static jitted helper at the module level (e.g., using functools.partial(jax.jit, static_argnums=(1,))) and call it directly here.

Proposed module-level helper (to be added at the top/import level):

@partial(jax.jit, static_argnums=(1,))
def _enforce_layout(x, axes):
  return jax.lax.with_sharding_constraint(x, axes)

And update the call sites as follows:

Suggested change
@jax.jit
def enforce_layout_act(x):
return jax.lax.with_sharding_constraint(x, activation_axes)
@jax.jit
def enforce_layout_act_audio(x):
return jax.lax.with_sharding_constraint(x, activation_axes_audio)
latents_jax = enforce_layout_act(latents_jax)
audio_latents_jax = enforce_layout_act_audio(audio_latents_jax)
# Enforce layout constraint using module-level jitted helper to avoid recompilation on every pipeline call
latents_jax = _enforce_layout(latents_jax, activation_axes)
audio_latents_jax = _enforce_layout(audio_latents_jax, activation_axes_audio)
video_embeds_sharded = _enforce_layout(video_embeds_sharded, activation_axes)
audio_embeds_sharded = _enforce_layout(audio_embeds_sharded, activation_axes_audio)

Comment thread src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py Outdated
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@Perseus14 Perseus14 force-pushed the ltx2-improvements branch from 1bbd72c to 2bf389d Compare June 21, 2026 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant