diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..ca0af8a1c711 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -159,13 +159,19 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return self.processor def set_attention_backend(self, backend: str): - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend.lower()) + _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py index 8b45c2148504..affb4df29c7d 100644 --- a/tests/models/test_attention_processor.py +++ b/tests/models/test_attention_processor.py @@ -1,6 +1,7 @@ import importlib.metadata import tempfile import unittest +import unittest.mock as mock import numpy as np import pytest @@ -8,6 +9,7 @@ from packaging import version from diffusers import DiffusionPipeline +from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor from ..testing_utils import torch_device @@ -133,3 +135,42 @@ def test_conversion_when_using_device_map(self): self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) + + +class AttentionModuleMixinSetBackendTests(unittest.TestCase): + """Regression tests for `AttentionModuleMixin.set_attention_backend` (issue #13284). + + When called on an individual submodule, the per-module setter must trigger the + same hub kernel download path that the model-level setter on `ModelMixin` does. + Otherwise hub-only backends (e.g. `sage_hub`) are silently configured without + their kernel ever being loaded and inference fails later inside + `dispatch_attention_fn`. + """ + + class _DummyProcessor: + _attention_backend = None + + class _DummyAttention(AttentionModuleMixin): + def __init__(self): + self.processor = AttentionModuleMixinSetBackendTests._DummyProcessor() + + def test_set_attention_backend_invokes_kernel_download_for_hub_backend(self): + module = self._DummyAttention() + + with ( + mock.patch("diffusers.models.attention_dispatch._check_attention_backend_requirements") as mocked_check, + mock.patch("diffusers.models.attention_dispatch._maybe_download_kernel_for_backend") as mocked_download, + ): + module.set_attention_backend("sage_hub") + + from diffusers.models.attention_dispatch import AttentionBackendName + + mocked_check.assert_called_once_with(AttentionBackendName.SAGE_HUB) + mocked_download.assert_called_once_with(AttentionBackendName.SAGE_HUB) + self.assertEqual(module.processor._attention_backend, AttentionBackendName.SAGE_HUB) + + def test_set_attention_backend_rejects_unknown_backend(self): + module = self._DummyAttention() + + with self.assertRaises(ValueError): + module.set_attention_backend("not_a_real_backend")