Skip to content

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398

Open
mbohlool wants to merge 1 commit intomainfrom
text_encoder_tpu3
Open

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
mbohlool wants to merge 1 commit intomainfrom
text_encoder_tpu3

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

@mbohlool mbohlool commented May 4, 2026

Description

This PR transitions the LTX-2 pipeline's text encoding process to utilize TorchAX, bridging the Gemma-3 model natively into JAX and significantly optimizing memory usage to prevent TPU out-of-memory errors. Minor PyLint warnings across the pipeline were also resolved during the refactor.

Key changes include:

  • TorchAX Integration: Replaced the eager PyTorch-based text encoder execution with the JAX-native TorchaxGemma3TextEncoder. TPU sharding is now manually distributed across the batch dimension via jax.device_put to prevent Softmax OOM crashes.
  • VAE Memory Optimization: Updated the VAE decoding loop to conditionally apply sharding constraints. By disabling sequential slicing and dynamically adjusting batch sharding for batch_size > 2, HBM crashes during decoding are avoided.
  • Lint Cleanup: Addressed minor PyLint warnings in the pipeline and encoder wrapper to maintain code health.

Benchmarks

Performance comparison demonstrating latency improvements from TorchAX integration.

Configuration Text Encoding (CPU) Text Encoding (TorchAX) Text Encoding Impr. Total Time (TE on CPU) Total Time (TE on TorchAX) Generation Impr.
Batch Size 1 (Latency Optimized) 3.75s 2.52s 32.93% 13.19s 11.67s 11.47%
Batch Size 1 (w/ Upsampler) 3.57s 2.47s 30.72% 16.65s 15.61s 6.28%
Batch Size 8 (Throughput Optimized) 23.23s 5.86s 74.77% 80.14s 60.40s 24.64%
Batch Size 8 (w/ Upsampler) 23.36s 6.10s 73.87% 114.98s 86.74s 24.56%

@mbohlool mbohlool requested a review from entrpn as a code owner May 4, 2026 20:08
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch from 7338ec7 to 13e195e Compare May 4, 2026 20:25
@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch from 13e195e to 7707c3d Compare May 4, 2026 20:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant