diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 542a6774d5a..040b55cb6ec 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -218,14 +218,18 @@ def _apply_fp8_layerwise_casting( if first_param is not None: compute_dtype = first_param.dtype - from diffusers.models.modeling_utils import ModelMixin - - if isinstance(model, ModelMixin): - model.enable_layerwise_casting( - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - ) - elif isinstance(model, torch.nn.Module): + # We use our own hook-based path for every nn.Module — including diffusers ModelMixin — + # rather than `model.enable_layerwise_casting()`. Diffusers' LayerwiseCastingHook installs + # an instance-level `forward` attribute that captures the original `Linear.forward` in a + # closure. `ModelCache.put()` later runs `apply_custom_layers_to_model`, which constructs a + # new `CustomLinear` sharing the original Linear's `__dict__` — so the diffusers wrapper + # carries over and routes calls back to the captured original forward, silently bypassing + # `CustomLinear.forward` and its `cast_to_device` autocast. With partial loading (e.g. FLUX.2 + # Klein 9B) some weights stay on CPU, the diffusers pre_forward only casts dtype, and + # `F.linear` then sees input on cuda and weight on cpu. Our `register_forward_pre_hook` / + # `register_forward_hook` path fires around `nn.Module._call_impl` without replacing + # `forward`, so `CustomLinear.forward` is still reached. + if isinstance(model, torch.nn.Module): self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype) else: return model diff --git a/tests/backend/model_manager/load/test_load_default_fp8.py b/tests/backend/model_manager/load/test_load_default_fp8.py index 4a4491c6a07..97a1cfb2e30 100644 --- a/tests/backend/model_manager/load/test_load_default_fp8.py +++ b/tests/backend/model_manager/load/test_load_default_fp8.py @@ -323,3 +323,49 @@ def tracked_patch_branch(self, input): "FP8-wrapped forward did not reach CustomLinear.forward — LoRA/ControlLoRA patches " "would be silently bypassed on FP8 checkpoint models." ) + + +def test_apply_fp8_layerwise_casting_uses_hook_path_for_model_mixin(): + """Regression test for the FLUX.2 Klein 9B partial-load device-mismatch crash. + + Diffusers' `enable_layerwise_casting()` registers a `LayerwiseCastingHook` whose + `pre_forward` only casts dtype (not device) and whose hook system replaces + `Linear.forward` with a wrapper that calls the *original* `Linear.forward` captured + before the hook was installed. `ModelCache.put()` later wraps Linear as CustomLinear + sharing `__dict__`, so the diffusers wrapper is carried into the new CustomLinear and + routes calls to the captured original Linear.forward — bypassing + `CustomLinear.forward`'s `cast_to_device`. On partial load (some weights on CPU, + input on cuda), this raises a device-mismatch error. + + The fix routes ModelMixin through `_apply_fp8_to_nn_module` (hook-based, + `forward`-preserving). This test asserts that path is taken even when the model + inherits from ModelMixin. + """ + from diffusers.models.modeling_utils import ModelMixin + + class _FakeModelMixin(ModelMixin): + # ModelMixin requires a config_name class attribute and a config dict for serialization. + # We never serialize, so we only need to satisfy isinstance() checks. + config_name = "config.json" + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x): + return self.linear(x) + + loader = _make_loader(device="cuda") + config = _make_config(ModelType.Main, fp8=True) + + model = _FakeModelMixin() + + with ( + patch.object(ModelLoader, "_should_use_fp8", return_value=True), + patch.object(ModelLoader, "_apply_fp8_to_nn_module") as mock_to_nn, + patch.object(_FakeModelMixin, "enable_layerwise_casting") as mock_enable, + ): + loader._apply_fp8_layerwise_casting(model, config) + + mock_to_nn.assert_called_once() + mock_enable.assert_not_called()