From 3915aa73cdcbe9f8ac0e4e4a7be534724a4fd9f3 Mon Sep 17 00:00:00 2001 From: Aditya Singh Date: Thu, 21 May 2026 04:19:45 -0700 Subject: [PATCH] Fix AttentionModuleMixin.set_attention_backend skipping hub kernel download The per-submodule set_attention_backend on AttentionModuleMixin only validated the backend name and updated self.processor._attention_backend, but skipped the _check_attention_backend_requirements and _maybe_download_kernel_for_backend calls that ModelMixin.set_attention_backend performs. As a result, hub-based backends like sage_hub were silently set without the kernel ever being downloaded, and inference failed later with TypeError: 'NoneType' object is not callable inside dispatch_attention_fn. This mirrors the requirement-check and kernel-download path from ModelMixin into the submodule-level setter so per-block backend overrides work for hub kernels. Fixes #13284 --- src/diffusers/models/attention.py | 8 ++++- tests/models/test_attention_processor.py | 41 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..ca0af8a1c711 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -159,13 +159,19 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return self.processor def set_attention_backend(self, backend: str): - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend.lower()) + _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py index 8b45c2148504..affb4df29c7d 100644 --- a/tests/models/test_attention_processor.py +++ b/tests/models/test_attention_processor.py @@ -1,6 +1,7 @@ import importlib.metadata import tempfile import unittest +import unittest.mock as mock import numpy as np import pytest @@ -8,6 +9,7 @@ from packaging import version from diffusers import DiffusionPipeline +from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor from ..testing_utils import torch_device @@ -133,3 +135,42 @@ def test_conversion_when_using_device_map(self): self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3)) self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) + + +class AttentionModuleMixinSetBackendTests(unittest.TestCase): + """Regression tests for `AttentionModuleMixin.set_attention_backend` (issue #13284). + + When called on an individual submodule, the per-module setter must trigger the + same hub kernel download path that the model-level setter on `ModelMixin` does. + Otherwise hub-only backends (e.g. `sage_hub`) are silently configured without + their kernel ever being loaded and inference fails later inside + `dispatch_attention_fn`. + """ + + class _DummyProcessor: + _attention_backend = None + + class _DummyAttention(AttentionModuleMixin): + def __init__(self): + self.processor = AttentionModuleMixinSetBackendTests._DummyProcessor() + + def test_set_attention_backend_invokes_kernel_download_for_hub_backend(self): + module = self._DummyAttention() + + with ( + mock.patch("diffusers.models.attention_dispatch._check_attention_backend_requirements") as mocked_check, + mock.patch("diffusers.models.attention_dispatch._maybe_download_kernel_for_backend") as mocked_download, + ): + module.set_attention_backend("sage_hub") + + from diffusers.models.attention_dispatch import AttentionBackendName + + mocked_check.assert_called_once_with(AttentionBackendName.SAGE_HUB) + mocked_download.assert_called_once_with(AttentionBackendName.SAGE_HUB) + self.assertEqual(module.processor._attention_backend, AttentionBackendName.SAGE_HUB) + + def test_set_attention_backend_rejects_unknown_backend(self): + module = self._DummyAttention() + + with self.assertRaises(ValueError): + module.set_attention_backend("not_a_real_backend")