fix(fp8): route ModelMixin through hook-based path to survive partialload#9231
Open
Pfannkuchensack wants to merge 1 commit into
Open
fix(fp8): route ModelMixin through hook-based path to survive partialload#9231Pfannkuchensack wants to merge 1 commit into
Pfannkuchensack wants to merge 1 commit into
Conversation
… load Diffusers' enable_layerwise_casting() installs a LayerwiseCastingHook that (a) only casts dtype in pre_forward, not device, and (b) replaces Linear.forward with an instance-level wrapper that calls the original Linear.forward captured before the hook was installed. ModelCache.put() later runs apply_custom_layers_to_model, which constructs a new CustomLinear sharing the original Linear's __dict__ — so the diffusers wrapper carries over and routes calls to the captured original forward, silently bypassing CustomLinear.forward and its cast_to_device autocast. With partial loading (e.g. FLUX.2 Klein 9B on a constrained GPU), some Linear weights stay on CPU. The diffusers pre_forward only casts dtype, so F.linear then sees input on cuda:0 and weight on cpu and raises "Expected all tensors to be on the same device". Route every nn.Module — including ModelMixin — through _apply_fp8_to_nn_module, which uses register_forward_pre_hook / register_forward_hook(always_call=True). nn.Module._call_impl dispatches these around forward without replacing it, so CustomLinear.forward is still reached and cast_to_device moves the weight to the input device. Lose diffusers' _disable_peft_input_autocast in the process, which is irrelevant — InvokeAI patches LoRAs through CustomLinear's _patches_and_weights, not PEFT BaseTunerLayer. Add regression test that asserts the ModelMixin branch calls _apply_fp8_to_nn_module and not enable_layerwise_casting.
Collaborator
|
It looks like this needs a documentation update.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Diffusers'
enable_layerwise_casting()installs aLayerwiseCastingHookthat (a) only casts dtype inpre_forward, not device, and (b) replacesLinear.forwardwith an instance-level wrapper that calls the originalLinear.forwardcaptured before the hook was installed.ModelCache.put()later runsapply_custom_layers_to_model, which constructs a newCustomLinearsharing the original Linear's__dict__— so the diffusers wrapper carries over and routes calls to the captured original forward, silently bypassingCustomLinear.forwardand itscast_to_deviceautocast.With partial loading (e.g. FLUX.2 Klein 9B on a constrained GPU), some Linear weights stay on CPU. The diffusers
pre_forwardonly casts dtype, soF.linearthen sees input oncuda:0and weight oncpuand raises"Expected all tensors to be on the same device".Route every
nn.Module— includingModelMixin— through_apply_fp8_to_nn_module, which usesregister_forward_pre_hook/register_forward_hook(always_call=True).nn.Module._call_impldispatches these aroundforwardwithout replacing it, soCustomLinear.forwardis still reached andcast_to_devicemoves the weight to the input device. Lose diffusers'_disable_peft_input_autocastin the process, which is irrelevant — InvokeAI patches LoRAs throughCustomLinear's_patches_and_weights, not PEFTBaseTunerLayer.Add regression test that asserts the
ModelMixinbranch calls_apply_fp8_to_nn_moduleand notenable_layerwise_casting.Related Issues / Discussions
https://discord.com/channels/1020123559063990373/1508132779164962850
Reported on Discord: FP8 storage on FLUX.2 Klein 9B crashes with
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)at
Flux2FeedForward.linear_outinsideff_context.Stack trace points to the diffusers
LayerwiseCastingHookwrapper (diffusers/hooks/hooks.py:189→torch/nn/modules/linear.py:125).QA Instructions
Repro (pre-fix):
RuntimeErroratlinear_outin the first transformer block.Regression coverage:
pytest tests/backend/model_manager/load/test_load_default_fp8.py— 13 tests, all green. The newtest_apply_fp8_layerwise_casting_uses_hook_path_for_model_mixinfails on the pre-fix code (it would observeenable_layerwise_castingbeing called) and passes on the fix.Also verify the existing FP8 paths still work:
_apply_fp8_to_nn_module, behavior unchanged.CustomLinear._autocast_forward_with_patchesbranch should fire (covered bytest_wrap_forward_reaches_custom_linear_after_apply_custom_layers).Merge Plan
Straight merge. No DB or schema changes. No frontend changes. Cache invalidation on the FP8 toggle already exists (
drop_modelon settings change), so a user toggling FP8 off/on after pulling this PR will get the fixed loader on next load.Checklist
What's Newcopy (if doing a release after this PR)