From 1deb0b4cba2fed3dc2d9c986149d369a91c3277c Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sun, 24 May 2026 18:54:32 +0200 Subject: [PATCH] fix(fp8): route ModelMixin through hook-based path to survive partial load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Diffusers' enable_layerwise_casting() installs a LayerwiseCastingHook that (a) only casts dtype in pre_forward, not device, and (b) replaces Linear.forward with an instance-level wrapper that calls the original Linear.forward captured before the hook was installed. ModelCache.put() later runs apply_custom_layers_to_model, which constructs a new CustomLinear sharing the original Linear's __dict__ — so the diffusers wrapper carries over and routes calls to the captured original forward, silently bypassing CustomLinear.forward and its cast_to_device autocast. With partial loading (e.g. FLUX.2 Klein 9B on a constrained GPU), some Linear weights stay on CPU. The diffusers pre_forward only casts dtype, so F.linear then sees input on cuda:0 and weight on cpu and raises "Expected all tensors to be on the same device". Route every nn.Module — including ModelMixin — through _apply_fp8_to_nn_module, which uses register_forward_pre_hook / register_forward_hook(always_call=True). nn.Module._call_impl dispatches these around forward without replacing it, so CustomLinear.forward is still reached and cast_to_device moves the weight to the input device. Lose diffusers' _disable_peft_input_autocast in the process, which is irrelevant — InvokeAI patches LoRAs through CustomLinear's _patches_and_weights, not PEFT BaseTunerLayer. Add regression test that asserts the ModelMixin branch calls _apply_fp8_to_nn_module and not enable_layerwise_casting. --- .../model_manager/load/load_default.py | 20 ++++---- .../load/test_load_default_fp8.py | 46 +++++++++++++++++++ 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 542a6774d5a..040b55cb6ec 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -218,14 +218,18 @@ def _apply_fp8_layerwise_casting( if first_param is not None: compute_dtype = first_param.dtype - from diffusers.models.modeling_utils import ModelMixin - - if isinstance(model, ModelMixin): - model.enable_layerwise_casting( - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - ) - elif isinstance(model, torch.nn.Module): + # We use our own hook-based path for every nn.Module — including diffusers ModelMixin — + # rather than `model.enable_layerwise_casting()`. Diffusers' LayerwiseCastingHook installs + # an instance-level `forward` attribute that captures the original `Linear.forward` in a + # closure. `ModelCache.put()` later runs `apply_custom_layers_to_model`, which constructs a + # new `CustomLinear` sharing the original Linear's `__dict__` — so the diffusers wrapper + # carries over and routes calls back to the captured original forward, silently bypassing + # `CustomLinear.forward` and its `cast_to_device` autocast. With partial loading (e.g. FLUX.2 + # Klein 9B) some weights stay on CPU, the diffusers pre_forward only casts dtype, and + # `F.linear` then sees input on cuda and weight on cpu. Our `register_forward_pre_hook` / + # `register_forward_hook` path fires around `nn.Module._call_impl` without replacing + # `forward`, so `CustomLinear.forward` is still reached. + if isinstance(model, torch.nn.Module): self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype) else: return model diff --git a/tests/backend/model_manager/load/test_load_default_fp8.py b/tests/backend/model_manager/load/test_load_default_fp8.py index 4a4491c6a07..97a1cfb2e30 100644 --- a/tests/backend/model_manager/load/test_load_default_fp8.py +++ b/tests/backend/model_manager/load/test_load_default_fp8.py @@ -323,3 +323,49 @@ def tracked_patch_branch(self, input): "FP8-wrapped forward did not reach CustomLinear.forward — LoRA/ControlLoRA patches " "would be silently bypassed on FP8 checkpoint models." ) + + +def test_apply_fp8_layerwise_casting_uses_hook_path_for_model_mixin(): + """Regression test for the FLUX.2 Klein 9B partial-load device-mismatch crash. + + Diffusers' `enable_layerwise_casting()` registers a `LayerwiseCastingHook` whose + `pre_forward` only casts dtype (not device) and whose hook system replaces + `Linear.forward` with a wrapper that calls the *original* `Linear.forward` captured + before the hook was installed. `ModelCache.put()` later wraps Linear as CustomLinear + sharing `__dict__`, so the diffusers wrapper is carried into the new CustomLinear and + routes calls to the captured original Linear.forward — bypassing + `CustomLinear.forward`'s `cast_to_device`. On partial load (some weights on CPU, + input on cuda), this raises a device-mismatch error. + + The fix routes ModelMixin through `_apply_fp8_to_nn_module` (hook-based, + `forward`-preserving). This test asserts that path is taken even when the model + inherits from ModelMixin. + """ + from diffusers.models.modeling_utils import ModelMixin + + class _FakeModelMixin(ModelMixin): + # ModelMixin requires a config_name class attribute and a config dict for serialization. + # We never serialize, so we only need to satisfy isinstance() checks. + config_name = "config.json" + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x): + return self.linear(x) + + loader = _make_loader(device="cuda") + config = _make_config(ModelType.Main, fp8=True) + + model = _FakeModelMixin() + + with ( + patch.object(ModelLoader, "_should_use_fp8", return_value=True), + patch.object(ModelLoader, "_apply_fp8_to_nn_module") as mock_to_nn, + patch.object(_FakeModelMixin, "enable_layerwise_casting") as mock_enable, + ): + loader._apply_fp8_layerwise_casting(model, config) + + mock_to_nn.assert_called_once() + mock_enable.assert_not_called()