Skip to content

Add Apple Silicon (MPS) device support#184

Open
Mayank Gupta (techfreakworm) wants to merge 2 commits intoLightricks:mainfrom
techfreakworm:fix/mps-apple-silicon-compatibility
Open

Add Apple Silicon (MPS) device support#184
Mayank Gupta (techfreakworm) wants to merge 2 commits intoLightricks:mainfrom
techfreakworm:fix/mps-apple-silicon-compatibility

Conversation

@techfreakworm
Copy link
Copy Markdown

@techfreakworm Mayank Gupta (techfreakworm) commented Apr 5, 2026

Summary

Changes

File Fix
ltx-core/layer_streaming.py Guard torch.cuda.synchronize in LayerStreamingWrapper cleanup, add MPS fallback
ltx-core/loader/fuse_loras.py Add MPS device detection in _get_device()
ltx-core/text_encoders/gemma/encoders/base_encoder.py Skip CUDA device list in torch.random.fork_rng on non-CUDA (MPS does not support fork_rng device pinning)
ltx-core/components/guiders.py Compute guidance formula (CFG + STG + modality + rescaling) in float32 to fix STG precision on MPS
ltx-pipelines/utils/blocks.py Guard CUDA synchronize + host cache cleanup in _streaming_model, add MPS fallback
ltx-pipelines/utils/gpu_model.py Guard torch.cuda.synchronize in gpu_model context manager cleanup, add MPS fallback
ltx-pipelines/utils/helpers.py Add MPS device detection in get_device() and MPS cache cleanup in cleanup_memory()

STG Precision Fix Details

On MPS, Apple's MPSGraph compiles Metal kernels with fast math enabled, causing ~1e-7 errors per operation vs CUDA. These errors compound across 30 transformer blocks. The STG delta (cond - perturbed) and CFG rescaling (cond.std() / pred.std()) are particularly sensitive — accumulated fast-math errors dominate the bfloat16 signal, producing corrupted output.

Fix: upcast all guidance operands to float32 in MultiModalGuider.calculate() before the guidance formula, then cast back. Zero performance impact (operates on single tensors per denoising step).

Test plan

  • Tested on Apple M5 Max (128GB unified memory) with MPS backend
  • Two-stage A2V pipeline: 448×832 @ 25fps, 10 steps — completes successfully
  • Single-stage A2V pipeline: same resolution — completes successfully
  • STG=1.0 with stg_blocks=[28] produces correct output on MPS (was previously corrupted)
  • Lip sync quality improved with STG enabled vs STG=0
  • Model load/unload cycle (sequential streaming) works without memory leaks
  • No regressions on CUDA path (all changes are guarded or use float32 upcast which is safe everywhere)

Context

Running LTX-2.3 A2V pipelines on Apple Silicon crashes at multiple points because the codebase unconditionally calls torch.cuda.synchronize(), torch.cuda.empty_cache(), and passes CUDA devices to torch.random.fork_rng. Additionally, STG guidance produces corrupted output on MPS due to bfloat16 precision loss from Apple's fast-math Metal kernels.

The codebase currently assumes CUDA for device synchronization, memory
cleanup, RNG forking, and device detection. This causes crashes on Apple
Silicon Macs (MPS backend) where torch.cuda.* APIs are unavailable.

Changes:
- layer_streaming.py: Guard torch.cuda.synchronize with availability
  check, add MPS synchronize fallback
- fuse_loras.py: Add MPS device detection in _get_device()
- base_encoder.py: Skip CUDA-only device list in torch.random.fork_rng
  on non-CUDA devices (MPS does not support fork_rng device pinning)
- blocks.py: Guard CUDA synchronize and host cache cleanup with
  availability checks, add MPS synchronize fallback
- gpu_model.py: Guard torch.cuda.synchronize in model cleanup context
  manager, add MPS fallback
- helpers.py: Add MPS device detection in get_device() and MPS cache
  cleanup in cleanup_memory()

Tested on M4 Max (128GB) with both two-stage and single-stage A2V
pipelines generating 640x960 @ 25fps video from image+audio input.
Apple Silicon MPS compiles Metal kernels with fast math enabled
(pytorch/pytorch#84936), causing ~1e-7 errors per operation that
compound across 30 transformer blocks. The STG delta (cond - ptb)
and CFG rescaling (std ratio) are particularly sensitive — the
accumulated errors dominate the signal in bfloat16.

Fix: upcast all guidance operands to float32 before computing the
guidance formula (CFG + STG + modality deltas) and rescaling, then
cast back to the original dtype. This eliminates catastrophic
cancellation in the subtraction and precision loss in the std ratio.

Zero performance impact — the guidance calculation operates on
single tensors per step, not per block.

Tested on M5 Max 128GB with LTX-2.3 A2V pipeline:
- STG=1.0 stg_blocks=[28] now produces correct output on MPS
- Lip sync quality improved vs STG=0
- No regression on output quality
@darjeeling
Copy link
Copy Markdown

Michael Kupchick (@michaellightricks) can you approve this? or is there another way to contribute?
I'm waiting for this PR and preparing training related code to generate lora in my mac

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants