Add Apple Silicon (MPS) device support#184
Open
Mayank Gupta (techfreakworm) wants to merge 2 commits intoLightricks:mainfrom
Open
Add Apple Silicon (MPS) device support#184Mayank Gupta (techfreakworm) wants to merge 2 commits intoLightricks:mainfrom
Mayank Gupta (techfreakworm) wants to merge 2 commits intoLightricks:mainfrom
Conversation
The codebase currently assumes CUDA for device synchronization, memory cleanup, RNG forking, and device detection. This causes crashes on Apple Silicon Macs (MPS backend) where torch.cuda.* APIs are unavailable. Changes: - layer_streaming.py: Guard torch.cuda.synchronize with availability check, add MPS synchronize fallback - fuse_loras.py: Add MPS device detection in _get_device() - base_encoder.py: Skip CUDA-only device list in torch.random.fork_rng on non-CUDA devices (MPS does not support fork_rng device pinning) - blocks.py: Guard CUDA synchronize and host cache cleanup with availability checks, add MPS synchronize fallback - gpu_model.py: Guard torch.cuda.synchronize in model cleanup context manager, add MPS fallback - helpers.py: Add MPS device detection in get_device() and MPS cache cleanup in cleanup_memory() Tested on M4 Max (128GB) with both two-stage and single-stage A2V pipelines generating 640x960 @ 25fps video from image+audio input.
Apple Silicon MPS compiles Metal kernels with fast math enabled (pytorch/pytorch#84936), causing ~1e-7 errors per operation that compound across 30 transformer blocks. The STG delta (cond - ptb) and CFG rescaling (std ratio) are particularly sensitive — the accumulated errors dominate the signal in bfloat16. Fix: upcast all guidance operands to float32 before computing the guidance formula (CFG + STG + modality deltas) and rescaling, then cast back to the original dtype. This eliminates catastrophic cancellation in the subtraction and precision loss in the std ratio. Zero performance impact — the guidance calculation operates on single tensors per step, not per block. Tested on M5 Max 128GB with LTX-2.3 A2V pipeline: - STG=1.0 stg_blocks=[28] now produces correct output on MPS - Lip sync quality improved vs STG=0 - No regression on output quality
|
Michael Kupchick (@michaellightricks) can you approve this? or is there another way to contribute? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ltx-coreandltx-pipelinespackagestorch.cuda.*calls with availability checks and adds MPS fallbacksChanges
ltx-core/layer_streaming.pytorch.cuda.synchronizeinLayerStreamingWrappercleanup, add MPS fallbackltx-core/loader/fuse_loras.py_get_device()ltx-core/text_encoders/gemma/encoders/base_encoder.pytorch.random.fork_rngon non-CUDA (MPS does not supportfork_rngdevice pinning)ltx-core/components/guiders.pyltx-pipelines/utils/blocks.py_streaming_model, add MPS fallbackltx-pipelines/utils/gpu_model.pytorch.cuda.synchronizeingpu_modelcontext manager cleanup, add MPS fallbackltx-pipelines/utils/helpers.pyget_device()and MPS cache cleanup incleanup_memory()STG Precision Fix Details
On MPS, Apple's MPSGraph compiles Metal kernels with fast math enabled, causing ~1e-7 errors per operation vs CUDA. These errors compound across 30 transformer blocks. The STG delta
(cond - perturbed)and CFG rescaling(cond.std() / pred.std())are particularly sensitive — accumulated fast-math errors dominate the bfloat16 signal, producing corrupted output.Fix: upcast all guidance operands to float32 in
MultiModalGuider.calculate()before the guidance formula, then cast back. Zero performance impact (operates on single tensors per denoising step).Test plan
Context
Running LTX-2.3 A2V pipelines on Apple Silicon crashes at multiple points because the codebase unconditionally calls
torch.cuda.synchronize(),torch.cuda.empty_cache(), and passes CUDA devices totorch.random.fork_rng. Additionally, STG guidance produces corrupted output on MPS due to bfloat16 precision loss from Apple's fast-math Metal kernels.