Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422
Perf(LTX2): Comprehensive XLA, Memory, and Transformer Code Quality Optimizations#422Perseus14 wants to merge 1 commit into
Conversation
3a13196 to
36e0f5d
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
36e0f5d to
1bbd72c
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request introduces a highly comprehensive set of optimizations and refactorings for the LTX2 model family, delivering significant enhancements to memory usage, model modularity, and JAX compilation behavior. By transforming static continuous hyperparameters into dynamic inputs, the sweep successfully eliminates expensive compilation timeouts while resolving underlying layout mismatches and unneeded computational overhead. The architectural quality of the codebase is substantially improved, particularly in the block consolidation and cleaner RoPE logic.
🔍 General Feedback
- Exceptional Block Refactoring: The transition of
LTX2VideoTransformer3DModelblock application to a single unifiedapply_blockhelper and an immutableLTX2BlockContextis a masterclass in reducing hazard-prone code duplication. - RoPE & Coordinates Cleanup: Removing the duplicated and dead coordinate generation block in
prepare_video_coordsand simplifying split-RoPE dimension manipulation are excellent and highly effective hygiene improvements. - Robustness in Fallbacks & Dtypes: Incorporating proper warning logs for the noise generator fallback and upfront dtype casting for prompt embeddings ensures both runtime reliability and optimal HBM performance.
| @@ -81,6 +81,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |||
| """ | |||
| std_text = jnp.std(noise_pred_text, axis=list(range(1, noise_pred_text.ndim)), keepdims=True) | |||
| std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True) | |||
There was a problem hiding this comment.
| std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True) | |
| std_cfg = jnp.maximum(std_cfg, 1e-5) |
| @@ -237,7 +237,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): | |||
|
|
|||
| # Export videos | |||
| for i in range(len(videos)): | |||
There was a problem hiding this comment.
| for i in range(len(videos)): | |
| model_name = getattr(config, "model_name", "ltx2") or "ltx2" | |
| model_name_prefix = model_name.replace(".", "_") |
| @@ -72,9 +69,6 @@ def __call__( | |||
| output_hidden_states=output_hidden_states, | |||
| ) | |||
| return interop.jax_view(output) | |||
There was a problem hiding this comment.
| import transformers.masking_utils | |
| orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay | |
| transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay | |
| try: | |
| with default_env(): | |
| input_ids = interop.torch_view(input_ids) | |
| attention_mask = interop.torch_view(attention_mask) | |
| output = self.functional_call( | |
| self._forward_inner, | |
| params=self.params, | |
| buffers=self.buffers, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| return interop.jax_view(output) | |
| finally: | |
| transformers.masking_utils.sliding_window_overlay = orig_sliding_window_overlay |
| @jax.jit | ||
| def enforce_layout_act(x): | ||
| return jax.lax.with_sharding_constraint(x, activation_axes) | ||
|
|
||
| @jax.jit | ||
| def enforce_layout_act_audio(x): | ||
| return jax.lax.with_sharding_constraint(x, activation_axes_audio) | ||
|
|
||
| latents_jax = enforce_layout_act(latents_jax) | ||
| audio_latents_jax = enforce_layout_act_audio(audio_latents_jax) | ||
| video_embeds_sharded = enforce_layout_act(video_embeds_sharded) | ||
| audio_embeds_sharded = enforce_layout_act_audio(audio_embeds_sharded) |
There was a problem hiding this comment.
🟠 JAX Performance Bottleneck: Inner @jax.jit Compilation
Defining @jax.jit decorated functions inside the pipeline's runtime __call__ method creates new Python function objects on every invocation. Since JAX's compilation cache relies on the function object's identity, this will trigger an expensive recompilation on every single pipeline call, leading to huge latency and CPU/TPU compilation overhead during inference.
To prevent recompilation, define a static, cached helper at the module-level:
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def _enforce_layout(x, axes):
return jax.lax.with_sharding_constraint(x, axes)And then use it directly inside the __call__ method.
| @jax.jit | |
| def enforce_layout_act(x): | |
| return jax.lax.with_sharding_constraint(x, activation_axes) | |
| @jax.jit | |
| def enforce_layout_act_audio(x): | |
| return jax.lax.with_sharding_constraint(x, activation_axes_audio) | |
| latents_jax = enforce_layout_act(latents_jax) | |
| audio_latents_jax = enforce_layout_act_audio(audio_latents_jax) | |
| video_embeds_sharded = enforce_layout_act(video_embeds_sharded) | |
| audio_embeds_sharded = enforce_layout_act_audio(audio_embeds_sharded) | |
| latents_jax = _enforce_layout(latents_jax, activation_axes) | |
| audio_latents_jax = _enforce_layout(audio_latents_jax, activation_axes_audio) | |
| video_embeds_sharded = _enforce_layout(video_embeds_sharded, activation_axes) | |
| audio_embeds_sharded = _enforce_layout(audio_embeds_sharded, activation_axes_audio) |
| import transformers.masking_utils | ||
|
|
||
| _orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay | ||
| from unittest import mock |
There was a problem hiding this comment.
🟡 Clean Code: Avoid using unittest.mock in production code
Using test-suite frameworks (unittest.mock.patch) in production runtime code is considered an anti-pattern. It introduces performance overhead and can fail in strict production/minimal packaging environments where tests are excluded.
Instead, define a lightweight and standard custom context manager using contextlib under the imports:
| from unittest import mock | |
| import contextlib | |
| import transformers.masking_utils | |
| @contextlib.contextmanager | |
| def patch_sliding_window_overlay(): | |
| orig = transformers.masking_utils.sliding_window_overlay | |
| transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay | |
| try: | |
| yield | |
| finally: | |
| transformers.masking_utils.sliding_window_overlay = orig |
| # Dynamically patch transformers.masking_utils only during the duration of this call | ||
| transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay | ||
| try: | ||
| with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay): |
There was a problem hiding this comment.
🟡 Clean Code: Use the custom context manager
Apply our custom context manager here to execute the forward pass under the monkeypatched sliding window overlay safely.
| with mock.patch("transformers.masking_utils.sliding_window_overlay", _patched_sliding_window_overlay): | |
| with patch_sliding_window_overlay(): |
|
|
||
| @jax.jit | ||
| def enforce_layout_act(x): | ||
| return jax.lax.with_sharding_constraint(x, activation_axes) | ||
|
|
||
| @jax.jit | ||
| def enforce_layout_act_audio(x): | ||
| return jax.lax.with_sharding_constraint(x, activation_axes_audio) | ||
|
|
||
| latents_jax = enforce_layout_act(latents_jax) | ||
| audio_latents_jax = enforce_layout_act_audio(audio_latents_jax) |
There was a problem hiding this comment.
To avoid this, define a single static jitted helper at the module level (e.g., using functools.partial(jax.jit, static_argnums=(1,))) and call it directly here.
Proposed module-level helper (to be added at the top/import level):
@partial(jax.jit, static_argnums=(1,))
def _enforce_layout(x, axes):
return jax.lax.with_sharding_constraint(x, axes)And update the call sites as follows:
| @jax.jit | |
| def enforce_layout_act(x): | |
| return jax.lax.with_sharding_constraint(x, activation_axes) | |
| @jax.jit | |
| def enforce_layout_act_audio(x): | |
| return jax.lax.with_sharding_constraint(x, activation_axes_audio) | |
| latents_jax = enforce_layout_act(latents_jax) | |
| audio_latents_jax = enforce_layout_act_audio(audio_latents_jax) | |
| # Enforce layout constraint using module-level jitted helper to avoid recompilation on every pipeline call | |
| latents_jax = _enforce_layout(latents_jax, activation_axes) | |
| audio_latents_jax = _enforce_layout(audio_latents_jax, activation_axes_audio) | |
| video_embeds_sharded = _enforce_layout(video_embeds_sharded, activation_axes) | |
| audio_embeds_sharded = _enforce_layout(audio_embeds_sharded, activation_axes_audio) |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
1bbd72c to
2bf389d
Compare
Description
This PR is a comprehensive refactor and optimization sweep. It brings massive improvements to XLA compilation times, memory usage (HBM), and architectural hygiene by stripping out redundant compute, unifying duplicated logic, and optimizing JAX tracing.
🧹 Architectural Hygiene & Code Quality
LTX2VideoTransformer3DModel.__call__. The 4 separate block execution paths (scan vs. loop, perturbation vs. no-perturbation) have been consolidated using a newTransformerContextcontainer and a singleapply_blockhelper function.prepare_video_coords: Deleted the wasteful 5Dlatent_coordsblock inattention_ltx2.pythat was computing an unused, wrongly-shaped tensor only to immediately overwrite it.apply_split_rotary_emb: Cleaned up the convoluted reshape/broadcast logic for split RoPE. Removed the redundantexpand_dimsandsqueezeoperations, executing the rotation directly (first_x * cos - second_x * sin) to avoid allocating unnecessary intermediate 5D tensors.hasattr(self, "rope_type")check inLTX2Attention.max_logging.logwarning when defaulting to a zero-seedjax.random.key(0)for noise generation.⚡ XLA & JAX Compilation Optimizations
guidance_scale,stg_scale,audio_guidance_scale, etc.) fromstatic_argnamesinrun_diffusion_loop(). Tweaking these generation scales will no longer trigger expensive 10-30 minute JAX recompilations!if guidance_rescale > 0:check inside the compiled diffusion loop withjax.lax.cond. This enables the CFG rescaling logic to be fully dynamic, complementing the removal of the static scales and fixing formulation inconsistencies.nnx.scanwith standardjax.lax.scanfor the primary denoising timestep loop to ensure predictable compilation.RuntimeProgramInputMismatchforscan_layers=False: Resolved an issue where XLA would fail during warmup compilation due to unrolled layer layout mismatches. Added explicit@jax.jitwrappers withjax.lax.with_sharding_constraintto enforce layout transpositions before crossing intorun_diffusion_loop.🧠 Memory (HBM) Optimizations
target_dtypeupfront and mapping it directly, avoiding a redundant double-casting pipeline that was passing throughbfloat16.