Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,19 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce
return self.processor

def set_attention_backend(self, backend: str):
from .attention_dispatch import AttentionBackendName
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)

available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

backend = AttentionBackendName(backend.lower())
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
self.processor._attention_backend = backend

def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/models/test_attention_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import importlib.metadata
import tempfile
import unittest
import unittest.mock as mock

import numpy as np
import pytest
import torch
from packaging import version

from diffusers import DiffusionPipeline
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor

from ..testing_utils import torch_device
Expand Down Expand Up @@ -133,3 +135,42 @@ def test_conversion_when_using_device_map(self):

self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3))
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))


class AttentionModuleMixinSetBackendTests(unittest.TestCase):
"""Regression tests for `AttentionModuleMixin.set_attention_backend` (issue #13284).

When called on an individual submodule, the per-module setter must trigger the
same hub kernel download path that the model-level setter on `ModelMixin` does.
Otherwise hub-only backends (e.g. `sage_hub`) are silently configured without
their kernel ever being loaded and inference fails later inside
`dispatch_attention_fn`.
"""

class _DummyProcessor:
_attention_backend = None

class _DummyAttention(AttentionModuleMixin):
def __init__(self):
self.processor = AttentionModuleMixinSetBackendTests._DummyProcessor()

def test_set_attention_backend_invokes_kernel_download_for_hub_backend(self):
module = self._DummyAttention()

with (
mock.patch("diffusers.models.attention_dispatch._check_attention_backend_requirements") as mocked_check,
mock.patch("diffusers.models.attention_dispatch._maybe_download_kernel_for_backend") as mocked_download,
):
module.set_attention_backend("sage_hub")

from diffusers.models.attention_dispatch import AttentionBackendName

mocked_check.assert_called_once_with(AttentionBackendName.SAGE_HUB)
mocked_download.assert_called_once_with(AttentionBackendName.SAGE_HUB)
self.assertEqual(module.processor._attention_backend, AttentionBackendName.SAGE_HUB)

def test_set_attention_backend_rejects_unknown_backend(self):
module = self._DummyAttention()

with self.assertRaises(ValueError):
module.set_attention_backend("not_a_real_backend")
Loading