Skip to content
16 changes: 14 additions & 2 deletions src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ class TransformerBlockMetadata:
_cls: Type = None
_cached_parameter_indices: dict[str, int] = None

def _register(self, cls):
"""Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`.

Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator
pattern (e.g. Flux). Writes directly to the registry dict instead of going through
``TransformerBlockRegistry.register`` so we don't trigger the lazy bulk-init while the decorated class's module
is mid-import (the bulk-init imports from the same module → circular).
"""
self._cls = cls
cls._block_metadata = self
TransformerBlockRegistry._registry[cls] = self

def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
Expand Down Expand Up @@ -107,8 +119,8 @@ def _register(cls):

def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.flux import FluxAttnProcessor
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
Expand Down Expand Up @@ -172,9 +184,9 @@ def _register_attention_processors_metadata():
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
Expand Down
Loading
Loading