Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def _apply_fp8_layerwise_casting(
if first_param is not None:
compute_dtype = first_param.dtype

from diffusers.models.modeling_utils import ModelMixin

if isinstance(model, ModelMixin):
model.enable_layerwise_casting(
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
)
elif isinstance(model, torch.nn.Module):
# We use our own hook-based path for every nn.Module — including diffusers ModelMixin —
# rather than `model.enable_layerwise_casting()`. Diffusers' LayerwiseCastingHook installs
# an instance-level `forward` attribute that captures the original `Linear.forward` in a
# closure. `ModelCache.put()` later runs `apply_custom_layers_to_model`, which constructs a
# new `CustomLinear` sharing the original Linear's `__dict__` — so the diffusers wrapper
# carries over and routes calls back to the captured original forward, silently bypassing
# `CustomLinear.forward` and its `cast_to_device` autocast. With partial loading (e.g. FLUX.2
# Klein 9B) some weights stay on CPU, the diffusers pre_forward only casts dtype, and
# `F.linear` then sees input on cuda and weight on cpu. Our `register_forward_pre_hook` /
# `register_forward_hook` path fires around `nn.Module._call_impl` without replacing
# `forward`, so `CustomLinear.forward` is still reached.
if isinstance(model, torch.nn.Module):
self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype)
else:
return model
Expand Down
46 changes: 46 additions & 0 deletions tests/backend/model_manager/load/test_load_default_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,49 @@ def tracked_patch_branch(self, input):
"FP8-wrapped forward did not reach CustomLinear.forward — LoRA/ControlLoRA patches "
"would be silently bypassed on FP8 checkpoint models."
)


def test_apply_fp8_layerwise_casting_uses_hook_path_for_model_mixin():
"""Regression test for the FLUX.2 Klein 9B partial-load device-mismatch crash.

Diffusers' `enable_layerwise_casting()` registers a `LayerwiseCastingHook` whose
`pre_forward` only casts dtype (not device) and whose hook system replaces
`Linear.forward` with a wrapper that calls the *original* `Linear.forward` captured
before the hook was installed. `ModelCache.put()` later wraps Linear as CustomLinear
sharing `__dict__`, so the diffusers wrapper is carried into the new CustomLinear and
routes calls to the captured original Linear.forward — bypassing
`CustomLinear.forward`'s `cast_to_device`. On partial load (some weights on CPU,
input on cuda), this raises a device-mismatch error.

The fix routes ModelMixin through `_apply_fp8_to_nn_module` (hook-based,
`forward`-preserving). This test asserts that path is taken even when the model
inherits from ModelMixin.
"""
from diffusers.models.modeling_utils import ModelMixin

class _FakeModelMixin(ModelMixin):
# ModelMixin requires a config_name class attribute and a config dict for serialization.
# We never serialize, so we only need to satisfy isinstance() checks.
config_name = "config.json"

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4, bias=False)

def forward(self, x):
return self.linear(x)

loader = _make_loader(device="cuda")
config = _make_config(ModelType.Main, fp8=True)

model = _FakeModelMixin()

with (
patch.object(ModelLoader, "_should_use_fp8", return_value=True),
patch.object(ModelLoader, "_apply_fp8_to_nn_module") as mock_to_nn,
patch.object(_FakeModelMixin, "enable_layerwise_casting") as mock_enable,
):
loader._apply_fp8_layerwise_casting(model, config)

mock_to_nn.assert_called_once()
mock_enable.assert_not_called()
Loading