From 1d3c5c527cff30703d1595c2013e60418f934722 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 8 May 2026 20:20:52 +0530 Subject: [PATCH 01/21] update --- src/diffusers/loaders/lora_base.py | 252 ++---- src/diffusers/loaders/lora_pipeline.py | 332 ++++++-- src/diffusers/loaders/single_file_model.py | 5 - src/diffusers/loaders/single_file_utils.py | 195 ----- src/diffusers/loaders/unet.py | 105 ++- src/diffusers/models/modeling_utils.py | 349 +++++++- src/diffusers/models/transformers/__init__.py | 9 +- .../models/transformers/transformer_flux.py | 778 ------------------ src/diffusers/utils/__init__.py | 1 - src/diffusers/utils/peft_utils.py | 131 +-- 10 files changed, 777 insertions(+), 1380 deletions(-) delete mode 100644 src/diffusers/models/transformers/transformer_flux.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 5b5579664b55..1440bcbb5cd2 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -22,14 +22,10 @@ import safetensors import torch -import torch.nn as nn -from huggingface_hub import model_info -from huggingface_hub.constants import HF_HUB_OFFLINE -from ..models.modeling_utils import ModelMixin, load_state_dict +from ..models.modeling_utils import ModelMixin from ..utils import ( USE_PEFT_BACKEND, - _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, delete_adapter_layers, @@ -46,8 +42,6 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config -from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): @@ -57,13 +51,81 @@ from peft.tuners.tuners_utils import BaseTunerLayer if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + pass logger = logging.get_logger(__name__) -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" -LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" +# Constants, fetch helpers, and the offload-disable shim now live in `loaders.lora`. +# Re-exported here for back-compat. +from .lora import ( # noqa: F401 (back-compat re-exports) + LORA_ADAPTER_METADATA_KEY, + LORA_WEIGHT_NAME, + LORA_WEIGHT_NAME_SAFE, + _best_guess_weight_name, + _fetch_lora_metadata, + _fetch_state_dict, + _func_optionally_disable_offloading, +) + + +class LoRAMappingMixin: + """ + Base mixin providing utilities for LoRA weight mapping and conversion. + + Subclasses should define: + - _lora_format_keys: Dict mapping format names to identifying key patterns + - _lora_rename_patterns: Dict mapping format names to rename pattern dicts + - _map_lora_to_diffusers: Staticmethod for LoRA state dict conversion + """ + + _lora_format_keys: dict[str, set[str]] = {} + _lora_rename_patterns: dict[str, dict[str, str]] = {} + _map_lora_to_diffusers = None + + @staticmethod + def _rename_lora_key(key: str, patterns: dict[str, str]) -> str: + """Apply rename patterns to a LoRA key.""" + for old, new in patterns.items(): + key = key.replace(old, new) + return key + + @classmethod + def _detect_lora_format(cls, state_dict: dict) -> str | None: + """ + Detect the LoRA format from state dict keys. + + Returns format name (e.g., 'kohya') or None if unknown. + """ + if not cls._lora_format_keys: + return None + + keys = set(state_dict.keys()) + for format_name, format_keys in cls._lora_format_keys.items(): + if any(any(fk in k for k in keys) for fk in format_keys): + return format_name + + return None + + @classmethod + def _normalize_lora_suffixes(cls, state_dict: dict) -> dict: + """Normalize LoRA suffixes to diffusers format (.lora_A.weight, .lora_B.weight).""" + normalized = {} + for key, value in state_dict.items(): + new_key = key + new_key = new_key.replace(".lora_down.weight", ".lora_A.weight") + new_key = new_key.replace(".lora_up.weight", ".lora_B.weight") + new_key = new_key.replace(".down.weight", ".lora_A.weight") + new_key = new_key.replace(".up.weight", ".lora_B.weight") + normalized[new_key] = value + return normalized + + @classmethod + def map_lora_to_diffusers(cls, state_dict: dict, **kwargs) -> dict: + """Normalize LoRA suffixes, then dispatch to the model-specific converter.""" + if cls._map_lora_to_diffusers is None: + raise NotImplementedError(f"{cls.__name__} does not define _map_lora_to_diffusers") + state_dict = cls._normalize_lora_suffixes(state_dict) + return cls._map_lora_to_diffusers(state_dict, **kwargs) def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): @@ -195,124 +257,6 @@ def _remove_text_encoder_monkey_patch(text_encoder): text_encoder._hf_peft_config_loaded = None -def _fetch_state_dict( - pretrained_model_name_or_path_or_dict, - weight_name, - use_safetensors, - local_files_only, - cache_dir, - force_download, - proxies, - token, - revision, - subfolder, - user_agent, - allow_pickle, - metadata=None, -): - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - # Here we're relaxing the loading check to enable more Inference API - # friendliness where sometimes, it's not at all possible to automatically - # determine `weight_name`. - if weight_name is None: - weight_name = _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, - file_extension=".safetensors", - local_files_only=local_files_only, - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - metadata = _load_sft_state_dict_metadata(model_file) - - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - metadata = None - pass - - if model_file is None: - if weight_name is None: - weight_name = _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - metadata = None - else: - state_dict = pretrained_model_name_or_path_or_dict - - return state_dict, metadata - - -def _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False -): - if local_files_only or HF_HUB_OFFLINE: - raise ValueError("When using the offline mode, you must specify a `weight_name`.") - - targeted_files = [] - - if os.path.isfile(pretrained_model_name_or_path_or_dict): - return - elif os.path.isdir(pretrained_model_name_or_path_or_dict): - targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] - else: - files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings - targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] - if len(targeted_files) == 0: - return - - # "scheduler" does not correspond to a LoRA checkpoint. - # "optimizer" does not correspond to a LoRA checkpoint - # only top-level checkpoints are considered and not the other ones, hence "checkpoint". - unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} - targeted_files = list( - filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) - ) - - if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) - elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) - - if len(targeted_files) > 1: - logger.warning( - f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`." - ) - weight_name = targeted_files[0] - return weight_name - - def _pack_dict_with_prefix(state_dict, prefix): sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} return sd_with_prefix @@ -387,7 +331,9 @@ def _load_lora_into_text_encoder( network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} # create `LoraConfig` - lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False) + from .lora import _create_lora_config + + lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) # adapter_name if adapter_name is None: @@ -432,50 +378,6 @@ def _load_lora_into_text_encoder( ) -def _func_optionally_disable_offloading(_pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. - """ - from ..hooks.group_offloading import _is_group_offload_enabled - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - is_group_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if not isinstance(component, nn.Module): - continue - is_group_offload = is_group_offload or _is_group_offload_enabled(component) - if not hasattr(component, "_hf_hook"): - continue - is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) - is_sequential_cpu_offload = is_sequential_cpu_offload or ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - if is_sequential_cpu_offload or is_model_cpu_offload: - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - for _, component in _pipeline.components.items(): - if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): - continue - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) - - class LoraBaseMixin: """Utility class for handling LoRAs.""" @@ -530,7 +432,7 @@ def unload_lora_weights(self): model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): - model.unload_lora() + model.delete_adapters() elif issubclass(model.__class__, PreTrainedModel): _remove_text_encoder_monkey_patch(model) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 403e5a87db61..a08a9738a185 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -35,6 +35,7 @@ LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, + _fetch_lora_metadata, _fetch_state_dict, _load_lora_into_text_encoder, _pack_dict_with_prefix, @@ -223,7 +224,6 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -235,7 +235,6 @@ def load_lora_weights( else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, - _pipeline=self, metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -312,7 +311,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -326,6 +325,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." @@ -402,7 +412,7 @@ def load_lora_into_unet( # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( + unet.load_adapter( state_dict, prefix=cls.unet_name, network_alphas=network_alphas, @@ -650,7 +660,6 @@ def load_lora_weights( unet=self.unet, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -662,7 +671,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -674,7 +682,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -751,7 +758,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -765,6 +772,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." @@ -842,7 +860,7 @@ def load_lora_into_unet( # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( + unet.load_adapter( state_dict, prefix=cls.unet_name, network_alphas=network_alphas, @@ -1029,7 +1047,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1043,6 +1061,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -1089,7 +1118,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1101,7 +1129,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1113,7 +1140,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1139,7 +1165,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -1324,7 +1350,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1338,6 +1364,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -1385,7 +1422,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1412,7 +1448,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -1529,7 +1565,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1543,6 +1579,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." @@ -1700,7 +1747,6 @@ def load_lora_weights( transformer=transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1720,7 +1766,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1747,7 +1792,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, @@ -2291,7 +2336,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, @@ -2454,7 +2499,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2468,6 +2513,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -2514,7 +2570,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2541,7 +2596,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -2650,7 +2705,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2664,6 +2719,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -2711,7 +2777,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2738,7 +2803,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -2849,7 +2914,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2863,6 +2928,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -2914,7 +2990,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2941,7 +3016,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3053,7 +3128,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3067,6 +3142,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -3127,7 +3213,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3139,7 +3224,6 @@ def load_lora_weights( else self.connectors, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, prefix=self.connectors_name, @@ -3167,7 +3251,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {prefix}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3280,7 +3364,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3294,6 +3378,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -3341,7 +3436,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3368,7 +3462,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3680,7 +3774,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3694,6 +3788,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -3745,7 +3850,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3772,7 +3876,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3883,7 +3987,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3897,6 +4001,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -3949,7 +4064,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3976,7 +4090,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4088,7 +4202,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4102,6 +4216,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4149,7 +4274,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4176,7 +4300,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4287,7 +4411,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4301,6 +4425,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) elif any(k.startswith("lora_unet_") for k in state_dict): @@ -4411,7 +4546,6 @@ def load_lora_weights( transformer=self.transformer_2, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4423,7 +4557,6 @@ def load_lora_weights( else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4450,7 +4583,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4562,7 +4695,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4576,6 +4709,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) elif any(k.startswith("lora_unet_") for k in state_dict): @@ -4688,7 +4832,6 @@ def load_lora_weights( transformer=self.transformer_2, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4700,7 +4843,6 @@ def load_lora_weights( else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4727,7 +4869,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4839,7 +4981,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4853,6 +4995,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4900,7 +5053,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4927,7 +5079,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5038,7 +5190,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5052,6 +5204,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -5103,7 +5266,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5130,7 +5292,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5241,7 +5403,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5255,6 +5417,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -5309,7 +5482,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5336,7 +5508,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5447,7 +5619,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5461,6 +5633,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -5515,7 +5698,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5542,7 +5724,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5653,7 +5835,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5667,6 +5849,17 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -5729,7 +5922,6 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, - _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5756,7 +5948,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( + transformer.load_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 43fc8d897fe6..5dbd8b1b60b8 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -38,7 +38,6 @@ convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, - convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, @@ -110,10 +109,6 @@ "SparseControlNetModel": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, - "FluxTransformer2DModel": { - "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, "ChromaTransformer2DModel": { "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 98b9e8266506..fff6be067f27 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2244,201 +2244,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): return converted_state_dict -def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 - num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 - mlp_ratio = 4.0 - inner_dim = 3072 - - # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; - # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation - def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") - - ## time_text_embed.text_embedder <- vector_in - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( - "vector_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") - - # guidance - has_guidance = any("guidance" in k for k in checkpoint) - if has_guidance: - converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( - "guidance_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( - "guidance_in.in_layer.bias" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( - "guidance_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( - "guidance_in.out_layer.bias" - ) - - # context_embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") - - # x_embedder - converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") - converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - # norms. - ## norm1 - converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.bias" - ) - ## norm1_context - converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.bias" - ) - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") - converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") - converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.bias" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.bias" - ) - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.bias" - ) - - # single transformer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.bias" - ) - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) - q_bias, k_bias, v_bias, mlp_bias = torch.split( - checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) - converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - # output projections. - converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") - converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight") - ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias") - ) - - return converted_state_dict def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9dab3bc667ea..e0dc3dcfc2fd 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -35,16 +35,17 @@ from ..utils import ( USE_PEFT_BACKEND, _get_model_file, + convert_sai_sd_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, deprecate, get_adapter_name, - get_peft_kwargs, is_accelerate_available, is_peft_version, is_torch_version, logging, ) from ..utils.torch_utils import empty_device_cache +from .lora import _create_lora_config from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -65,6 +66,83 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + def _load_adapter_from_pretrained(self, pretrained_model_name_or_path_or_dict, **kwargs): + """UNet override that handles model-specific LoRA formats before delegating to the base loader. + + - Converts old non-PEFT UNet LoRA naming to PEFT shape (when no key carries ``lora_A``). + - Detects SAI Control LoRA (``lora_controlnet`` marker) — that path has its own + loader because the LoraConfig needs post-create overrides the base flow doesn't expose. + See https://huggingface.co/stabilityai/control-lora and + https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors. + """ + from .lora import _HUB_KWARGS, _fetch_state_dict + + # Resolve to a state_dict up-front so we can inspect / convert before the base loader. + if isinstance(pretrained_model_name_or_path_or_dict, dict): + state_dict = pretrained_model_name_or_path_or_dict + else: + fetch_kwargs = {k: kwargs.get(k, default) for k, default in _HUB_KWARGS.items()} + fetch_kwargs["weight_name"] = kwargs.get("weight_name") + fetch_kwargs["use_safetensors"] = kwargs.get("use_safetensors") + state_dict = _fetch_state_dict(pretrained_model_name_or_path_or_dict, **fetch_kwargs) + + if not any("lora_A" in k for k in state_dict): + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if "lora_controlnet" in state_dict: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) + return self._load_sai_control_lora(state_dict, **kwargs) + + # Hand the (possibly-converted) state_dict to the base loader. It hits the + # dict-passthrough branch in `_fetch_state_dict` and runs the rest of the flow. + return super()._load_adapter_from_pretrained(state_dict, **kwargs) + + def _load_sai_control_lora(self, state_dict, **kwargs): + """Bespoke loader for SAI Control LoRA: same flow as the base, plus LoraConfig overrides + (``lora_alpha`` follows ``r``, all biases trained, ``exclude_modules`` repurposed).""" + from .lora import _maybe_warn_for_unhandled_keys, _offloading_disabled + + adapter_name = kwargs.get("adapter_name") or get_adapter_name(self) + network_alphas = kwargs.get("network_alphas") + prefix = kwargs.get("prefix", "transformer") + hotswap = kwargs.get("hotswap", False) + low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage", False) + metadata = kwargs.get("metadata") + + if prefix is not None: + state_dict = { + k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") + } + + rank = { + f"^{key}": val.shape[1] + for key, val in state_dict.items() + if "lora_B" in key and val.ndim > 1 + } + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas if k.startswith(f"{prefix}.")] + network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) + + # SAI Control LoRA overrides: alpha follows rank; all biases are trained. + lora_config.lora_alpha = lora_config.r + lora_config.alpha_pattern = lora_config.rank_pattern + lora_config.bias = "all" + lora_config.modules_to_save = lora_config.exclude_modules + lora_config.exclude_modules = None + + peft_kwargs = {"low_cpu_mem_usage": low_cpu_mem_usage} + with _offloading_disabled(self): + if hotswap: + self._hotswap_adapter(state_dict, lora_config, adapter_name) + incompatible_keys = None + else: + incompatible_keys = self._inject_adapter(state_dict, lora_config, adapter_name, peft_kwargs) + self._maybe_apply_deferred_hotswap_prep(lora_config) + _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) + @validate_hf_hub_args def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs): r""" @@ -301,7 +379,7 @@ def _process_lora( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft import inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -339,28 +417,7 @@ def _process_lora( if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) + lora_config = _create_lora_config(state_dict, network_alphas, rank) # adapter_name if adapter_name is None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..2d685c180644 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -24,7 +24,8 @@ import shutil import tempfile from collections import OrderedDict -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, contextmanager, nullcontext +from dataclasses import dataclass, field from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -229,6 +230,52 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) +@dataclass +class ModelMetadata: + """ + Metadata describing model capabilities and configuration hints. + + This is NOT configuration (which is saved to config.json and defines architecture). + This is static metadata about the model class's capabilities and hints for + optimization features like gradient checkpointing, offloading, and parallelism. + + Attributes: + supports_gradient_checkpointing: Whether the model supports gradient checkpointing + for memory-efficient training. + no_split_modules: List of module class names that should NOT be split across + devices during model parallelism. + keep_in_fp32_modules: List of module names to keep in FP32 precision when using + lower precision dtypes for numerical stability. + skip_layerwise_casting_patterns: Tuple of module name patterns to exclude from + layerwise casting operations. + supports_group_offloading: Whether the model supports group offloading. + repeated_blocks: List of module class names that repeat throughout the model, + useful for optimization and pattern analysis. + cp_plan: Context parallel configuration plan defining how to split model + components for context parallelism across devices. + keys_to_ignore_on_load_unexpected: List of keys to ignore when loading + unexpected keys from a checkpoint. + """ + + supports_gradient_checkpointing: bool = False + no_split_modules: Optional[List[str]] = None + keep_in_fp32_modules: Optional[List[str]] = None + skip_layerwise_casting_patterns: Optional[Tuple[str, ...]] = None + supports_group_offloading: bool = True + repeated_blocks: List[str] = field(default_factory=list) + cp_plan: Optional[Dict[str, Any]] = None + keys_to_ignore_on_load_unexpected: Optional[List[str]] = None + + +def _should_convert_checkpoint(model_state_dict: Dict[str, Any], checkpoint: Dict[str, Any]) -> bool: + """Check if checkpoint needs conversion by comparing keys with model state dict.""" + model_state_dict_keys = set(model_state_dict.keys()) + checkpoint_state_dict_keys = set(checkpoint.keys()) + is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys) + is_match = model_state_dict_keys == checkpoint_state_dict_keys + return not (is_subset and is_match) + + class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. @@ -251,6 +298,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _parallel_config = None _cp_plan = None _skip_keys = None + _model_metadata: Optional["ModelMetadata"] = None + + @classmethod + def get_metadata(cls) -> "ModelMetadata": + """ + Get the model's metadata for discovery and introspection. + + Returns a ModelMetadata instance describing the model's capabilities. + If `_model_metadata` is defined on the class, returns that directly. + Otherwise, constructs a ModelMetadata from the individual class attributes. + """ + if cls._model_metadata is not None: + return cls._model_metadata + return ModelMetadata( + supports_gradient_checkpointing=cls._supports_gradient_checkpointing, + no_split_modules=cls._no_split_modules, + keep_in_fp32_modules=cls._keep_in_fp32_modules, + skip_layerwise_casting_patterns=cls._skip_layerwise_casting_patterns, + supports_group_offloading=cls._supports_group_offloading, + repeated_blocks=cls._repeated_blocks if cls._repeated_blocks else [], + cp_plan=cls._cp_plan, + keys_to_ignore_on_load_unexpected=cls._keys_to_ignore_on_load_unexpected, + ) + + @classmethod + def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert state dict from original format to diffusers format if needed. + + This method checks if the state dict keys match the model's expected keys. + If not, it applies normalization and conversion using the model's + `_normalize_checkpoint_keys` and `map_to_diffusers` methods. + + Args: + model: The model instance to compare against. + state_dict: The loaded state dict. + + Returns: + The state dict, potentially converted to diffusers format. + """ + model_state_dict = model.state_dict() + + if not _should_convert_checkpoint(model_state_dict, state_dict): + return state_dict + + normalize_fn = getattr(cls, "_normalize_checkpoint_keys", None) + if normalize_fn is not None: + state_dict = normalize_fn(state_dict) + + if not _should_convert_checkpoint(model_state_dict, state_dict): + return state_dict + + map_to_diffusers_fn = getattr(cls, "map_to_diffusers", None) + if map_to_diffusers_fn is None: + return state_dict + + return map_to_diffusers_fn(state_dict) def __init__(self): super().__init__() @@ -1362,6 +1466,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict) + # Convert checkpoint if needed (e.g., original format to diffusers format) + state_dict = cls._maybe_convert_state_dict(model, state_dict) + if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] else: @@ -1450,6 +1557,246 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None return model + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self: + r""" + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. + The model is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path_or_dict (`str`, *optional*): + Can be either: + - A link to the `.safetensors` or `.ckpt` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a local *file* containing the weights of the component model. + - A state dict containing the component model weights. + config (`str`, *optional*): + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, + overriding the cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the + standard cache is not used. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained weights and not initializing the weights. + disable_mmap (`bool`, *optional*, defaults to `False`): + Whether to disable mmap when loading a Safetensors model. + + Returns: + The instantiated model. + + Example: + ```python + >>> from diffusers import FluxTransformer2DModel + + >>> ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + >>> model = FluxTransformer2DModel.from_single_file(ckpt_path) + ``` + """ + from ..loaders.single_file_utils import ( + SingleFileComponentError, + load_single_file_checkpoint, + ) + + map_to_diffusers_fn = getattr(cls, "map_to_diffusers", None) + default_subfolder = getattr(cls, "_default_subfolder", None) + + if map_to_diffusers_fn is None: + raise ValueError( + f"{cls.__name__} does not support `from_single_file`. " + f"Please ensure the model class defines `map_to_diffusers`." + ) + + pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) + if pretrained_model_link_or_path is not None: + deprecation_message = ( + "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" + ) + deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) + pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + + config = kwargs.pop("config", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + subfolder = kwargs.pop("subfolder", None) + revision = kwargs.pop("revision", None) + config_revision = kwargs.pop("config_revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + device = kwargs.pop("device", None) + disable_mmap = kwargs.pop("disable_mmap", False) + device_map = kwargs.pop("device_map", None) + + user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} + if quantization_config is not None: + user_agent["quant"] = quantization_config.quant_method.value + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + checkpoint = pretrained_model_link_or_path_or_dict + else: + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + disable_mmap=disable_mmap, + user_agent=user_agent, + ) + + # Normalize checkpoint keys (strip known prefixes) if the model defines a normalizer + normalize_fn = getattr(cls, "_normalize_checkpoint_keys", None) + if normalize_fn is not None: + checkpoint = normalize_fn(checkpoint) + + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + else: + hf_quantizer = None + + if config is not None: + if isinstance(config, str): + default_pretrained_model_config_name = config + else: + raise ValueError( + "Invalid `config` argument. Please provide a string representing a repo id " + "or path to a local Diffusers model repo." + ) + else: + get_model_config_fn = getattr(cls, "_get_model_config", None) + if get_model_config_fn is None: + raise ValueError( + f"{cls.__name__} does not support automatic config detection. " + f"Please provide a `config` argument or define `_get_model_config` on the model class." + ) + default_pretrained_model_config_name = get_model_config_fn(checkpoint) + + if default_subfolder is not None: + subfolder = default_subfolder + + diffusers_model_config = cls.load_config( + pretrained_model_name_or_path=default_pretrained_model_config_name, + subfolder=subfolder, + local_files_only=local_files_only, + token=token, + revision=config_revision, + ) + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} + diffusers_model_config.update(model_kwargs) + + if is_accelerate_available(): + from accelerate import init_empty_weights + + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + else: + ctx = nullcontext + + with ctx(): + model = cls.from_config(diffusers_model_config) + + model_state_dict = model.state_dict() + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + else: + keep_in_fp32_modules = [] + + if _should_convert_checkpoint(model_state_dict, checkpoint): + checkpoint = map_to_diffusers_fn(checkpoint) + + if not checkpoint: + raise SingleFileComponentError( + f"Failed to load {cls.__name__}. Weights for this component appear to be missing in the checkpoint." + ) + + loaded_keys = list(checkpoint.keys()) + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer) + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + checkpoint, + None, + None, + loaded_keys, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if device_map is not None: + from accelerate import dispatch_model + + device_map_kwargs = { + "device_map": device_map, + "offload_index": offload_index, + } + dispatch_model(model, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: + model.to(torch_dtype) + + model.eval() + + return model + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 156b54e7f07d..1f45accb3d5e 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -1,13 +1,21 @@ +import sys + from ...utils import is_torch_available if is_torch_available(): + from . import flux + + # Register backwards compatibility alias so `from .transformer_flux import X` works + sys.modules[__name__ + ".transformer_flux"] = flux + from .ace_step_transformer import AceStepTransformer1DModel from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .cogvideox_transformer_3d import CogVideoXTransformer3DModel from .consisid_transformer_3d import ConsisIDTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel + from .flux import FluxTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel from .latte_transformer_3d import LatteTransformer3DModel from .lumina_nextdit2d import LuminaNextDiT2DModel @@ -27,7 +35,6 @@ from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_ernie_image import ErnieImageTransformer2DModel - from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_helios import HeliosTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py deleted file mode 100644 index 78a77ebcfea9..000000000000 --- a/src/diffusers/models/transformers/transformer_flux.py +++ /dev/null @@ -1,778 +0,0 @@ -# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ..attention_dispatch import dispatch_attention_fn -from ..cache_utils import CacheMixin -from ..embeddings import ( - CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, - apply_rotary_emb, - get_1d_rotary_pos_embed, -) -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - - encoder_query = encoder_key = encoder_value = (None,) - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - if attn.fused_projections: - return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - - -class FluxAttnProcessor: - _attention_backend = None - _parallel_config = None - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - - def __call__( - self, - attn: "FluxAttention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: torch.Tensor | None = None, - image_rotary_emb: torch.Tensor | None = None, - ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states - ) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - query = attn.norm_q(query) - key = attn.norm_k(key) - - if attn.added_kv_proj_dim is not None: - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - - encoder_query = attn.norm_added_q(encoder_query) - encoder_key = attn.norm_added_k(encoder_key) - - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - hidden_states = attn.to_out[0](hidden_states.contiguous()) - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous()) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxIPAdapterAttnProcessor(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - _attention_backend = None - _parallel_config = None - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: "FluxAttention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: torch.Tensor | None = None, - image_rotary_emb: torch.Tensor | None = None, - ip_hidden_states: list[torch.Tensor] | None = None, - ip_adapter_masks: torch.Tensor | None = None, - ) -> torch.Tensor: - batch_size = hidden_states.shape[0] - - query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( - attn, hidden_states, encoder_hidden_states - ) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - query = attn.norm_q(query) - key = attn.norm_k(key) - ip_query = query - - if encoder_hidden_states is not None: - encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - - encoder_query = attn.norm_added_q(encoder_query) - encoder_key = attn.norm_added_k(encoder_key) - - query = torch.cat([encoder_query, query], dim=1) - key = torch.cat([encoder_key, key], dim=1) - value = torch.cat([encoder_value, value], dim=1) - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) - ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) - - current_ip_hidden_states = dispatch_attention_fn( - ip_query, - ip_key, - ip_value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - -class FluxAttention(torch.nn.Module, AttentionModuleMixin): - _default_processor_cls = FluxAttnProcessor - _available_processors = [ - FluxAttnProcessor, - FluxIPAdapterAttnProcessor, - ] - - def __init__( - self, - query_dim: int, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: int | None = None, - added_proj_bias: bool | None = True, - out_bias: bool = True, - eps: float = 1e-5, - out_dim: int = None, - context_pre_only: bool | None = None, - pre_only: bool = False, - elementwise_affine: bool = True, - processor=None, - ): - super().__init__() - - self.head_dim = dim_head - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.query_dim = query_dim - self.use_bias = bias - self.dropout = dropout - self.out_dim = out_dim if out_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.heads = out_dim // dim_head if out_dim is not None else heads - self.added_kv_proj_dim = added_kv_proj_dim - self.added_proj_bias = added_proj_bias - - self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.pre_only: - self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(torch.nn.Dropout(dropout)) - - if added_kv_proj_dim is not None: - self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) - self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) - self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) - - if processor is None: - processor = self._default_processor_cls() - self.set_processor(processor) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - image_rotary_emb: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} - unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] - if len(unused_kwargs) > 0: - logger.warning( - f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) - - -@maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): - super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) - - self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - - self.attn = FluxAttention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=FluxAttnProcessor(), - eps=1e-6, - pre_only=True, - ) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - joint_attention_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - text_seq_len = encoder_hidden_states.shape[1] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs or {} - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - - encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] - return encoder_hidden_states, hidden_states - - -@maybe_allow_in_graph -class FluxTransformerBlock(nn.Module): - def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 - ): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - self.norm1_context = AdaLayerNormZero(dim) - - self.attn = FluxAttention( - query_dim=dim, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, - processor=FluxAttnProcessor(), - eps=eps, - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - joint_attention_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - joint_attention_kwargs = joint_attention_kwargs or {} - - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - if len(attention_outputs) == 3: - hidden_states = hidden_states + ip_attn_output - - # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - -class FluxPosEmbed(nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: list[int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - - -class FluxTransformer2DModel( - ModelMixin, - ConfigMixin, - PeftAdapterMixin, - FromOriginalModelMixin, - FluxTransformer2DLoadersMixin, - CacheMixin, - AttentionMixin, -): - """ - The Transformer model introduced in Flux. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - patch_size (`int`, defaults to `1`): - Patch size to turn the input data into small patches. - in_channels (`int`, defaults to `64`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `None`): - The number of channels in the output. If not specified, it defaults to `in_channels`. - num_layers (`int`, defaults to `19`): - The number of layers of dual stream DiT blocks to use. - num_single_layers (`int`, defaults to `38`): - The number of layers of single stream DiT blocks to use. - attention_head_dim (`int`, defaults to `128`): - The number of dimensions to use for each attention head. - num_attention_heads (`int`, defaults to `24`): - The number of attention heads to use. - joint_attention_dim (`int`, defaults to `4096`): - The number of dimensions to use for the joint attention (embedding/channel dimension of - `encoder_hidden_states`). - pooled_projection_dim (`int`, defaults to `768`): - The number of dimensions to use for the pooled projection. - guidance_embeds (`bool`, defaults to `False`): - Whether to use guidance embeddings for guidance-distilled variant of the model. - axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): - The dimensions to use for the rotary positional embeddings. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _skip_layerwise_casting_patterns = ["pos_embed", "norm"] - _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _cp_plan = { - "": { - "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - }, - "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), - } - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - out_channels: int | None = None, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: tuple[int, int, int] = (16, 56, 56), - ): - super().__init__() - self.out_channels = out_channels or in_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = nn.Linear(in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for _ in range(num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for _ in range(num_single_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False - - @apply_lora_scale("joint_attention_kwargs") - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: dict[str, Any] | None = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, - return_dict: bool = True, - controlnet_blocks_repeat: bool = False, - ) -> torch.Tensor | Transformer2DModelOutput: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): - Input `hidden_states`. - encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - - hidden_states = self.x_embedder(hidden_states) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: - ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") - ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) - joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) - - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - joint_attention_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) - - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] - ) - else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - joint_attention_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 008426f5275e..ffb94c411b3e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -138,7 +138,6 @@ check_peft_version, delete_adapter_layers, get_adapter_name, - get_peft_kwargs, recurse_remove_peft_layers, scale_lora_layers, set_adapter_layers, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 65bcfe631e97..058fa75c7d07 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -22,7 +22,7 @@ from packaging import version from . import logging -from .import_utils import is_peft_available, is_peft_version, is_torch_available +from .import_utils import is_peft_available, is_torch_available from .torch_utils import empty_device_cache @@ -150,56 +150,6 @@ def unscale_lora_layers(model, weight: float | None = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs( - rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None -): - rank_pattern = {} - alpha_pattern = {} - r = lora_alpha = list(rank_dict.values())[0] - - if len(set(rank_dict.values())) > 1: - # get the rank occurring the most number of times - r = collections.Counter(rank_dict.values()).most_common()[0][0] - - # for modules with rank different from the most occurring rank, add it to the `rank_pattern` - rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) - rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} - - if network_alpha_dict is not None and len(network_alpha_dict) > 0: - if len(set(network_alpha_dict.values())) > 1: - # get the alpha occurring the most number of times - lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] - - # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` - alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) - if is_unet: - alpha_pattern = { - ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v - for k, v in alpha_pattern.items() - } - else: - alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} - else: - lora_alpha = set(network_alpha_dict.values()).pop() - - target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) - use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) - # for now we know that the "bias" keys are only associated with `lora_B`. - lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) - - lora_config_kwargs = { - "r": r, - "lora_alpha": lora_alpha, - "rank_pattern": rank_pattern, - "alpha_pattern": alpha_pattern, - "target_modules": target_modules, - "use_dora": use_dora, - "lora_bias": lora_bias, - } - - return lora_config_kwargs - - def get_adapter_name(model): from peft.tuners.tuners_utils import BaseTunerLayer @@ -344,82 +294,3 @@ def check_peft_version(min_version: str) -> None: ) -def _create_lora_config( - state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None -): - from peft import LoraConfig - - if metadata is not None: - lora_config_kwargs = metadata - else: - lora_config_kwargs = get_peft_kwargs( - rank_pattern_dict, - network_alpha_dict=network_alphas, - peft_state_dict=state_dict, - is_unet=is_unet, - model_state_dict=model_state_dict, - adapter_name=adapter_name, - ) - - _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) - - # Version checks for DoRA and lora_bias - if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") - - if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") - - try: - return LoraConfig(**lora_config_kwargs) - except TypeError as e: - raise TypeError("`LoraConfig` class could not be instantiated.") from e - - -def _maybe_raise_error_for_ambiguous_keys(config): - rank_pattern = config["rank_pattern"].copy() - target_modules = config["target_modules"] - - for key in list(rank_pattern.keys()): - # try to detect ambiguity - # `target_modules` can also be a str, in which case this loop would loop - # over the chars of the str. The technically correct way to match LoRA keys - # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). - # But this cuts it for now. - exact_matches = [mod for mod in target_modules if mod == key] - substring_matches = [mod for mod in target_modules if key in mod and mod != key] - - if exact_matches and substring_matches: - if is_peft_version("<", "0.14.1"): - raise ValueError( - "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." - ) - - -def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) From 96c078e48f28906e8d61deb6b8ecb85678813a51 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 8 May 2026 20:21:01 +0530 Subject: [PATCH 02/21] update --- src/diffusers/loaders/lora.py | 877 ++++++++++++++++++ src/diffusers/loaders/weight_mapping.py | 133 +++ .../models/transformers/flux/__init__.py | 25 + .../models/transformers/flux/lora.py | 476 ++++++++++ .../models/transformers/flux/model.py | 803 ++++++++++++++++ .../transformers/flux/weight_mapping.py | 315 +++++++ 6 files changed, 2629 insertions(+) create mode 100644 src/diffusers/loaders/lora.py create mode 100644 src/diffusers/loaders/weight_mapping.py create mode 100644 src/diffusers/models/transformers/flux/__init__.py create mode 100644 src/diffusers/models/transformers/flux/lora.py create mode 100644 src/diffusers/models/transformers/flux/model.py create mode 100644 src/diffusers/models/transformers/flux/weight_mapping.py diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py new file mode 100644 index 000000000000..fc5ec5a250e4 --- /dev/null +++ b/src/diffusers/loaders/lora.py @@ -0,0 +1,877 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import functools +import json +import os +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union + +import safetensors +import torch +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE + +from ..hooks.group_offloading import ( + _maybe_remove_and_reapply_group_offloading, +) +from ..models.modeling_utils import load_state_dict +from ..utils import ( + USE_PEFT_BACKEND, + _get_model_file, + delete_adapter_layers, + deprecate, + get_adapter_name, + is_accelerate_available, + is_peft_available, + is_peft_version, + logging, + recurse_remove_peft_layers, + set_weights_and_activate_adapters, +) +from ..utils.state_dict_utils import _load_sft_state_dict_metadata +from .unet_loader_utils import _maybe_expand_lora_scales + + +if is_accelerate_available(): + pass + + +if is_peft_available(): + from peft import LoraConfig, PeftConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import get_peft_model_state_dict + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) + + +logger = logging.get_logger(__name__) + + +# Minimum PEFT version this mixin relies on. Bumping this lets us delete the +# version-fallback branches scattered through the methods (DoRA, lora_bias, +# hotswap, set_adapter hasattr, etc.). +_MIN_PEFT_VERSION_FOR_LORA = "0.14.1" +_HAS_REQUIRED_PEFT = USE_PEFT_BACKEND and is_peft_version(">=", _MIN_PEFT_VERSION_FOR_LORA) + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" + + +# Hub-download kwargs forwarded to `_get_model_file`. +_HUB_KWARGS = { + "cache_dir": None, + "force_download": False, + "proxies": None, + "local_files_only": None, + "token": None, + "revision": None, + "subfolder": None, +} + + +# Per-class hook for expanding adapter weights before activation. Models that need +# expansion (currently only UNet variants) register here; everything else falls +# through to the identity default so new transformers don't need an entry. +_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( + lambda: (lambda model_cls, weights: weights), + { + "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, + }, +) + + +def _requires_peft(method): + """Guard a method with a uniform PEFT availability + minimum-version check.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not _HAS_REQUIRED_PEFT: + raise ValueError( + f"`{method.__name__}()` requires PEFT >= {_MIN_PEFT_VERSION_FOR_LORA}. " + "Please install or upgrade PEFT: `pip install -U peft`." + ) + return method(self, *args, **kwargs) + + return wrapper + + +def _fuse_lora_apply(module, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """Per-module callback for ``self.apply(...)`` in ``fuse_lora``.""" + if not isinstance(module, BaseTunerLayer): + return + if lora_scale != 1.0: + module.scale_layer(lora_scale) + module.merge(safe_merge=safe_fusing, adapter_names=adapter_names) + + +def _unfuse_lora_apply(module): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + +def _serialize_lora_adapter_metadata(peft_config): + """Convert a ``PeftConfig`` to a JSON string suitable for the safetensors metadata blob. + + PEFT configs may contain ``set`` values (which JSON can't serialize); coerce those + to lists first. + """ + cfg = peft_config.to_dict() + for key, value in cfg.items(): + if isinstance(value, set): + cfg[key] = list(value) + return json.dumps(cfg, indent=2, sort_keys=True) + + +def _scope_state_dict_to_adapter(state_dict, adapter_name): + """Rewrite ``lora_A.weight`` / ``lora_B.weight`` keys to include the adapter name + (the format expected by ``hotswap_adapter_from_state_dict``).""" + out = {} + for k, v in state_dict.items(): + if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"): + k = k[: -len(".weight")] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[: -len(".bias")] + f".{adapter_name}.bias" + out[k] = v + return out + + +def _split_majority_and_outliers(value_dict): + """Return ``(majority, outliers)`` for ``value_dict``. + + ``majority`` is the most common value (or the lone value if all are equal, or + None for an empty dict). ``outliers`` is a sub-dict of the items whose value + differs from the majority — empty when every value matches. + """ + values = list(value_dict.values()) + if not values: + return None, {} + if len(set(values)) == 1: + return values[0], {} + majority = collections.Counter(values).most_common(1)[0][0] + return majority, {k: v for k, v in value_dict.items() if v != majority} + + +def _create_lora_config(state_dict, network_alphas, rank_dict, metadata=None): + """Build a PEFT ``LoraConfig`` from a LoRA state dict. + + ``metadata`` (when present) overrides the inferred kwargs entirely — used when a + saved adapter shipped its own serialized ``LoraConfig`` blob. Otherwise we infer: + per-module rank / alpha values that don't match the majority go into + ``rank_pattern`` / ``alpha_pattern``; the majority becomes the global default. + """ + if metadata is not None: + return LoraConfig(**metadata) + + r, rank_outliers = _split_majority_and_outliers(rank_dict) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_outliers.items()} + + lora_alpha = r + alpha_pattern = {} + if network_alphas: + lora_alpha, alpha_outliers = _split_majority_and_outliers(network_alphas) + if alpha_outliers: + # PEFT-converted alpha keys (UNet / transformer LoRAs) carry ``.lora_A.``; + # raw kohya-style alphas (legacy text-encoder LoRAs) carry ``.down.``. + sample = next(iter(alpha_outliers)) + if ".lora_A." in sample: + alpha_pattern = {k.split(".lora_A.")[0].replace(".alpha", ""): v for k, v in alpha_outliers.items()} + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_outliers.items()} + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": list({name.split(".lora")[0] for name in state_dict}), + "use_dora": any("lora_magnitude_vector" in k for k in state_dict), + "lora_bias": any("lora_B" in k and k.endswith(".bias") for k in state_dict), + } + + return LoraConfig(**lora_config_kwargs) + + +def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): + if incompatible_keys is None: + return + warn_msg = "" + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model: " + f"{', '.join(lora_unexpected_keys)}. " + ) + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model: " + f"{', '.join(lora_missing_keys)}." + ) + if warn_msg: + logger.warning(warn_msg) + + +def _fetch_state_dict( + pretrained_model_name_or_path_or_dict, + weight_name=None, + use_safetensors=True, + allow_pickle=False, + **hub_kwargs, +): + """Load a LoRA state dict from a path/repo/dict. + + ``hub_kwargs`` are the download / file-discovery options forwarded to + ``_get_model_file`` (see ``_HUB_KWARGS`` for the canonical set). Sidecar + :func:`_fetch_lora_metadata`. + """ + if isinstance(pretrained_model_name_or_path_or_dict, dict): + return pretrained_model_name_or_path_or_dict + + source = pretrained_model_name_or_path_or_dict + local_files_only = hub_kwargs.get("local_files_only") + + # Try safetensors first when the user asked for it (or named a .safetensors file). + # Fall through to .bin if the safetensors lookup fails and pickle is allowed. + prefer_safetensors = (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ) + if prefer_safetensors: + try: + name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) + return load_state_dict(model_file) + except (IOError, safetensors.SafetensorError): + if not allow_pickle: + raise + + name = weight_name or _best_guess_weight_name(source, ".bin", local_files_only) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME, **hub_kwargs) + return load_state_dict(model_file) + + +def _fetch_lora_metadata(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): + """Load LoRA adapter metadata from a safetensors file's sidecar. + + Returns ``None`` for non-safetensors sources (dicts, ``.bin`` files, missing + sidecar). The hub layer caches the file, so calling this after + """ + if isinstance(pretrained_model_name_or_path_or_dict, dict): + return None + + source = pretrained_model_name_or_path_or_dict + local_files_only = hub_kwargs.get("local_files_only") + name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) + if not name or not name.endswith(".safetensors"): + return None + try: + model_file = _get_model_file(source, weights_name=name, **hub_kwargs) + return _load_sft_state_dict_metadata(model_file) + except (IOError, safetensors.SafetensorError): + return None + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return None + if os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + + # Strip non-LoRA files: scheduler/optimizer state, intermediate checkpoints. + unallowed = {"scheduler", "optimizer", "checkpoint"} + targeted_files = [f for f in targeted_files if not any(s in f for s in unallowed)] + + # Prefer the canonical filenames if present. + for canonical in (LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE): + if any(f.endswith(canonical) for f in targeted_files): + targeted_files = [f for f in targeted_files if f.endswith(canonical)] + break + + if not targeted_files: + return None + if len(targeted_files) > 1: + logger.warning( + f"Provided path contains more than one weights file in the {file_extension} format. " + f"`{targeted_files[0]}` is going to be loaded; for precise control, specify a `weight_name` " + "in `load_lora_weights`." + ) + return targeted_files[0] + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them in a base model, check out the PEFT + [documentation](https://huggingface.co/docs/peft/index). + + Install the latest version of PEFT, and use this mixin to: + + - Attach new adapters in the model. + - Attach multiple adapters and iteratively activate/deactivate them. + - Activate/deactivate all adapters from the model. + - Get a list of the active adapters. + """ + + _hf_peft_config_loaded = False + # kwargs for prepare_model_for_compiled_hotswap, if required + _lora_hotswap_kwargs: Optional[dict] = None + + @_requires_peft + def load_adapter( + self, + adapter, + adapter_name=None, + prefix="transformer", + hotswap: bool = False, + **kwargs, + ): + r""" + Add an adapter to the underlying model. + + ``source`` can be either: + + - A ``PeftConfig`` (e.g. ``LoraConfig``) — initializes a fresh adapter with + random weights, suitable for training. + - A repo id, local path, or pre-loaded ``state_dict`` — loads pretrained + adapter weights, suitable for inference. + + For the config path, only ``adapter_name`` is used; ``prefix``, ``hotswap``, + and the download/loading kwargs apply to the pretrained path. + """ + adapter_name = adapter_name or get_adapter_name(self) + if isinstance(adapter, PeftConfig): + return self._load_adapter_from_config(adapter, adapter_name=adapter_name) + + return self._load_adapter_from_pretrained( + adapter, adapter_name=adapter_name, prefix=prefix, hotswap=hotswap, **kwargs + ) + + def _load_adapter_from_config(self, adapter_config, adapter_name="default"): + if self._hf_peft_config_loaded and adapter_name in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self._hf_peft_config_loaded = True + self.set_adapters(adapter_name) + + def _load_adapter_from_pretrained( + self, + pretrained_model_name_or_path_or_dict, + adapter_name=None, + prefix="transformer", + hotswap: bool = False, + **kwargs, + ): + r""" + Loads a LoRA adapter into the underlying model. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + prefix (`str`, *optional*): Prefix to filter the state dict. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + metadata: + LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to + initialize `LoraConfig`. + """ + hub_kwargs = {k: kwargs.pop(k, default) for k, default in _HUB_KWARGS.items()} + hub_kwargs["user_agent"] = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + network_alphas = kwargs.pop("network_alphas", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + metadata = kwargs.pop("metadata", None) + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + allow_pickle=False, + **hub_kwargs, + ) + if not state_dict: + model_class_name = self.__class__.__name__ + logger.warning( + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " + "This is safe to ignore if LoRA state dict didn't originally have any " + f"{model_class_name} related params. You can also try specifying `prefix=None` " + "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " + "https://github.com/huggingface/diffusers/issues/new" + ) + return + + metadata = metadata or _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict, weight_name=weight_name, **hub_kwargs + ) + + if network_alphas is not None and prefix is None: + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") + + if prefix is not None: + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} + + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: + raise ValueError( + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." + ) + if adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." + ) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + # See https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # Disambiguates module names sharing a common prefix + # (e.g. `proj_out.weight` vs `blocks.transformer.proj_out.weight`). + rank[f"^{key}"] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) + + # Mutating the model would otherwise fight with active offload hooks; the + # context manager strips them for the duration and restores them on exit. + peft_kwargs = {"low_cpu_mem_usage": low_cpu_mem_usage} + with _offloading_disabled(self): + if hotswap: + self._hotswap_adapter(state_dict, lora_config, adapter_name) + incompatible_keys = None + + else: + incompatible_keys = self._inject_adapter(state_dict, lora_config, adapter_name, peft_kwargs) + self._maybe_apply_deferred_hotswap_prep(lora_config) + + _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) + + def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs): + """Inject a new adapter into ``self`` and load its weights. + + Returns the ``incompatible_keys`` reported by ``set_peft_model_state_dict``. + On failure, rolls back any partial peft_config / adapter modules so the model + is left in its prior state. + """ + try: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + self._hf_peft_config_loaded = True + + return incompatible_keys + + except Exception as e: + self._rollback_adapter(adapter_name, e) + raise + + def _maybe_apply_deferred_hotswap_prep(self, lora_config): + """If ``enable_lora_hotswap`` was called before the first adapter was loaded, + we deferred ``prepare_model_for_compiled_hotswap`` until LoRA layers existed. + Apply it now (after a successful inject) and clear the stash so it only fires once.""" + if self._lora_hotswap_kwargs is None: + return + prepare_model_for_compiled_hotswap(self, config=lora_config, **self._lora_hotswap_kwargs) + self._lora_hotswap_kwargs = None + + def _hotswap_adapter(self, state_dict, lora_config, adapter_name): + """Replace the weights of an already-loaded adapter in-place. + + ``hotswap_adapter_from_state_dict`` raises on incompatible keys; reaching the + end of this function means the swap succeeded. + """ + state_dict = _scope_state_dict_to_adapter(state_dict, adapter_name) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + try: + hotswap_adapter_from_state_dict( + model=self, state_dict=state_dict, adapter_name=adapter_name, config=lora_config + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") + self._rollback_adapter(adapter_name, e) + raise + + def _rollback_adapter(self, adapter_name, error): + """Remove ``adapter_name`` from ``self`` so failed loads don't leave partial state.""" + if hasattr(self, "peft_config"): + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + for active_adapter in module.active_adapters: + if adapter_name in active_adapter: + module.delete_adapter(adapter_name) + self.peft_config.pop(adapter_name, None) + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{error}") + + @_requires_peft + def save_adapter( + self, + save_directory, + adapter_name: str = "default", + upcast_before_saving: bool = False, + safe_serialization: bool = True, + weight_name: Optional[str] = None, + ): + """Save the LoRA parameters corresponding to the underlying model. + + Args: + save_directory: Directory to save LoRA parameters to. Created if missing. + adapter_name: Name of the adapter to serialize. Useful when the model has + multiple adapters loaded. + upcast_before_saving: Whether to cast the underlying model to ``torch.float32`` + before serialization. + safe_serialization: Save with ``safetensors`` (default) or pickled torch save. + weight_name: Override the default filename. + """ + if adapter_name not in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter name {adapter_name} not found in the model.") + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + + state_dict = get_peft_model_state_dict( + self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name + ) + + os.makedirs(save_directory, exist_ok=True) + weight_name = weight_name or (LORA_WEIGHT_NAME_SAFE if safe_serialization else LORA_WEIGHT_NAME) + save_path = Path(save_directory, weight_name).as_posix() + + if safe_serialization: + metadata = { + "format": "pt", + LORA_ADAPTER_METADATA_KEY: _serialize_lora_adapter_metadata(self.peft_config[adapter_name]), + } + safetensors.torch.save_file(state_dict, save_path, metadata=metadata) + else: + torch.save(state_dict, save_path) + + logger.info(f"Model weights saved in {save_path}") + + def save_lora_adapter(self, *args, **kwargs): + """Deprecated alias for :meth:`save_adapter`.""" + deprecate( + "save_lora_adapter", + "1.0.0", + "`save_lora_adapter` is deprecated; use `save_adapter` instead.", + ) + return self.save_adapter(*args, **kwargs) + + @_requires_peft + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.). + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__] + weights = scale_expansion_fn(self, weights) + + set_weights_and_activate_adapters(self, adapter_names, weights) + + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + """Deprecated alias for :meth:`load_adapter` with a ``PeftConfig``.""" + deprecate( + "add_adapter", + "1.0.0", + "`add_adapter` is deprecated; use `load_adapter(adapter_config)` instead.", + ) + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + return self.load_adapter(adapter_config, adapter_name=adapter_name) + + def load_lora_adapter( + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + ): + """Deprecated alias for :meth:`load_adapter`.""" + deprecate( + "load_lora_adapter", + "1.0.0", + "`load_lora_adapter` is deprecated; use `load_adapter` instead.", + ) + return self.load_adapter(pretrained_model_name_or_path_or_dict, prefix=prefix, hotswap=hotswap, **kwargs) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """Deprecated alias for :meth:`set_adapters`. + + Note: ``set_adapters`` resets the per-adapter scale to ``1.0`` when no weights + are passed; the original ``set_adapter`` left the previous scale untouched. + """ + deprecate( + "set_adapter", + "1.0.0", + "`set_adapter` is deprecated; use `set_adapters` instead. " + "Note that `set_adapters(name)` resets the per-adapter scale to 1.0; " + "pass `weights=...` to control it explicitly.", + ) + return self.set_adapters(adapter_name) + + @_requires_peft + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(enabled=False) + + @_requires_peft + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of + adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(enabled=True) + + @_requires_peft + def active_adapters(self) -> List[str]: + """Return the sorted union of active adapter names across all PEFT layers.""" + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + active = set() + for module in self.modules(): + if not isinstance(module, BaseTunerLayer): + continue + names = module.active_adapter + active.update([names] if isinstance(names, str) else names) + return sorted(active) + + @_requires_peft + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """Merge LoRA adapter weights into the base model in-place.""" + self.apply( + partial(_fuse_lora_apply, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + ) + + @_requires_peft + def unfuse_lora(self): + """Reverse of :meth:`fuse_lora` — unmerge LoRA weights from the base model.""" + self.apply(_unfuse_lora_apply) + + @_requires_peft + def delete_adapters(self, adapter_names: Optional[Union[List[str], str]] = None): + """Remove adapter(s) from the model. + + Pass specific names to delete those adapters only — the PEFT wrapper layers + (``lora_A`` / ``lora_B`` modules) stay in place, so a subsequent + :meth:`load_adapter` call can reuse them without re-injecting. + + Pass ``None`` (the default) to remove every adapter *and* strip the wrapper + layers themselves, returning the model to its pre-LoRA state. + """ + if adapter_names is None: + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config + + self._hf_peft_config_loaded = False + + else: + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + delete_adapter_layers(self, adapter_name) + if hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) + + # In-place mutation invalidates group-offload tensor refs; refresh them. + _maybe_remove_and_reapply_group_offloading(self) + + def unload_lora(self): + """Deprecated alias for :meth:`delete_adapters` (with no arguments).""" + deprecate( + "unload_lora", + "1.0.0", + "`unload_lora` is deprecated; use `delete_adapters()` (no args) for the same teardown.", + ) + return self.delete_adapters() + + def enable_lora_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`, *optional*, defaults to `128`): + The highest rank among all the adapters that will be loaded. + + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + if check_compiled not in ("error", "warn", "ignore"): + raise ValueError( + f"check_compiled should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + if getattr(self, "peft_config", {}): + if check_compiled == "error": + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + if check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + self._lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py new file mode 100644 index 000000000000..fa7d5b001d09 --- /dev/null +++ b/src/diffusers/loaders/weight_mapping.py @@ -0,0 +1,133 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reusable infrastructure for converting model checkpoints between original +and diffusers naming conventions. + +A model defines its mapping by subclassing :class:`WeightMappingMixin` and +populating the class attributes (`_rename_patterns`, `_checkpoint_keys`, etc.) +plus assigning ``_map_to_diffusers`` / ``_map_from_diffusers`` callables. + +The :meth:`apply_transforms` helper drives the forward direction from a single +declarative table — see ``models/transformers/flux/weight_mapping.py`` for an +example. +""" + + +class WeightMappingMixin: + """ + Base mixin providing utilities for checkpoint weight mapping and conversion. + + Subclasses should define: + - _checkpoint_key_prefixes: List of key prefixes to strip (e.g., ["model.diffusion_model."]) + - _checkpoint_keys: Set of keys to identify compatible checkpoints + - _rename_patterns: Dict of substring replacements for key renaming + - _model_variants: Dict mapping variant names to config repos + - _map_to_diffusers: Function to convert original format to diffusers format + - _map_from_diffusers: Function to convert diffusers format to original format + """ + + _checkpoint_key_prefixes: list[str] = [] + _checkpoint_keys: set[str] = set() + _rename_patterns: dict[str, str] = {} + _model_variants: dict[str, str] = {} + _map_to_diffusers = None + _map_from_diffusers = None + + @staticmethod + def _rename_key(key: str, patterns: dict[str, str]) -> str: + """Apply rename patterns to a key.""" + for old, new in patterns.items(): + key = key.replace(old, new) + return key + + @classmethod + def _normalize_checkpoint_keys(cls, state_dict: dict) -> dict: + """Strip known prefixes from state_dict keys.""" + if not cls._checkpoint_key_prefixes: + return state_dict + + result = {} + for key, value in state_dict.items(): + new_key = key + for prefix in cls._checkpoint_key_prefixes: + if key.startswith(prefix): + new_key = key[len(prefix) :] + break + result[new_key] = value + return result + + @classmethod + def _is_original_format(cls, state_dict: dict) -> bool: + """Check if state_dict is in original (non-diffusers) format.""" + if not cls._checkpoint_keys: + return False + keys = set(state_dict.keys()) + return bool(cls._checkpoint_keys & keys) + + @classmethod + def _detect_model_variant(cls, state_dict: dict) -> str | None: + """Detect which model variant a state_dict belongs to. Subclasses should override.""" + raise NotImplementedError(f"{cls.__name__} does not implement _detect_model_variant") + + @classmethod + def _get_model_config(cls, state_dict: dict) -> str: + """Get the default config repo for the detected variant.""" + variant = cls._detect_model_variant(state_dict) + if variant is None: + raise ValueError(f"Could not detect model variant from state_dict. Expected keys: {cls._checkpoint_keys}") + return cls._model_variants[variant] + + @staticmethod + def apply_transforms(state_dict, transforms, rename_patterns, **ctx): + """Drive a forward state-dict conversion from a list of (source, targets, fn) entries. + + Each entry is a tuple ``(source, targets, forward_fn, reverse_fn)``: + - ``source``: substring matched against each key (with surrounding dots, + e.g. ``".img_attn.qkv."``); the first matching entry wins. + - ``targets``: list of substrings substituted for ``source`` to build the + output keys. ``len(targets)`` is the fan-out (1 for a unary transform, + >1 for a split). + - ``forward_fn(value, **ctx) -> list[tensor]`` returns one tensor per + target. (``reverse_fn`` is reserved for a future + ``apply_reverse_transforms`` driver.) + + Keys that match no transform get their dots renamed via ``rename_patterns``. + """ + out = {} + for key, value in state_dict.items(): + for source, targets, forward_fn, _ in transforms: + if source in key: + tensors = forward_fn(value, **ctx) + for target, tensor in zip(targets, tensors): + new_key = WeightMappingMixin._rename_key(key.replace(source, target), rename_patterns) + out[new_key] = tensor + break + else: + out[WeightMappingMixin._rename_key(key, rename_patterns)] = value + return out + + @classmethod + def map_to_diffusers(cls, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from original format to diffusers format.""" + if cls._map_to_diffusers is None: + raise NotImplementedError(f"{cls.__name__} does not define _map_to_diffusers") + return cls._map_to_diffusers(state_dict, **kwargs) + + @classmethod + def map_from_diffusers(cls, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from diffusers format to original format.""" + if cls._map_from_diffusers is None: + raise NotImplementedError(f"{cls.__name__} does not define _map_from_diffusers") + return cls._map_from_diffusers(state_dict, **kwargs) diff --git a/src/diffusers/models/transformers/flux/__init__.py b/src/diffusers/models/transformers/flux/__init__.py new file mode 100644 index 000000000000..2996567d82d1 --- /dev/null +++ b/src/diffusers/models/transformers/flux/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lora import FluxTransformerLoRAMixin +from .model import ( + FluxAttention, + FluxAttnProcessor, + FluxIPAdapterAttnProcessor, + FluxPosEmbed, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) +from .weight_mapping import FluxTransformerWeightMappingMixin diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py new file mode 100644 index 000000000000..0a79f77377ff --- /dev/null +++ b/src/diffusers/models/transformers/flux/lora.py @@ -0,0 +1,476 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flux LoRA conversion. + +Pipeline: + 1. Detect the source format (kohya / xlabs / bfl / kontext) via + ``FluxLoRAMappingMixin._detect_lora_format``. + 2. Run the format-specific *normalizer* that rewrites keys into the + canonical "BFL-style" form: original Flux module names with + ``.lora_A`` / ``.lora_B`` suffixes. Tensor-level transforms unique + to a format (e.g., Kohya alpha scaling) happen here. + 3. Run the shared *converter* that maps canonical keys to diffusers + names by reusing the rename / QKV-split / special-key tables in + ``weight_mapping.py``, applying the LoRA-specific QKV semantics + (lora_A weight replicates, everything else chunks). + +A normalizer may also emit keys that bypass step 3 (returned as a second +"extras" dict) — used for keys that don't fit the canonical intermediate +(e.g., text-encoder LoRA keys, XLabs single-block QKV without MLP). +""" + +import re + +import torch + +from ....loaders.lora_base import LoRAMappingMixin +from ....utils import logging, state_dict_all_zero +from .weight_mapping import ( + FLUX_QKV_SPLIT_PATTERNS, + FLUX_QKVMLP_SPLIT_PATTERN, + FLUX_QKVMLP_TARGETS, + FLUX_RENAME_PATTERNS, + FLUX_SPECIAL_KEYS, +) + + +logger = logging.get_logger(__name__) + + +# ============================================================================ +# Stage 3: shared canonical -> diffusers converter +# ============================================================================ +# Canonical keys are BFL-style: original Flux module names + .lora_A/.lora_B +# suffixes. The shared converter handles three cases — pure renames, QKV +# splits, special transforms — by reusing the tables from weight_mapping. + +_LORA_SUFFIXES = (".lora_A.weight", ".lora_A.bias", ".lora_B.weight", ".lora_B.bias") + +# Module-path versions (boundary dots stripped) of the weight-mapping tables. +_QKV_PATTERNS = {p.strip("."): [t.strip(".") for t in ts] for p, ts in FLUX_QKV_SPLIT_PATTERNS.items()} +_QKVMLP_PATTERN = FLUX_QKVMLP_SPLIT_PATTERN.strip(".") +_QKVMLP_TARGETS = [t.strip(".") for t in FLUX_QKVMLP_TARGETS] +_SPECIAL_MODULES = {} +for _full_src, _spec in FLUX_SPECIAL_KEYS.items(): + for _tail in (".weight", ".bias"): + if _full_src.endswith(_tail) and _spec["target"].endswith(_tail): + _SPECIAL_MODULES.setdefault(_full_src[: -len(_tail)], (_spec["target"][: -len(_tail)], _spec["transform"])) + break + + +def _split_lora_suffix(key): + for suffix in _LORA_SUFFIXES: + if key.endswith(suffix): + return key[: -len(suffix)], suffix + return key, "" + + +def _apply_renames(s, patterns): + for old, new in patterns.items(): + s = s.replace(old, new) + return s + + +def _rename_module(module_path): + # FLUX_RENAME_PATTERNS keys often have a trailing "."; pad-and-strip so + # bare module paths like "final_layer.linear" still match patterns like + # "final_layer.linear.". + out = _apply_renames(module_path + ".", FLUX_RENAME_PATTERNS) + return out[:-1] if out.endswith(".") else out + + +def _map_lora_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): + out = {} + qkvmlp_dims = (inner_dim, inner_dim, inner_dim, int(inner_dim * mlp_ratio)) + + for key, value in state_dict.items(): + module_path, suffix = _split_lora_suffix(key) + + if not suffix: + out[f"transformer.{_apply_renames(key, FLUX_RENAME_PATTERNS)}"] = value + continue + + qkv = next(((p, ts) for p, ts in _QKV_PATTERNS.items() if p in module_path), None) + if qkv is not None: + pattern, targets = qkv + chunks = ( + [value] * len(targets) if suffix == ".lora_A.weight" else list(torch.chunk(value, len(targets), dim=0)) + ) + for target, chunk in zip(targets, chunks): + new_module = _rename_module(module_path.replace(pattern, target)) + out[f"transformer.{new_module}{suffix}"] = chunk + continue + + if _QKVMLP_PATTERN in module_path and "single_blocks." in module_path: + chunks = ( + [value] * len(_QKVMLP_TARGETS) + if suffix == ".lora_A.weight" + else list(torch.split(value, qkvmlp_dims, dim=0)) + ) + for target, chunk in zip(_QKVMLP_TARGETS, chunks): + new_module = _rename_module(module_path.replace(_QKVMLP_PATTERN, target)) + out[f"transformer.{new_module}{suffix}"] = chunk + continue + + if module_path in _SPECIAL_MODULES: + target_module, transform = _SPECIAL_MODULES[module_path] + out[f"transformer.{target_module}{suffix}"] = transform(value) + continue + + out[f"transformer.{_rename_module(module_path)}{suffix}"] = value + + return out + + +# ============================================================================ +# Stage 2a: BFL normalizer (identity) +# ============================================================================ + + +def _normalize_bfl(state_dict): + return dict(state_dict), {} + + +# ============================================================================ +# Stage 2b: fal Kontext normalizer (strip "base_model.model." prefix) +# ============================================================================ + + +def _normalize_kontext(state_dict): + prefix = "base_model.model." + canonical = {(k[len(prefix) :] if k.startswith(prefix) else k): v for k, v in state_dict.items()} + return canonical, {} + + +# ============================================================================ +# Stage 2c: XLabs normalizer +# ============================================================================ +# XLabs key shape: [diffusion_model.]{double|single}_blocks.{i}.processor.{X}.{down|up}.weight +# Double-block X ∈ {qkv_lora1, qkv_lora2, proj_lora1, proj_lora2}. +# Single-block X ∈ {qkv_lora, proj_lora}. Single-block lacks an MLP LoRA, so +# its qkv keys can't be expressed as canonical "linear1" (which is QKV+MLP); +# we emit pre-converted diffusers keys for single blocks instead. + +_XLABS_DOUBLE_RENAMES = { + ".processor.proj_lora1.": ".img_attn.proj.", + ".processor.proj_lora2.": ".txt_attn.proj.", + ".processor.qkv_lora1.": ".img_attn.qkv.", + ".processor.qkv_lora2.": ".txt_attn.qkv.", +} +_XLABS_SINGLE_QKV_TARGETS = ["attn.to_q", "attn.to_k", "attn.to_v"] + + +def _normalize_xlabs(state_dict): + canonical = {} + extras = {} + for key, value in state_dict.items(): + k = key + if k.startswith("diffusion_model."): + k = k[len("diffusion_model.") :] + + if "single_blocks." in k: + block = re.search(r"single_blocks\.(\d+)", k).group(1) + base = f"transformer.single_transformer_blocks.{block}" + suffix = ".lora_A.weight" if k.endswith(".lora_A.weight") else ".lora_B.weight" + if "proj_lora" in k: + extras[f"{base}.proj_out{suffix}"] = value + elif "qkv_lora" in k: + if suffix == ".lora_A.weight": + for t in _XLABS_SINGLE_QKV_TARGETS: + extras[f"{base}.{t}{suffix}"] = value + else: + for t, chunk in zip(_XLABS_SINGLE_QKV_TARGETS, torch.chunk(value, 3, dim=0)): + extras[f"{base}.{t}{suffix}"] = chunk + continue + + # Double block: rename to canonical BFL-style; shared converter handles the QKV split. + for old, new in _XLABS_DOUBLE_RENAMES.items(): + k = k.replace(old, new) + canonical[k] = value + + return canonical, extras + + +# ============================================================================ +# Stage 2d: Kohya normalizer (sd-scripts and mixture variants) +# ============================================================================ +# Kohya keys collapse all dots into underscores in the module path, then append +# .lora_down/.lora_up/.alpha. We invert this with explicit per-suffix tables +# (the original-name underscore <-> dot mapping isn't recoverable by rule), then +# apply alpha-driven scaling so canonical tensors are pre-scaled. + +_KOHYA_DOUBLE_SUFFIXES = { + "img_attn_proj": "img_attn.proj", + "img_attn_qkv": "img_attn.qkv", + "img_mlp_0": "img_mlp.0", + "img_mlp_2": "img_mlp.2", + "img_mod_lin": "img_mod.lin", + "txt_attn_proj": "txt_attn.proj", + "txt_attn_qkv": "txt_attn.qkv", + "txt_mlp_0": "txt_mlp.0", + "txt_mlp_2": "txt_mlp.2", + "txt_mod_lin": "txt_mod.lin", +} +_KOHYA_SINGLE_SUFFIXES = { + "linear1": "linear1", + "linear2": "linear2", + "modulation_lin": "modulation.lin", +} +_KOHYA_GLOBAL_SUFFIXES = { + "guidance_in_in_layer": "guidance_in.in_layer", + "guidance_in_out_layer": "guidance_in.out_layer", + "img_in": "img_in", + "txt_in": "txt_in", + "time_in_in_layer": "time_in.in_layer", + "time_in_out_layer": "time_in.out_layer", + "vector_in_in_layer": "vector_in.in_layer", + "vector_in_out_layer": "vector_in.out_layer", + "final_layer_linear": "final_layer.linear", + "final_layer_adaLN_modulation_1": "final_layer.adaLN_modulation.1", +} + + +def _kohya_scale(alpha, rank): + """Split alpha/rank into (down, up) factors so down*up == alpha/rank but stays bounded.""" + scale = alpha / rank + down, up = scale, 1.0 + while down * 2 < up: + down *= 2 + up /= 2 + return down, up + + +def _custom_replace(key, substrings): + """Replace dots with underscores in `key` up to the first occurrence of any substring.""" + pattern = "(" + "|".join(re.escape(s) for s in substrings) + ")" + match = re.search(pattern, key) + if not match: + return key.replace(".", "_") + boundary = match.start() - 1 if match.start() > 0 and key[match.start() - 1] == "." else match.start() + return key[:boundary].replace(".", "_") + key[boundary:] + + +def _kohya_pre_filter(state_dict): + """Drop Kohya keys we don't support (with logging), then normalize key prefixes.""" + state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} + + drop_specs = [ + (lambda k: "position_embedding" in k, "position_embedding", "position_embedding"), + (lambda k: ".diff_b" in k and k.startswith("lora_unet_"), ".diff_b", "diff_b"), + (lambda k: ".norm" in k and ".diff" in k, ".diff", "diff"), + ] + for predicate, marker, label in drop_specs: + if not any(predicate(k) for k in state_dict): + continue + if state_dict_all_zero(state_dict, marker): + logger.info( + f"The `{label}` LoRA params are all zeros which make them ineffective. " + "So, we will purge them out of the current state dict to make loading possible." + ) + else: + logger.info( + f"`{label}` keys found in the state dict are currently unsupported and will be filtered out. " + "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if not predicate(k)} + + # Some keys come with dots in the prefix; collapse them up to lora_A/lora_B/alpha. + limit = ["lora_A", "lora_B"] + if any("alpha" in k for k in state_dict): + limit.append("alpha") + state_dict = {_custom_replace(k, limit): v for k, v in state_dict.items() if k.startswith("lora_unet_")} + + return state_dict + + +def _kohya_canonical_path(stub): + """Map a Kohya stub like 'double_blocks_0_img_attn_qkv' to BFL-style 'double_blocks.0.img_attn.qkv'.""" + m = re.match(r"double_blocks_(\d+)_(.+)$", stub) + if m: + i, suffix = m.group(1), m.group(2) + bfl = _KOHYA_DOUBLE_SUFFIXES.get(suffix) + return f"double_blocks.{i}.{bfl}" if bfl else None + + m = re.match(r"single_blocks_(\d+)_(.+)$", stub) + if m: + i, suffix = m.group(1), m.group(2) + bfl = _KOHYA_SINGLE_SUFFIXES.get(suffix) + return f"single_blocks.{i}.{bfl}" if bfl else None + + return _KOHYA_GLOBAL_SUFFIXES.get(stub) + + +def _normalize_kohya(state_dict): + state_dict = _kohya_pre_filter(state_dict) + + # Mixture variant has its own prefix (lora_transformer_*); dispatch separately. + has_mixture = any( + k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict + ) + if has_mixture: + return {}, _convert_mixture(state_dict) + + # Group keys per Kohya module (lora_unet_) so we can apply alpha + # scaling, then rewrite to canonical names. + groups = {} # stub -> {"lora_A": full_key, "lora_B": ..., "alpha": ...} + for key in list(state_dict): + if not key.startswith("lora_unet_"): + continue + for kind in ("lora_A.weight", "lora_B.weight", "alpha"): + tail = "." + kind + if key.endswith(tail): + stub = key[len("lora_unet_") : -len(tail)] + groups.setdefault(stub, {})[kind.split(".")[0]] = key + break + + canonical = {} + + for stub, group in groups.items(): + down_key, up_key = group.get("lora_A"), group.get("lora_B") + if down_key is None or up_key is None: + continue + rank = state_dict[down_key].shape[0] + alpha = state_dict.pop(group["alpha"]).item() if "alpha" in group else float(rank) + d_scale, u_scale = _kohya_scale(alpha, rank) + down = state_dict.pop(down_key) * d_scale + up = state_dict.pop(up_key) * u_scale + + bfl = _kohya_canonical_path(stub) + if bfl is None: + logger.warning(f"Unsupported Kohya key: lora_unet_{stub}") + continue + canonical[f"{bfl}.lora_A.weight"] = down + canonical[f"{bfl}.lora_B.weight"] = up + + if state_dict: + logger.warning(f"Unsupported keys after Kohya normalization: {list(state_dict.keys())}") + + return canonical, {} + + +# ---------------------------------------------------------------------------- +# Mixture variant (Kohya-trained but using lora_transformer_* keys) +# ---------------------------------------------------------------------------- + + +def _convert_mixture(state_dict): + """Convert Kohya mixture-format LoRA directly to diffusers keys.""" + new_state_dict = {} + + def emit(orig, diffusers_key): + down = state_dict.pop(f"{orig}.lora_A.weight") + up = state_dict.pop(f"{orig}.lora_B.weight") + alpha = state_dict.pop(f"{orig}.alpha") + rank = down.shape[0] + d_scale, u_scale = _kohya_scale(alpha, rank) + new_state_dict[f"{diffusers_key}.lora_A.weight"] = down * d_scale + new_state_dict[f"{diffusers_key}.lora_B.weight"] = up * u_scale + + unique = { + k.replace(".lora_A.weight", "").replace(".lora_B.weight", "").replace(".alpha", "") + for k in state_dict + if k.startswith("lora_transformer_") + } + + for k in unique: + if k.startswith("lora_transformer_single_transformer_blocks_"): + i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"single_transformer_blocks.{i}" + elif k.startswith("lora_transformer_transformer_blocks_"): + i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"transformer_blocks.{i}" + elif k.startswith("lora_transformer_context_embedder"): + diffusers_key = "context_embedder" + elif k.startswith("lora_transformer_norm_out_linear"): + diffusers_key = "norm_out.linear" + elif k.startswith("lora_transformer_proj_out"): + diffusers_key = "proj_out" + elif k.startswith("lora_transformer_x_embedder"): + diffusers_key = "x_embedder" + elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.text_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" + else: + raise NotImplementedError(f"Handling for key ({k}) is not implemented.") + + if "attn_" in k: + tail = k.split("attn_")[-1] + if "_to_out_0" in k: + diffusers_key += ".attn.to_out.0" + elif "_to_add_out" in k: + diffusers_key += ".attn.to_add_out" + elif any(qkv in k for qkv in ("to_q", "to_k", "to_v", "add_q_proj", "add_k_proj", "add_v_proj")): + diffusers_key += f".attn.{tail}" + + emit(k, diffusers_key) + + leftover = [k for k in state_dict if not k.startswith("lora_unet_")] + if leftover: + logger.warning(f"Unsupported mixture keys ignored: {leftover}") + + return {f"transformer.{k}": v for k, v in new_state_dict.items()} + + +# ============================================================================ +# Top-level dispatch +# ============================================================================ + + +_NORMALIZERS = { + "bfl": _normalize_bfl, + "kontext": _normalize_kontext, + "xlabs": _normalize_xlabs, + "kohya": _normalize_kohya, +} + + +def map_lora_to_diffusers(state_dict, **kwargs): + """Convert a Flux LoRA state_dict from any supported format to diffusers naming. + + Suffix normalization (lora_down/up -> lora_A/B) is run by + ``LoRAMappingMixin.map_lora_to_diffusers`` before this is dispatched. + """ + # Already-converted (peft) state dicts: keep only the transformer.* keys. + if any(k.startswith("transformer.") for k in state_dict): + return {k: v for k, v in state_dict.items() if k.startswith("transformer.")} + + fmt = FluxTransformerLoRAMixin._detect_lora_format(state_dict) + if fmt is None or fmt not in _NORMALIZERS: + raise ValueError( + f"Unable to determine format of LoRA weights. Supported formats are: {FluxTransformerLoRAMixin._lora_format_keys.keys()}" + ) + + canonical, extras = _NORMALIZERS[fmt](state_dict) + converted = _map_lora_to_diffusers(canonical) if canonical else {} + return {**converted, **extras} + + +class FluxTransformerLoRAMixin(LoRAMappingMixin): + """Mixin providing Flux-specific LoRA format detection and conversion.""" + + _lora_format_keys: dict[str, set[str]] = { + "kohya": {"lora_unet_double_blocks_", "lora_unet_single_blocks_"}, + "xlabs": {".processor.qkv_lora", ".processor.proj_lora"}, + "bfl": {"time_in.in_layer.lora_A", "double_blocks.0.img_mod.lin.lora_A"}, + "kontext": {"base_model.model.double_blocks"}, + } + + _map_lora_to_diffusers = staticmethod(map_lora_to_diffusers) diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py new file mode 100644 index 000000000000..5de2682bee70 --- /dev/null +++ b/src/diffusers/models/transformers/flux/model.py @@ -0,0 +1,803 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....configuration_utils import ConfigMixin, register_to_config +from ....loaders import FluxTransformer2DLoadersMixin, PeftAdapterMixin +from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import maybe_allow_in_graph +from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ...attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ...attention_dispatch import dispatch_attention_fn +from ...cache_utils import CacheMixin +from ...embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ...modeling_outputs import Transformer2DModelOutput +from ...modeling_utils import ModelMetadata, ModelMixin +from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from .lora import FluxTransformerLoRAMixin +from .weight_mapping import FluxTransformerWeightMappingMixin + + +logger = logging.get_logger(__name__) + + +FLUX_METADATA = ModelMetadata( + supports_gradient_checkpointing=True, + no_split_modules=["FluxTransformerBlock", "FluxSingleTransformerBlock"], + skip_layerwise_casting_patterns=("pos_embed", "norm"), + repeated_blocks=["FluxTransformerBlock", "FluxSingleTransformerBlock"], + cp_plan={ + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + }, +) + + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class FluxAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxIPAdapterAttnProcessor(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + _attention_backend = None + _parallel_config = None + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + ip_query = query + + if encoder_hidden_states is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_attn_output = torch.zeros_like(hidden_states) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + +class FluxAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = FluxAttnProcessor + _available_processors = [ + FluxAttnProcessor, + FluxIPAdapterAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=FluxAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=FluxAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class FluxTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FluxTransformerWeightMappingMixin, + FluxTransformerLoRAMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _model_metadata = FLUX_METADATA + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py new file mode 100644 index 000000000000..90a3479ba8c3 --- /dev/null +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -0,0 +1,315 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ....loaders.weight_mapping import WeightMappingMixin + + +def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor: + """Swap scale and shift in AdaLayerNorm weights (original uses shift,scale; diffusers uses scale,shift).""" + shift, scale = weight.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + +# Pattern-based key renaming (substring replacements applied in order) +FLUX_RENAME_PATTERNS: dict[str, str] = { + # Global key renames + "time_in.in_layer": "time_text_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_text_embed.timestep_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "guidance_in.in_layer": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_text_embed.guidance_embedder.linear_2", + "txt_in.": "context_embedder.", + "img_in.": "x_embedder.", + "final_layer.linear.": "proj_out.", + # Double block patterns + "double_blocks.": "transformer_blocks.", + ".img_mod.lin.": ".norm1.linear.", + ".txt_mod.lin.": ".norm1_context.linear.", + ".img_attn.norm.query_norm.scale": ".attn.norm_q.weight", + ".img_attn.norm.key_norm.scale": ".attn.norm_k.weight", + ".txt_attn.norm.query_norm.scale": ".attn.norm_added_q.weight", + ".txt_attn.norm.key_norm.scale": ".attn.norm_added_k.weight", + ".img_mlp.0.": ".ff.net.0.proj.", + ".img_mlp.2.": ".ff.net.2.", + ".txt_mlp.0.": ".ff_context.net.0.proj.", + ".txt_mlp.2.": ".ff_context.net.2.", + ".img_attn.proj.": ".attn.to_out.0.", + ".txt_attn.proj.": ".attn.to_add_out.", + # Single block patterns + "single_blocks.": "single_transformer_blocks.", + ".modulation.lin.": ".norm.linear.", + ".norm.query_norm.scale": ".attn.norm_q.weight", + ".norm.key_norm.scale": ".attn.norm_k.weight", + ".linear2.": ".proj_out.", +} + + +# -------------------------------------------------------------------------- +# Per-key transforms (split + special), unified. +# -------------------------------------------------------------------------- +# Single source of truth. Each entry is +# (source_substring, [target_substrings], forward_fn, reverse_fn) +# - source/targets include surrounding dots so they only match at module +# boundaries (e.g. ".img_attn.qkv." matches both "X.img_attn.qkv.weight" +# and "X.img_attn.qkv.bias" with one entry). +# - len(targets) == 1 -> a unary transform (e.g. AdaLN scale/shift swap). +# - len(targets) > 1 -> a split transform (forward chunks the tensor). +# - forward_fn(tensor, **ctx) -> list[tensor] of length len(targets). +# - reverse_fn(list[tensor], **ctx) -> tensor. +def _swap_to_list(v, **_): + return [swap_scale_shift(v)] + + +def _list_to_swap(vs, **_): + return swap_scale_shift(vs[0]) + + +def _make_chunk(n): + return lambda v, **_: torch.chunk(v, n, dim=0) + + +def _qkvmlp_split(v, inner_dim=3072, **_): + return torch.split(v, [inner_dim, inner_dim, inner_dim, inner_dim * 4], dim=0) + + +def _cat0(vs, **_): + return torch.cat(vs, dim=0) + + +FLUX_TRANSFORMS = [ + ("final_layer.adaLN_modulation.1.", ["norm_out.linear."], _swap_to_list, _list_to_swap), + (".img_attn.qkv.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v."], _make_chunk(3), _cat0), + ( + ".txt_attn.qkv.", + [".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj."], + _make_chunk(3), + _cat0, + ), + (".linear1.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."], _qkvmlp_split, _cat0), +] + + +# Backward-compat tables derived from FLUX_TRANSFORMS so the existing +# map_from_diffusers code (and any external readers, including lora.py) +# keep working without changes. +def _wrap_unary(fwd_fn): + return lambda v: fwd_fn(v)[0] + + +FLUX_SPECIAL_KEYS: dict[str, dict] = {} +FLUX_QKV_SPLIT_PATTERNS: dict[str, list[str]] = {} +FLUX_QKVMLP_SPLIT_PATTERN: str = "" +FLUX_QKVMLP_TARGETS: list[str] = [] +for _src, _tgts, _fwd, _ in FLUX_TRANSFORMS: + if len(_tgts) == 1: + for _suffix in ("weight", "bias"): + FLUX_SPECIAL_KEYS[_src + _suffix] = { + "target": _tgts[0] + _suffix, + "transform": _wrap_unary(_fwd), + } + elif _src == ".linear1.": + FLUX_QKVMLP_SPLIT_PATTERN = _src + FLUX_QKVMLP_TARGETS = list(_tgts) + else: + FLUX_QKV_SPLIT_PATTERNS[_src] = list(_tgts) + + +def _get_inner_dim(state_dict: dict[str, torch.Tensor]) -> int: + """Infer inner_dim from state_dict weights.""" + for key in state_dict: + if "single_blocks." in key and ".linear1." in key and key.endswith(".bias"): + # linear1 contains Q, K, V, MLP fused - Q/K/V each have inner_dim + # Total size = 3 * inner_dim + mlp_hidden_dim = 3 * inner_dim + 4 * inner_dim = 7 * inner_dim + total = state_dict[key].shape[0] + return total // 7 + return 3072 # Default + + +def map_to_diffusers( + state_dict: dict[str, torch.Tensor], + **kwargs, +) -> dict[str, torch.Tensor]: + """Convert a Flux transformer state_dict from original format to diffusers format.""" + inner_dim = _get_inner_dim(state_dict) + return WeightMappingMixin.apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) + + +# Build reverse patterns for map_from_diffusers +FLUX_RENAME_PATTERNS_REVERSE: dict[str, str] = {v: k for k, v in FLUX_RENAME_PATTERNS.items()} +FLUX_SPECIAL_KEYS_REVERSE: dict[str, dict] = { + v["target"]: {"target": k, "transform": v["transform"]} for k, v in FLUX_SPECIAL_KEYS.items() +} +FLUX_QKV_SPLIT_PATTERNS_REVERSE: dict[str, str] = { + target: pattern for pattern, targets in FLUX_QKV_SPLIT_PATTERNS.items() for target in targets +} + + +def map_from_diffusers( + state_dict: dict[str, torch.Tensor], + **kwargs, +) -> dict[str, torch.Tensor]: + """ + Convert a Flux transformer state_dict from diffusers format to original format. + + Args: + state_dict: State dict in diffusers format + + Returns: + State dict in original Flux format + """ + converted_state_dict = {} + keys = list(state_dict.keys()) + + # Group keys for QKV concatenation + qkv_groups: dict[str, list[tuple[str, torch.Tensor]]] = {} + qkvmlp_groups: dict[str, list[tuple[str, torch.Tensor]]] = {} + + for key in keys: + value = state_dict[key] + + # Handle special keys with transforms + if key in FLUX_SPECIAL_KEYS_REVERSE: + spec = FLUX_SPECIAL_KEYS_REVERSE[key] + converted_state_dict[spec["target"]] = spec["transform"](value) + continue + + # Check if this is part of a QKV group (double blocks) + qkv_pattern = None + for target, pattern in FLUX_QKV_SPLIT_PATTERNS_REVERSE.items(): + if target in key: + qkv_pattern = pattern + break + + if qkv_pattern and "transformer_blocks." in key: + # Build the original key by replacing target with pattern + base_key = key + for target in FLUX_QKV_SPLIT_PATTERNS_REVERSE: + if target in base_key: + base_key = base_key.replace(target, qkv_pattern) + break + orig_key = WeightMappingMixin._rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + + if orig_key not in qkv_groups: + qkv_groups[orig_key] = [] + qkv_groups[orig_key].append((key, value)) + continue + + # Check if this is part of a QKV+MLP group (single blocks) + is_qkvmlp = False + for target in FLUX_QKVMLP_TARGETS: + if target in key and "single_transformer_blocks." in key: + base_key = key.replace(target, FLUX_QKVMLP_SPLIT_PATTERN) + orig_key = WeightMappingMixin._rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + + if orig_key not in qkvmlp_groups: + qkvmlp_groups[orig_key] = [] + qkvmlp_groups[orig_key].append((key, value)) + is_qkvmlp = True + break + + if is_qkvmlp: + continue + + # Standard rename + new_key = WeightMappingMixin._rename_key(key, FLUX_RENAME_PATTERNS_REVERSE) + converted_state_dict[new_key] = value + + # Concatenate QKV groups + for orig_key, items in qkv_groups.items(): + if len(items) == 3: + # Sort by the target pattern order + items.sort( + key=lambda x: next( + i + for i, t in enumerate( + FLUX_QKV_SPLIT_PATTERNS[".img_attn.qkv."] + if ".img_attn." in orig_key + else FLUX_QKV_SPLIT_PATTERNS[".txt_attn.qkv."] + ) + if t in x[0] + ) + ) + converted_state_dict[orig_key] = torch.cat([v for _, v in items], dim=0) + + # Concatenate QKV+MLP groups + for orig_key, items in qkvmlp_groups.items(): + if len(items) == 4: + items.sort(key=lambda x: next(i for i, t in enumerate(FLUX_QKVMLP_TARGETS) if t in x[0])) + converted_state_dict[orig_key] = torch.cat([v for _, v in items], dim=0) + + return converted_state_dict + + +class FluxTransformerWeightMappingMixin(WeightMappingMixin): + """ + Mixin providing Flux-specific weight mapping and conversion. + + This mixin defines class attributes used by ModelMixin for checkpoint conversion: + - Checkpoint identification keys (shared across variants) + - Variant-specific metadata (config repos) + - Conversion function + - Default subfolder + """ + + _checkpoint_key_prefixes: list[str] = ["model.diffusion_model."] + # Distinctive keys for original format detection (only keys that use simple renaming, not splits) + _checkpoint_keys: set[str] = { + "time_in.in_layer.weight", + "double_blocks.0.img_mod.lin.weight", + } + _rename_patterns: dict[str, str] = FLUX_RENAME_PATTERNS + _model_variants: dict[str, str] = { + "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", + "flux-fill": "black-forest-labs/FLUX.1-Fill-dev", + "flux-depth": "black-forest-labs/FLUX.1-Depth-dev", + } + + _map_to_diffusers = staticmethod(map_to_diffusers) + _map_from_diffusers = staticmethod(map_from_diffusers) + _default_subfolder: str = "transformer" + + @classmethod + def _detect_model_variant(cls, state_dict: dict[str, Any]) -> str | None: + """ + Detect which Flux variant a state_dict belongs to. + + Returns the variant name (e.g., "flux-dev", "flux-schnell", "flux-fill", "flux-depth") + or None if unknown. + """ + guidance_key = "guidance_in.in_layer.bias" + x_embedder_key = "img_in.weight" + + if not cls._is_original_format(state_dict): + guidance_key = cls._rename_key(guidance_key, cls._rename_patterns) + x_embedder_key = cls._rename_key(x_embedder_key, cls._rename_patterns) + + if x_embedder_key not in state_dict: + return None + + if guidance_key not in state_dict: + return "flux-schnell" + + in_channels = state_dict[x_embedder_key].shape[1] + if in_channels == 384: + return "flux-fill" + elif in_channels == 128: + return "flux-depth" + + return "flux-dev" From 8b458f891588608780c10b31a8d18df15e6da136 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 14 May 2026 22:49:28 +0530 Subject: [PATCH 03/21] update --- src/diffusers/hooks/_helpers.py | 11 + src/diffusers/loaders/lora.py | 182 +++++--- src/diffusers/loaders/lora_base.py | 109 ++--- src/diffusers/loaders/lora_pipeline.py | 237 ++-------- src/diffusers/loaders/single_file_utils.py | 2 - src/diffusers/loaders/unet.py | 16 +- src/diffusers/loaders/weight_mapping.py | 59 ++- src/diffusers/models/modeling_utils.py | 299 +++++++++---- .../models/transformers/flux/__init__.py | 5 +- .../models/transformers/flux/lora.py | 421 ++++++++---------- .../models/transformers/flux/model.py | 62 +-- .../transformers/flux/weight_mapping.py | 103 ++--- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 13 + src/diffusers/utils/peft_utils.py | 2 - 15 files changed, 766 insertions(+), 756 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 372ce4f76e91..692bec80c84d 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -31,6 +31,17 @@ class TransformerBlockMetadata: _cls: Type = None _cached_parameter_indices: dict[str, int] = None + def _register(self, cls): + """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. + + Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that + opt into the decorator pattern (e.g. Flux). Models that use the legacy bulk registration + in ``_register_transformer_blocks_metadata`` are unaffected — both code paths call the + same ``TransformerBlockRegistry.register`` underneath. + """ + cls._block_metadata = self + TransformerBlockRegistry.register(cls, self) + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): kwargs = kwargs or {} if identifier in kwargs: diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index fc5ec5a250e4..2e7f4848f0fc 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -17,9 +17,10 @@ import json import os from collections import defaultdict +from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Callable, Dict, List, Literal, Optional, Set, Union import safetensors import torch @@ -27,10 +28,17 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..hooks.group_offloading import ( + _GROUP_OFFLOADING, + _LAYER_EXECUTION_TRACKER, + _LAZY_PREFETCH_GROUP_OFFLOADING, + _apply_group_offloading, + _get_top_level_group_offload_hook, _maybe_remove_and_reapply_group_offloading, ) -from ..models.modeling_utils import load_state_dict +from ..hooks.hooks import HookRegistry +from ..models.model_loading_utils import load_state_dict from ..utils import ( + HUB_KWARGS, USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, @@ -48,7 +56,7 @@ if is_accelerate_available(): - pass + from accelerate.hooks import AlignDevicesHook, CpuOffload, add_hook_to_module, remove_hook_from_module if is_peft_available(): @@ -76,18 +84,6 @@ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" -# Hub-download kwargs forwarded to `_get_model_file`. -_HUB_KWARGS = { - "cache_dir": None, - "force_download": False, - "proxies": None, - "local_files_only": None, - "token": None, - "revision": None, - "subfolder": None, -} - - # Per-class hook for expanding adapter weights before activation. Models that need # expansion (currently only UNet variants) register here; everything else falls # through to the identity default so new transformers don't need an entry. @@ -171,6 +167,47 @@ def _split_majority_and_outliers(value_dict): return majority, {k: v for k, v in value_dict.items() if v != majority} +@contextmanager +def _offloading_disabled(model): + """Temporarily strip accelerate and group-offload hooks from ``model``. + + PEFT injection and weight loading mutate the model graph in ways that fight with + active offload hooks (sequential CPU offload, group offload, etc.). This context + saves the hook state, removes the hooks for the duration of the block, and + restores them on exit so existing offloading config survives a LoRA load. + """ + saved_hf_hook = None + is_sequential = False + if hasattr(model, "_hf_hook"): + hook = model._hf_hook + if isinstance(hook, CpuOffload): + saved_hf_hook = hook + elif isinstance(hook, AlignDevicesHook) or ( + hasattr(hook, "hooks") and isinstance(hook.hooks[0], AlignDevicesHook) + ): + saved_hf_hook = hook + is_sequential = True + if saved_hf_hook is not None: + remove_hook_from_module(model, recurse=is_sequential) + + saved_group_offload_config = None + top_level_group_hook = _get_top_level_group_offload_hook(model) + if top_level_group_hook is not None: + saved_group_offload_config = top_level_group_hook.config + registry = HookRegistry.check_if_exists_or_initialize(model) + registry.remove_hook(_GROUP_OFFLOADING, recurse=True) + registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) + registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) + + try: + yield + finally: + if saved_hf_hook is not None: + add_hook_to_module(model, saved_hf_hook) + if saved_group_offload_config is not None: + _apply_group_offloading(model, saved_group_offload_config) + + def _create_lora_config(state_dict, network_alphas, rank_dict, metadata=None): """Build a PEFT ``LoraConfig`` from a LoRA state dict. @@ -235,41 +272,23 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): logger.warning(warn_msg) -def _fetch_state_dict( - pretrained_model_name_or_path_or_dict, - weight_name=None, - use_safetensors=True, - allow_pickle=False, - **hub_kwargs, -): +def _fetch_state_dict(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): """Load a LoRA state dict from a path/repo/dict. + Safetensors only — pickle (``.bin``) LoRAs are no longer supported. Re-save legacy + checkpoints with ``safetensors.torch.save_file`` or load them manually with + ``torch.load`` and pass the resulting dict. + ``hub_kwargs`` are the download / file-discovery options forwarded to - ``_get_model_file`` (see ``_HUB_KWARGS`` for the canonical set). Sidecar - :func:`_fetch_lora_metadata`. + ``_get_model_file`` (see ``HUB_KWARGS`` for the canonical set). """ if isinstance(pretrained_model_name_or_path_or_dict, dict): return pretrained_model_name_or_path_or_dict source = pretrained_model_name_or_path_or_dict local_files_only = hub_kwargs.get("local_files_only") - - # Try safetensors first when the user asked for it (or named a .safetensors file). - # Fall through to .bin if the safetensors lookup fails and pickle is allowed. - prefer_safetensors = (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ) - if prefer_safetensors: - try: - name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) - model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) - return load_state_dict(model_file) - except (IOError, safetensors.SafetensorError): - if not allow_pickle: - raise - - name = weight_name or _best_guess_weight_name(source, ".bin", local_files_only) - model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME, **hub_kwargs) + name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) return load_state_dict(model_file) @@ -329,11 +348,17 @@ def _best_guess_weight_name( return targeted_files[0] -class PeftAdapterMixin: +class LoRAModelMixin: """ - A class containing all functions for loading and using adapters weights that are supported in PEFT library. For - more details about adapters and injecting them in a base model, check out the PEFT - [documentation](https://huggingface.co/docs/peft/index). + Single mixin for everything LoRA on a diffusers model: PEFT adapter lifecycle + (load / fuse / unfuse / set / delete / hotswap) plus foreign-format conversion + (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. + + Per-model conversion knobs live in a ``LoRAMetadata`` declared in the model's + ``lora.py`` (e.g. ``FLUX_LORA_METADATA``) and attached to the class via + ``@register_model_metadata(lora=...)``. The default no-op path just normalizes + ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state + dict unchanged. Install the latest version of PEFT, and use this mixin to: @@ -347,6 +372,49 @@ class PeftAdapterMixin: # kwargs for prepare_model_for_compiled_hotswap, if required _lora_hotswap_kwargs: Optional[dict] = None + # Default class-attribute values; populated per-model by ``register_model_metadata`` + # (or set directly on subclasses for legacy callers). + _lora_format_keys: Dict[str, Set[str]] = {} + _map_lora_to_diffusers: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None + + @classmethod + def _detect_lora_format(cls, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: + """Return the format name (``"kohya"`` etc.) matched by ``state_dict``, or ``None``.""" + if not cls._lora_format_keys: + return None + keys = set(state_dict) + for fmt, fmt_keys in cls._lora_format_keys.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return fmt + return None + + @classmethod + def _normalize_lora_suffixes(cls, state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) to ``.lora_A/.lora_B`` (diffusers).""" + out: Dict[str, "torch.Tensor"] = {} + for k, v in state_dict.items(): + new_k = ( + k.replace(".lora_down.weight", ".lora_A.weight") + .replace(".lora_up.weight", ".lora_B.weight") + .replace(".down.weight", ".lora_A.weight") + .replace(".up.weight", ".lora_B.weight") + ) + out[new_k] = v + return out + + @classmethod + def map_lora_to_diffusers(cls, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: + """Canonicalize a LoRA state dict to diffusers naming. + + Default: just normalize suffixes. Models with foreign formats register a + converter via ``LoRAMetadata._map_lora_to_diffusers`` — the decorator mirrors + it onto ``cls._map_lora_to_diffusers``. + """ + state_dict = cls._normalize_lora_suffixes(state_dict) + if cls._map_lora_to_diffusers is None: + return state_dict + return cls._map_lora_to_diffusers(state_dict, **kwargs) + @_requires_peft def load_adapter( self, @@ -466,22 +534,23 @@ def _load_adapter_from_pretrained( LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to initialize `LoraConfig`. """ - hub_kwargs = {k: kwargs.pop(k, default) for k, default in _HUB_KWARGS.items()} + hub_kwargs = {k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()} hub_kwargs["user_agent"] = {"file_type": "attn_procs_weights", "framework": "pytorch"} weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) network_alphas = kwargs.pop("network_alphas", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) metadata = kwargs.pop("metadata", None) - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - allow_pickle=False, - **hub_kwargs, - ) + if isinstance(pretrained_model_name_or_path_or_dict, dict): + state_dict = pretrained_model_name_or_path_or_dict + else: + source = pretrained_model_name_or_path_or_dict + name = weight_name or _best_guess_weight_name(source, ".safetensors", hub_kwargs.get("local_files_only")) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) + state_dict = load_state_dict(model_file) + + state_dict = self.map_lora_to_diffusers(state_dict) if not state_dict: model_class_name = self.__class__.__name__ logger.warning( @@ -601,6 +670,7 @@ def _rollback_adapter(self, adapter_name, error): if adapter_name in active_adapter: module.delete_adapter(adapter_name) self.peft_config.pop(adapter_name, None) + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{error}") @_requires_peft @@ -875,3 +945,7 @@ def enable_lora_hotswap( "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." ) self._lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} + + +# Back-compat alias. Old name from the PEFT-only era; prefer ``LoRAModelMixin``. +PeftAdapterMixin = LoRAModelMixin diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 1440bcbb5cd2..62fcad958e11 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -51,81 +51,66 @@ from peft.tuners.tuners_utils import BaseTunerLayer if is_accelerate_available(): - pass + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) -# Constants, fetch helpers, and the offload-disable shim now live in `loaders.lora`. -# Re-exported here for back-compat. -from .lora import ( # noqa: F401 (back-compat re-exports) - LORA_ADAPTER_METADATA_KEY, - LORA_WEIGHT_NAME, - LORA_WEIGHT_NAME_SAFE, - _best_guess_weight_name, - _fetch_lora_metadata, - _fetch_state_dict, - _func_optionally_disable_offloading, -) +def _func_optionally_disable_offloading(_pipeline): + """Optionally remove accelerate offloading hooks before mutating a pipeline's components. -class LoRAMappingMixin: - """ - Base mixin providing utilities for LoRA weight mapping and conversion. + Walks ``_pipeline.components``, detects accelerate / group-offload hooks, and removes + accelerate hooks in-place (group-offload is reapplied later by the LoRA load path). + Returns ``(is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)`` so + callers know which offloading mode was active and can re-enable it after loading. - Subclasses should define: - - _lora_format_keys: Dict mapping format names to identifying key patterns - - _lora_rename_patterns: Dict mapping format names to rename pattern dicts - - _map_lora_to_diffusers: Staticmethod for LoRA state dict conversion + Used by pipeline-side LoRA loaders (``LoraBaseMixin._optionally_disable_offloading``) + and the legacy paths in ``peft.py`` / ``unet.py``. Model-side loading uses the + ``_offloading_disabled`` context manager in ``loaders.lora`` instead. """ + from ..hooks.group_offloading import _is_group_offload_enabled - _lora_format_keys: dict[str, set[str]] = {} - _lora_rename_patterns: dict[str, dict[str, str]] = {} - _map_lora_to_diffusers = None - - @staticmethod - def _rename_lora_key(key: str, patterns: dict[str, str]) -> str: - """Apply rename patterns to a LoRA key.""" - for old, new in patterns.items(): - key = key.replace(old, new) - return key + is_model_cpu_offload = False + is_sequential_cpu_offload = False + is_group_offload = False - @classmethod - def _detect_lora_format(cls, state_dict: dict) -> str | None: - """ - Detect the LoRA format from state dict keys. - - Returns format name (e.g., 'kohya') or None if unknown. - """ - if not cls._lora_format_keys: - return None + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if not isinstance(component, torch.nn.Module): + continue + is_group_offload = is_group_offload or _is_group_offload_enabled(component) + if not hasattr(component, "_hf_hook"): + continue + is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) + is_sequential_cpu_offload = is_sequential_cpu_offload or ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) - keys = set(state_dict.keys()) - for format_name, format_keys in cls._lora_format_keys.items(): - if any(any(fk in k for k in keys) for fk in format_keys): - return format_name + if is_sequential_cpu_offload or is_model_cpu_offload: + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous " + "hooks will be first removed. Then the LoRA parameters will be loaded and the hooks " + "will be applied again." + ) + for _, component in _pipeline.components.items(): + if not isinstance(component, torch.nn.Module) or not hasattr(component, "_hf_hook"): + continue + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - return None + return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) - @classmethod - def _normalize_lora_suffixes(cls, state_dict: dict) -> dict: - """Normalize LoRA suffixes to diffusers format (.lora_A.weight, .lora_B.weight).""" - normalized = {} - for key, value in state_dict.items(): - new_key = key - new_key = new_key.replace(".lora_down.weight", ".lora_A.weight") - new_key = new_key.replace(".lora_up.weight", ".lora_B.weight") - new_key = new_key.replace(".down.weight", ".lora_A.weight") - new_key = new_key.replace(".up.weight", ".lora_B.weight") - normalized[new_key] = value - return normalized - @classmethod - def map_lora_to_diffusers(cls, state_dict: dict, **kwargs) -> dict: - """Normalize LoRA suffixes, then dispatch to the model-specific converter.""" - if cls._map_lora_to_diffusers is None: - raise NotImplementedError(f"{cls.__name__} does not define _map_lora_to_diffusers") - state_dict = cls._normalize_lora_suffixes(state_dict) - return cls._map_lora_to_diffusers(state_dict, **kwargs) +# Constants and fetch helpers live in ``loaders.lora`` — re-exported here for back-compat. +from .lora import ( # noqa: E402, F401 (intentional mid-file import: back-compat re-exports) + LORA_ADAPTER_METADATA_KEY, + LORA_WEIGHT_NAME, + LORA_WEIGHT_NAME_SAFE, + _best_guess_weight_name, + _fetch_lora_metadata, + _fetch_state_dict, +) def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a08a9738a185..2da0cfeca4d0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -41,11 +41,8 @@ _pack_dict_with_prefix, ) from .lora_conversion_utils import ( - _convert_bfl_flux_control_lora_to_diffusers, - _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux2_lora_to_diffusers, - _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, @@ -56,7 +53,6 @@ _convert_non_diffusers_qwen_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_non_diffusers_z_image_lora_to_diffusers, - _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -301,20 +297,14 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -323,7 +313,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -748,20 +737,14 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -770,7 +753,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -1037,20 +1019,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1059,7 +1035,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -1340,20 +1315,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1362,7 +1331,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -1555,20 +1523,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1577,7 +1539,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -1596,45 +1557,18 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) - if is_kohya: - state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) - # Kohya already takes care of scaling the LoRA parameters with alpha. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - is_xlabs = any("processor" in k for k in state_dict) - if is_xlabs: - state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) - # xlabs doesn't use `alpha`. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - - is_bfl_control = any("query_norm.scale" in k for k in state_dict) - if is_bfl_control: - state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) + from ..models.transformers.flux import FluxTransformer2DModel - is_fal_kontext = any("base_model" in k for k in state_dict) - if is_fal_kontext: - state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) + # Format-specific dispatch lives on the model: detect format (kohya/xlabs/bfl/kontext) + # and convert to diffusers naming. Unknown / diffusers-native state dicts fall + # through to the alpha-extraction path below. + is_recognized_format = FluxTransformer2DModel._detect_lora_format(state_dict) is not None or any( + k.startswith("transformer.") for k in state_dict + ) + if is_recognized_format: + state_dict = FluxTransformer2DModel.map_lora_to_diffusers(state_dict) + # Recognized formats embed alphas in the conversion (kohya scales weights; + # xlabs / bfl / kontext don't use alphas). return cls._prepare_outputs( state_dict, metadata=metadata, @@ -1643,8 +1577,8 @@ def lora_state_dict( return_metadata=return_lora_metadata, ) - # For state dicts like - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + # Diffusers-native fallback (e.g. https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA): + # alphas ride alongside the weights as separate ``.alpha`` keys. keys = list(state_dict.keys()) network_alphas = {} for k in keys: @@ -2489,20 +2423,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2511,7 +2439,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -2695,20 +2622,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2717,7 +2638,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -2904,20 +2824,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2926,7 +2840,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -3118,20 +3031,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3140,7 +3047,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -3354,20 +3260,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3376,7 +3276,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -3764,20 +3663,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3786,7 +3679,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -3977,20 +3869,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3999,7 +3885,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4192,20 +4077,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4214,7 +4093,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4401,20 +4279,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4423,7 +4295,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4685,20 +4556,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4707,7 +4572,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4971,20 +4835,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4993,7 +4851,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -5180,20 +5037,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5202,7 +5053,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -5393,20 +5243,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5415,7 +5259,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -5609,20 +5452,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5631,7 +5468,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -5825,20 +5661,14 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) + kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored return_lora_metadata = kwargs.pop("return_lora_metadata", False) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, - use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5847,7 +5677,6 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - allow_pickle=allow_pickle, ) metadata = _fetch_lora_metadata( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index fff6be067f27..170dca27ae10 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2244,8 +2244,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): return converted_state_dict - - def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index e0dc3dcfc2fd..1bf93d5cec42 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -75,15 +75,15 @@ def _load_adapter_from_pretrained(self, pretrained_model_name_or_path_or_dict, * See https://huggingface.co/stabilityai/control-lora and https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors. """ - from .lora import _HUB_KWARGS, _fetch_state_dict + from ..utils import HUB_KWARGS + from .lora import _fetch_state_dict # Resolve to a state_dict up-front so we can inspect / convert before the base loader. if isinstance(pretrained_model_name_or_path_or_dict, dict): state_dict = pretrained_model_name_or_path_or_dict else: - fetch_kwargs = {k: kwargs.get(k, default) for k, default in _HUB_KWARGS.items()} + fetch_kwargs = {k: kwargs.get(k, default) for k, default in HUB_KWARGS.items()} fetch_kwargs["weight_name"] = kwargs.get("weight_name") - fetch_kwargs["use_safetensors"] = kwargs.get("use_safetensors") state_dict = _fetch_state_dict(pretrained_model_name_or_path_or_dict, **fetch_kwargs) if not any("lora_A" in k for k in state_dict): @@ -110,15 +110,9 @@ def _load_sai_control_lora(self, state_dict, **kwargs): metadata = kwargs.get("metadata") if prefix is not None: - state_dict = { - k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - rank = { - f"^{key}": val.shape[1] - for key, val in state_dict.items() - if "lora_B" in key and val.ndim > 1 - } + rank = {f"^{key}": val.shape[1] for key, val in state_dict.items() if "lora_B" in key and val.ndim > 1} if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas if k.startswith(f"{prefix}.")] diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index fa7d5b001d09..63e3e4a683b2 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -15,38 +15,41 @@ """Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. -A model defines its mapping by subclassing :class:`WeightMappingMixin` and -populating the class attributes (`_rename_patterns`, `_checkpoint_keys`, etc.) -plus assigning ``_map_to_diffusers`` / ``_map_from_diffusers`` callables. +A model declares its mapping in a ``WeightMappingMetadata`` instance (typically +in its ``weight_mapping.py`` module) and attaches it via +``@register_model_metadata(weight_mapping=...)``. This mixin supplies the +generic dispatch methods that read from that metadata. The :meth:`apply_transforms` helper drives the forward direction from a single declarative table — see ``models/transformers/flux/weight_mapping.py`` for an example. """ +from typing import Optional + class WeightMappingMixin: """ Base mixin providing utilities for checkpoint weight mapping and conversion. - Subclasses should define: - - _checkpoint_key_prefixes: List of key prefixes to strip (e.g., ["model.diffusion_model."]) - - _checkpoint_keys: Set of keys to identify compatible checkpoints - - _rename_patterns: Dict of substring replacements for key renaming - - _model_variants: Dict mapping variant names to config repos - - _map_to_diffusers: Function to convert original format to diffusers format - - _map_from_diffusers: Function to convert diffusers format to original format + Per-model configuration (rename patterns, format-identifying keys, conversion + callables, etc.) lives in the model's registered ``WeightMappingMetadata`` — + declared in the model's ``weight_mapping.py`` and attached via + ``@register_model_metadata``. This mixin just supplies the dispatch methods. """ - _checkpoint_key_prefixes: list[str] = [] - _checkpoint_keys: set[str] = set() - _rename_patterns: dict[str, str] = {} - _model_variants: dict[str, str] = {} + # Default class-attribute values; populated per-model by ``register_model_metadata``. + _checkpoint_key_prefixes: list = [] + _checkpoint_keys: set = set() + _rename_patterns: dict = {} + _model_variants: dict = {} _map_to_diffusers = None _map_from_diffusers = None + _detect_model_variant_fn = None + _default_subfolder: str = "transformer" @staticmethod - def _rename_key(key: str, patterns: dict[str, str]) -> str: + def _rename_key(key: str, patterns: dict) -> str: """Apply rename patterns to a key.""" for old, new in patterns.items(): key = key.replace(old, new) @@ -57,7 +60,6 @@ def _normalize_checkpoint_keys(cls, state_dict: dict) -> dict: """Strip known prefixes from state_dict keys.""" if not cls._checkpoint_key_prefixes: return state_dict - result = {} for key, value in state_dict.items(): new_key = key @@ -73,13 +75,20 @@ def _is_original_format(cls, state_dict: dict) -> bool: """Check if state_dict is in original (non-diffusers) format.""" if not cls._checkpoint_keys: return False - keys = set(state_dict.keys()) - return bool(cls._checkpoint_keys & keys) + return bool(cls._checkpoint_keys & set(state_dict.keys())) @classmethod - def _detect_model_variant(cls, state_dict: dict) -> str | None: - """Detect which model variant a state_dict belongs to. Subclasses should override.""" - raise NotImplementedError(f"{cls.__name__} does not implement _detect_model_variant") + def _detect_model_variant(cls, state_dict: dict) -> Optional[str]: + """Detect which model variant a state_dict belongs to. + + Dispatches to ``cls._detect_model_variant_fn`` (mirrored from the model's metadata); + raises if no detector is registered. + """ + if cls._detect_model_variant_fn is None: + raise NotImplementedError( + f"{cls.__name__} did not register a `_detect_model_variant_fn` in its WeightMappingMetadata." + ) + return cls._detect_model_variant_fn(cls, state_dict) @classmethod def _get_model_config(cls, state_dict: dict) -> str: @@ -122,12 +131,16 @@ def apply_transforms(state_dict, transforms, rename_patterns, **ctx): def map_to_diffusers(cls, state_dict: dict, **kwargs) -> dict: """Convert state_dict from original format to diffusers format.""" if cls._map_to_diffusers is None: - raise NotImplementedError(f"{cls.__name__} does not define _map_to_diffusers") + raise NotImplementedError( + f"{cls.__name__} did not register a `_map_to_diffusers` in its WeightMappingMetadata." + ) return cls._map_to_diffusers(state_dict, **kwargs) @classmethod def map_from_diffusers(cls, state_dict: dict, **kwargs) -> dict: """Convert state_dict from diffusers format to original format.""" if cls._map_from_diffusers is None: - raise NotImplementedError(f"{cls.__name__} does not define _map_from_diffusers") + raise NotImplementedError( + f"{cls.__name__} did not register a `_map_from_diffusers` in its WeightMappingMetadata." + ) return cls._map_from_diffusers(state_dict, **kwargs) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2d685c180644..c99016738e1a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,7 @@ import tempfile from collections import OrderedDict from contextlib import ExitStack, contextmanager, nullcontext -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields, is_dataclass from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -39,6 +39,9 @@ from typing_extensions import Self from .. import __version__ +from ..configuration_utils import ConfigMixin +from ..loaders.lora import LoRAModelMixin +from ..loaders.weight_mapping import WeightMappingMixin from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -46,6 +49,7 @@ FLASHPACK_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, HF_ENABLE_PARALLEL_LOADING, + HUB_KWARGS, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -230,6 +234,77 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) +@dataclass +class LoRAMetadata: + """Per-model LoRA configuration: what foreign formats this model accepts and how to convert them. + + Field names match the legacy ``cls._`` class attributes consumed by + ``LoRAModelMixin``, so the decorator can mirror them 1:1. + + Attributes: + _lora_format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying + key substrings. The first format whose substrings appear in the state dict wins. + _map_lora_to_diffusers: Callable ``(state_dict, **kwargs) -> state_dict`` that rewrites + foreign-format keys to diffusers naming. Called from + ``LoRAModelMixin.map_lora_to_diffusers`` after generic suffix normalization. + ``None`` for models that only ingest diffusers-native LoRAs. + """ + + _lora_format_keys: Dict[str, set] = field(default_factory=dict) + _map_lora_to_diffusers: Optional[Callable] = None + + +@dataclass +class IPAdapterMetadata: + """Per-model IP-Adapter configuration: how to convert IP-Adapter state dicts for this architecture. + + Field names match the legacy ``cls._`` class attributes consumed by + ``IPAdapterModelMixin``, so the decorator can mirror them 1:1. + + Attributes: + _convert_ip_adapter_attn_to_diffusers: Callable + ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the + attn-processor dict ready for ``set_attn_processor``. Receives the model instance because it + needs ``model.attn_processors``, ``model.config``, ``model.inner_dim``, etc. + _convert_ip_adapter_image_proj_to_diffusers: Callable + ``(model, state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image + projection layer. + """ + + _convert_ip_adapter_attn_to_diffusers: Optional[Callable] = None + _convert_ip_adapter_image_proj_to_diffusers: Optional[Callable] = None + + +@dataclass +class WeightMappingMetadata: + """Per-model checkpoint conversion metadata for single-file loading. + + Field names match the legacy ``cls._`` class attributes consumed by + ``WeightMappingMixin``, so the decorator can mirror them 1:1. + + Note: per-key rename tables and checkpoint key prefixes live in the model's + ``weight_mapping.py`` module as plain constants (e.g. ``FLUX_RENAME_PATTERNS``). + They're consumed directly by the model's ``map_to_diffusers`` / ``map_from_diffusers`` + callables and don't need to be threaded through metadata. + + Attributes: + _checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is + in the original (pre-diffusers) format. + _model_variants: Map of variant name to its default config repo on the Hub. + _map_to_diffusers / _map_from_diffusers: Callables driving the two conversion directions. + _detect_model_variant_fn: Optional ``(cls, state_dict) -> Optional[str]`` for picking the + right variant when a single checkpoint format spans multiple architectures. + _default_subfolder: Default ``subfolder`` to use when fetching configs (e.g. ``"transformer"``). + """ + + _checkpoint_keys: set = field(default_factory=set) + _model_variants: Dict[str, str] = field(default_factory=dict) + _map_to_diffusers: Optional[Callable] = None + _map_from_diffusers: Optional[Callable] = None + _detect_model_variant_fn: Optional[Callable] = None + _default_subfolder: str = "transformer" + + @dataclass class ModelMetadata: """ @@ -239,32 +314,83 @@ class ModelMetadata: This is static metadata about the model class's capabilities and hints for optimization features like gradient checkpointing, offloading, and parallelism. + Field names match the legacy ``cls._`` class attributes (so the decorator + mirrors them 1:1 and existing consumer code keeps working). + Attributes: - supports_gradient_checkpointing: Whether the model supports gradient checkpointing + _supports_gradient_checkpointing: Whether the model supports gradient checkpointing for memory-efficient training. - no_split_modules: List of module class names that should NOT be split across + _no_split_modules: List of module class names that should NOT be split across devices during model parallelism. - keep_in_fp32_modules: List of module names to keep in FP32 precision when using + _keep_in_fp32_modules: List of module names to keep in FP32 precision when using lower precision dtypes for numerical stability. - skip_layerwise_casting_patterns: Tuple of module name patterns to exclude from + _skip_layerwise_casting_patterns: Tuple of module name patterns to exclude from layerwise casting operations. - supports_group_offloading: Whether the model supports group offloading. - repeated_blocks: List of module class names that repeat throughout the model, + _supports_group_offloading: Whether the model supports group offloading. + _repeated_blocks: List of module class names that repeat throughout the model, useful for optimization and pattern analysis. - cp_plan: Context parallel configuration plan defining how to split model + _cp_plan: Context parallel configuration plan defining how to split model components for context parallelism across devices. - keys_to_ignore_on_load_unexpected: List of keys to ignore when loading + _keys_to_ignore_on_load_unexpected: List of keys to ignore when loading unexpected keys from a checkpoint. + _lora: Per-model LoRA loading metadata. See :class:`LoRAMetadata`. + _weight_mapping: Per-model checkpoint conversion metadata. See :class:`WeightMappingMetadata`. + """ + + _supports_gradient_checkpointing: bool = False + _no_split_modules: Optional[List[str]] = None + _keep_in_fp32_modules: Optional[List[str]] = None + _skip_layerwise_casting_patterns: Optional[Tuple[str, ...]] = None + _supports_group_offloading: bool = True + _repeated_blocks: List[str] = field(default_factory=list) + _cp_plan: Optional[Dict[str, Any]] = None + _keys_to_ignore_on_load_unexpected: Optional[List[str]] = None + _lora: LoRAMetadata = field(default_factory=LoRAMetadata) + _ip_adapter: IPAdapterMetadata = field(default_factory=IPAdapterMetadata) + _weight_mapping: WeightMappingMetadata = field(default_factory=WeightMappingMetadata) + + def _register(self, cls): + """Attach this ``ModelMetadata`` to ``cls`` and mirror leaf fields to legacy class attrs. + + Walks nested dataclasses (``_lora``, ``_weight_mapping``, ``_ip_adapter``) so their + leaf fields land flat on ``cls``. Field names already starting with ``_`` map 1:1 + (``_lora_format_keys`` → ``cls._lora_format_keys``); unprefixed names get the + underscore added (``rename_patterns`` → ``cls._rename_patterns``). + """ + cls._model_metadata = self + pending = [self] + while pending: + obj = pending.pop() + for f in fields(obj): + value = getattr(obj, f.name) + if is_dataclass(value): + pending.append(value) + else: + attr = f.name if f.name.startswith("_") else f"_{f.name}" + setattr(cls, attr, value) + + +def register_metadata(metadata): + """Generic class decorator that attaches metadata to the decorated class. + + Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own + attachment logic. Works for both model-level metadata (``ModelMetadata``) and + block-level metadata (``TransformerBlockMetadata``):: + + @register_metadata(FLUX_MODEL_METADATA) + class FluxTransformer2DModel(...): + ... + + @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) + class FluxTransformerBlock(nn.Module): + ... """ - supports_gradient_checkpointing: bool = False - no_split_modules: Optional[List[str]] = None - keep_in_fp32_modules: Optional[List[str]] = None - skip_layerwise_casting_patterns: Optional[Tuple[str, ...]] = None - supports_group_offloading: bool = True - repeated_blocks: List[str] = field(default_factory=list) - cp_plan: Optional[Dict[str, Any]] = None - keys_to_ignore_on_load_unexpected: Optional[List[str]] = None + def wrap(cls): + metadata._register(cls) + return cls + + return wrap def _should_convert_checkpoint(model_state_dict: Dict[str, Any], checkpoint: Dict[str, Any]) -> bool: @@ -276,7 +402,7 @@ def _should_convert_checkpoint(model_state_dict: Dict[str, Any], checkpoint: Dic return not (is_subset and is_match) -class ModelMixin(torch.nn.Module, PushToHubMixin): +class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixin, PushToHubMixin): r""" Base class for all models. @@ -311,50 +437,59 @@ def get_metadata(cls) -> "ModelMetadata": """ if cls._model_metadata is not None: return cls._model_metadata + # Fallback for unmigrated models: build from the legacy per-attribute class vars. + # ``lora`` / ``weight_mapping`` read from old mixin attrs if present, else stay empty. return ModelMetadata( - supports_gradient_checkpointing=cls._supports_gradient_checkpointing, - no_split_modules=cls._no_split_modules, - keep_in_fp32_modules=cls._keep_in_fp32_modules, - skip_layerwise_casting_patterns=cls._skip_layerwise_casting_patterns, - supports_group_offloading=cls._supports_group_offloading, - repeated_blocks=cls._repeated_blocks if cls._repeated_blocks else [], - cp_plan=cls._cp_plan, - keys_to_ignore_on_load_unexpected=cls._keys_to_ignore_on_load_unexpected, + _supports_gradient_checkpointing=cls._supports_gradient_checkpointing, + _no_split_modules=cls._no_split_modules, + _keep_in_fp32_modules=cls._keep_in_fp32_modules, + _skip_layerwise_casting_patterns=cls._skip_layerwise_casting_patterns, + _supports_group_offloading=cls._supports_group_offloading, + _repeated_blocks=cls._repeated_blocks if cls._repeated_blocks else [], + _cp_plan=cls._cp_plan, + _keys_to_ignore_on_load_unexpected=cls._keys_to_ignore_on_load_unexpected, + _lora=LoRAMetadata( + _lora_format_keys=getattr(cls, "_lora_format_keys", None) or {}, + _map_lora_to_diffusers=getattr(cls, "_map_lora_to_diffusers", None), + ), + _weight_mapping=WeightMappingMetadata( + _checkpoint_keys=getattr(cls, "_checkpoint_keys", None) or set(), + _model_variants=getattr(cls, "_model_variants", None) or {}, + _map_to_diffusers=getattr(cls, "_map_to_diffusers", None), + _map_from_diffusers=getattr(cls, "_map_from_diffusers", None), + _detect_model_variant_fn=getattr(cls, "_detect_model_variant_fn", None), + _default_subfolder=getattr(cls, "_default_subfolder", "transformer"), + ), + _ip_adapter=IPAdapterMetadata( + _convert_ip_adapter_attn_to_diffusers=getattr(cls, "_convert_ip_adapter_attn_to_diffusers", None), + _convert_ip_adapter_image_proj_to_diffusers=getattr( + cls, "_convert_ip_adapter_image_proj_to_diffusers", None + ), + ), ) @classmethod def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Convert state dict from original format to diffusers format if needed. - - This method checks if the state dict keys match the model's expected keys. - If not, it applies normalization and conversion using the model's - `_normalize_checkpoint_keys` and `map_to_diffusers` methods. + """Convert ``state_dict`` from original format to diffusers format if needed. - Args: - model: The model instance to compare against. - state_dict: The loaded state dict. + Two phases, both declared via the model's :class:`WeightMappingMetadata`: - Returns: - The state dict, potentially converted to diffusers format. + 1. ``_normalize_checkpoint_keys`` — strip known prefixes (e.g. ``model.diffusion_model.``). + Run unconditionally; idempotent and a no-op if no prefixes were registered. + 2. ``_map_to_diffusers`` — the actual format converter, only invoked if step 1 alone + didn't already make the keys match. Skipped if no converter was registered (loading + then fails downstream with a clearer key-mismatch error than a deep + ``NotImplementedError``). """ - model_state_dict = model.state_dict() - - if not _should_convert_checkpoint(model_state_dict, state_dict): - return state_dict - - normalize_fn = getattr(cls, "_normalize_checkpoint_keys", None) - if normalize_fn is not None: - state_dict = normalize_fn(state_dict) - - if not _should_convert_checkpoint(model_state_dict, state_dict): + # Step 1: always strip checkpoint key prefixes — idempotent, no-op if none registered. + state_dict = cls._normalize_checkpoint_keys(state_dict) + if not _should_convert_checkpoint(model.state_dict(), state_dict): return state_dict - map_to_diffusers_fn = getattr(cls, "map_to_diffusers", None) - if map_to_diffusers_fn is None: + # Step 2: run the per-model converter. Skip if no metadata-registered converter. + if getattr(cls, "_map_to_diffusers", None) is None: return state_dict - - return map_to_diffusers_fn(state_dict) + return cls.map_to_diffusers(state_dict) def __init__(self): super().__init__() @@ -1615,14 +1750,16 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = load_single_file_checkpoint, ) - map_to_diffusers_fn = getattr(cls, "map_to_diffusers", None) - default_subfolder = getattr(cls, "_default_subfolder", None) - - if map_to_diffusers_fn is None: + # `map_to_diffusers` is inherited universally via ``WeightMappingMixin``-via-``ModelMixin``; + # the genuine "is this single-file capable?" signal is whether the model registered a + # ``_map_to_diffusers`` callable in its ``WeightMappingMetadata``. + if getattr(cls, "_map_to_diffusers", None) is None: raise ValueError( f"{cls.__name__} does not support `from_single_file`. " - f"Please ensure the model class defines `map_to_diffusers`." + f"Register a `_map_to_diffusers` callable in its `WeightMappingMetadata` " + f"(or use `from_pretrained` if the model is in diffusers format)." ) + default_subfolder = getattr(cls, "_default_subfolder", None) pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) if pretrained_model_link_or_path is not None: @@ -1632,19 +1769,16 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + # Hub-download kwargs (cache_dir / force_download / proxies / local_files_only / token / + # revision / subfolder) consolidated via the canonical ``HUB_KWARGS`` defaults. + hub_kwargs = {k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()} + config = kwargs.pop("config", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - subfolder = kwargs.pop("subfolder", None) - revision = kwargs.pop("revision", None) config_revision = kwargs.pop("config_revision", None) torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - device = kwargs.pop("device", None) + kwargs.pop("device", None) # consumed elsewhere; pop to prevent forwarding disable_mmap = kwargs.pop("disable_mmap", False) device_map = kwargs.pop("device_map", None) @@ -1659,24 +1793,20 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) if isinstance(pretrained_model_link_or_path_or_dict, dict): - checkpoint = pretrained_model_link_or_path_or_dict + state_dict = pretrained_model_link_or_path_or_dict else: - checkpoint = load_single_file_checkpoint( + # ``load_single_file_checkpoint`` takes everything in ``HUB_KWARGS`` except ``subfolder``. + state_dict = load_single_file_checkpoint( pretrained_model_link_or_path_or_dict, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, disable_mmap=disable_mmap, user_agent=user_agent, + **{k: v for k, v in hub_kwargs.items() if k != "subfolder"}, ) - # Normalize checkpoint keys (strip known prefixes) if the model defines a normalizer + # Normalize state_dict keys (strip known prefixes) if the model defines a normalizer normalize_fn = getattr(cls, "_normalize_checkpoint_keys", None) if normalize_fn is not None: - checkpoint = normalize_fn(checkpoint) + state_dict = normalize_fn(state_dict) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) @@ -1700,17 +1830,16 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = f"{cls.__name__} does not support automatic config detection. " f"Please provide a `config` argument or define `_get_model_config` on the model class." ) - default_pretrained_model_config_name = get_model_config_fn(checkpoint) + default_pretrained_model_config_name = get_model_config_fn(state_dict) if default_subfolder is not None: - subfolder = default_subfolder + hub_kwargs["subfolder"] = default_subfolder + # ``load_config`` consumes the hub-download kwargs; ``config_revision`` (if set) overrides + # the file ``revision`` for resolving the config repo specifically. diffusers_model_config = cls.load_config( pretrained_model_name_or_path=default_pretrained_model_config_name, - subfolder=subfolder, - local_files_only=local_files_only, - token=token, - revision=config_revision, + **{**hub_kwargs, "revision": config_revision}, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} @@ -1738,15 +1867,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = else: keep_in_fp32_modules = [] - if _should_convert_checkpoint(model_state_dict, checkpoint): - checkpoint = map_to_diffusers_fn(checkpoint) + if _should_convert_checkpoint(model_state_dict, state_dict): + state_dict = cls.map_to_diffusers(state_dict) - if not checkpoint: + if not state_dict: raise SingleFileComponentError( f"Failed to load {cls.__name__}. Weights for this component appear to be missing in the checkpoint." ) - loaded_keys = list(checkpoint.keys()) + loaded_keys = list(state_dict.keys()) if hf_quantizer is not None: hf_quantizer.preprocess_model( @@ -1766,7 +1895,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = error_msgs, ) = cls._load_pretrained_model( model, - checkpoint, + state_dict, None, None, loaded_keys, diff --git a/src/diffusers/models/transformers/flux/__init__.py b/src/diffusers/models/transformers/flux/__init__.py index 2996567d82d1..f7b77eee6486 100644 --- a/src/diffusers/models/transformers/flux/__init__.py +++ b/src/diffusers/models/transformers/flux/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .lora import FluxTransformerLoRAMixin +from .ip_adapter import FLUX_IP_ADAPTER_METADATA +from .lora import FLUX_LORA_METADATA from .model import ( FluxAttention, FluxAttnProcessor, @@ -22,4 +23,4 @@ FluxTransformer2DModel, FluxTransformerBlock, ) -from .weight_mapping import FluxTransformerWeightMappingMixin +from .weight_mapping import FLUX_WEIGHT_MAPPING_METADATA diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py index 0a79f77377ff..c33140a61d1a 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/lora.py @@ -14,29 +14,31 @@ """Flux LoRA conversion. -Pipeline: - 1. Detect the source format (kohya / xlabs / bfl / kontext) via - ``FluxLoRAMappingMixin._detect_lora_format``. - 2. Run the format-specific *normalizer* that rewrites keys into the - canonical "BFL-style" form: original Flux module names with - ``.lora_A`` / ``.lora_B`` suffixes. Tensor-level transforms unique - to a format (e.g., Kohya alpha scaling) happen here. - 3. Run the shared *converter* that maps canonical keys to diffusers - names by reusing the rename / QKV-split / special-key tables in - ``weight_mapping.py``, applying the LoRA-specific QKV semantics - (lora_A weight replicates, everything else chunks). - -A normalizer may also emit keys that bypass step 3 (returned as a second -"extras" dict) — used for keys that don't fit the canonical intermediate -(e.g., text-encoder LoRA keys, XLabs single-block QKV without MLP). +Each supported foreign format has a top-level ``map__to_diffusers`` entry point: + + - :func:`map_bfl_to_diffusers` — original BFL repo layout + - :func:`map_kontext_to_diffusers` — fal Kontext checkpoints (BFL + ``base_model.model.`` prefix) + - :func:`map_xlabs_to_diffusers` — XLabs ``.processor.qkv_lora`` / ``.processor.proj_lora`` shape + - :func:`map_kohya_to_diffusers` — kohya sd-scripts (and "mixture" / ``lora_transformer_*`` variants) + +Each entry point produces a state dict with diffusers naming. Internally they all +funnel through :func:`_map_to_diffusers`, which converts a BFL-style state dict +(original Flux module names + ``.lora_A``/``.lora_B`` suffixes) to diffusers names by +reusing the rename / QKV-split / special-key tables in ``weight_mapping.py`` and applying +LoRA-specific QKV semantics (``lora_A.weight`` replicates across heads; everything else +chunks). + +A format-specific converter may also emit pre-converted diffusers keys directly when a +key shape doesn't fit the canonical intermediate (e.g., XLabs single-block QKV without +a paired MLP LoRA). """ import re import torch -from ....loaders.lora_base import LoRAMappingMixin from ....utils import logging, state_dict_all_zero +from ...modeling_utils import LoRAMetadata from .weight_mapping import ( FLUX_QKV_SPLIT_PATTERNS, FLUX_QKVMLP_SPLIT_PATTERN, @@ -50,11 +52,11 @@ # ============================================================================ -# Stage 3: shared canonical -> diffusers converter +# Shared canonical -> diffusers converter # ============================================================================ # Canonical keys are BFL-style: original Flux module names + .lora_A/.lora_B -# suffixes. The shared converter handles three cases — pure renames, QKV -# splits, special transforms — by reusing the tables from weight_mapping. +# suffixes. The shared converter handles three cases — pure renames, QKV splits, +# special transforms — by reusing the tables from weight_mapping. _LORA_SUFFIXES = (".lora_A.weight", ".lora_A.bias", ".lora_B.weight", ".lora_B.bias") @@ -70,37 +72,30 @@ break -def _split_lora_suffix(key): - for suffix in _LORA_SUFFIXES: - if key.endswith(suffix): - return key[: -len(suffix)], suffix - return key, "" - - def _apply_renames(s, patterns): for old, new in patterns.items(): s = s.replace(old, new) return s -def _rename_module(module_path): - # FLUX_RENAME_PATTERNS keys often have a trailing "."; pad-and-strip so - # bare module paths like "final_layer.linear" still match patterns like - # "final_layer.linear.". - out = _apply_renames(module_path + ".", FLUX_RENAME_PATTERNS) - return out[:-1] if out.endswith(".") else out - - -def _map_lora_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): +def _map_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): + """Convert a BFL-style canonical LoRA state dict to diffusers naming.""" out = {} qkvmlp_dims = (inner_dim, inner_dim, inner_dim, int(inner_dim * mlp_ratio)) for key, value in state_dict.items(): - module_path, suffix = _split_lora_suffix(key) - + # Split off the .lora_A/.lora_B suffix; non-LoRA keys pass through with renames. + suffix = next((s for s in _LORA_SUFFIXES if key.endswith(s)), "") if not suffix: out[f"transformer.{_apply_renames(key, FLUX_RENAME_PATTERNS)}"] = value continue + module_path = key[: -len(suffix)] + + # FLUX_RENAME_PATTERNS keys often end with "."; pad-and-strip so bare module paths + # like "final_layer.linear" still match patterns like "final_layer.linear.". + def _rename(path): + renamed = _apply_renames(path + ".", FLUX_RENAME_PATTERNS) + return renamed[:-1] if renamed.endswith(".") else renamed qkv = next(((p, ts) for p, ts in _QKV_PATTERNS.items() if p in module_path), None) if qkv is not None: @@ -109,8 +104,7 @@ def _map_lora_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): [value] * len(targets) if suffix == ".lora_A.weight" else list(torch.chunk(value, len(targets), dim=0)) ) for target, chunk in zip(targets, chunks): - new_module = _rename_module(module_path.replace(pattern, target)) - out[f"transformer.{new_module}{suffix}"] = chunk + out[f"transformer.{_rename(module_path.replace(pattern, target))}{suffix}"] = chunk continue if _QKVMLP_PATTERN in module_path and "single_blocks." in module_path: @@ -120,8 +114,7 @@ def _map_lora_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): else list(torch.split(value, qkvmlp_dims, dim=0)) ) for target, chunk in zip(_QKVMLP_TARGETS, chunks): - new_module = _rename_module(module_path.replace(_QKVMLP_PATTERN, target)) - out[f"transformer.{new_module}{suffix}"] = chunk + out[f"transformer.{_rename(module_path.replace(_QKVMLP_PATTERN, target))}{suffix}"] = chunk continue if module_path in _SPECIAL_MODULES: @@ -129,39 +122,41 @@ def _map_lora_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): out[f"transformer.{target_module}{suffix}"] = transform(value) continue - out[f"transformer.{_rename_module(module_path)}{suffix}"] = value + out[f"transformer.{_rename(module_path)}{suffix}"] = value return out # ============================================================================ -# Stage 2a: BFL normalizer (identity) +# BFL — identity (canonical form is BFL-style) # ============================================================================ -def _normalize_bfl(state_dict): - return dict(state_dict), {} +def map_bfl_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from BFL format to diffusers naming.""" + return _map_to_diffusers(dict(state_dict)) # ============================================================================ -# Stage 2b: fal Kontext normalizer (strip "base_model.model." prefix) +# fal Kontext — BFL with ``base_model.model.`` prefix # ============================================================================ -def _normalize_kontext(state_dict): +def map_kontext_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from fal Kontext format to diffusers naming.""" prefix = "base_model.model." canonical = {(k[len(prefix) :] if k.startswith(prefix) else k): v for k, v in state_dict.items()} - return canonical, {} + return _map_to_diffusers(canonical) # ============================================================================ -# Stage 2c: XLabs normalizer +# XLabs # ============================================================================ # XLabs key shape: [diffusion_model.]{double|single}_blocks.{i}.processor.{X}.{down|up}.weight -# Double-block X ∈ {qkv_lora1, qkv_lora2, proj_lora1, proj_lora2}. -# Single-block X ∈ {qkv_lora, proj_lora}. Single-block lacks an MLP LoRA, so -# its qkv keys can't be expressed as canonical "linear1" (which is QKV+MLP); -# we emit pre-converted diffusers keys for single blocks instead. +# Double-block X ∈ {qkv_lora1, qkv_lora2, proj_lora1, proj_lora2} — renameable to canonical +# BFL form. Single-block X ∈ {qkv_lora, proj_lora} — single blocks lack an MLP LoRA, so +# qkv keys can't be expressed as canonical "linear1" (QKV+MLP fused); we emit pre-converted +# diffusers keys for single-block extras and route only double-block keys through canonical. _XLABS_DOUBLE_RENAMES = { ".processor.proj_lora1.": ".img_attn.proj.", @@ -172,13 +167,12 @@ def _normalize_kontext(state_dict): _XLABS_SINGLE_QKV_TARGETS = ["attn.to_q", "attn.to_k", "attn.to_v"] -def _normalize_xlabs(state_dict): +def map_xlabs_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from XLabs format to diffusers naming.""" canonical = {} extras = {} for key, value in state_dict.items(): - k = key - if k.startswith("diffusion_model."): - k = k[len("diffusion_model.") :] + k = key.removeprefix("diffusion_model.") if "single_blocks." in k: block = re.search(r"single_blocks\.(\d+)", k).group(1) @@ -187,12 +181,13 @@ def _normalize_xlabs(state_dict): if "proj_lora" in k: extras[f"{base}.proj_out{suffix}"] = value elif "qkv_lora" in k: - if suffix == ".lora_A.weight": - for t in _XLABS_SINGLE_QKV_TARGETS: - extras[f"{base}.{t}{suffix}"] = value - else: - for t, chunk in zip(_XLABS_SINGLE_QKV_TARGETS, torch.chunk(value, 3, dim=0)): - extras[f"{base}.{t}{suffix}"] = chunk + chunks = ( + [value] * len(_XLABS_SINGLE_QKV_TARGETS) + if suffix == ".lora_A.weight" + else list(torch.chunk(value, 3, dim=0)) + ) + for t, chunk in zip(_XLABS_SINGLE_QKV_TARGETS, chunks): + extras[f"{base}.{t}{suffix}"] = chunk continue # Double block: rename to canonical BFL-style; shared converter handles the QKV split. @@ -200,18 +195,23 @@ def _normalize_xlabs(state_dict): k = k.replace(old, new) canonical[k] = value - return canonical, extras + converted = _map_to_diffusers(canonical) if canonical else {} + return {**converted, **extras} # ============================================================================ -# Stage 2d: Kohya normalizer (sd-scripts and mixture variants) +# Kohya (sd-scripts + "mixture" variant) # ============================================================================ -# Kohya keys collapse all dots into underscores in the module path, then append -# .lora_down/.lora_up/.alpha. We invert this with explicit per-suffix tables +# Kohya keys collapse dots into underscores in the module path, then append +# .lora_down/.lora_up/.alpha. We invert this with a single explicit suffix table # (the original-name underscore <-> dot mapping isn't recoverable by rule), then # apply alpha-driven scaling so canonical tensors are pre-scaled. -_KOHYA_DOUBLE_SUFFIXES = { +# Kohya stub-suffix → BFL form. Block stubs (``double_blocks_{i}_`` and +# ``single_blocks_{i}_``) look up just the trailing here; everything +# else is a global stub that maps directly. No overlap between contexts. +_KOHYA_TO_BFL = { + # double_blocks_{i}_ "img_attn_proj": "img_attn.proj", "img_attn_qkv": "img_attn.qkv", "img_mlp_0": "img_mlp.0", @@ -222,13 +222,11 @@ def _normalize_xlabs(state_dict): "txt_mlp_0": "txt_mlp.0", "txt_mlp_2": "txt_mlp.2", "txt_mod_lin": "txt_mod.lin", -} -_KOHYA_SINGLE_SUFFIXES = { + # single_blocks_{i}_ "linear1": "linear1", "linear2": "linear2", "modulation_lin": "modulation.lin", -} -_KOHYA_GLOBAL_SUFFIXES = { + # Global stubs (used directly as canonical path) "guidance_in_in_layer": "guidance_in.in_layer", "guidance_in_out_layer": "guidance_in.out_layer", "img_in": "img_in", @@ -252,18 +250,67 @@ def _kohya_scale(alpha, rank): return down, up -def _custom_replace(key, substrings): - """Replace dots with underscores in `key` up to the first occurrence of any substring.""" - pattern = "(" + "|".join(re.escape(s) for s in substrings) + ")" - match = re.search(pattern, key) - if not match: - return key.replace(".", "_") - boundary = match.start() - 1 if match.start() > 0 and key[match.start() - 1] == "." else match.start() - return key[:boundary].replace(".", "_") + key[boundary:] +def _kohya_mixture_to_diffusers(state_dict): + """Convert Kohya mixture-format (``lora_transformer_*`` keys) directly to diffusers naming.""" + out = {} + unique = { + k.replace(".lora_A.weight", "").replace(".lora_B.weight", "").replace(".alpha", "") + for k in state_dict + if k.startswith("lora_transformer_") + } + + for k in unique: + if k.startswith("lora_transformer_single_transformer_blocks_"): + i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"single_transformer_blocks.{i}" + elif k.startswith("lora_transformer_transformer_blocks_"): + i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"transformer_blocks.{i}" + elif k.startswith("lora_transformer_context_embedder"): + diffusers_key = "context_embedder" + elif k.startswith("lora_transformer_norm_out_linear"): + diffusers_key = "norm_out.linear" + elif k.startswith("lora_transformer_proj_out"): + diffusers_key = "proj_out" + elif k.startswith("lora_transformer_x_embedder"): + diffusers_key = "x_embedder" + elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.text_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" + else: + raise NotImplementedError(f"Handling for key ({k}) is not implemented.") + + if "attn_" in k: + tail = k.split("attn_")[-1] + if "_to_out_0" in k: + diffusers_key += ".attn.to_out.0" + elif "_to_add_out" in k: + diffusers_key += ".attn.to_add_out" + elif any(qkv in k for qkv in ("to_q", "to_k", "to_v", "add_q_proj", "add_k_proj", "add_v_proj")): + diffusers_key += f".attn.{tail}" + + down = state_dict.pop(f"{k}.lora_A.weight") + up = state_dict.pop(f"{k}.lora_B.weight") + alpha = state_dict.pop(f"{k}.alpha") + d_scale, u_scale = _kohya_scale(alpha, down.shape[0]) + out[f"transformer.{diffusers_key}.lora_A.weight"] = down * d_scale + out[f"transformer.{diffusers_key}.lora_B.weight"] = up * u_scale + + leftover = [k for k in state_dict if not k.startswith("lora_unet_")] + if leftover: + logger.warning(f"Unsupported mixture keys ignored: {leftover}") + return out -def _kohya_pre_filter(state_dict): - """Drop Kohya keys we don't support (with logging), then normalize key prefixes.""" +def map_kohya_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from Kohya format (sd-scripts or mixture) to diffusers naming.""" + # ---- Pre-filter: rename prefix, drop unsupported keys, collapse leading dots. ---- state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} drop_specs = [ @@ -274,57 +321,38 @@ def _kohya_pre_filter(state_dict): for predicate, marker, label in drop_specs: if not any(predicate(k) for k in state_dict): continue - if state_dict_all_zero(state_dict, marker): - logger.info( - f"The `{label}` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the current state dict to make loading possible." - ) - else: - logger.info( - f"`{label}` keys found in the state dict are currently unsupported and will be filtered out. " - "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." - ) + msg = ( + f"The `{label}` LoRA params are all zeros which make them ineffective. So, we will purge them out of " + "the current state dict to make loading possible." + if state_dict_all_zero(state_dict, marker) + else f"`{label}` keys found in the state dict are currently unsupported and will be filtered out. " + "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." + ) + logger.info(msg) state_dict = {k: v for k, v in state_dict.items() if not predicate(k)} # Some keys come with dots in the prefix; collapse them up to lora_A/lora_B/alpha. - limit = ["lora_A", "lora_B"] - if any("alpha" in k for k in state_dict): - limit.append("alpha") - state_dict = {_custom_replace(k, limit): v for k, v in state_dict.items() if k.startswith("lora_unet_")} - - return state_dict - + limit = ["lora_A", "lora_B"] + (["alpha"] if any("alpha" in k for k in state_dict) else []) + boundary_re = re.compile("(" + "|".join(re.escape(s) for s in limit) + ")") -def _kohya_canonical_path(stub): - """Map a Kohya stub like 'double_blocks_0_img_attn_qkv' to BFL-style 'double_blocks.0.img_attn.qkv'.""" - m = re.match(r"double_blocks_(\d+)_(.+)$", stub) - if m: - i, suffix = m.group(1), m.group(2) - bfl = _KOHYA_DOUBLE_SUFFIXES.get(suffix) - return f"double_blocks.{i}.{bfl}" if bfl else None + def _collapse_prefix(key): + match = boundary_re.search(key) + if not match: + return key.replace(".", "_") + i = match.start() + boundary = i - 1 if i > 0 and key[i - 1] == "." else i + return key[:boundary].replace(".", "_") + key[boundary:] - m = re.match(r"single_blocks_(\d+)_(.+)$", stub) - if m: - i, suffix = m.group(1), m.group(2) - bfl = _KOHYA_SINGLE_SUFFIXES.get(suffix) - return f"single_blocks.{i}.{bfl}" if bfl else None + state_dict = {_collapse_prefix(k): v for k, v in state_dict.items() if k.startswith("lora_unet_")} - return _KOHYA_GLOBAL_SUFFIXES.get(stub) - - -def _normalize_kohya(state_dict): - state_dict = _kohya_pre_filter(state_dict) - - # Mixture variant has its own prefix (lora_transformer_*); dispatch separately. - has_mixture = any( + # ---- Mixture variant has its own prefix; route to its direct converter. ---- + if any( k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict - ) - if has_mixture: - return {}, _convert_mixture(state_dict) + ): + return _kohya_mixture_to_diffusers(state_dict) - # Group keys per Kohya module (lora_unet_) so we can apply alpha - # scaling, then rewrite to canonical names. - groups = {} # stub -> {"lora_A": full_key, "lora_B": ..., "alpha": ...} + # ---- sd-scripts variant: group by stub, apply alpha scaling, rewrite to canonical. ---- + groups = {} # stub -> {"lora_A": full_key, "lora_B": full_key, "alpha": full_key} for key in list(state_dict): if not key.startswith("lora_unet_"): continue @@ -336,7 +364,6 @@ def _normalize_kohya(state_dict): break canonical = {} - for stub, group in groups.items(): down_key, up_key = group.get("lora_A"), group.get("lora_B") if down_key is None or up_key is None: @@ -347,7 +374,18 @@ def _normalize_kohya(state_dict): down = state_dict.pop(down_key) * d_scale up = state_dict.pop(up_key) * u_scale - bfl = _kohya_canonical_path(stub) + # Map kohya stub → BFL canonical path. Block stubs strip their "{kind}_blocks_{i}_" + # prefix and look up the trailing suffix; global stubs map directly. + bfl = None + for kind in ("double_blocks", "single_blocks"): + m = re.match(rf"{kind}_(\d+)_(.+)$", stub) + if m: + suffix = _KOHYA_TO_BFL.get(m.group(2)) + bfl = f"{kind}.{m.group(1)}.{suffix}" if suffix else None + break + else: + bfl = _KOHYA_TO_BFL.get(stub) + if bfl is None: logger.warning(f"Unsupported Kohya key: lora_unet_{stub}") continue @@ -357,120 +395,49 @@ def _normalize_kohya(state_dict): if state_dict: logger.warning(f"Unsupported keys after Kohya normalization: {list(state_dict.keys())}") - return canonical, {} - - -# ---------------------------------------------------------------------------- -# Mixture variant (Kohya-trained but using lora_transformer_* keys) -# ---------------------------------------------------------------------------- - - -def _convert_mixture(state_dict): - """Convert Kohya mixture-format LoRA directly to diffusers keys.""" - new_state_dict = {} - - def emit(orig, diffusers_key): - down = state_dict.pop(f"{orig}.lora_A.weight") - up = state_dict.pop(f"{orig}.lora_B.weight") - alpha = state_dict.pop(f"{orig}.alpha") - rank = down.shape[0] - d_scale, u_scale = _kohya_scale(alpha, rank) - new_state_dict[f"{diffusers_key}.lora_A.weight"] = down * d_scale - new_state_dict[f"{diffusers_key}.lora_B.weight"] = up * u_scale - - unique = { - k.replace(".lora_A.weight", "").replace(".lora_B.weight", "").replace(".alpha", "") - for k in state_dict - if k.startswith("lora_transformer_") - } - - for k in unique: - if k.startswith("lora_transformer_single_transformer_blocks_"): - i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0]) - diffusers_key = f"single_transformer_blocks.{i}" - elif k.startswith("lora_transformer_transformer_blocks_"): - i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0]) - diffusers_key = f"transformer_blocks.{i}" - elif k.startswith("lora_transformer_context_embedder"): - diffusers_key = "context_embedder" - elif k.startswith("lora_transformer_norm_out_linear"): - diffusers_key = "norm_out.linear" - elif k.startswith("lora_transformer_proj_out"): - diffusers_key = "proj_out" - elif k.startswith("lora_transformer_x_embedder"): - diffusers_key = "x_embedder" - elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" - elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.text_embedder.linear_{i}" - elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" - else: - raise NotImplementedError(f"Handling for key ({k}) is not implemented.") - - if "attn_" in k: - tail = k.split("attn_")[-1] - if "_to_out_0" in k: - diffusers_key += ".attn.to_out.0" - elif "_to_add_out" in k: - diffusers_key += ".attn.to_add_out" - elif any(qkv in k for qkv in ("to_q", "to_k", "to_v", "add_q_proj", "add_k_proj", "add_v_proj")): - diffusers_key += f".attn.{tail}" - - emit(k, diffusers_key) - - leftover = [k for k in state_dict if not k.startswith("lora_unet_")] - if leftover: - logger.warning(f"Unsupported mixture keys ignored: {leftover}") - - return {f"transformer.{k}": v for k, v in new_state_dict.items()} + return _map_to_diffusers(canonical) # ============================================================================ # Top-level dispatch # ============================================================================ +# Per-format identifying key substrings. Single source of truth — also exported +# via ``FLUX_LORA_METADATA`` so ``LoRAModelMixin._detect_lora_format`` finds it. + +_FLUX_LORA_FORMAT_KEYS: dict[str, set[str]] = { + "kohya": {"lora_unet_double_blocks_", "lora_unet_single_blocks_"}, + "xlabs": {".processor.qkv_lora", ".processor.proj_lora"}, + "bfl": {"time_in.in_layer.lora_A", "double_blocks.0.img_mod.lin.lora_A"}, + "kontext": {"base_model.model.double_blocks"}, +} - -_NORMALIZERS = { - "bfl": _normalize_bfl, - "kontext": _normalize_kontext, - "xlabs": _normalize_xlabs, - "kohya": _normalize_kohya, +_FORMAT_DISPATCH = { + "bfl": map_bfl_to_diffusers, + "kontext": map_kontext_to_diffusers, + "xlabs": map_xlabs_to_diffusers, + "kohya": map_kohya_to_diffusers, } def map_lora_to_diffusers(state_dict, **kwargs): - """Convert a Flux LoRA state_dict from any supported format to diffusers naming. + """Detect a Flux LoRA's source format and dispatch to its per-format converter. - Suffix normalization (lora_down/up -> lora_A/B) is run by - ``LoRAMappingMixin.map_lora_to_diffusers`` before this is dispatched. + Already-converted (peft) state dicts pass through after filtering to ``transformer.*`` + keys. Unknown formats (incl. diffusers-native LoRAs with raw ``.alpha`` keys) pass + through unchanged so the pipeline's diffusers-native fallback can run. """ - # Already-converted (peft) state dicts: keep only the transformer.* keys. if any(k.startswith("transformer.") for k in state_dict): return {k: v for k, v in state_dict.items() if k.startswith("transformer.")} - fmt = FluxTransformerLoRAMixin._detect_lora_format(state_dict) - if fmt is None or fmt not in _NORMALIZERS: - raise ValueError( - f"Unable to determine format of LoRA weights. Supported formats are: {FluxTransformerLoRAMixin._lora_format_keys.keys()}" - ) - - canonical, extras = _NORMALIZERS[fmt](state_dict) - converted = _map_lora_to_diffusers(canonical) if canonical else {} - return {**converted, **extras} - - -class FluxTransformerLoRAMixin(LoRAMappingMixin): - """Mixin providing Flux-specific LoRA format detection and conversion.""" + keys = set(state_dict) + for fmt, fmt_keys in _FLUX_LORA_FORMAT_KEYS.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return _FORMAT_DISPATCH[fmt](state_dict) + return state_dict - _lora_format_keys: dict[str, set[str]] = { - "kohya": {"lora_unet_double_blocks_", "lora_unet_single_blocks_"}, - "xlabs": {".processor.qkv_lora", ".processor.proj_lora"}, - "bfl": {"time_in.in_layer.lora_A", "double_blocks.0.img_mod.lin.lora_A"}, - "kontext": {"base_model.model.double_blocks"}, - } - _map_lora_to_diffusers = staticmethod(map_lora_to_diffusers) +# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_LORA_METADATA = LoRAMetadata( + _lora_format_keys=_FLUX_LORA_FORMAT_KEYS, + _map_lora_to_diffusers=map_lora_to_diffusers, +) diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 5de2682bee70..df759e71b06e 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -20,8 +20,9 @@ import torch.nn as nn import torch.nn.functional as F -from ....configuration_utils import ConfigMixin, register_to_config -from ....loaders import FluxTransformer2DLoadersMixin, PeftAdapterMixin +from ....configuration_utils import register_to_config +from ....hooks._helpers import TransformerBlockMetadata +from ....loaders.ip_adapter_model import IPAdapterModelMixin from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -35,32 +36,16 @@ get_1d_rotary_pos_embed, ) from ...modeling_outputs import Transformer2DModelOutput -from ...modeling_utils import ModelMetadata, ModelMixin +from ...modeling_utils import ModelMetadata, ModelMixin, register_metadata from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from .lora import FluxTransformerLoRAMixin -from .weight_mapping import FluxTransformerWeightMappingMixin +from .ip_adapter import FLUX_IP_ADAPTER_METADATA +from .lora import FLUX_LORA_METADATA +from .weight_mapping import FLUX_WEIGHT_MAPPING_METADATA logger = logging.get_logger(__name__) -FLUX_METADATA = ModelMetadata( - supports_gradient_checkpointing=True, - no_split_modules=["FluxTransformerBlock", "FluxSingleTransformerBlock"], - skip_layerwise_casting_patterns=("pos_embed", "norm"), - repeated_blocks=["FluxTransformerBlock", "FluxSingleTransformerBlock"], - cp_plan={ - "": { - "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - }, - "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), - }, -) - - def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -372,6 +357,7 @@ def forward( @maybe_allow_in_graph +@register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0)) class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): super().__init__() @@ -426,6 +412,7 @@ def forward( @maybe_allow_in_graph +@register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0)) class FluxTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 @@ -541,15 +528,32 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +FLUX_MODEL_METADATA = ModelMetadata( + _supports_gradient_checkpointing=True, + _no_split_modules=["FluxTransformerBlock", "FluxSingleTransformerBlock"], + _skip_layerwise_casting_patterns=("pos_embed", "norm"), + _repeated_blocks=["FluxTransformerBlock", "FluxSingleTransformerBlock"], + _cp_plan={ + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + }, + _lora=FLUX_LORA_METADATA, + _weight_mapping=FLUX_WEIGHT_MAPPING_METADATA, + _ip_adapter=FLUX_IP_ADAPTER_METADATA, +) + + +@register_metadata(FLUX_MODEL_METADATA) class FluxTransformer2DModel( ModelMixin, - ConfigMixin, - PeftAdapterMixin, - FluxTransformerWeightMappingMixin, - FluxTransformerLoRAMixin, - FluxTransformer2DLoadersMixin, - CacheMixin, AttentionMixin, + CacheMixin, + IPAdapterModelMixin, ): """ The Transformer model introduced in Flux. @@ -582,8 +586,6 @@ class FluxTransformer2DModel( The dimensions to use for the rotary positional embeddings. """ - _model_metadata = FLUX_METADATA - @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index 90a3479ba8c3..748a34fa252f 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -17,6 +17,7 @@ import torch from ....loaders.weight_mapping import WeightMappingMixin +from ...modeling_utils import WeightMappingMetadata def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor: @@ -256,60 +257,54 @@ def map_from_diffusers( return converted_state_dict -class FluxTransformerWeightMappingMixin(WeightMappingMixin): - """ - Mixin providing Flux-specific weight mapping and conversion. +_FLUX_CHECKPOINT_KEY_PREFIXES: list[str] = ["model.diffusion_model."] + +# Distinctive keys for original format detection (only keys that use simple renaming, not splits) +_FLUX_CHECKPOINT_KEYS: set[str] = { + "time_in.in_layer.weight", + "double_blocks.0.img_mod.lin.weight", +} +_FLUX_MODEL_VARIANTS: dict[str, str] = { + "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", + "flux-fill": "black-forest-labs/FLUX.1-Fill-dev", + "flux-depth": "black-forest-labs/FLUX.1-Depth-dev", +} + + +def detect_model_variant(cls, state_dict: dict[str, Any]) -> str | None: + """Detect which Flux variant a state_dict belongs to (``flux-dev`` / ``-schnell`` / ``-fill`` / ``-depth``). - This mixin defines class attributes used by ModelMixin for checkpoint conversion: - - Checkpoint identification keys (shared across variants) - - Variant-specific metadata (config repos) - - Conversion function - - Default subfolder + Receives ``cls`` so it can reuse the model's ``_is_original_format`` / ``_rename_key`` helpers. """ + guidance_key = "guidance_in.in_layer.bias" + x_embedder_key = "img_in.weight" + + if not cls._is_original_format(state_dict): + guidance_key = cls._rename_key(guidance_key, FLUX_RENAME_PATTERNS) + x_embedder_key = cls._rename_key(x_embedder_key, FLUX_RENAME_PATTERNS) + + if x_embedder_key not in state_dict: + return None + + if guidance_key not in state_dict: + return "flux-schnell" + + in_channels = state_dict[x_embedder_key].shape[1] + if in_channels == 384: + return "flux-fill" + elif in_channels == 128: + return "flux-depth" + + return "flux-dev" + - _checkpoint_key_prefixes: list[str] = ["model.diffusion_model."] - # Distinctive keys for original format detection (only keys that use simple renaming, not splits) - _checkpoint_keys: set[str] = { - "time_in.in_layer.weight", - "double_blocks.0.img_mod.lin.weight", - } - _rename_patterns: dict[str, str] = FLUX_RENAME_PATTERNS - _model_variants: dict[str, str] = { - "flux-dev": "black-forest-labs/FLUX.1-dev", - "flux-schnell": "black-forest-labs/FLUX.1-schnell", - "flux-fill": "black-forest-labs/FLUX.1-Fill-dev", - "flux-depth": "black-forest-labs/FLUX.1-Depth-dev", - } - - _map_to_diffusers = staticmethod(map_to_diffusers) - _map_from_diffusers = staticmethod(map_from_diffusers) - _default_subfolder: str = "transformer" - - @classmethod - def _detect_model_variant(cls, state_dict: dict[str, Any]) -> str | None: - """ - Detect which Flux variant a state_dict belongs to. - - Returns the variant name (e.g., "flux-dev", "flux-schnell", "flux-fill", "flux-depth") - or None if unknown. - """ - guidance_key = "guidance_in.in_layer.bias" - x_embedder_key = "img_in.weight" - - if not cls._is_original_format(state_dict): - guidance_key = cls._rename_key(guidance_key, cls._rename_patterns) - x_embedder_key = cls._rename_key(x_embedder_key, cls._rename_patterns) - - if x_embedder_key not in state_dict: - return None - - if guidance_key not in state_dict: - return "flux-schnell" - - in_channels = state_dict[x_embedder_key].shape[1] - if in_channels == 384: - return "flux-fill" - elif in_channels == 128: - return "flux-depth" - - return "flux-dev" +# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_WEIGHT_MAPPING_METADATA = WeightMappingMetadata( + _checkpoint_keys=_FLUX_CHECKPOINT_KEYS, + _model_variants=_FLUX_MODEL_VARIANTS, + _map_to_diffusers=map_to_diffusers, + _map_from_diffusers=map_from_diffusers, + _detect_model_variant_fn=detect_model_variant, + _default_subfolder="transformer", +) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index ffb94c411b3e..3c3511380549 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ GGUF_FILE_EXTENSION, HF_ENABLE_PARALLEL_LOADING, HF_MODULES_CACHE, + HUB_KWARGS, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index cbfe2da0d32a..28b842ce2594 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -42,6 +42,19 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + +# Canonical set of hub-download kwargs (with defaults) forwarded to ``_get_model_file`` and +# related loaders. Use ``{k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()}`` to +# extract them from a caller's ``**kwargs`` in one shot. +HUB_KWARGS = { + "cache_dir": None, + "force_download": False, + "proxies": None, + "local_files_only": None, + "token": None, + "revision": None, + "subfolder": None, +} DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 058fa75c7d07..6464efb6331b 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -292,5 +292,3 @@ def check_peft_version(min_version: str) -> None: f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) - - From 0d1c8855a91d76394a264b010dc11e26b851bb9c Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 14 May 2026 23:07:40 +0530 Subject: [PATCH 04/21] update --- src/diffusers/hooks/_helpers.py | 7 +- src/diffusers/loaders/lora.py | 91 ++++++------ src/diffusers/loaders/lora_base.py | 16 +-- src/diffusers/loaders/unet.py | 7 +- src/diffusers/loaders/weight_mapping.py | 38 +++-- src/diffusers/models/modeling_utils.py | 136 ++++++------------ .../models/transformers/flux/lora.py | 31 ++-- 7 files changed, 127 insertions(+), 199 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 692bec80c84d..3dcd6d057388 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -34,10 +34,9 @@ class TransformerBlockMetadata: def _register(self, cls): """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. - Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that - opt into the decorator pattern (e.g. Flux). Models that use the legacy bulk registration - in ``_register_transformer_blocks_metadata`` are unaffected — both code paths call the - same ``TransformerBlockRegistry.register`` underneath. + Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator + pattern (e.g. Flux). Models that use the legacy bulk registration in ``_register_transformer_blocks_metadata`` + are unaffected — both code paths call the same ``TransformerBlockRegistry.register`` underneath. """ cls._block_metadata = self TransformerBlockRegistry.register(cls, self) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 2e7f4848f0fc..4368f75f2ffd 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -128,8 +128,7 @@ def _unfuse_lora_apply(module): def _serialize_lora_adapter_metadata(peft_config): """Convert a ``PeftConfig`` to a JSON string suitable for the safetensors metadata blob. - PEFT configs may contain ``set`` values (which JSON can't serialize); coerce those - to lists first. + PEFT configs may contain ``set`` values (which JSON can't serialize); coerce those to lists first. """ cfg = peft_config.to_dict() for key, value in cfg.items(): @@ -154,9 +153,8 @@ def _scope_state_dict_to_adapter(state_dict, adapter_name): def _split_majority_and_outliers(value_dict): """Return ``(majority, outliers)`` for ``value_dict``. - ``majority`` is the most common value (or the lone value if all are equal, or - None for an empty dict). ``outliers`` is a sub-dict of the items whose value - differs from the majority — empty when every value matches. + ``majority`` is the most common value (or the lone value if all are equal, or None for an empty dict). ``outliers`` + is a sub-dict of the items whose value differs from the majority — empty when every value matches. """ values = list(value_dict.values()) if not values: @@ -171,10 +169,9 @@ def _split_majority_and_outliers(value_dict): def _offloading_disabled(model): """Temporarily strip accelerate and group-offload hooks from ``model``. - PEFT injection and weight loading mutate the model graph in ways that fight with - active offload hooks (sequential CPU offload, group offload, etc.). This context - saves the hook state, removes the hooks for the duration of the block, and - restores them on exit so existing offloading config survives a LoRA load. + PEFT injection and weight loading mutate the model graph in ways that fight with active offload hooks (sequential + CPU offload, group offload, etc.). This context saves the hook state, removes the hooks for the duration of the + block, and restores them on exit so existing offloading config survives a LoRA load. """ saved_hf_hook = None is_sequential = False @@ -211,10 +208,9 @@ def _offloading_disabled(model): def _create_lora_config(state_dict, network_alphas, rank_dict, metadata=None): """Build a PEFT ``LoraConfig`` from a LoRA state dict. - ``metadata`` (when present) overrides the inferred kwargs entirely — used when a - saved adapter shipped its own serialized ``LoraConfig`` blob. Otherwise we infer: - per-module rank / alpha values that don't match the majority go into - ``rank_pattern`` / ``alpha_pattern``; the majority becomes the global default. + ``metadata`` (when present) overrides the inferred kwargs entirely — used when a saved adapter shipped its own + serialized ``LoraConfig`` blob. Otherwise we infer: per-module rank / alpha values that don't match the majority go + into ``rank_pattern`` / ``alpha_pattern``; the majority becomes the global default. """ if metadata is not None: return LoraConfig(**metadata) @@ -275,12 +271,11 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): def _fetch_state_dict(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): """Load a LoRA state dict from a path/repo/dict. - Safetensors only — pickle (``.bin``) LoRAs are no longer supported. Re-save legacy - checkpoints with ``safetensors.torch.save_file`` or load them manually with - ``torch.load`` and pass the resulting dict. + Safetensors only — pickle (``.bin``) LoRAs are no longer supported. Re-save legacy checkpoints with + ``safetensors.torch.save_file`` or load them manually with ``torch.load`` and pass the resulting dict. - ``hub_kwargs`` are the download / file-discovery options forwarded to - ``_get_model_file`` (see ``HUB_KWARGS`` for the canonical set). + ``hub_kwargs`` are the download / file-discovery options forwarded to ``_get_model_file`` (see ``HUB_KWARGS`` for + the canonical set). """ if isinstance(pretrained_model_name_or_path_or_dict, dict): return pretrained_model_name_or_path_or_dict @@ -295,8 +290,8 @@ def _fetch_state_dict(pretrained_model_name_or_path_or_dict, weight_name=None, * def _fetch_lora_metadata(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): """Load LoRA adapter metadata from a safetensors file's sidecar. - Returns ``None`` for non-safetensors sources (dicts, ``.bin`` files, missing - sidecar). The hub layer caches the file, so calling this after + Returns ``None`` for non-safetensors sources (dicts, ``.bin`` files, missing sidecar). The hub layer caches the + file, so calling this after """ if isinstance(pretrained_model_name_or_path_or_dict, dict): return None @@ -350,15 +345,12 @@ def _best_guess_weight_name( class LoRAModelMixin: """ - Single mixin for everything LoRA on a diffusers model: PEFT adapter lifecycle - (load / fuse / unfuse / set / delete / hotswap) plus foreign-format conversion - (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. + Single mixin for everything LoRA on a diffusers model: PEFT adapter lifecycle (load / fuse / unfuse / set / delete + / hotswap) plus foreign-format conversion (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. - Per-model conversion knobs live in a ``LoRAMetadata`` declared in the model's - ``lora.py`` (e.g. ``FLUX_LORA_METADATA``) and attached to the class via - ``@register_model_metadata(lora=...)``. The default no-op path just normalizes - ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state - dict unchanged. + Per-model conversion knobs live in a ``LoRAMetadata`` declared in the model's ``lora.py`` (e.g. + ``FLUX_LORA_METADATA``) and attached to the class via ``@register_model_metadata(lora=...)``. The default no-op + path just normalizes ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. Install the latest version of PEFT, and use this mixin to: @@ -406,9 +398,8 @@ def _normalize_lora_suffixes(cls, state_dict: Dict[str, "torch.Tensor"]) -> Dict def map_lora_to_diffusers(cls, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: """Canonicalize a LoRA state dict to diffusers naming. - Default: just normalize suffixes. Models with foreign formats register a - converter via ``LoRAMetadata._map_lora_to_diffusers`` — the decorator mirrors - it onto ``cls._map_lora_to_diffusers``. + Default: just normalize suffixes. Models with foreign formats register a converter via + ``LoRAMetadata._map_lora_to_diffusers`` — the decorator mirrors it onto ``cls._map_lora_to_diffusers``. """ state_dict = cls._normalize_lora_suffixes(state_dict) if cls._map_lora_to_diffusers is None: @@ -429,13 +420,13 @@ def load_adapter( ``source`` can be either: - - A ``PeftConfig`` (e.g. ``LoraConfig``) — initializes a fresh adapter with - random weights, suitable for training. - - A repo id, local path, or pre-loaded ``state_dict`` — loads pretrained - adapter weights, suitable for inference. + - A ``PeftConfig`` (e.g. ``LoraConfig``) — initializes a fresh adapter with random weights, suitable for + training. + - A repo id, local path, or pre-loaded ``state_dict`` — loads pretrained adapter weights, suitable for + inference. - For the config path, only ``adapter_name`` is used; ``prefix``, ``hotswap``, - and the download/loading kwargs apply to the pretrained path. + For the config path, only ``adapter_name`` is used; ``prefix``, ``hotswap``, and the download/loading kwargs + apply to the pretrained path. """ adapter_name = adapter_name or get_adapter_name(self) if isinstance(adapter, PeftConfig): @@ -620,9 +611,8 @@ def _load_adapter_from_pretrained( def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs): """Inject a new adapter into ``self`` and load its weights. - Returns the ``incompatible_keys`` reported by ``set_peft_model_state_dict``. - On failure, rolls back any partial peft_config / adapter modules so the model - is left in its prior state. + Returns the ``incompatible_keys`` reported by ``set_peft_model_state_dict``. On failure, rolls back any partial + peft_config / adapter modules so the model is left in its prior state. """ try: inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs) @@ -637,8 +627,8 @@ def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs): def _maybe_apply_deferred_hotswap_prep(self, lora_config): """If ``enable_lora_hotswap`` was called before the first adapter was loaded, - we deferred ``prepare_model_for_compiled_hotswap`` until LoRA layers existed. - Apply it now (after a successful inject) and clear the stash so it only fires once.""" + we deferred ``prepare_model_for_compiled_hotswap`` until LoRA layers existed. Apply it now (after a successful + inject) and clear the stash so it only fires once.""" if self._lora_hotswap_kwargs is None: return prepare_model_for_compiled_hotswap(self, config=lora_config, **self._lora_hotswap_kwargs) @@ -647,8 +637,8 @@ def _maybe_apply_deferred_hotswap_prep(self, lora_config): def _hotswap_adapter(self, state_dict, lora_config, adapter_name): """Replace the weights of an already-loaded adapter in-place. - ``hotswap_adapter_from_state_dict`` raises on incompatible keys; reaching the - end of this function means the swap succeeded. + ``hotswap_adapter_from_state_dict`` raises on incompatible keys; reaching the end of this function means the + swap succeeded. """ state_dict = _scope_state_dict_to_adapter(state_dict, adapter_name) check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) @@ -807,8 +797,8 @@ def load_lora_adapter( def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: """Deprecated alias for :meth:`set_adapters`. - Note: ``set_adapters`` resets the per-adapter scale to ``1.0`` when no weights - are passed; the original ``set_adapter`` left the previous scale untouched. + Note: ``set_adapters`` resets the per-adapter scale to ``1.0`` when no weights are passed; the original + ``set_adapter`` left the previous scale untouched. """ deprecate( "set_adapter", @@ -879,12 +869,11 @@ def unfuse_lora(self): def delete_adapters(self, adapter_names: Optional[Union[List[str], str]] = None): """Remove adapter(s) from the model. - Pass specific names to delete those adapters only — the PEFT wrapper layers - (``lora_A`` / ``lora_B`` modules) stay in place, so a subsequent - :meth:`load_adapter` call can reuse them without re-injecting. + Pass specific names to delete those adapters only — the PEFT wrapper layers (``lora_A`` / ``lora_B`` modules) + stay in place, so a subsequent :meth:`load_adapter` call can reuse them without re-injecting. - Pass ``None`` (the default) to remove every adapter *and* strip the wrapper - layers themselves, returning the model to its pre-LoRA state. + Pass ``None`` (the default) to remove every adapter *and* strip the wrapper layers themselves, returning the + model to its pre-LoRA state. """ if adapter_names is None: recurse_remove_peft_layers(self) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 62fcad958e11..aded26bf3624 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -59,14 +59,14 @@ def _func_optionally_disable_offloading(_pipeline): """Optionally remove accelerate offloading hooks before mutating a pipeline's components. - Walks ``_pipeline.components``, detects accelerate / group-offload hooks, and removes - accelerate hooks in-place (group-offload is reapplied later by the LoRA load path). - Returns ``(is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)`` so - callers know which offloading mode was active and can re-enable it after loading. - - Used by pipeline-side LoRA loaders (``LoraBaseMixin._optionally_disable_offloading``) - and the legacy paths in ``peft.py`` / ``unet.py``. Model-side loading uses the - ``_offloading_disabled`` context manager in ``loaders.lora`` instead. + Walks ``_pipeline.components``, detects accelerate / group-offload hooks, and removes accelerate hooks in-place + (group-offload is reapplied later by the LoRA load path). Returns ``(is_model_cpu_offload, + is_sequential_cpu_offload, is_group_offload)`` so callers know which offloading mode was active and can re-enable + it after loading. + + Used by pipeline-side LoRA loaders (``LoraBaseMixin._optionally_disable_offloading``) and the legacy paths in + ``peft.py`` / ``unet.py``. Model-side loading uses the ``_offloading_disabled`` context manager in ``loaders.lora`` + instead. """ from ..hooks.group_offloading import _is_group_offload_enabled diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1bf93d5cec42..3046e2b3cdcf 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -70,10 +70,9 @@ def _load_adapter_from_pretrained(self, pretrained_model_name_or_path_or_dict, * """UNet override that handles model-specific LoRA formats before delegating to the base loader. - Converts old non-PEFT UNet LoRA naming to PEFT shape (when no key carries ``lora_A``). - - Detects SAI Control LoRA (``lora_controlnet`` marker) — that path has its own - loader because the LoraConfig needs post-create overrides the base flow doesn't expose. - See https://huggingface.co/stabilityai/control-lora and - https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors. + - Detects SAI Control LoRA (``lora_controlnet`` marker) — that path has its own loader because the LoraConfig + needs post-create overrides the base flow doesn't expose. See https://huggingface.co/stabilityai/control-lora + and https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors. """ from ..utils import HUB_KWARGS from .lora import _fetch_state_dict diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index 63e3e4a683b2..b39ec5100697 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Reusable infrastructure for converting model checkpoints between original -and diffusers naming conventions. +"""Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. -A model declares its mapping in a ``WeightMappingMetadata`` instance (typically -in its ``weight_mapping.py`` module) and attaches it via -``@register_model_metadata(weight_mapping=...)``. This mixin supplies the -generic dispatch methods that read from that metadata. +A model declares its mapping in a ``WeightMappingMetadata`` instance (typically in its ``weight_mapping.py`` module) +and attaches it via ``@register_model_metadata(weight_mapping=...)``. This mixin supplies the generic dispatch methods +that read from that metadata. -The :meth:`apply_transforms` helper drives the forward direction from a single -declarative table — see ``models/transformers/flux/weight_mapping.py`` for an -example. +The :meth:`apply_transforms` helper drives the forward direction from a single declarative table — see +``models/transformers/flux/weight_mapping.py`` for an example. """ from typing import Optional @@ -32,9 +29,8 @@ class WeightMappingMixin: """ Base mixin providing utilities for checkpoint weight mapping and conversion. - Per-model configuration (rename patterns, format-identifying keys, conversion - callables, etc.) lives in the model's registered ``WeightMappingMetadata`` — - declared in the model's ``weight_mapping.py`` and attached via + Per-model configuration (rename patterns, format-identifying keys, conversion callables, etc.) lives in the model's + registered ``WeightMappingMetadata`` — declared in the model's ``weight_mapping.py`` and attached via ``@register_model_metadata``. This mixin just supplies the dispatch methods. """ @@ -81,8 +77,8 @@ def _is_original_format(cls, state_dict: dict) -> bool: def _detect_model_variant(cls, state_dict: dict) -> Optional[str]: """Detect which model variant a state_dict belongs to. - Dispatches to ``cls._detect_model_variant_fn`` (mirrored from the model's metadata); - raises if no detector is registered. + Dispatches to ``cls._detect_model_variant_fn`` (mirrored from the model's metadata); raises if no detector is + registered. """ if cls._detect_model_variant_fn is None: raise NotImplementedError( @@ -103,14 +99,12 @@ def apply_transforms(state_dict, transforms, rename_patterns, **ctx): """Drive a forward state-dict conversion from a list of (source, targets, fn) entries. Each entry is a tuple ``(source, targets, forward_fn, reverse_fn)``: - - ``source``: substring matched against each key (with surrounding dots, - e.g. ``".img_attn.qkv."``); the first matching entry wins. - - ``targets``: list of substrings substituted for ``source`` to build the - output keys. ``len(targets)`` is the fan-out (1 for a unary transform, - >1 for a split). - - ``forward_fn(value, **ctx) -> list[tensor]`` returns one tensor per - target. (``reverse_fn`` is reserved for a future - ``apply_reverse_transforms`` driver.) + - ``source``: substring matched against each key (with surrounding dots, e.g. ``".img_attn.qkv."``); the + first matching entry wins. + - ``targets``: list of substrings substituted for ``source`` to build the output keys. ``len(targets)`` is + the fan-out (1 for a unary transform, >1 for a split). + - ``forward_fn(value, **ctx) -> list[tensor]`` returns one tensor per target. (``reverse_fn`` is reserved for + a future ``apply_reverse_transforms`` driver.) Keys that match no transform get their dots renamed via ``rename_patterns``. """ diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c99016738e1a..67c1d9a2a74f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -238,16 +238,15 @@ def _skip_init(*args, **kwargs): class LoRAMetadata: """Per-model LoRA configuration: what foreign formats this model accepts and how to convert them. - Field names match the legacy ``cls._`` class attributes consumed by - ``LoRAModelMixin``, so the decorator can mirror them 1:1. + Field names match the legacy ``cls._`` class attributes consumed by ``LoRAModelMixin``, so the decorator can + mirror them 1:1. Attributes: _lora_format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying key substrings. The first format whose substrings appear in the state dict wins. _map_lora_to_diffusers: Callable ``(state_dict, **kwargs) -> state_dict`` that rewrites - foreign-format keys to diffusers naming. Called from - ``LoRAModelMixin.map_lora_to_diffusers`` after generic suffix normalization. - ``None`` for models that only ingest diffusers-native LoRAs. + foreign-format keys to diffusers naming. Called from ``LoRAModelMixin.map_lora_to_diffusers`` after generic + suffix normalization. ``None`` for models that only ingest diffusers-native LoRAs. """ _lora_format_keys: Dict[str, set] = field(default_factory=dict) @@ -258,17 +257,16 @@ class LoRAMetadata: class IPAdapterMetadata: """Per-model IP-Adapter configuration: how to convert IP-Adapter state dicts for this architecture. - Field names match the legacy ``cls._`` class attributes consumed by - ``IPAdapterModelMixin``, so the decorator can mirror them 1:1. + Field names match the legacy ``cls._`` class attributes consumed by ``IPAdapterModelMixin``, so the decorator + can mirror them 1:1. Attributes: _convert_ip_adapter_attn_to_diffusers: Callable - ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the - attn-processor dict ready for ``set_attn_processor``. Receives the model instance because it - needs ``model.attn_processors``, ``model.config``, ``model.inner_dim``, etc. + ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the attn-processor + dict ready for ``set_attn_processor``. Receives the model instance because it needs + ``model.attn_processors``, ``model.config``, ``model.inner_dim``, etc. _convert_ip_adapter_image_proj_to_diffusers: Callable - ``(model, state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image - projection layer. + ``(model, state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image projection layer. """ _convert_ip_adapter_attn_to_diffusers: Optional[Callable] = None @@ -279,13 +277,12 @@ class IPAdapterMetadata: class WeightMappingMetadata: """Per-model checkpoint conversion metadata for single-file loading. - Field names match the legacy ``cls._`` class attributes consumed by - ``WeightMappingMixin``, so the decorator can mirror them 1:1. + Field names match the legacy ``cls._`` class attributes consumed by ``WeightMappingMixin``, so the decorator + can mirror them 1:1. - Note: per-key rename tables and checkpoint key prefixes live in the model's - ``weight_mapping.py`` module as plain constants (e.g. ``FLUX_RENAME_PATTERNS``). - They're consumed directly by the model's ``map_to_diffusers`` / ``map_from_diffusers`` - callables and don't need to be threaded through metadata. + Note: per-key rename tables and checkpoint key prefixes live in the model's ``weight_mapping.py`` module as plain + constants (e.g. ``FLUX_RENAME_PATTERNS``). They're consumed directly by the model's ``map_to_diffusers`` / + ``map_from_diffusers`` callables and don't need to be threaded through metadata. Attributes: _checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is @@ -310,12 +307,12 @@ class ModelMetadata: """ Metadata describing model capabilities and configuration hints. - This is NOT configuration (which is saved to config.json and defines architecture). - This is static metadata about the model class's capabilities and hints for - optimization features like gradient checkpointing, offloading, and parallelism. + This is NOT configuration (which is saved to config.json and defines architecture). This is static metadata about + the model class's capabilities and hints for optimization features like gradient checkpointing, offloading, and + parallelism. - Field names match the legacy ``cls._`` class attributes (so the decorator - mirrors them 1:1 and existing consumer code keeps working). + Field names match the legacy ``cls._`` class attributes (so the decorator mirrors them 1:1 and existing + consumer code keeps working). Attributes: _supports_gradient_checkpointing: Whether the model supports gradient checkpointing @@ -352,10 +349,9 @@ class ModelMetadata: def _register(self, cls): """Attach this ``ModelMetadata`` to ``cls`` and mirror leaf fields to legacy class attrs. - Walks nested dataclasses (``_lora``, ``_weight_mapping``, ``_ip_adapter``) so their - leaf fields land flat on ``cls``. Field names already starting with ``_`` map 1:1 - (``_lora_format_keys`` → ``cls._lora_format_keys``); unprefixed names get the - underscore added (``rename_patterns`` → ``cls._rename_patterns``). + Walks nested dataclasses (``_lora``, ``_weight_mapping``, ``_ip_adapter``) so their leaf fields land flat on + ``cls``. Field names already starting with ``_`` map 1:1 (``_lora_format_keys`` → ``cls._lora_format_keys``); + unprefixed names get the underscore added (``rename_patterns`` → ``cls._rename_patterns``). """ cls._model_metadata = self pending = [self] @@ -373,16 +369,14 @@ def _register(self, cls): def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. - Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own - attachment logic. Works for both model-level metadata (``ModelMetadata``) and - block-level metadata (``TransformerBlockMetadata``):: + Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own attachment logic. Works for both + model-level metadata (``ModelMetadata``) and block-level metadata (``TransformerBlockMetadata``):: - @register_metadata(FLUX_MODEL_METADATA) - class FluxTransformer2DModel(...): + @register_metadata(FLUX_MODEL_METADATA) class FluxTransformer2DModel(...): ... - @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) - class FluxTransformerBlock(nn.Module): + @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) class + FluxTransformerBlock(nn.Module): ... """ @@ -424,49 +418,6 @@ class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixi _parallel_config = None _cp_plan = None _skip_keys = None - _model_metadata: Optional["ModelMetadata"] = None - - @classmethod - def get_metadata(cls) -> "ModelMetadata": - """ - Get the model's metadata for discovery and introspection. - - Returns a ModelMetadata instance describing the model's capabilities. - If `_model_metadata` is defined on the class, returns that directly. - Otherwise, constructs a ModelMetadata from the individual class attributes. - """ - if cls._model_metadata is not None: - return cls._model_metadata - # Fallback for unmigrated models: build from the legacy per-attribute class vars. - # ``lora`` / ``weight_mapping`` read from old mixin attrs if present, else stay empty. - return ModelMetadata( - _supports_gradient_checkpointing=cls._supports_gradient_checkpointing, - _no_split_modules=cls._no_split_modules, - _keep_in_fp32_modules=cls._keep_in_fp32_modules, - _skip_layerwise_casting_patterns=cls._skip_layerwise_casting_patterns, - _supports_group_offloading=cls._supports_group_offloading, - _repeated_blocks=cls._repeated_blocks if cls._repeated_blocks else [], - _cp_plan=cls._cp_plan, - _keys_to_ignore_on_load_unexpected=cls._keys_to_ignore_on_load_unexpected, - _lora=LoRAMetadata( - _lora_format_keys=getattr(cls, "_lora_format_keys", None) or {}, - _map_lora_to_diffusers=getattr(cls, "_map_lora_to_diffusers", None), - ), - _weight_mapping=WeightMappingMetadata( - _checkpoint_keys=getattr(cls, "_checkpoint_keys", None) or set(), - _model_variants=getattr(cls, "_model_variants", None) or {}, - _map_to_diffusers=getattr(cls, "_map_to_diffusers", None), - _map_from_diffusers=getattr(cls, "_map_from_diffusers", None), - _detect_model_variant_fn=getattr(cls, "_detect_model_variant_fn", None), - _default_subfolder=getattr(cls, "_default_subfolder", "transformer"), - ), - _ip_adapter=IPAdapterMetadata( - _convert_ip_adapter_attn_to_diffusers=getattr(cls, "_convert_ip_adapter_attn_to_diffusers", None), - _convert_ip_adapter_image_proj_to_diffusers=getattr( - cls, "_convert_ip_adapter_image_proj_to_diffusers", None - ), - ), - ) @classmethod def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: Dict[str, Any]) -> Dict[str, Any]: @@ -474,12 +425,11 @@ def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: Dict[str, An Two phases, both declared via the model's :class:`WeightMappingMetadata`: - 1. ``_normalize_checkpoint_keys`` — strip known prefixes (e.g. ``model.diffusion_model.``). - Run unconditionally; idempotent and a no-op if no prefixes were registered. - 2. ``_map_to_diffusers`` — the actual format converter, only invoked if step 1 alone - didn't already make the keys match. Skipped if no converter was registered (loading - then fails downstream with a clearer key-mismatch error than a deep - ``NotImplementedError``). + 1. ``_normalize_checkpoint_keys`` — strip known prefixes (e.g. ``model.diffusion_model.``). Run + unconditionally; idempotent and a no-op if no prefixes were registered. + 2. ``_map_to_diffusers`` — the actual format converter, only invoked if step 1 alone didn't already make the + keys match. Skipped if no converter was registered (loading then fails downstream with a clearer + key-mismatch error than a deep ``NotImplementedError``). """ # Step 1: always strip checkpoint key prefixes — idempotent, no-op if none registered. state_dict = cls._normalize_checkpoint_keys(state_dict) @@ -1696,8 +1646,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self: r""" - Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. - The model is set in evaluation mode (`model.eval()`) by default. + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model + is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path_or_dict (`str`, *optional*): @@ -1707,20 +1657,20 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = - A path to a local *file* containing the weights of the component model. - A state dict containing the component model weights. config (`str`, *optional*): - - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline - hosted on the Hub. - - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline - component configs in Diffusers format. + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted + on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component + configs in Diffusers format. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. torch_dtype (`torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, - overriding the cached versions if they exist. + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the - standard cache is not used. + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint. local_files_only (`bool`, *optional*, defaults to `False`): diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py index c33140a61d1a..d774fc5fc63e 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/lora.py @@ -16,21 +16,18 @@ Each supported foreign format has a top-level ``map__to_diffusers`` entry point: - - :func:`map_bfl_to_diffusers` — original BFL repo layout + - :func:`map_bfl_to_diffusers` — original BFL repo layout - :func:`map_kontext_to_diffusers` — fal Kontext checkpoints (BFL + ``base_model.model.`` prefix) - - :func:`map_xlabs_to_diffusers` — XLabs ``.processor.qkv_lora`` / ``.processor.proj_lora`` shape - - :func:`map_kohya_to_diffusers` — kohya sd-scripts (and "mixture" / ``lora_transformer_*`` variants) - -Each entry point produces a state dict with diffusers naming. Internally they all -funnel through :func:`_map_to_diffusers`, which converts a BFL-style state dict -(original Flux module names + ``.lora_A``/``.lora_B`` suffixes) to diffusers names by -reusing the rename / QKV-split / special-key tables in ``weight_mapping.py`` and applying -LoRA-specific QKV semantics (``lora_A.weight`` replicates across heads; everything else -chunks). - -A format-specific converter may also emit pre-converted diffusers keys directly when a -key shape doesn't fit the canonical intermediate (e.g., XLabs single-block QKV without -a paired MLP LoRA). + - :func:`map_xlabs_to_diffusers` — XLabs ``.processor.qkv_lora`` / ``.processor.proj_lora`` shape + - :func:`map_kohya_to_diffusers` — kohya sd-scripts (and "mixture" / ``lora_transformer_*`` variants) + +Each entry point produces a state dict with diffusers naming. Internally they all funnel through +:func:`_map_to_diffusers`, which converts a BFL-style state dict (original Flux module names + ``.lora_A``/``.lora_B`` +suffixes) to diffusers names by reusing the rename / QKV-split / special-key tables in ``weight_mapping.py`` and +applying LoRA-specific QKV semantics (``lora_A.weight`` replicates across heads; everything else chunks). + +A format-specific converter may also emit pre-converted diffusers keys directly when a key shape doesn't fit the +canonical intermediate (e.g., XLabs single-block QKV without a paired MLP LoRA). """ import re @@ -422,9 +419,9 @@ def _collapse_prefix(key): def map_lora_to_diffusers(state_dict, **kwargs): """Detect a Flux LoRA's source format and dispatch to its per-format converter. - Already-converted (peft) state dicts pass through after filtering to ``transformer.*`` - keys. Unknown formats (incl. diffusers-native LoRAs with raw ``.alpha`` keys) pass - through unchanged so the pipeline's diffusers-native fallback can run. + Already-converted (peft) state dicts pass through after filtering to ``transformer.*`` keys. Unknown formats (incl. + diffusers-native LoRAs with raw ``.alpha`` keys) pass through unchanged so the pipeline's diffusers-native fallback + can run. """ if any(k.startswith("transformer.") for k in state_dict): return {k: v for k, v in state_dict.items() if k.startswith("transformer.")} From a5ba74387da4b225d24a12bbb8ce03c408286e92 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 19 May 2026 14:13:09 +0530 Subject: [PATCH 05/21] update --- src/diffusers/models/attention.py | 2 ++ src/diffusers/models/modeling_utils.py | 7 +++---- src/diffusers/models/transformers/transformer_qwenimage.py | 1 - 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..f4cd1ff6856b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -122,6 +122,8 @@ class AttentionModuleMixin: _default_processor_cls = None _available_processors = [] _supports_qkv_fusion = True + _parallel_config = None + fused_projections = False def set_processor(self, processor: AttentionProcessor) -> None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 67c1d9a2a74f..66dc026eeda6 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -326,8 +326,9 @@ class ModelMetadata: _supports_group_offloading: Whether the model supports group offloading. _repeated_blocks: List of module class names that repeat throughout the model, useful for optimization and pattern analysis. - _cp_plan: Context parallel configuration plan defining how to split model - components for context parallelism across devices. + _cp_plan: Context parallel sharding plan. Maps model input/output tensor names to + ``ContextParallelInput`` / ``ContextParallelOutput`` declarations. Universal — + applies to any tensor-sharding work, not attention-specific. _keys_to_ignore_on_load_unexpected: List of keys to ignore when loading unexpected keys from a checkpoint. _lora: Per-model LoRA loading metadata. See :class:`LoRAMetadata`. @@ -415,7 +416,6 @@ class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixi _skip_layerwise_casting_patterns = None _supports_group_offloading = True _repeated_blocks = [] - _parallel_config = None _cp_plan = None _skip_keys = None @@ -2077,7 +2077,6 @@ def enable_parallelism( ) config.setup(rank, world_size, device, mesh=mesh) - self._parallel_config = config for module in self.modules(): if not isinstance(module, attention_classes): diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index bdb87a385da7..2d8bc58683b2 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -877,7 +877,6 @@ def forward( `tuple` where the first element is the sample tensor. """ hidden_states = self.img_in(hidden_states) - timestep = timestep.to(hidden_states.dtype) if self.zero_cond_t: From 644c3e7afef95a15c6ba133ab0bd14e5a1637e4c Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 19 May 2026 16:19:21 +0530 Subject: [PATCH 06/21] update --- src/diffusers/hooks/_helpers.py | 12 +- src/diffusers/loaders/ip_adapter_model.py | 72 +++++++++ src/diffusers/models/modeling_utils.py | 38 ++--- src/diffusers/models/transformers/__init__.py | 52 ++++++- .../models/transformers/flux/ip_adapter.py | 141 ++++++++++++++++++ .../models/transformers/flux/model.py | 76 ++++------ .../models/transformers/transformer_chroma.py | 2 +- src/diffusers/models/transformers/utils.py | 77 ++++++++++ src/diffusers/utils/peft_utils.py | 124 ++++++++++++++- 9 files changed, 516 insertions(+), 78 deletions(-) create mode 100644 src/diffusers/loaders/ip_adapter_model.py create mode 100644 src/diffusers/models/transformers/flux/ip_adapter.py create mode 100644 src/diffusers/models/transformers/utils.py diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 3dcd6d057388..cb4c609915ee 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -35,11 +35,13 @@ def _register(self, cls): """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator - pattern (e.g. Flux). Models that use the legacy bulk registration in ``_register_transformer_blocks_metadata`` - are unaffected — both code paths call the same ``TransformerBlockRegistry.register`` underneath. + pattern (e.g. Flux). Writes directly to the registry dict instead of going through + ``TransformerBlockRegistry.register`` so we don't trigger the lazy bulk-init while the decorated class's module + is mid-import (the bulk-init imports from the same module → circular). """ + self._cls = cls cls._block_metadata = self - TransformerBlockRegistry.register(cls, self) + TransformerBlockRegistry._registry[cls] = self def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): kwargs = kwargs or {} @@ -117,8 +119,8 @@ def _register(cls): def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 + from ..models.transformers.flux import FluxAttnProcessor from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor - from ..models.transformers.transformer_flux import FluxAttnProcessor from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 from ..models.transformers.transformer_wan import WanAttnProcessor2_0 @@ -182,9 +184,9 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock - from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_hunyuan_video import ( HunyuanVideoSingleTransformerBlock, HunyuanVideoTokenReplaceSingleTransformerBlock, diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py new file mode 100644 index 000000000000..695c43af48a5 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -0,0 +1,72 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model-side IP-Adapter mixin. + +Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives +here. Per-model conversion lives in a ``IPAdapterMetadata`` declared next to the model — e.g. ``flux/ip_adapter.py`` +exports ``FLUX_IP_ADAPTER_METADATA``, which is then composed into the model's ``ModelMetadata`` and attached via +``@register_model_metadata``. +""" + +from ..models.embeddings import MultiIPAdapterImageProjection +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +class IPAdapterModelMixin: + """Metadata-driven IP-Adapter loader for diffusers transformer / UNet models. + + Reads the per-model converters from ``IPAdapterMetadata`` (mirrored onto + ``cls._convert_ip_adapter_attn_to_diffusers`` and ``cls._convert_ip_adapter_image_proj_to_diffusers`` by + ``register_model_metadata``). + """ + + # No-op defaults; populated per-model by ``register_model_metadata``. + _convert_ip_adapter_attn_to_diffusers = None + _convert_ip_adapter_image_proj_to_diffusers = None + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Install IP-Adapter weights on the model. + + ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain + ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. + """ + if ( + self._convert_ip_adapter_attn_to_diffusers is None + or self._convert_ip_adapter_image_proj_to_diffusers is None + ): + raise NotImplementedError( + f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterMetadata." + ) + + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 66dc026eeda6..76ee096193e6 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -249,8 +249,8 @@ class LoRAMetadata: suffix normalization. ``None`` for models that only ingest diffusers-native LoRAs. """ - _lora_format_keys: Dict[str, set] = field(default_factory=dict) - _map_lora_to_diffusers: Optional[Callable] = None + _lora_format_keys: dict[str, set] = field(default_factory=dict) + _map_lora_to_diffusers: Callable | None = None @dataclass @@ -269,8 +269,8 @@ class IPAdapterMetadata: ``(model, state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image projection layer. """ - _convert_ip_adapter_attn_to_diffusers: Optional[Callable] = None - _convert_ip_adapter_image_proj_to_diffusers: Optional[Callable] = None + _convert_ip_adapter_attn_to_diffusers: Callable | None = None + _convert_ip_adapter_image_proj_to_diffusers: Callable | None = None @dataclass @@ -295,10 +295,10 @@ class WeightMappingMetadata: """ _checkpoint_keys: set = field(default_factory=set) - _model_variants: Dict[str, str] = field(default_factory=dict) - _map_to_diffusers: Optional[Callable] = None - _map_from_diffusers: Optional[Callable] = None - _detect_model_variant_fn: Optional[Callable] = None + _model_variants: dict[str, str] = field(default_factory=dict) + _map_to_diffusers: Callable | None = None + _map_from_diffusers: Callable | None = None + _detect_model_variant_fn: Callable | None = None _default_subfolder: str = "transformer" @@ -327,8 +327,8 @@ class ModelMetadata: _repeated_blocks: List of module class names that repeat throughout the model, useful for optimization and pattern analysis. _cp_plan: Context parallel sharding plan. Maps model input/output tensor names to - ``ContextParallelInput`` / ``ContextParallelOutput`` declarations. Universal — - applies to any tensor-sharding work, not attention-specific. + ``ContextParallelInput`` / ``ContextParallelOutput`` declarations. Universal — applies to any + tensor-sharding work, not attention-specific. _keys_to_ignore_on_load_unexpected: List of keys to ignore when loading unexpected keys from a checkpoint. _lora: Per-model LoRA loading metadata. See :class:`LoRAMetadata`. @@ -336,13 +336,13 @@ class ModelMetadata: """ _supports_gradient_checkpointing: bool = False - _no_split_modules: Optional[List[str]] = None - _keep_in_fp32_modules: Optional[List[str]] = None - _skip_layerwise_casting_patterns: Optional[Tuple[str, ...]] = None + _no_split_modules: list[str] | None = None + _keep_in_fp32_modules: list[str] | None = None + _skip_layerwise_casting_patterns: tuple[str, ...] | None = None _supports_group_offloading: bool = True - _repeated_blocks: List[str] = field(default_factory=list) - _cp_plan: Optional[Dict[str, Any]] = None - _keys_to_ignore_on_load_unexpected: Optional[List[str]] = None + _repeated_blocks: list[str] = field(default_factory=list) + _cp_plan: dict[str, Any] | None = None + _keys_to_ignore_on_load_unexpected: list[str] | None = None _lora: LoRAMetadata = field(default_factory=LoRAMetadata) _ip_adapter: IPAdapterMetadata = field(default_factory=IPAdapterMetadata) _weight_mapping: WeightMappingMetadata = field(default_factory=WeightMappingMetadata) @@ -388,7 +388,7 @@ def wrap(cls): return wrap -def _should_convert_checkpoint(model_state_dict: Dict[str, Any], checkpoint: Dict[str, Any]) -> bool: +def _should_convert_checkpoint(model_state_dict: dict[str, Any], checkpoint: dict[str, Any]) -> bool: """Check if checkpoint needs conversion by comparing keys with model state dict.""" model_state_dict_keys = set(model_state_dict.keys()) checkpoint_state_dict_keys = set(checkpoint.keys()) @@ -420,7 +420,7 @@ class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixi _skip_keys = None @classmethod - def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: Dict[str, Any]) -> Dict[str, Any]: + def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: dict[str, Any]) -> dict[str, Any]: """Convert ``state_dict`` from original format to diffusers format if needed. Two phases, both declared via the model's :class:`WeightMappingMetadata`: @@ -1644,7 +1644,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None @classmethod @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self: + def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = None, **kwargs) -> Self: r""" Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default. diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 1f45accb3d5e..d32f0970768f 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -1,13 +1,56 @@ import sys +import types -from ...utils import is_torch_available +from ...utils import deprecate, is_torch_available + + +class _DeprecatedModuleAlias(types.ModuleType): + """Backwards-compat alias for a transformer module that's been moved to a subpackage. + + Lives only in ``sys.modules`` — no stub file. Emits a one-time ``deprecate`` warning on first attribute access, + then forwards every attribute lookup to the new target module. Used when a flat ``transformer_.py`` is split + into a ``/`` subpackage and we want the old import path to keep working for a release cycle. + """ + + def __init__(self, old_dotted_path: str, target: types.ModuleType): + super().__init__(target.__name__, target.__doc__) + # Bypass __getattr__ when writing internals. + self.__dict__["_target"] = target + self.__dict__["_old_path"] = old_dotted_path + self.__dict__["_warned"] = False + + def __getattr__(self, name): + if not self.__dict__["_warned"]: + self.__dict__["_warned"] = True + old = self.__dict__["_old_path"] + new = self.__dict__["_target"].__name__ + deprecate( + old, + "1.0.0", + f"Importing from `{old}` is deprecated. Import from `{new}` instead.", + standard_warn=True, + stacklevel=3, + ) + return getattr(self.__dict__["_target"], name) + + +def _register_legacy_module_alias(old_name: str, new_name: str) -> None: + """Register ``old_name`` as a deprecated alias for the already-loaded ``new_name`` submodule. + + Both names are relative to ``diffusers.models.transformers``. The new submodule must already be in + ``sys.modules`` (loaded by a prior ``from . import ...`` in this file). + """ + old_dotted = f"{__name__}.{old_name}" + target = sys.modules[f"{__name__}.{new_name}"] + sys.modules[old_dotted] = _DeprecatedModuleAlias(old_dotted, target) if is_torch_available(): - from . import flux + # Load flux first and install the legacy alias before any other transformer module imports, + # since some of them still pull from `transformer_flux` during their own load. + from .flux import FluxTransformer2DModel - # Register backwards compatibility alias so `from .transformer_flux import X` works - sys.modules[__name__ + ".transformer_flux"] = flux + _register_legacy_module_alias("transformer_flux", "flux") from .ace_step_transformer import AceStepTransformer1DModel from .auraflow_transformer_2d import AuraFlowTransformer2DModel @@ -15,7 +58,6 @@ from .consisid_transformer_3d import ConsisIDTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel - from .flux import FluxTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel from .latte_transformer_3d import LatteTransformer3DModel from .lumina_nextdit2d import LuminaNextDiT2DModel diff --git a/src/diffusers/models/transformers/flux/ip_adapter.py b/src/diffusers/models/transformers/flux/ip_adapter.py new file mode 100644 index 000000000000..f6232ed64678 --- /dev/null +++ b/src/diffusers/models/transformers/flux/ip_adapter.py @@ -0,0 +1,141 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flux IP-Adapter conversion. + +Per-model converters consumed by ``IPAdapterModelMixin`` via ``FLUX_IP_ADAPTER_METADATA``: + +- ``convert_image_proj``: rewrites ``proj.weight`` → ``image_embeds.weight`` and builds an ``ImageProjection`` sized + off the source state dict (4 or 16 image-text embeds depending on the ``proj.weight`` row count). +- ``convert_attn_processors``: walks ``model.attn_processors``, skips ``single_transformer_blocks`` (Flux only attaches + IP-Adapter on the double-stream blocks), and builds one ``FluxIPAdapterAttnProcessor`` per remaining block. Reads + ``model.config.joint_attention_dim`` and ``model.inner_dim`` for the projection dimensions and pulls ``to_k_ip`` / + ``to_v_ip`` weights/biases keyed by ``key_id``. +""" + +from contextlib import nullcontext + +from ....models.embeddings import ImageProjection +from ....models.model_loading_utils import load_model_dict_into_meta +from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, IPAdapterMetadata +from ....utils import is_accelerate_available, is_torch_version, logging +from ....utils.torch_utils import empty_device_cache + + +logger = logging.get_logger(__name__) + + +def _resolve_init_context(low_cpu_mem_usage): + """Return ``(init_context, low_cpu_mem_usage)`` — disables low-cpu init if accelerate is missing.""" + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + if not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch " + "version or set `low_cpu_mem_usage=False`." + ) + return init_empty_weights, True + + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the " + "environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install " + "`accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip " + "install accelerate\n```\n." + ) + return nullcontext, False + + +def convert_image_proj(model, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Build a Flux ``ImageProjection`` from an IP-Adapter ``image_proj`` state dict.""" + init_context, low_cpu_mem_usage = _resolve_init_context(low_cpu_mem_usage) + + # ``proj.weight`` rows == cross_attention_dim * num_image_text_embeds. The two + # supported configurations: 4 tokens (default) and 16 tokens (when rows == 65536). + num_image_text_embeds = 16 if state_dict["proj.weight"].shape[0] == 65536 else 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + updated_state_dict = {key.replace("proj", "image_embeds"): value for key, value in state_dict.items()} + + if low_cpu_mem_usage: + load_model_dict_into_meta( + image_projection, updated_state_dict, device_map={"": model.device}, dtype=model.dtype + ) + empty_device_cache() + else: + image_projection.load_state_dict(updated_state_dict, strict=True) + + return image_projection + + +def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Build the IP-Adapter attn-processor dict for a ``FluxTransformer2DModel``. + + Single-stream blocks keep their existing processor; double-stream blocks get a ``FluxIPAdapterAttnProcessor`` + loaded with the per-state-dict ``to_k_ip`` / ``to_v_ip`` weights. + """ + from .model import FluxIPAdapterAttnProcessor + + init_context, low_cpu_mem_usage = _resolve_init_context(low_cpu_mem_usage) + + attn_procs = {} + key_id = 0 + for name in model.attn_processors: + if name.startswith("single_transformer_blocks"): + attn_procs[name] = model.attn_processors[name].__class__() + continue + + num_image_text_embeds = [16 if sd["image_proj"]["proj.weight"].shape[0] == 65536 else 4 for sd in state_dicts] + + with init_context(): + attn_procs[name] = FluxIPAdapterAttnProcessor( + hidden_size=model.inner_dim, + cross_attention_dim=model.config.joint_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + dtype=model.dtype, + device=model.device, + ) + + value_dict = {} + for i, sd in enumerate(state_dicts): + value_dict[f"to_k_ip.{i}.weight"] = sd["ip_adapter"][f"{key_id}.to_k_ip.weight"] + value_dict[f"to_v_ip.{i}.weight"] = sd["ip_adapter"][f"{key_id}.to_v_ip.weight"] + value_dict[f"to_k_ip.{i}.bias"] = sd["ip_adapter"][f"{key_id}.to_k_ip.bias"] + value_dict[f"to_v_ip.{i}.bias"] = sd["ip_adapter"][f"{key_id}.to_v_ip.bias"] + + if low_cpu_mem_usage: + load_model_dict_into_meta(attn_procs[name], value_dict, device_map={"": model.device}, dtype=model.dtype) + else: + attn_procs[name].load_state_dict(value_dict) + + key_id += 1 + + empty_device_cache() + return attn_procs + + +# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_IP_ADAPTER_METADATA = IPAdapterMetadata( + _convert_ip_adapter_attn_to_diffusers=convert_attn_processors, + _convert_ip_adapter_image_proj_to_diffusers=convert_image_proj, +) diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index df759e71b06e..1112b3feacc8 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import numpy as np import torch @@ -23,7 +23,7 @@ from ....configuration_utils import register_to_config from ....hooks._helpers import TransformerBlockMetadata from ....loaders.ip_adapter_model import IPAdapterModelMixin -from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ....utils import apply_lora_scale, logging from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput from ...attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -89,8 +89,8 @@ def __call__( attn: "FluxAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, ) -> torch.Tensor: query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( attn, hidden_states, encoder_hidden_states @@ -134,9 +134,9 @@ def __call__( encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states.contiguous()) hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous()) return hidden_states, encoder_hidden_states else: @@ -189,10 +189,10 @@ def __call__( attn: "FluxAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + ip_hidden_states: list[torch.Tensor] | None = None, + ip_adapter_masks: torch.Tensor | None = None, ) -> torch.Tensor: batch_size = hidden_states.shape[0] @@ -290,12 +290,12 @@ def __init__( dim_head: int = 64, dropout: float = 0.0, bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, out_bias: bool = True, eps: float = 1e-5, out_dim: int = None, - context_pre_only: Optional[bool] = None, + context_pre_only: bool | None = None, pre_only: bool = False, elementwise_affine: bool = True, processor=None, @@ -340,9 +340,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) @@ -384,9 +384,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -445,9 +445,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( @@ -499,7 +499,7 @@ def forward( class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: List[int]): + def __init__(self, theta: int, axes_dim: list[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim @@ -582,7 +582,7 @@ class FluxTransformer2DModel( The number of dimensions to use for the pooled projection. guidance_embeds (`bool`, defaults to `False`): Whether to use guidance embeddings for guidance-distilled variant of the model. - axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): The dimensions to use for the rotary positional embeddings. """ @@ -591,7 +591,7 @@ def __init__( self, patch_size: int = 1, in_channels: int = 64, - out_channels: Optional[int] = None, + out_channels: int | None = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, @@ -599,7 +599,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + axes_dims_rope: tuple[int, int, int] = (16, 56, 56), ): super().__init__() self.out_channels = out_channels or in_channels @@ -644,6 +644,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -653,12 +654,12 @@ def forward( img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: dict[str, Any] | None = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: + ) -> torch.Tensor | Transformer2DModelOutput: """ The [`FluxTransformer2DModel`] forward method. @@ -685,21 +686,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -795,10 +781,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d7cc96d018b3..0a55e3202f77 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -30,7 +30,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm -from .transformer_flux import FluxAttention, FluxAttnProcessor +from .flux import FluxAttention, FluxAttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/utils.py b/src/diffusers/models/transformers/utils.py new file mode 100644 index 000000000000..b14ab8d206e2 --- /dev/null +++ b/src/diffusers/models/transformers/utils.py @@ -0,0 +1,77 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for transformer model implementations.""" + +from dataclasses import dataclass, fields + +import torch + + +@dataclass +class TransformerModuleOutput: + """Base class providing tuple-compatible iteration for structured submodule outputs. + + Doesn't declare any fields itself — subclasses define their own schema. Provides only the plumbing that lets + callers unpack positionally (``h, e = output``), index (``output[0]``), and check length, with ``None`` fields + transparently skipped so a single-stream output unpacks as a 1-tuple. This matches the legacy bare-tuple return + shape so subclasses can be adopted without touching callers. + """ + + def _as_tuple(self): + """Tuple-compat view of the dataclass: declared field order, with ``None`` values skipped.""" + return tuple(getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None) + + def __iter__(self): + return iter(self._as_tuple()) + + def __getitem__(self, idx): + return self._as_tuple()[idx] + + def __len__(self): + return len(self._as_tuple()) + + +@dataclass +class TransformerBlockOutput(TransformerModuleOutput): + """Structured return type for transformer-block ``forward`` methods. + + Replaces the historical pattern of returning bare tuples whose element ordering varied per model (e.g. Flux + returned ``(encoder_hidden_states, hidden_states)`` while CogVideoX returned ``(hidden_states, + encoder_hidden_states)``). Tuple-compatibility inherited from :class:`TransformerModuleOutput`. + + Attributes: + hidden_states: The block's primary output tensor. Always populated. + encoder_hidden_states: The text / context stream output for dual-stream blocks. ``None`` for single-stream. + """ + + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor | None = None + + +@dataclass +class AttnProcessorOutput(TransformerModuleOutput): + """Structured return type for attention-processor ``__call__`` methods. + + Replaces the historical pattern of returning a bare tensor for single-stream attention and a bare + ``(hidden_states, encoder_hidden_states)`` tuple for dual-stream attention. Tuple-compatibility inherited from + :class:`TransformerModuleOutput`. + + Attributes: + hidden_states: The processor's primary output tensor. Always populated. + encoder_hidden_states: The text / context stream output for dual-stream attention processors. ``None`` for + single-stream. + """ + + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor | None = None diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 6464efb6331b..f2b5b4ccd822 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -22,7 +22,7 @@ from packaging import version from . import logging -from .import_utils import is_peft_available, is_torch_available +from .import_utils import is_peft_available, is_peft_version, is_torch_available from .torch_utils import empty_device_cache @@ -292,3 +292,125 @@ def check_peft_version(min_version: str) -> None: f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) + + +def get_peft_kwargs( + rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None +): + rank_pattern = {} + alpha_pattern = {} + r = lora_alpha = list(rank_dict.values())[0] + + if len(set(rank_dict.values())) > 1: + # get the rank occurring the most number of times + r = collections.Counter(rank_dict.values()).most_common()[0][0] + + # for modules with rank different from the most occurring rank, add it to the `rank_pattern` + rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} + + if network_alpha_dict is not None and len(network_alpha_dict) > 0: + if len(set(network_alpha_dict.values())) > 1: + # get the alpha occurring the most number of times + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + + # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + if is_unet: + alpha_pattern = { + ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v + for k, v in alpha_pattern.items() + } + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + else: + lora_alpha = set(network_alpha_dict.values()).pop() + + target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) + lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": target_modules, + "use_dora": use_dora, + "lora_bias": lora_bias, + } + + return lora_config_kwargs + + +def _create_lora_config( + state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None +): + from peft import LoraConfig + + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs( + rank_pattern_dict, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + is_unet=is_unet, + model_state_dict=model_state_dict, + adapter_name=adapter_name, + ) + + _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) + + if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") + + if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") + + try: + return LoraConfig(**lora_config_kwargs) + except TypeError as e: + raise TypeError("`LoraConfig` class could not be instantiated.") from e + + +def _maybe_raise_error_for_ambiguous_keys(config): + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + + for key in list(rank_pattern.keys()): + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + + if exact_matches and substring_matches: + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) + + +def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): + warn_msg = "" + if incompatible_keys is not None: + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) From d73d98555a811821af5163701db03330a6a762b6 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 19 May 2026 17:43:19 +0530 Subject: [PATCH 07/21] update --- src/diffusers/loaders/weight_mapping.py | 57 +++++++++++++------ src/diffusers/models/modeling_utils.py | 41 +++++++++---- src/diffusers/models/transformers/__init__.py | 4 +- .../transformers/flux/weight_mapping.py | 14 +++-- src/diffusers/models/transformers/utils.py | 4 +- 5 files changed, 83 insertions(+), 37 deletions(-) diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index b39ec5100697..d010a1f6d750 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -38,10 +38,11 @@ class WeightMappingMixin: _checkpoint_key_prefixes: list = [] _checkpoint_keys: set = set() _rename_patterns: dict = {} - _model_variants: dict = {} + _available_configs: dict = {} _map_to_diffusers = None _map_from_diffusers = None - _detect_model_variant_fn = None + _detect_config_fn = None + _default_config: Optional[str] = None _default_subfolder: str = "transformer" @staticmethod @@ -74,25 +75,49 @@ def _is_original_format(cls, state_dict: dict) -> bool: return bool(cls._checkpoint_keys & set(state_dict.keys())) @classmethod - def _detect_model_variant(cls, state_dict: dict) -> Optional[str]: - """Detect which model variant a state_dict belongs to. + def _detect_config(cls, state_dict: dict) -> Optional[str]: + """Detect which config name from ``_available_configs`` matches this state_dict. - Dispatches to ``cls._detect_model_variant_fn`` (mirrored from the model's metadata); raises if no detector is - registered. + Dispatches to ``cls._detect_config_fn`` (mirrored from the model's metadata). If no detector is registered, + returns ``None`` so the caller can fall back to ``_default_config``. """ - if cls._detect_model_variant_fn is None: - raise NotImplementedError( - f"{cls.__name__} did not register a `_detect_model_variant_fn` in its WeightMappingMetadata." - ) - return cls._detect_model_variant_fn(cls, state_dict) + if cls._detect_config_fn is None: + return None + return cls._detect_config_fn(cls, state_dict) @classmethod def _get_model_config(cls, state_dict: dict) -> str: - """Get the default config repo for the detected variant.""" - variant = cls._detect_model_variant(state_dict) - if variant is None: - raise ValueError(f"Could not detect model variant from state_dict. Expected keys: {cls._checkpoint_keys}") - return cls._model_variants[variant] + """Resolve the hub repo id whose config best matches this checkpoint. + + Resolution order: + 1. Run ``_detect_config_fn`` (if registered) against the state_dict; it should return a config name from + ``_available_configs`` or ``None``. + 2. If detection returns ``None`` (or no detector is registered), fall back to ``_default_config``. + 3. Look up the chosen name in ``_available_configs`` to get the hub repo id. + """ + config_name = cls._detect_config(state_dict) or cls._default_config + if config_name is None: + available = sorted(cls._available_configs) or "" + has_detector = cls._detect_config_fn is not None + raise ValueError( + f"`{cls.__name__}.from_single_file` could not determine which config to load for this checkpoint.\n" + f"\n" + f" Detection: {'registered, but returned None for this state_dict' if has_detector else 'no `_detect_config_fn` registered'}\n" + f" Default config: not set\n" + f" Available configs: {available}\n" + f"\n" + f"To fix this, either:\n" + f' - pass `config=""` to `from_single_file(...)` to skip auto-detection, OR\n' + f" - update `{cls.__name__}`'s `WeightMappingMetadata` to register a `_detect_config_fn` that " + f"returns a name from `_available_configs`, and/or set `_default_config` to a name in " + f"`_available_configs`." + ) + if config_name not in cls._available_configs: + raise ValueError( + f"{cls.__name__}: resolved config name '{config_name}' is not a key of `_available_configs` " + f"(available: {sorted(cls._available_configs)})." + ) + return cls._available_configs[config_name] @staticmethod def apply_transforms(state_dict, transforms, rename_patterns, **ctx): diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 76ee096193e6..88215a39c1ef 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -287,18 +287,27 @@ class WeightMappingMetadata: Attributes: _checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is in the original (pre-diffusers) format. - _model_variants: Map of variant name to its default config repo on the Hub. + _available_configs: + Map of short config name to hub repo id (e.g. ``{"flux-dev": "black-forest-labs/FLUX.1-dev"}``). + ``from_single_file`` resolves a config name (via detection or default) to a repo id through this map. + Single-config models can ship a one-entry dict; multi-config models like Flux list all known architectures. + The short names are stable identifiers — useful in detection logic, error messages, and tracebacks — + independent of where the configs are currently hosted on the hub. _map_to_diffusers / _map_from_diffusers: Callables driving the two conversion directions. - _detect_model_variant_fn: Optional ``(cls, state_dict) -> Optional[str]`` for picking the - right variant when a single checkpoint format spans multiple architectures. + _detect_config_fn: Optional ``(cls, state_dict) -> Optional[str]`` returning a config name (key into + ``_available_configs``) or ``None`` to defer to ``_default_config``. + _default_config: Config name (key into ``_available_configs``) used when ``_detect_config_fn`` is + unregistered or returns ``None``. Lets single-config models skip detection entirely and multi-config models + declare a "best guess" fallback. _default_subfolder: Default ``subfolder`` to use when fetching configs (e.g. ``"transformer"``). """ _checkpoint_keys: set = field(default_factory=set) - _model_variants: dict[str, str] = field(default_factory=dict) + _available_configs: dict[str, str] = field(default_factory=dict) _map_to_diffusers: Callable | None = None _map_from_diffusers: Callable | None = None - _detect_model_variant_fn: Callable | None = None + _detect_config_fn: Callable | None = None + _default_config: str | None = None _default_subfolder: str = "transformer" @@ -1700,14 +1709,22 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No load_single_file_checkpoint, ) - # `map_to_diffusers` is inherited universally via ``WeightMappingMixin``-via-``ModelMixin``; - # the genuine "is this single-file capable?" signal is whether the model registered a - # ``_map_to_diffusers`` callable in its ``WeightMappingMetadata``. - if getattr(cls, "_map_to_diffusers", None) is None: + # A model is "single-file capable" if its ``WeightMappingMetadata`` declares enough to (a) bring keys to + # diffusers naming and (b) resolve which config to load. The two paths are: + # - converter path: ``_map_to_diffusers`` is registered (full key conversion); + # - declarative path: ``_checkpoint_key_prefixes`` strips a known prefix and the model is otherwise + # in diffusers format (e.g. prefix-only finetunes). + # Either way, ``_available_configs`` must be non-empty so we know which config repo to fetch. + has_converter = getattr(cls, "_map_to_diffusers", None) is not None + has_prefix_only = bool(getattr(cls, "_checkpoint_key_prefixes", None)) + has_available_configs = bool(getattr(cls, "_available_configs", None)) + if not (has_available_configs and (has_converter or has_prefix_only)): raise ValueError( - f"{cls.__name__} does not support `from_single_file`. " - f"Register a `_map_to_diffusers` callable in its `WeightMappingMetadata` " - f"(or use `from_pretrained` if the model is in diffusers format)." + f"`{cls.__name__}.from_single_file` is not supported. " + f"The model's `WeightMappingMetadata` must register `_available_configs` (so we know which config " + f"to load) plus at least one of: `_map_to_diffusers` (full key conversion) or " + f"`_checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " + f"foreign prefix). Use `from_pretrained` if the model is already in diffusers format." ) default_subfolder = getattr(cls, "_default_subfolder", None) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d32f0970768f..6ad7f6ae551c 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,8 +37,8 @@ def __getattr__(self, name): def _register_legacy_module_alias(old_name: str, new_name: str) -> None: """Register ``old_name`` as a deprecated alias for the already-loaded ``new_name`` submodule. - Both names are relative to ``diffusers.models.transformers``. The new submodule must already be in - ``sys.modules`` (loaded by a prior ``from . import ...`` in this file). + Both names are relative to ``diffusers.models.transformers``. The new submodule must already be in ``sys.modules`` + (loaded by a prior ``from . import ...`` in this file). """ old_dotted = f"{__name__}.{old_name}" target = sys.modules[f"{__name__}.{new_name}"] diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index 748a34fa252f..8f8ef28f13b8 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -264,7 +264,7 @@ def map_from_diffusers( "time_in.in_layer.weight", "double_blocks.0.img_mod.lin.weight", } -_FLUX_MODEL_VARIANTS: dict[str, str] = { +_FLUX_AVAILABLE_CONFIGS: dict[str, str] = { "flux-dev": "black-forest-labs/FLUX.1-dev", "flux-schnell": "black-forest-labs/FLUX.1-schnell", "flux-fill": "black-forest-labs/FLUX.1-Fill-dev", @@ -272,8 +272,8 @@ def map_from_diffusers( } -def detect_model_variant(cls, state_dict: dict[str, Any]) -> str | None: - """Detect which Flux variant a state_dict belongs to (``flux-dev`` / ``-schnell`` / ``-fill`` / ``-depth``). +def detect_config(cls, state_dict: dict[str, Any]) -> str | None: + """Detect which Flux config name matches this state_dict. Receives ``cls`` so it can reuse the model's ``_is_original_format`` / ``_rename_key`` helpers. """ @@ -302,9 +302,13 @@ def detect_model_variant(cls, state_dict: dict[str, Any]) -> str | None: # Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. FLUX_WEIGHT_MAPPING_METADATA = WeightMappingMetadata( _checkpoint_keys=_FLUX_CHECKPOINT_KEYS, - _model_variants=_FLUX_MODEL_VARIANTS, + _available_configs=_FLUX_AVAILABLE_CONFIGS, _map_to_diffusers=map_to_diffusers, _map_from_diffusers=map_from_diffusers, - _detect_model_variant_fn=detect_model_variant, + _detect_config_fn=detect_config, + # Kicks in only when ``detect_config`` returns ``None`` (e.g. the ``img_in`` / ``x_embedder`` key is + # absent so we can't read in_channels). Most Flux checkpoints in the wild are dev-derived, so it's + # the safest fallback config to load. + _default_config="flux-dev", _default_subfolder="transformer", ) diff --git a/src/diffusers/models/transformers/utils.py b/src/diffusers/models/transformers/utils.py index b14ab8d206e2..a99e2fcdc392 100644 --- a/src/diffusers/models/transformers/utils.py +++ b/src/diffusers/models/transformers/utils.py @@ -63,8 +63,8 @@ class TransformerBlockOutput(TransformerModuleOutput): class AttnProcessorOutput(TransformerModuleOutput): """Structured return type for attention-processor ``__call__`` methods. - Replaces the historical pattern of returning a bare tensor for single-stream attention and a bare - ``(hidden_states, encoder_hidden_states)`` tuple for dual-stream attention. Tuple-compatibility inherited from + Replaces the historical pattern of returning a bare tensor for single-stream attention and a bare ``(hidden_states, + encoder_hidden_states)`` tuple for dual-stream attention. Tuple-compatibility inherited from :class:`TransformerModuleOutput`. Attributes: From eefc961b647b0e456c856c9d02bacf2e20b532e2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 20 May 2026 17:16:36 +0530 Subject: [PATCH 08/21] update --- src/diffusers/loaders/ip_adapter_model.py | 88 ++++- src/diffusers/loaders/lora.py | 102 +++-- src/diffusers/loaders/weight_mapping.py | 216 +++++++---- src/diffusers/models/modeling_utils.py | 367 +++++++++--------- .../models/transformers/flux/__init__.py | 6 +- .../models/transformers/flux/ip_adapter.py | 11 +- .../models/transformers/flux/lora.py | 10 +- .../models/transformers/flux/model.py | 18 +- .../transformers/flux/weight_mapping.py | 44 ++- 9 files changed, 480 insertions(+), 382 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index 695c43af48a5..bcfa4b7f0ccf 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -13,31 +13,84 @@ # limitations under the License. """Model-side IP-Adapter mixin. -Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives -here. Per-model conversion lives in a ``IPAdapterMetadata`` declared next to the model — e.g. ``flux/ip_adapter.py`` -exports ``FLUX_IP_ADAPTER_METADATA``, which is then composed into the model's ``ModelMetadata`` and attached via -``@register_model_metadata``. +Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on +:class:`IPAdapterModelMixin`. Per-model conversion lives in a :class:`IPAdapterMetadata` declared next to the model +(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER_METADATA``), composed into the model's ``ModelMetadata``, and +attached as ``cls._ip_adapter`` (an :class:`IPAdapterHandler` instance) by ``@register_metadata``. """ +from typing import Callable, Optional + from ..models.embeddings import MultiIPAdapterImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT -from ..utils import logging +from ..utils import is_torch_version, logging + + +# Local copy to avoid a circular import with ``models.modeling_utils`` — that module's +# end-of-file ``ModelMetadata()._register(ModelMixin)`` call instantiates the default +# ``IPAdapterHandler`` from here, so we can't import back into it during module load. +_LOW_CPU_MEM_USAGE_DEFAULT = is_torch_version(">=", "1.9.0") logger = logging.get_logger(__name__) +class IPAdapterHandler: + """Composition-style holder for a model class's IP-Adapter conversion callables. + + Instances are attached to model classes as ``cls._ip_adapter`` by ``IPAdapterMetadata._register``. The converter + callables receive the model instance because they need to read its config (e.g. ``attn_processors``, + ``inner_dim``). + """ + + def __init__( + self, + *, + convert_attn_to_diffusers: Optional[Callable] = None, + convert_image_proj_to_diffusers: Optional[Callable] = None, + ): + self._convert_attn_fn = convert_attn_to_diffusers + self._convert_image_proj_fn = convert_image_proj_to_diffusers + + @property + def supports_ip_adapter(self) -> bool: + """Whether the model has both converters registered (required to actually load weights).""" + return self._convert_attn_fn is not None and self._convert_image_proj_fn is not None + + def convert_attn_processors(self, model, state_dicts, low_cpu_mem_usage: bool = False): + """Build the attention-processor dict for a list of IP-Adapter state dicts. + + Receives the model so the converter can inspect ``model.attn_processors``, ``model.config``, + ``model.inner_dim``, etc. + """ + if self._convert_attn_fn is None: + raise NotImplementedError( + f"{type(model).__name__} did not register `_convert_ip_adapter_attn_to_diffusers` in its " + f"IPAdapterMetadata." + ) + return self._convert_attn_fn(model, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + def convert_image_proj(self, model, image_proj_state_dict, low_cpu_mem_usage: bool = False): + """Build the image-projection module from a single IP-Adapter state dict.""" + if self._convert_image_proj_fn is None: + raise NotImplementedError( + f"{type(model).__name__} did not register `_convert_ip_adapter_image_proj_to_diffusers` in its " + f"IPAdapterMetadata." + ) + return self._convert_image_proj_fn(model, image_proj_state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + class IPAdapterModelMixin: - """Metadata-driven IP-Adapter loader for diffusers transformer / UNet models. + """Generic IP-Adapter loader for diffusers transformer / UNet models. - Reads the per-model converters from ``IPAdapterMetadata`` (mirrored onto - ``cls._convert_ip_adapter_attn_to_diffusers`` and ``cls._convert_ip_adapter_image_proj_to_diffusers`` by - ``register_model_metadata``). + The per-model conversion callables live on ``self._ip_adapter`` (an :class:`IPAdapterHandler` composed by the + metadata decorator). This mixin owns only the orchestration: dispatching to the converters, wiring up + ``MultiIPAdapterImageProjection``, and flipping ``encoder_hid_dim_type``. """ - # No-op defaults; populated per-model by ``register_model_metadata``. - _convert_ip_adapter_attn_to_diffusers = None - _convert_ip_adapter_image_proj_to_diffusers = None + # ``_ip_adapter: IPAdapterHandler`` is provided universally by ``ModelMixin`` (set via the default + # ``ModelMetadata()._register(ModelMixin)`` call). Models without IP-Adapter metadata inherit a no-op + # handler; calling ``_load_ip_adapter_weights`` on such a model raises ``NotImplementedError`` from inside + # the handler. def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): """Install IP-Adapter weights on the model. @@ -45,10 +98,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. """ - if ( - self._convert_ip_adapter_attn_to_diffusers is None - or self._convert_ip_adapter_image_proj_to_diffusers is None - ): + if not self._ip_adapter.supports_ip_adapter: raise NotImplementedError( f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterMetadata." ) @@ -58,12 +108,12 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U self.encoder_hid_proj = None - attn_procs = self._convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + attn_procs = self._ip_adapter.convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) image_projection_layers = [] for state_dict in state_dicts: - image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + image_projection_layer = self._ip_adapter.convert_image_proj( self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage ) image_projection_layers.append(image_projection_layer) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 4368f75f2ffd..cc969334711b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -84,6 +84,60 @@ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" +class LoRAHandler: + """Composition-style holder for a model class's LoRA conversion configuration. + + Instances are attached to model classes as ``cls._lora`` by ``LoRAMetadata._register``. Owns the foreign-format + detection and conversion logic that the legacy ``LoRAModelMixin`` flattened onto the model class itself. The + mixin's public-facing methods (``load_lora_adapter``, ``fuse_lora``, etc.) stay on the mixin but read their data + from ``self._lora.X`` instead of ``self._X`` flattened attrs. + """ + + def __init__( + self, + *, + format_keys: Optional[Dict[str, Set[str]]] = None, + map_lora_to_diffusers: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None, + ): + self.format_keys = format_keys or {} + self._map_to_diffusers_fn = map_lora_to_diffusers + + def detect_format(self, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: + """Return the format name (``"kohya"`` etc.) matched by ``state_dict``, or ``None``.""" + if not self.format_keys: + return None + keys = set(state_dict) + for fmt, fmt_keys in self.format_keys.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return fmt + return None + + @staticmethod + def normalize_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) to ``.lora_A/.lora_B`` (diffusers).""" + out: Dict[str, "torch.Tensor"] = {} + for k, v in state_dict.items(): + new_k = ( + k.replace(".lora_down.weight", ".lora_A.weight") + .replace(".lora_up.weight", ".lora_B.weight") + .replace(".down.weight", ".lora_A.weight") + .replace(".up.weight", ".lora_B.weight") + ) + out[new_k] = v + return out + + def map_to_diffusers(self, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: + """Canonicalize a LoRA state dict to diffusers naming. + + Default: just normalize suffixes. Models with foreign formats register a converter via + ``LoRAMetadata._map_lora_to_diffusers``. + """ + state_dict = self.normalize_suffixes(state_dict) + if self._map_to_diffusers_fn is None: + return state_dict + return self._map_to_diffusers_fn(state_dict, **kwargs) + + # Per-class hook for expanding adapter weights before activation. Models that need # expansion (currently only UNet variants) register here; everything else falls # through to the identity default so new transformers don't need an entry. @@ -360,51 +414,13 @@ class LoRAModelMixin: - Get a list of the active adapters. """ + # Runtime PEFT state — set during adapter load / hotswap setup. Not part of the metadata-driven config. _hf_peft_config_loaded = False - # kwargs for prepare_model_for_compiled_hotswap, if required _lora_hotswap_kwargs: Optional[dict] = None - # Default class-attribute values; populated per-model by ``register_model_metadata`` - # (or set directly on subclasses for legacy callers). - _lora_format_keys: Dict[str, Set[str]] = {} - _map_lora_to_diffusers: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None - - @classmethod - def _detect_lora_format(cls, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: - """Return the format name (``"kohya"`` etc.) matched by ``state_dict``, or ``None``.""" - if not cls._lora_format_keys: - return None - keys = set(state_dict) - for fmt, fmt_keys in cls._lora_format_keys.items(): - if any(any(fk in k for k in keys) for fk in fmt_keys): - return fmt - return None - - @classmethod - def _normalize_lora_suffixes(cls, state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: - """Rewrite ``.lora_down/.lora_up`` (kohya-ish) to ``.lora_A/.lora_B`` (diffusers).""" - out: Dict[str, "torch.Tensor"] = {} - for k, v in state_dict.items(): - new_k = ( - k.replace(".lora_down.weight", ".lora_A.weight") - .replace(".lora_up.weight", ".lora_B.weight") - .replace(".down.weight", ".lora_A.weight") - .replace(".up.weight", ".lora_B.weight") - ) - out[new_k] = v - return out - - @classmethod - def map_lora_to_diffusers(cls, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: - """Canonicalize a LoRA state dict to diffusers naming. - - Default: just normalize suffixes. Models with foreign formats register a converter via - ``LoRAMetadata._map_lora_to_diffusers`` — the decorator mirrors it onto ``cls._map_lora_to_diffusers``. - """ - state_dict = cls._normalize_lora_suffixes(state_dict) - if cls._map_lora_to_diffusers is None: - return state_dict - return cls._map_lora_to_diffusers(state_dict, **kwargs) + # ``_lora: LoRAHandler`` is provided universally by ``ModelMixin`` (set via the default + # ``ModelMetadata()._register(ModelMixin)`` call). Models without LoRA conversion metadata inherit a no-op + # handler; ``@register_metadata(ModelMetadata(_lora=...))`` overrides it on the subclass. @_requires_peft def load_adapter( @@ -541,7 +557,7 @@ def _load_adapter_from_pretrained( model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) state_dict = load_state_dict(model_file) - state_dict = self.map_lora_to_diffusers(state_dict) + state_dict = self._lora.map_to_diffusers(state_dict) if not state_dict: model_class_name = self.__class__.__name__ logger.warning( diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index d010a1f6d750..fecfec4e53cb 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -14,110 +14,172 @@ """Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. -A model declares its mapping in a ``WeightMappingMetadata`` instance (typically in its ``weight_mapping.py`` module) -and attaches it via ``@register_model_metadata(weight_mapping=...)``. This mixin supplies the generic dispatch methods -that read from that metadata. - -The :meth:`apply_transforms` helper drives the forward direction from a single declarative table — see -``models/transformers/flux/weight_mapping.py`` for an example. +A model declares its mapping in a :class:`WeightMappingMetadata` instance (typically in its ``weight_mapping.py`` +module). The ``@register_metadata`` decorator instantiates a :class:`WeightMappingHandler` from that metadata and +attaches it to the model class as ``cls._weight_mapping``. Internal call sites then go through +``self._weight_mapping.X`` (e.g. ``self._weight_mapping.normalize_checkpoint_keys(state_dict)``) instead of flattening +the methods onto the model class itself. + +The :meth:`WeightMappingHandler.apply_transforms` helper drives the forward direction from a single declarative table — +see ``models/transformers/flux/weight_mapping.py`` for an example. """ -from typing import Optional +from typing import Callable, Optional -class WeightMappingMixin: - """ - Base mixin providing utilities for checkpoint weight mapping and conversion. +class WeightMappingHandler: + """Composition-style holder for a model class's weight-mapping configuration and helpers. - Per-model configuration (rename patterns, format-identifying keys, conversion callables, etc.) lives in the model's - registered ``WeightMappingMetadata`` — declared in the model's ``weight_mapping.py`` and attached via - ``@register_model_metadata``. This mixin just supplies the dispatch methods. + Instances are attached to model classes as ``cls._weight_mapping`` by ``WeightMappingMetadata._register``. Owns all + the data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, + normalize) that the legacy ``WeightMappingMixin`` flattened onto the model class. The model class itself no longer + carries those attributes; access is always via ``cls._weight_mapping.X`` / ``self._weight_mapping.X``. """ - # Default class-attribute values; populated per-model by ``register_model_metadata``. - _checkpoint_key_prefixes: list = [] - _checkpoint_keys: set = set() - _rename_patterns: dict = {} - _available_configs: dict = {} - _map_to_diffusers = None - _map_from_diffusers = None - _detect_config_fn = None - _default_config: Optional[str] = None - _default_subfolder: str = "transformer" + def __init__( + self, + *, + checkpoint_keys: Optional[set] = None, + checkpoint_key_prefixes: Optional[list] = None, + rename_patterns: Optional[dict] = None, + available_configs: Optional[dict] = None, + default_config: Optional[str] = None, + default_subfolder: str = "transformer", + map_to_diffusers: Optional[Callable] = None, + map_from_diffusers: Optional[Callable] = None, + detect_config_fn: Optional[Callable] = None, + ): + self.checkpoint_keys = checkpoint_keys or set() + self.checkpoint_key_prefixes = checkpoint_key_prefixes or [] + self.rename_patterns = rename_patterns or {} + self.available_configs = available_configs or {} + self.default_config = default_config + self.default_subfolder = default_subfolder + self._map_to_diffusers_fn = map_to_diffusers + self._map_from_diffusers_fn = map_from_diffusers + self._detect_config_fn = detect_config_fn + + # ---- single-file capability ---- + + @property + def supports_single_file(self) -> bool: + """Whether the model has enough metadata to load from a single-file checkpoint. + + Requires ``available_configs`` (so a config repo can be resolved) plus either a converter callable + (``_map_to_diffusers_fn``) or a non-empty ``checkpoint_key_prefixes`` (declarative prefix-only path). + """ + has_normalizer = self._map_to_diffusers_fn is not None or bool(self.checkpoint_key_prefixes) + return bool(self.available_configs) and has_normalizer + + # ---- key utilities ---- @staticmethod - def _rename_key(key: str, patterns: dict) -> str: - """Apply rename patterns to a key.""" + def rename_key(key: str, patterns: dict) -> str: + """Apply rename patterns to a key (first match wins per substring).""" for old, new in patterns.items(): key = key.replace(old, new) return key - @classmethod - def _normalize_checkpoint_keys(cls, state_dict: dict) -> dict: - """Strip known prefixes from state_dict keys.""" - if not cls._checkpoint_key_prefixes: + def is_original_format(self, state_dict: dict) -> bool: + """Check if state_dict is in original (non-diffusers) format by presence of a known foreign key.""" + if not self.checkpoint_keys: + return False + return bool(self.checkpoint_keys & set(state_dict.keys())) + + def normalize_checkpoint_keys(self, state_dict: dict) -> dict: + """Strip known foreign prefixes (e.g. ``model.diffusion_model.``) from state_dict keys.""" + if not self.checkpoint_key_prefixes: return state_dict result = {} for key, value in state_dict.items(): new_key = key - for prefix in cls._checkpoint_key_prefixes: + for prefix in self.checkpoint_key_prefixes: if key.startswith(prefix): new_key = key[len(prefix) :] break result[new_key] = value return result - @classmethod - def _is_original_format(cls, state_dict: dict) -> bool: - """Check if state_dict is in original (non-diffusers) format.""" - if not cls._checkpoint_keys: - return False - return bool(cls._checkpoint_keys & set(state_dict.keys())) + # ---- config resolution ---- - @classmethod - def _detect_config(cls, state_dict: dict) -> Optional[str]: - """Detect which config name from ``_available_configs`` matches this state_dict. + def detect_config(self, state_dict: dict) -> Optional[str]: + """Detect which config name from ``available_configs`` matches this state_dict. - Dispatches to ``cls._detect_config_fn`` (mirrored from the model's metadata). If no detector is registered, - returns ``None`` so the caller can fall back to ``_default_config``. + Dispatches to ``self._detect_config_fn(self, state_dict)``. If unregistered, returns ``None`` so the caller can + fall back to ``self.default_config``. """ - if cls._detect_config_fn is None: + if self._detect_config_fn is None: return None - return cls._detect_config_fn(cls, state_dict) + return self._detect_config_fn(self, state_dict) - @classmethod - def _get_model_config(cls, state_dict: dict) -> str: + def get_model_config(self, state_dict: dict) -> str: """Resolve the hub repo id whose config best matches this checkpoint. Resolution order: - 1. Run ``_detect_config_fn`` (if registered) against the state_dict; it should return a config name from - ``_available_configs`` or ``None``. - 2. If detection returns ``None`` (or no detector is registered), fall back to ``_default_config``. - 3. Look up the chosen name in ``_available_configs`` to get the hub repo id. + 1. Run ``detect_config(state_dict)`` (if a detector is registered). + 2. If detection returns ``None``, fall back to ``default_config``. + 3. Look up the chosen name in ``available_configs`` to get the hub repo id. """ - config_name = cls._detect_config(state_dict) or cls._default_config + config_name = self.detect_config(state_dict) or self.default_config if config_name is None: - available = sorted(cls._available_configs) or "" - has_detector = cls._detect_config_fn is not None + available = sorted(self.available_configs) or "" + has_detector = self._detect_config_fn is not None raise ValueError( - f"`{cls.__name__}.from_single_file` could not determine which config to load for this checkpoint.\n" - f"\n" - f" Detection: {'registered, but returned None for this state_dict' if has_detector else 'no `_detect_config_fn` registered'}\n" - f" Default config: not set\n" + "Could not determine which config to load for this checkpoint.\n" + "\n" + f" Detection: {'registered, but returned None for this state_dict' if has_detector else 'no detect_config_fn registered'}\n" + " Default config: not set\n" f" Available configs: {available}\n" - f"\n" - f"To fix this, either:\n" - f' - pass `config=""` to `from_single_file(...)` to skip auto-detection, OR\n' - f" - update `{cls.__name__}`'s `WeightMappingMetadata` to register a `_detect_config_fn` that " - f"returns a name from `_available_configs`, and/or set `_default_config` to a name in " - f"`_available_configs`." + "\n" + "To fix this, either:\n" + ' - pass `config=""` to `from_single_file(...)` to skip auto-detection, OR\n' + " - update the model's `WeightMappingMetadata` to register a `_detect_config_fn` that returns a " + "name from `_available_configs`, and/or set `_default_config` to a name in `_available_configs`." ) - if config_name not in cls._available_configs: + if config_name not in self.available_configs: raise ValueError( - f"{cls.__name__}: resolved config name '{config_name}' is not a key of `_available_configs` " - f"(available: {sorted(cls._available_configs)})." + f"Resolved config name '{config_name}' is not a key of `available_configs` " + f"(available: {sorted(self.available_configs)})." ) - return cls._available_configs[config_name] + return self.available_configs[config_name] + + # ---- conversion ---- + + def map_to_diffusers(self, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from original format to diffusers format. + + No-op (returns ``state_dict`` unchanged) if no converter callable is registered; callers are expected to use + the prefix-only path (via :meth:`normalize_checkpoint_keys`) in that case. + """ + if self._map_to_diffusers_fn is None: + return state_dict + return self._map_to_diffusers_fn(state_dict, **kwargs) + + def maybe_convert_state_dict(self, model, state_dict: dict) -> dict: + """Bring ``state_dict`` to diffusers naming if it isn't already. Two phases: + + 1. :meth:`normalize_checkpoint_keys` — strip known prefixes (idempotent; no-op if none registered). + 2. :meth:`map_to_diffusers` — full key conversion, only invoked if step 1 alone didn't make the keys match the + model's. Skipped (no-op) if no converter callable was registered. + + Idempotent overall: calling twice produces the same result as calling once. + """ + state_dict = self.normalize_checkpoint_keys(state_dict) + model_keys = set(model.state_dict().keys()) + ckpt_keys = set(state_dict.keys()) + # If the model's keys are a (strict) subset of the checkpoint's, the rest is extras we'll surface later + # via the missing/unexpected keys report — but no key-renaming pass is needed. + if model_keys.issubset(ckpt_keys): + return state_dict + return self.map_to_diffusers(state_dict) + + def map_from_diffusers(self, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from diffusers format to original format.""" + if self._map_from_diffusers_fn is None: + raise NotImplementedError("No `_map_from_diffusers` callable registered for this model.") + return self._map_from_diffusers_fn(state_dict, **kwargs) + + # ---- driver for declarative transforms ---- @staticmethod def apply_transforms(state_dict, transforms, rename_patterns, **ctx): @@ -139,27 +201,9 @@ def apply_transforms(state_dict, transforms, rename_patterns, **ctx): if source in key: tensors = forward_fn(value, **ctx) for target, tensor in zip(targets, tensors): - new_key = WeightMappingMixin._rename_key(key.replace(source, target), rename_patterns) + new_key = WeightMappingHandler.rename_key(key.replace(source, target), rename_patterns) out[new_key] = tensor break else: - out[WeightMappingMixin._rename_key(key, rename_patterns)] = value + out[WeightMappingHandler.rename_key(key, rename_patterns)] = value return out - - @classmethod - def map_to_diffusers(cls, state_dict: dict, **kwargs) -> dict: - """Convert state_dict from original format to diffusers format.""" - if cls._map_to_diffusers is None: - raise NotImplementedError( - f"{cls.__name__} did not register a `_map_to_diffusers` in its WeightMappingMetadata." - ) - return cls._map_to_diffusers(state_dict, **kwargs) - - @classmethod - def map_from_diffusers(cls, state_dict: dict, **kwargs) -> dict: - """Convert state_dict from diffusers format to original format.""" - if cls._map_from_diffusers is None: - raise NotImplementedError( - f"{cls.__name__} did not register a `_map_from_diffusers` in its WeightMappingMetadata." - ) - return cls._map_from_diffusers(state_dict, **kwargs) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 88215a39c1ef..0593a8f811ed 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,7 @@ import tempfile from collections import OrderedDict from contextlib import ExitStack, contextmanager, nullcontext -from dataclasses import dataclass, field, fields, is_dataclass +from dataclasses import dataclass, field, fields from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -40,8 +40,6 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..loaders.lora import LoRAModelMixin -from ..loaders.weight_mapping import WeightMappingMixin from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -234,114 +232,37 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -@dataclass -class LoRAMetadata: - """Per-model LoRA configuration: what foreign formats this model accepts and how to convert them. - - Field names match the legacy ``cls._`` class attributes consumed by ``LoRAModelMixin``, so the decorator can - mirror them 1:1. - - Attributes: - _lora_format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying - key substrings. The first format whose substrings appear in the state dict wins. - _map_lora_to_diffusers: Callable ``(state_dict, **kwargs) -> state_dict`` that rewrites - foreign-format keys to diffusers naming. Called from ``LoRAModelMixin.map_lora_to_diffusers`` after generic - suffix normalization. ``None`` for models that only ingest diffusers-native LoRAs. - """ - - _lora_format_keys: dict[str, set] = field(default_factory=dict) - _map_lora_to_diffusers: Callable | None = None - - -@dataclass -class IPAdapterMetadata: - """Per-model IP-Adapter configuration: how to convert IP-Adapter state dicts for this architecture. - - Field names match the legacy ``cls._`` class attributes consumed by ``IPAdapterModelMixin``, so the decorator - can mirror them 1:1. - - Attributes: - _convert_ip_adapter_attn_to_diffusers: Callable - ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the attn-processor - dict ready for ``set_attn_processor``. Receives the model instance because it needs - ``model.attn_processors``, ``model.config``, ``model.inner_dim``, etc. - _convert_ip_adapter_image_proj_to_diffusers: Callable - ``(model, state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image projection layer. - """ - - _convert_ip_adapter_attn_to_diffusers: Callable | None = None - _convert_ip_adapter_image_proj_to_diffusers: Callable | None = None - - -@dataclass -class WeightMappingMetadata: - """Per-model checkpoint conversion metadata for single-file loading. - - Field names match the legacy ``cls._`` class attributes consumed by ``WeightMappingMixin``, so the decorator - can mirror them 1:1. - - Note: per-key rename tables and checkpoint key prefixes live in the model's ``weight_mapping.py`` module as plain - constants (e.g. ``FLUX_RENAME_PATTERNS``). They're consumed directly by the model's ``map_to_diffusers`` / - ``map_from_diffusers`` callables and don't need to be threaded through metadata. - - Attributes: - _checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is - in the original (pre-diffusers) format. - _available_configs: - Map of short config name to hub repo id (e.g. ``{"flux-dev": "black-forest-labs/FLUX.1-dev"}``). - ``from_single_file`` resolves a config name (via detection or default) to a repo id through this map. - Single-config models can ship a one-entry dict; multi-config models like Flux list all known architectures. - The short names are stable identifiers — useful in detection logic, error messages, and tracebacks — - independent of where the configs are currently hosted on the hub. - _map_to_diffusers / _map_from_diffusers: Callables driving the two conversion directions. - _detect_config_fn: Optional ``(cls, state_dict) -> Optional[str]`` returning a config name (key into - ``_available_configs``) or ``None`` to defer to ``_default_config``. - _default_config: Config name (key into ``_available_configs``) used when ``_detect_config_fn`` is - unregistered or returns ``None``. Lets single-config models skip detection entirely and multi-config models - declare a "best guess" fallback. - _default_subfolder: Default ``subfolder`` to use when fetching configs (e.g. ``"transformer"``). - """ - - _checkpoint_keys: set = field(default_factory=set) - _available_configs: dict[str, str] = field(default_factory=dict) - _map_to_diffusers: Callable | None = None - _map_from_diffusers: Callable | None = None - _detect_config_fn: Callable | None = None - _default_config: str | None = None - _default_subfolder: str = "transformer" - - @dataclass class ModelMetadata: """ - Metadata describing model capabilities and configuration hints. + Capability flags and subsystem registrations for a diffusers model class. + + Two kinds of fields: - This is NOT configuration (which is saved to config.json and defines architecture). This is static metadata about - the model class's capabilities and hints for optimization features like gradient checkpointing, offloading, and - parallelism. + 1. **Scalar capability flags** (``_supports_gradient_checkpointing``, ``_no_split_modules``, ``_cp_plan``, etc.) — + passive declarative data; mirrored directly onto the model class. + 2. **Subsystem handlers** (``_lora``, ``_weight_mapping``, ``_ip_adapter``) — concrete handler instances holding + both data and behavior for foreign-format conversion. The handler is the *runtime* object the model uses; the + model class accesses them as ``cls._lora`` / ``cls._weight_mapping`` / ``cls._ip_adapter``. - Field names match the legacy ``cls._`` class attributes (so the decorator mirrors them 1:1 and existing - consumer code keeps working). + Models declare their full picture in one place via ``@register_metadata(ModelMetadata(...))``. Attributes: - _supports_gradient_checkpointing: Whether the model supports gradient checkpointing - for memory-efficient training. - _no_split_modules: List of module class names that should NOT be split across - devices during model parallelism. - _keep_in_fp32_modules: List of module names to keep in FP32 precision when using - lower precision dtypes for numerical stability. - _skip_layerwise_casting_patterns: Tuple of module name patterns to exclude from - layerwise casting operations. + _supports_gradient_checkpointing: Whether the model supports gradient checkpointing for + memory-efficient training. + _no_split_modules: Module class names that should NOT be split across devices during model + parallelism. + _keep_in_fp32_modules: Module names to keep in FP32 precision when using lower-precision dtypes. + _skip_layerwise_casting_patterns: Patterns to exclude from layerwise casting. _supports_group_offloading: Whether the model supports group offloading. - _repeated_blocks: List of module class names that repeat throughout the model, - useful for optimization and pattern analysis. - _cp_plan: Context parallel sharding plan. Maps model input/output tensor names to - ``ContextParallelInput`` / ``ContextParallelOutput`` declarations. Universal — applies to any - tensor-sharding work, not attention-specific. - _keys_to_ignore_on_load_unexpected: List of keys to ignore when loading - unexpected keys from a checkpoint. - _lora: Per-model LoRA loading metadata. See :class:`LoRAMetadata`. - _weight_mapping: Per-model checkpoint conversion metadata. See :class:`WeightMappingMetadata`. + _repeated_blocks: Module class names that repeat throughout the model (used by optimization passes). + _cp_plan: Context-parallel sharding plan mapping input/output tensor names to ``ContextParallelInput`` + / ``ContextParallelOutput`` declarations. Universal — applies to any tensor-sharding work. + _keys_to_ignore_on_load_unexpected: State-dict keys to silently ignore at load time. + _lora: :class:`LoRAHandler` instance owning per-model LoRA format detection and conversion. + _weight_mapping: :class:`WeightMappingHandler` instance owning per-model single-file checkpoint + conversion (prefix-stripping, key remapping, variant detection). + _ip_adapter: :class:`IPAdapterHandler` instance owning per-model IP-Adapter conversion callables. """ _supports_gradient_checkpointing: bool = False @@ -352,28 +273,50 @@ class ModelMetadata: _repeated_blocks: list[str] = field(default_factory=list) _cp_plan: dict[str, Any] | None = None _keys_to_ignore_on_load_unexpected: list[str] | None = None - _lora: LoRAMetadata = field(default_factory=LoRAMetadata) - _ip_adapter: IPAdapterMetadata = field(default_factory=IPAdapterMetadata) - _weight_mapping: WeightMappingMetadata = field(default_factory=WeightMappingMetadata) + # Handler instances; annotated ``Any`` to avoid a top-level import of the handler classes (which would + # create a cycle: ``loaders/{lora,weight_mapping,ip_adapter_model}.py`` already import from ``modeling_utils``). + _lora: Any = field(default_factory=lambda: _default_lora_handler()) + _ip_adapter: Any = field(default_factory=lambda: _default_ip_adapter_handler()) + _weight_mapping: Any = field(default_factory=lambda: _default_weight_mapping_handler()) def _register(self, cls): - """Attach this ``ModelMetadata`` to ``cls`` and mirror leaf fields to legacy class attrs. - - Walks nested dataclasses (``_lora``, ``_weight_mapping``, ``_ip_adapter``) so their leaf fields land flat on - ``cls``. Field names already starting with ``_`` map 1:1 (``_lora_format_keys`` → ``cls._lora_format_keys``); - unprefixed names get the underscore added (``rename_patterns`` → ``cls._rename_patterns``). + """Attach this ``ModelMetadata`` to ``cls``: stash for introspection, install handlers, mirror scalars. + + Two steps: + 1. Attach subsystem handlers directly as ``cls._lora`` / ``cls._weight_mapping`` / ``cls._ip_adapter``. + These are the runtime entry points; internal code reaches them via composition (e.g. + ``cls._weight_mapping.normalize_checkpoint_keys(...)``). + 2. Mirror the scalar capability fields (``_supports_gradient_checkpointing``, ``_no_split_modules``, + ``_cp_plan``, etc.) directly onto ``cls`` — consumed by code via bare attribute access. """ cls._model_metadata = self - pending = [self] - while pending: - obj = pending.pop() - for f in fields(obj): - value = getattr(obj, f.name) - if is_dataclass(value): - pending.append(value) - else: - attr = f.name if f.name.startswith("_") else f"_{f.name}" - setattr(cls, attr, value) + # Subsystem handlers attach directly — they're already the runtime objects. + cls._lora = self._lora + cls._weight_mapping = self._weight_mapping + cls._ip_adapter = self._ip_adapter + # Scalar capability fields mirror to ``cls``. + for f in fields(self): + if f.name in {"_lora", "_weight_mapping", "_ip_adapter"}: + continue + setattr(cls, f.name, getattr(self, f.name)) + + +def _default_lora_handler(): + from ..loaders.lora import LoRAHandler + + return LoRAHandler() + + +def _default_ip_adapter_handler(): + from ..loaders.ip_adapter_model import IPAdapterHandler + + return IPAdapterHandler() + + +def _default_weight_mapping_handler(): + from ..loaders.weight_mapping import WeightMappingHandler + + return WeightMappingHandler() def register_metadata(metadata): @@ -397,16 +340,13 @@ def wrap(cls): return wrap -def _should_convert_checkpoint(model_state_dict: dict[str, Any], checkpoint: dict[str, Any]) -> bool: - """Check if checkpoint needs conversion by comparing keys with model state dict.""" - model_state_dict_keys = set(model_state_dict.keys()) - checkpoint_state_dict_keys = set(checkpoint.keys()) - is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys) - is_match = model_state_dict_keys == checkpoint_state_dict_keys - return not (is_subset and is_match) +# Deprecation message reused across the per-backend attention helpers on ``ModelMixin`` (npu / xla / xformers). +# These have been superseded by the unified ``set_attention_backend(...)`` / ``reset_attention_backend()`` API; +# each call site supplies its specific replacement call as ``{replacement}``. +_ATTENTION_API_DEPRECATION_MSG = "`ModelMixin.{name}` is deprecated. Use `{replacement}` instead." -class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixin, PushToHubMixin): +class ModelMixin(torch.nn.Module, ConfigMixin, PushToHubMixin): r""" Base class for all models. @@ -416,40 +356,13 @@ class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, WeightMappingMixi - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ + # Non-metadata class attrs. Everything that lives on ``ModelMetadata`` (capability flags, partition plans, + # subsystem handlers) is set at the bottom of this module via ``ModelMetadata()._register(ModelMixin)`` so + # ``ModelMetadata``'s dataclass field defaults are the single source of truth. config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _supports_gradient_checkpointing = False - _keys_to_ignore_on_load_unexpected = None - _no_split_modules = None - _keep_in_fp32_modules = None - _skip_layerwise_casting_patterns = None - _supports_group_offloading = True - _repeated_blocks = [] - _cp_plan = None _skip_keys = None - @classmethod - def _maybe_convert_state_dict(cls, model: "ModelMixin", state_dict: dict[str, Any]) -> dict[str, Any]: - """Convert ``state_dict`` from original format to diffusers format if needed. - - Two phases, both declared via the model's :class:`WeightMappingMetadata`: - - 1. ``_normalize_checkpoint_keys`` — strip known prefixes (e.g. ``model.diffusion_model.``). Run - unconditionally; idempotent and a no-op if no prefixes were registered. - 2. ``_map_to_diffusers`` — the actual format converter, only invoked if step 1 alone didn't already make the - keys match. Skipped if no converter was registered (loading then fails downstream with a clearer - key-mismatch error than a deep ``NotImplementedError``). - """ - # Step 1: always strip checkpoint key prefixes — idempotent, no-op if none registered. - state_dict = cls._normalize_checkpoint_keys(state_dict) - if not _should_convert_checkpoint(model.state_dict(), state_dict): - return state_dict - - # Step 2: run the per-model converter. Skip if no metadata-registered converter. - if getattr(cls, "_map_to_diffusers", None) is None: - return state_dict - return cls.map_to_diffusers(state_dict) - def __init__(self): super().__init__() @@ -522,6 +435,14 @@ def set_use_npu_flash_attention(self, valid: bool) -> None: r""" Set the switch for the npu flash attention. """ + deprecate( + "ModelMixin.set_use_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_npu_flash_attention", + replacement='set_attention_backend("_native_npu") / reset_attention_backend()', + ), + ) def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): if hasattr(module, "set_use_npu_flash_attention"): @@ -539,6 +460,14 @@ def enable_npu_flash_attention(self) -> None: Enable npu flash attention from torch_npu """ + deprecate( + "ModelMixin.enable_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_npu_flash_attention", + replacement='set_attention_backend("_native_npu")', + ), + ) self.set_use_npu_flash_attention(True) def disable_npu_flash_attention(self) -> None: @@ -546,11 +475,28 @@ def disable_npu_flash_attention(self) -> None: disable npu flash attention from torch_npu """ + deprecate( + "ModelMixin.disable_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_npu_flash_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_npu_flash_attention(False) def set_use_xla_flash_attention( self, use_xla_flash_attention: bool, partition_spec: Callable | None = None, **kwargs ) -> None: + deprecate( + "ModelMixin.set_use_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_xla_flash_attention", + replacement='set_attention_backend("_native_xla") / reset_attention_backend()', + ), + ) + # Recursively walk through all the children. # Any children which exposes the set_use_xla_flash_attention method # gets the message @@ -569,15 +515,40 @@ def enable_xla_flash_attention(self, partition_spec: Callable | None = None, **k r""" Enable the flash attention pallals kernel for torch_xla. """ + deprecate( + "ModelMixin.enable_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_xla_flash_attention", + replacement='set_attention_backend("_native_xla")', + ), + ) self.set_use_xla_flash_attention(True, partition_spec, **kwargs) def disable_xla_flash_attention(self): r""" Disable the flash attention pallals kernel for torch_xla. """ + deprecate( + "ModelMixin.disable_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_xla_flash_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_xla_flash_attention(False) def set_use_memory_efficient_attention_xformers(self, valid: bool, attention_op: Callable | None = None) -> None: + deprecate( + "ModelMixin.set_use_memory_efficient_attention_xformers", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_memory_efficient_attention_xformers", + replacement='set_attention_backend("xformers") / reset_attention_backend()', + ), + ) + # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message @@ -622,12 +593,28 @@ def enable_xformers_memory_efficient_attention(self, attention_op: Callable | No >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) ``` """ + deprecate( + "ModelMixin.enable_xformers_memory_efficient_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_xformers_memory_efficient_attention", + replacement='set_attention_backend("xformers")', + ), + ) self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self) -> None: r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ + deprecate( + "ModelMixin.disable_xformers_memory_efficient_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_xformers_memory_efficient_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_memory_efficient_attention_xformers(False) def enable_layerwise_casting( @@ -1560,8 +1547,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict) - # Convert checkpoint if needed (e.g., original format to diffusers format) - state_dict = cls._maybe_convert_state_dict(model, state_dict) + # Convert checkpoint if needed (e.g., original format to diffusers format). For models that haven't + # registered weight-mapping metadata this is a no-op via the default handler. + state_dict = cls._weight_mapping.maybe_convert_state_dict(model, state_dict) if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] @@ -1709,24 +1697,21 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No load_single_file_checkpoint, ) - # A model is "single-file capable" if its ``WeightMappingMetadata`` declares enough to (a) bring keys to - # diffusers naming and (b) resolve which config to load. The two paths are: - # - converter path: ``_map_to_diffusers`` is registered (full key conversion); - # - declarative path: ``_checkpoint_key_prefixes`` strips a known prefix and the model is otherwise - # in diffusers format (e.g. prefix-only finetunes). - # Either way, ``_available_configs`` must be non-empty so we know which config repo to fetch. - has_converter = getattr(cls, "_map_to_diffusers", None) is not None - has_prefix_only = bool(getattr(cls, "_checkpoint_key_prefixes", None)) - has_available_configs = bool(getattr(cls, "_available_configs", None)) - if not (has_available_configs and (has_converter or has_prefix_only)): + # ``cls._weight_mapping`` is the composed ``WeightMappingHandler`` (attached by ``@register_metadata``, + # or the no-op default placeholder on ``ModelMixin``). Bind it once and reuse below. Its + # ``supports_single_file`` property checks that the model declared enough to (a) bring keys to + # diffusers naming (converter or prefix-only path) and (b) resolve which config to load + # (``available_configs`` non-empty). + _weight_mapping = cls._weight_mapping + if not _weight_mapping.supports_single_file: raise ValueError( f"`{cls.__name__}.from_single_file` is not supported. " - f"The model's `WeightMappingMetadata` must register `_available_configs` (so we know which config " - f"to load) plus at least one of: `_map_to_diffusers` (full key conversion) or " - f"`_checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " - f"foreign prefix). Use `from_pretrained` if the model is already in diffusers format." + "The model's `WeightMappingMetadata` must register `_available_configs` (so we know which config " + "to load) plus at least one of: `_map_to_diffusers` (full key conversion) or " + "`_checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " + "foreign prefix). Use `from_pretrained` if the model is already in diffusers format." ) - default_subfolder = getattr(cls, "_default_subfolder", None) + default_subfolder = _weight_mapping.default_subfolder pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) if pretrained_model_link_or_path is not None: @@ -1745,7 +1730,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - kwargs.pop("device", None) # consumed elsewhere; pop to prevent forwarding disable_mmap = kwargs.pop("disable_mmap", False) device_map = kwargs.pop("device_map", None) @@ -1770,10 +1754,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No **{k: v for k, v in hub_kwargs.items() if k != "subfolder"}, ) - # Normalize state_dict keys (strip known prefixes) if the model defines a normalizer - normalize_fn = getattr(cls, "_normalize_checkpoint_keys", None) - if normalize_fn is not None: - state_dict = normalize_fn(state_dict) + # Normalize state_dict keys via the weight-mapping handler (strip known prefixes; no-op if none registered). + state_dict = _weight_mapping.normalize_checkpoint_keys(state_dict) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) @@ -1791,14 +1773,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No "or path to a local Diffusers model repo." ) else: - get_model_config_fn = getattr(cls, "_get_model_config", None) - if get_model_config_fn is None: - raise ValueError( - f"{cls.__name__} does not support automatic config detection. " - f"Please provide a `config` argument or define `_get_model_config` on the model class." - ) - default_pretrained_model_config_name = get_model_config_fn(state_dict) - + default_pretrained_model_config_name = _weight_mapping.get_model_config(state_dict) if default_subfolder is not None: hub_kwargs["subfolder"] = default_subfolder @@ -1822,8 +1797,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No with ctx(): model = cls.from_config(diffusers_model_config) - model_state_dict = model.state_dict() - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") ) @@ -1834,8 +1807,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No else: keep_in_fp32_modules = [] - if _should_convert_checkpoint(model_state_dict, state_dict): - state_dict = cls.map_to_diffusers(state_dict) + # ``normalize_checkpoint_keys`` already ran earlier (before model creation) for detection; this call is + # idempotent and runs the full converter only if keys still don't match the freshly-built model. + state_dict = _weight_mapping.maybe_convert_state_dict(model, state_dict) if not state_dict: raise SingleFileComponentError( @@ -2506,6 +2480,13 @@ def recursive_find_attn_block(name, module): return state_dict +# Seed ``ModelMixin`` with the default ``ModelMetadata()``: capability flags, partition plans, and subsystem +# handler placeholders. Every subclass inherits these via MRO; ``@register_metadata(ModelMetadata(...))`` on a +# subclass overrides them with model-specific values. Doing this here means the dataclass field defaults are +# the single source of truth — no parallel copy of the same defaults declared in ``ModelMixin``'s class body. +ModelMetadata()._register(ModelMixin) + + class LegacyModelMixin(ModelMixin): r""" A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more diff --git a/src/diffusers/models/transformers/flux/__init__.py b/src/diffusers/models/transformers/flux/__init__.py index f7b77eee6486..c4c8707dff45 100644 --- a/src/diffusers/models/transformers/flux/__init__.py +++ b/src/diffusers/models/transformers/flux/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ip_adapter import FLUX_IP_ADAPTER_METADATA -from .lora import FLUX_LORA_METADATA +from .ip_adapter import FLUX_IP_ADAPTER +from .lora import FLUX_LORA from .model import ( FluxAttention, FluxAttnProcessor, @@ -23,4 +23,4 @@ FluxTransformer2DModel, FluxTransformerBlock, ) -from .weight_mapping import FLUX_WEIGHT_MAPPING_METADATA +from .weight_mapping import FLUX_WEIGHT_MAPPING diff --git a/src/diffusers/models/transformers/flux/ip_adapter.py b/src/diffusers/models/transformers/flux/ip_adapter.py index f6232ed64678..1e7866c8e8e1 100644 --- a/src/diffusers/models/transformers/flux/ip_adapter.py +++ b/src/diffusers/models/transformers/flux/ip_adapter.py @@ -25,9 +25,10 @@ from contextlib import nullcontext +from ....loaders.ip_adapter_model import IPAdapterHandler from ....models.embeddings import ImageProjection from ....models.model_loading_utils import load_model_dict_into_meta -from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, IPAdapterMetadata +from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ....utils import is_accelerate_available, is_torch_version, logging from ....utils.torch_utils import empty_device_cache @@ -134,8 +135,8 @@ def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U return attn_procs -# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. -FLUX_IP_ADAPTER_METADATA = IPAdapterMetadata( - _convert_ip_adapter_attn_to_diffusers=convert_attn_processors, - _convert_ip_adapter_image_proj_to_diffusers=convert_image_proj, +# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_IP_ADAPTER = IPAdapterHandler( + convert_attn_to_diffusers=convert_attn_processors, + convert_image_proj_to_diffusers=convert_image_proj, ) diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py index d774fc5fc63e..2d2b7bb3b3ed 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/lora.py @@ -34,8 +34,8 @@ import torch +from ....loaders.lora import LoRAHandler from ....utils import logging, state_dict_all_zero -from ...modeling_utils import LoRAMetadata from .weight_mapping import ( FLUX_QKV_SPLIT_PATTERNS, FLUX_QKVMLP_SPLIT_PATTERN, @@ -433,8 +433,8 @@ def map_lora_to_diffusers(state_dict, **kwargs): return state_dict -# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. -FLUX_LORA_METADATA = LoRAMetadata( - _lora_format_keys=_FLUX_LORA_FORMAT_KEYS, - _map_lora_to_diffusers=map_lora_to_diffusers, +# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_LORA = LoRAHandler( + format_keys=_FLUX_LORA_FORMAT_KEYS, + map_lora_to_diffusers=map_lora_to_diffusers, ) diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 1112b3feacc8..0afc2b9f76c5 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -23,6 +23,7 @@ from ....configuration_utils import register_to_config from ....hooks._helpers import TransformerBlockMetadata from ....loaders.ip_adapter_model import IPAdapterModelMixin +from ....loaders.lora import LoRAModelMixin from ....utils import apply_lora_scale, logging from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -38,9 +39,9 @@ from ...modeling_outputs import Transformer2DModelOutput from ...modeling_utils import ModelMetadata, ModelMixin, register_metadata from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from .ip_adapter import FLUX_IP_ADAPTER_METADATA -from .lora import FLUX_LORA_METADATA -from .weight_mapping import FLUX_WEIGHT_MAPPING_METADATA +from .ip_adapter import FLUX_IP_ADAPTER +from .lora import FLUX_LORA +from .weight_mapping import FLUX_WEIGHT_MAPPING logger = logging.get_logger(__name__) @@ -528,7 +529,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -FLUX_MODEL_METADATA = ModelMetadata( +_METADATA = ModelMetadata( _supports_gradient_checkpointing=True, _no_split_modules=["FluxTransformerBlock", "FluxSingleTransformerBlock"], _skip_layerwise_casting_patterns=("pos_embed", "norm"), @@ -542,18 +543,19 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), }, - _lora=FLUX_LORA_METADATA, - _weight_mapping=FLUX_WEIGHT_MAPPING_METADATA, - _ip_adapter=FLUX_IP_ADAPTER_METADATA, + _lora=FLUX_LORA, + _weight_mapping=FLUX_WEIGHT_MAPPING, + _ip_adapter=FLUX_IP_ADAPTER, ) -@register_metadata(FLUX_MODEL_METADATA) +@register_metadata(_METADATA) class FluxTransformer2DModel( ModelMixin, AttentionMixin, CacheMixin, IPAdapterModelMixin, + LoRAModelMixin, ): """ The Transformer model introduced in Flux. diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index 8f8ef28f13b8..2d8b128d98dc 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -16,8 +16,7 @@ import torch -from ....loaders.weight_mapping import WeightMappingMixin -from ...modeling_utils import WeightMappingMetadata +from ....loaders.weight_mapping import WeightMappingHandler def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor: @@ -148,7 +147,9 @@ def map_to_diffusers( ) -> dict[str, torch.Tensor]: """Convert a Flux transformer state_dict from original format to diffusers format.""" inner_dim = _get_inner_dim(state_dict) - return WeightMappingMixin.apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) + return WeightMappingHandler.apply_transforms( + state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim + ) # Build reverse patterns for map_from_diffusers @@ -204,7 +205,7 @@ def map_from_diffusers( if target in base_key: base_key = base_key.replace(target, qkv_pattern) break - orig_key = WeightMappingMixin._rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + orig_key = WeightMappingHandler.rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) if orig_key not in qkv_groups: qkv_groups[orig_key] = [] @@ -216,7 +217,7 @@ def map_from_diffusers( for target in FLUX_QKVMLP_TARGETS: if target in key and "single_transformer_blocks." in key: base_key = key.replace(target, FLUX_QKVMLP_SPLIT_PATTERN) - orig_key = WeightMappingMixin._rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + orig_key = WeightMappingHandler.rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) if orig_key not in qkvmlp_groups: qkvmlp_groups[orig_key] = [] @@ -228,7 +229,7 @@ def map_from_diffusers( continue # Standard rename - new_key = WeightMappingMixin._rename_key(key, FLUX_RENAME_PATTERNS_REVERSE) + new_key = WeightMappingHandler.rename_key(key, FLUX_RENAME_PATTERNS_REVERSE) converted_state_dict[new_key] = value # Concatenate QKV groups @@ -272,17 +273,18 @@ def map_from_diffusers( } -def detect_config(cls, state_dict: dict[str, Any]) -> str | None: +def detect_config(weight_mapping, state_dict: dict[str, Any]) -> str | None: """Detect which Flux config name matches this state_dict. - Receives ``cls`` so it can reuse the model's ``_is_original_format`` / ``_rename_key`` helpers. + Receives the :class:`WeightMappingHandler` (not the model class) so it can call ``is_original_format`` and + ``rename_key`` directly on the subsystem that owns them. """ guidance_key = "guidance_in.in_layer.bias" x_embedder_key = "img_in.weight" - if not cls._is_original_format(state_dict): - guidance_key = cls._rename_key(guidance_key, FLUX_RENAME_PATTERNS) - x_embedder_key = cls._rename_key(x_embedder_key, FLUX_RENAME_PATTERNS) + if not weight_mapping.is_original_format(state_dict): + guidance_key = weight_mapping.rename_key(guidance_key, FLUX_RENAME_PATTERNS) + x_embedder_key = weight_mapping.rename_key(x_embedder_key, FLUX_RENAME_PATTERNS) if x_embedder_key not in state_dict: return None @@ -299,16 +301,18 @@ def detect_config(cls, state_dict: dict[str, Any]) -> str | None: return "flux-dev" -# Metadata constant assembled into ``ModelMetadata`` by ``flux/model.py``. -FLUX_WEIGHT_MAPPING_METADATA = WeightMappingMetadata( - _checkpoint_keys=_FLUX_CHECKPOINT_KEYS, - _available_configs=_FLUX_AVAILABLE_CONFIGS, - _map_to_diffusers=map_to_diffusers, - _map_from_diffusers=map_from_diffusers, - _detect_config_fn=detect_config, +# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +FLUX_WEIGHT_MAPPING = WeightMappingHandler( + checkpoint_keys=_FLUX_CHECKPOINT_KEYS, + checkpoint_key_prefixes=_FLUX_CHECKPOINT_KEY_PREFIXES, + rename_patterns=FLUX_RENAME_PATTERNS, + available_configs=_FLUX_AVAILABLE_CONFIGS, + map_to_diffusers=map_to_diffusers, + map_from_diffusers=map_from_diffusers, + detect_config_fn=detect_config, # Kicks in only when ``detect_config`` returns ``None`` (e.g. the ``img_in`` / ``x_embedder`` key is # absent so we can't read in_channels). Most Flux checkpoints in the wild are dev-derived, so it's # the safest fallback config to load. - _default_config="flux-dev", - _default_subfolder="transformer", + default_config="flux-dev", + default_subfolder="transformer", ) From 99ba461731ce560299d5e51146a3dc8ee886457c Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 20 May 2026 18:13:57 +0530 Subject: [PATCH 09/21] update --- src/diffusers/models/modeling_utils.py | 31 +++++-------------- .../models/transformers/flux/model.py | 2 -- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0593a8f811ed..16f84229bb9c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -40,6 +40,9 @@ from .. import __version__ from ..configuration_utils import ConfigMixin +from ..loaders.ip_adapter_model import IPAdapterHandler +from ..loaders.lora import LoRAHandler, LoRAModelMixin +from ..loaders.weight_mapping import WeightMappingHandler from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -273,11 +276,9 @@ class ModelMetadata: _repeated_blocks: list[str] = field(default_factory=list) _cp_plan: dict[str, Any] | None = None _keys_to_ignore_on_load_unexpected: list[str] | None = None - # Handler instances; annotated ``Any`` to avoid a top-level import of the handler classes (which would - # create a cycle: ``loaders/{lora,weight_mapping,ip_adapter_model}.py`` already import from ``modeling_utils``). - _lora: Any = field(default_factory=lambda: _default_lora_handler()) - _ip_adapter: Any = field(default_factory=lambda: _default_ip_adapter_handler()) - _weight_mapping: Any = field(default_factory=lambda: _default_weight_mapping_handler()) + _lora: LoRAHandler = field(default_factory=LoRAHandler) + _ip_adapter: IPAdapterHandler = field(default_factory=IPAdapterHandler) + _weight_mapping: WeightMappingHandler = field(default_factory=WeightMappingHandler) def _register(self, cls): """Attach this ``ModelMetadata`` to ``cls``: stash for introspection, install handlers, mirror scalars. @@ -301,24 +302,6 @@ def _register(self, cls): setattr(cls, f.name, getattr(self, f.name)) -def _default_lora_handler(): - from ..loaders.lora import LoRAHandler - - return LoRAHandler() - - -def _default_ip_adapter_handler(): - from ..loaders.ip_adapter_model import IPAdapterHandler - - return IPAdapterHandler() - - -def _default_weight_mapping_handler(): - from ..loaders.weight_mapping import WeightMappingHandler - - return WeightMappingHandler() - - def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. @@ -346,7 +329,7 @@ def wrap(cls): _ATTENTION_API_DEPRECATION_MSG = "`ModelMixin.{name}` is deprecated. Use `{replacement}` instead." -class ModelMixin(torch.nn.Module, ConfigMixin, PushToHubMixin): +class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, PushToHubMixin): r""" Base class for all models. diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 0afc2b9f76c5..321deeccaeb4 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -23,7 +23,6 @@ from ....configuration_utils import register_to_config from ....hooks._helpers import TransformerBlockMetadata from ....loaders.ip_adapter_model import IPAdapterModelMixin -from ....loaders.lora import LoRAModelMixin from ....utils import apply_lora_scale, logging from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -555,7 +554,6 @@ class FluxTransformer2DModel( AttentionMixin, CacheMixin, IPAdapterModelMixin, - LoRAModelMixin, ): """ The Transformer model introduced in Flux. From 12ae376e7a28ade47d3c15a6084527c7b80ff1b7 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 21 May 2026 17:00:44 +0530 Subject: [PATCH 10/21] update --- src/diffusers/loaders/ip_adapter_model.py | 66 +++--- src/diffusers/loaders/lora.py | 114 +++++---- src/diffusers/loaders/weight_mapping.py | 120 ++++++---- src/diffusers/models/modeling_utils.py | 217 +++++++++++++----- .../models/transformers/flux/ip_adapter.py | 4 +- .../models/transformers/flux/lora.py | 2 +- .../transformers/flux/weight_mapping.py | 4 +- 7 files changed, 334 insertions(+), 193 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index bcfa4b7f0ccf..13f2fd2835c8 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Model-side IP-Adapter mixin. +"""Model-side IP-Adapter machinery. Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on -:class:`IPAdapterModelMixin`. Per-model conversion lives in a :class:`IPAdapterMetadata` declared next to the model -(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER_METADATA``), composed into the model's ``ModelMetadata``, and -attached as ``cls._ip_adapter`` (an :class:`IPAdapterHandler` instance) by ``@register_metadata``. +:class:`IPAdapterModelMixin`. Per-model conversion lives in an :class:`IPAdapterHandler` declared next to the model +(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), composed into the model's ``ModelMetadata``, and attached as +``cls._metadata._ip_adapter`` by ``@register_metadata``. """ +from dataclasses import dataclass from typing import Callable, Optional from ..models.embeddings import MultiIPAdapterImageProjection @@ -34,27 +35,29 @@ logger = logging.get_logger(__name__) +@dataclass class IPAdapterHandler: """Composition-style holder for a model class's IP-Adapter conversion callables. - Instances are attached to model classes as ``cls._ip_adapter`` by ``IPAdapterMetadata._register``. The converter - callables receive the model instance because they need to read its config (e.g. ``attn_processors``, - ``inner_dim``). + Attached to ``cls._metadata._ip_adapter`` by :meth:`ModelMetadata._register`. The converter callables receive the + model instance because they need to read its config (``attn_processors``, ``inner_dim``, etc.). + + Attributes: + convert_attn_to_diffusers_fn: + Callable ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the + attn-processor dict ready for ``set_attn_processor``. + convert_image_proj_to_diffusers_fn: Callable + ``(model, image_proj_state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image + projection module. """ - def __init__( - self, - *, - convert_attn_to_diffusers: Optional[Callable] = None, - convert_image_proj_to_diffusers: Optional[Callable] = None, - ): - self._convert_attn_fn = convert_attn_to_diffusers - self._convert_image_proj_fn = convert_image_proj_to_diffusers + convert_attn_to_diffusers_fn: Optional[Callable] = None + convert_image_proj_to_diffusers_fn: Optional[Callable] = None @property def supports_ip_adapter(self) -> bool: """Whether the model has both converters registered (required to actually load weights).""" - return self._convert_attn_fn is not None and self._convert_image_proj_fn is not None + return self.convert_attn_to_diffusers_fn is not None and self.convert_image_proj_to_diffusers_fn is not None def convert_attn_processors(self, model, state_dicts, low_cpu_mem_usage: bool = False): """Build the attention-processor dict for a list of IP-Adapter state dicts. @@ -62,28 +65,29 @@ def convert_attn_processors(self, model, state_dicts, low_cpu_mem_usage: bool = Receives the model so the converter can inspect ``model.attn_processors``, ``model.config``, ``model.inner_dim``, etc. """ - if self._convert_attn_fn is None: + if self.convert_attn_to_diffusers_fn is None: raise NotImplementedError( - f"{type(model).__name__} did not register `_convert_ip_adapter_attn_to_diffusers` in its " - f"IPAdapterMetadata." + f"{type(model).__name__} did not register `convert_attn_to_diffusers_fn` in its IPAdapterHandler." ) - return self._convert_attn_fn(model, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + return self.convert_attn_to_diffusers_fn(model, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) def convert_image_proj(self, model, image_proj_state_dict, low_cpu_mem_usage: bool = False): """Build the image-projection module from a single IP-Adapter state dict.""" - if self._convert_image_proj_fn is None: + if self.convert_image_proj_to_diffusers_fn is None: raise NotImplementedError( - f"{type(model).__name__} did not register `_convert_ip_adapter_image_proj_to_diffusers` in its " - f"IPAdapterMetadata." + f"{type(model).__name__} did not register `convert_image_proj_to_diffusers_fn` in its " + f"IPAdapterHandler." ) - return self._convert_image_proj_fn(model, image_proj_state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + return self.convert_image_proj_to_diffusers_fn( + model, image_proj_state_dict, low_cpu_mem_usage=low_cpu_mem_usage + ) class IPAdapterModelMixin: """Generic IP-Adapter loader for diffusers transformer / UNet models. - The per-model conversion callables live on ``self._ip_adapter`` (an :class:`IPAdapterHandler` composed by the - metadata decorator). This mixin owns only the orchestration: dispatching to the converters, wiring up + The per-model conversion callables live on ``self._metadata._ip_adapter`` (an :class:`IPAdapterHandler` composed by + the metadata decorator). This mixin owns only the orchestration: dispatching to the converters, wiring up ``MultiIPAdapterImageProjection``, and flipping ``encoder_hid_dim_type``. """ @@ -98,9 +102,9 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. """ - if not self._ip_adapter.supports_ip_adapter: + if not self._metadata._ip_adapter.supports_ip_adapter: raise NotImplementedError( - f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterMetadata." + f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterHandler." ) if not isinstance(state_dicts, list): @@ -108,12 +112,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U self.encoder_hid_proj = None - attn_procs = self._ip_adapter.convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + attn_procs = self._metadata._ip_adapter.convert_attn_processors( + self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage + ) self.set_attn_processor(attn_procs) image_projection_layers = [] for state_dict in state_dicts: - image_projection_layer = self._ip_adapter.convert_image_proj( + image_projection_layer = self._metadata._ip_adapter.convert_image_proj( self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage ) image_projection_layers.append(image_projection_layer) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index cc969334711b..e7cc6379d91f 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -18,6 +18,7 @@ import os from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Callable, Dict, List, Literal, Optional, Set, Union @@ -84,58 +85,52 @@ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" -class LoRAHandler: - """Composition-style holder for a model class's LoRA conversion configuration. +def _normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) suffixes to ``.lora_A/.lora_B`` (diffusers). - Instances are attached to model classes as ``cls._lora`` by ``LoRAMetadata._register``. Owns the foreign-format - detection and conversion logic that the legacy ``LoRAModelMixin`` flattened onto the model class itself. The - mixin's public-facing methods (``load_lora_adapter``, ``fuse_lora``, etc.) stay on the mixin but read their data - from ``self._lora.X`` instead of ``self._X`` flattened attrs. + Universal — every LoRA state dict goes through this regardless of model. Module-level so both :class:`LoRAHandler` + (in its ``map_to_diffusers`` dispatcher) and :class:`LoRAModelMixin` (as a public ``normalize_lora_suffixes`` + utility) can call it without circular references. """ + out: Dict[str, "torch.Tensor"] = {} + for k, v in state_dict.items(): + new_k = ( + k.replace(".lora_down.weight", ".lora_A.weight") + .replace(".lora_up.weight", ".lora_B.weight") + .replace(".down.weight", ".lora_A.weight") + .replace(".up.weight", ".lora_B.weight") + ) + out[new_k] = v + return out - def __init__( - self, - *, - format_keys: Optional[Dict[str, Set[str]]] = None, - map_lora_to_diffusers: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None, - ): - self.format_keys = format_keys or {} - self._map_to_diffusers_fn = map_lora_to_diffusers - def detect_format(self, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: - """Return the format name (``"kohya"`` etc.) matched by ``state_dict``, or ``None``.""" - if not self.format_keys: - return None - keys = set(state_dict) - for fmt, fmt_keys in self.format_keys.items(): - if any(any(fk in k for k in keys) for fk in fmt_keys): - return fmt - return None +@dataclass +class LoRAHandler: + """Composition-style holder for a model class's LoRA conversion configuration. - @staticmethod - def normalize_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: - """Rewrite ``.lora_down/.lora_up`` (kohya-ish) to ``.lora_A/.lora_B`` (diffusers).""" - out: Dict[str, "torch.Tensor"] = {} - for k, v in state_dict.items(): - new_k = ( - k.replace(".lora_down.weight", ".lora_A.weight") - .replace(".lora_up.weight", ".lora_B.weight") - .replace(".down.weight", ".lora_A.weight") - .replace(".up.weight", ".lora_B.weight") - ) - out[new_k] = v - return out + Attached to ``cls._metadata._lora`` by :meth:`ModelMetadata._register`. Holds the per-model foreign-format + conversion data. Public conversion utilities (``normalize_lora_suffixes``, ``detect_lora_format``) live on + :class:`LoRAModelMixin` and read from this handler. + + Attributes: + format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying key substrings. The first + format whose substrings appear in the state dict wins. + map_lora_to_diffusers_fn: Callable ``(state_dict, **kwargs) -> state_dict`` that rewrites foreign-format + keys to diffusers naming. ``None`` for models that only ingest diffusers-native LoRAs. + """ + + format_keys: Dict[str, Set[str]] = field(default_factory=dict) + map_lora_to_diffusers_fn: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None def map_to_diffusers(self, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: - """Canonicalize a LoRA state dict to diffusers naming. + """Run the per-model converter (or pass through if none is registered). - Default: just normalize suffixes. Models with foreign formats register a converter via - ``LoRAMetadata._map_lora_to_diffusers``. + Callers are expected to call :meth:`LoRAModelMixin.normalize_lora_suffixes` separately before this — the + kohya-style suffix normalization is universal and isn't this handler's responsibility. """ - state_dict = self.normalize_suffixes(state_dict) - if self._map_to_diffusers_fn is None: + if self.map_lora_to_diffusers_fn is None: return state_dict - return self._map_to_diffusers_fn(state_dict, **kwargs) + return self.map_lora_to_diffusers_fn(state_dict, **kwargs) # Per-class hook for expanding adapter weights before activation. Models that need @@ -402,9 +397,9 @@ class LoRAModelMixin: Single mixin for everything LoRA on a diffusers model: PEFT adapter lifecycle (load / fuse / unfuse / set / delete / hotswap) plus foreign-format conversion (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. - Per-model conversion knobs live in a ``LoRAMetadata`` declared in the model's ``lora.py`` (e.g. - ``FLUX_LORA_METADATA``) and attached to the class via ``@register_model_metadata(lora=...)``. The default no-op - path just normalizes ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. + Per-model conversion knobs live in a :class:`LoRAHandler` declared in the model's ``lora.py`` (e.g. ``FLUX_LORA``) + and attached to the class via ``@register_metadata(ModelMetadata(_lora=...))``. The default no-op handler just + normalizes ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. Install the latest version of PEFT, and use this mixin to: @@ -422,6 +417,30 @@ class LoRAModelMixin: # ``ModelMetadata()._register(ModelMixin)`` call). Models without LoRA conversion metadata inherit a no-op # handler; ``@register_metadata(ModelMetadata(_lora=...))`` overrides it on the subclass. + @staticmethod + def normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) suffixes to ``.lora_A/.lora_B`` (diffusers). + + Universal — applies to every LoRA state dict regardless of model. Useful as a standalone utility for callers + that want suffix normalization without running the full ``map_to_diffusers`` pipeline. + """ + return _normalize_lora_suffixes(state_dict) + + def detect_lora_format(self, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: + """Return the foreign LoRA format name (``"kohya"`` / ``"xlabs"`` / ...) matched by ``state_dict``, + or ``None`` if no registered format matches (e.g. it's already in diffusers naming). + + Reads ``self._metadata._lora.format_keys`` (the per-model registry of identifying key substrings). + """ + format_keys = self._metadata._lora.format_keys + if not format_keys: + return None + keys = set(state_dict) + for fmt, fmt_keys in format_keys.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return fmt + return None + @_requires_peft def load_adapter( self, @@ -557,7 +576,10 @@ def _load_adapter_from_pretrained( model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) state_dict = load_state_dict(model_file) - state_dict = self._lora.map_to_diffusers(state_dict) + # Universal suffix normalization first (kohya-style ``.lora_down/.lora_up`` → ``.lora_A/.lora_B``), then + # run the per-model foreign-format converter (no-op when none is registered). + state_dict = self.normalize_lora_suffixes(state_dict) + state_dict = self._metadata._lora.map_to_diffusers(state_dict) if not state_dict: model_class_name = self.__class__.__name__ logger.warning( diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index fecfec4e53cb..125e1d4c266b 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -14,62 +14,75 @@ """Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. -A model declares its mapping in a :class:`WeightMappingMetadata` instance (typically in its ``weight_mapping.py`` -module). The ``@register_metadata`` decorator instantiates a :class:`WeightMappingHandler` from that metadata and -attaches it to the model class as ``cls._weight_mapping``. Internal call sites then go through -``self._weight_mapping.X`` (e.g. ``self._weight_mapping.normalize_checkpoint_keys(state_dict)``) instead of flattening -the methods onto the model class itself. +A model declares its mapping in a :class:`WeightMappingHandler` instance (typically in its ``weight_mapping.py`` +module). The ``@register_metadata`` decorator bundles it into the model's ``ModelMetadata``, reachable as +``cls._metadata._weight_mapping``. Internal call sites go through ``cls._metadata._weight_mapping.X`` (e.g. +``cls._metadata._weight_mapping.normalize_checkpoint_keys(state_dict)``) instead of flattening the methods onto the +model class itself. The :meth:`WeightMappingHandler.apply_transforms` helper drives the forward direction from a single declarative table — see ``models/transformers/flux/weight_mapping.py`` for an example. """ +from dataclasses import dataclass, field from typing import Callable, Optional +from ..utils import logging + +logger = logging.get_logger(__name__) + + +@dataclass class WeightMappingHandler: """Composition-style holder for a model class's weight-mapping configuration and helpers. - Instances are attached to model classes as ``cls._weight_mapping`` by ``WeightMappingMetadata._register``. Owns all - the data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, - normalize) that the legacy ``WeightMappingMixin`` flattened onto the model class. The model class itself no longer - carries those attributes; access is always via ``cls._weight_mapping.X`` / ``self._weight_mapping.X``. + Attached to ``cls._metadata._weight_mapping`` by :meth:`ModelMetadata._register`. Owns all the data (available + configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, normalize) for + single-file checkpoint loading. Internal callers reach it via ``cls._metadata._weight_mapping.X``. + + Attributes: + checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is in the original + (pre-diffusers) format. + checkpoint_key_prefixes: Foreign prefixes (e.g. ``["model.diffusion_model."]``) the handler will strip via + :meth:`normalize_checkpoint_keys`. Set this on prefix-only models to skip registering a + ``map_to_diffusers_fn`` callable. + rename_patterns: Default rename patterns shared between forward and reverse conversions (consumed by + :meth:`apply_transforms`). + available_configs: + Map of short config name to hub repo id (e.g. ``{"flux-dev": "black-forest-labs/FLUX.1-dev"}``). + default_config: Config name (key into ``available_configs``) used when ``detect_config_fn`` is + unregistered or returns ``None``. + default_subfolder: Default ``subfolder`` to use when fetching configs (e.g. ``"transformer"``). + map_to_diffusers_fn: Callable ``(state_dict, **kwargs) -> state_dict`` performing full key conversion. + ``None`` for prefix-only models. + map_from_diffusers_fn: Reverse callable (diffusers → original format). + detect_config_fn: ``(handler, state_dict) -> Optional[str]`` returning a config name from + ``available_configs``, or ``None`` to fall back to ``default_config``. """ - def __init__( - self, - *, - checkpoint_keys: Optional[set] = None, - checkpoint_key_prefixes: Optional[list] = None, - rename_patterns: Optional[dict] = None, - available_configs: Optional[dict] = None, - default_config: Optional[str] = None, - default_subfolder: str = "transformer", - map_to_diffusers: Optional[Callable] = None, - map_from_diffusers: Optional[Callable] = None, - detect_config_fn: Optional[Callable] = None, - ): - self.checkpoint_keys = checkpoint_keys or set() - self.checkpoint_key_prefixes = checkpoint_key_prefixes or [] - self.rename_patterns = rename_patterns or {} - self.available_configs = available_configs or {} - self.default_config = default_config - self.default_subfolder = default_subfolder - self._map_to_diffusers_fn = map_to_diffusers - self._map_from_diffusers_fn = map_from_diffusers - self._detect_config_fn = detect_config_fn + checkpoint_keys: set = field(default_factory=set) + checkpoint_key_prefixes: list = field(default_factory=list) + rename_patterns: dict = field(default_factory=dict) + available_configs: dict = field(default_factory=dict) + default_config: Optional[str] = None + default_subfolder: str = "transformer" + map_to_diffusers_fn: Optional[Callable] = None + map_from_diffusers_fn: Optional[Callable] = None + detect_config_fn: Optional[Callable] = None # ---- single-file capability ---- @property def supports_single_file(self) -> bool: - """Whether the model has enough metadata to load from a single-file checkpoint. + """Whether ``from_single_file(path)`` works for this model with no extra arguments. - Requires ``available_configs`` (so a config repo can be resolved) plus either a converter callable - (``_map_to_diffusers_fn``) or a non-empty ``checkpoint_key_prefixes`` (declarative prefix-only path). + Requires ``default_config`` to be set so config resolution always succeeds (with or without a successful + ``detect_config_fn`` call). Models that declare only ``available_configs`` still load via + ``from_single_file(path, config=...)``, but they don't auto-resolve and so don't count as supporting. Key + normalization is all no-op-safe; the architecture-resolution step is the only hard requirement. """ - has_normalizer = self._map_to_diffusers_fn is not None or bool(self.checkpoint_key_prefixes) - return bool(self.available_configs) and has_normalizer + return self.default_config is not None # ---- key utilities ---- @@ -105,25 +118,34 @@ def normalize_checkpoint_keys(self, state_dict: dict) -> dict: def detect_config(self, state_dict: dict) -> Optional[str]: """Detect which config name from ``available_configs`` matches this state_dict. - Dispatches to ``self._detect_config_fn(self, state_dict)``. If unregistered, returns ``None`` so the caller can + Dispatches to ``self.detect_config_fn(self, state_dict)``. If unregistered, returns ``None`` so the caller can fall back to ``self.default_config``. """ - if self._detect_config_fn is None: + if self.detect_config_fn is None: return None - return self._detect_config_fn(self, state_dict) + return self.detect_config_fn(self, state_dict) def get_model_config(self, state_dict: dict) -> str: """Resolve the hub repo id whose config best matches this checkpoint. Resolution order: 1. Run ``detect_config(state_dict)`` (if a detector is registered). - 2. If detection returns ``None``, fall back to ``default_config``. + 2. If detection returns ``None``, fall back to ``default_config`` and warn (since the user is now getting a + config that may not match the checkpoint shape). 3. Look up the chosen name in ``available_configs`` to get the hub repo id. """ - config_name = self.detect_config(state_dict) or self.default_config + detected = self.detect_config(state_dict) + if detected is None and self.default_config is not None and self.detect_config_fn is not None: + logger.warning( + f"Could not auto-detect a config for this checkpoint; falling back to default_config=" + f"'{self.default_config}' ({self.available_configs.get(self.default_config)}). " + f"If this is the wrong architecture, pass `config=` to `from_single_file(...)` " + f"explicitly. Known configs: {sorted(self.available_configs)}." + ) + config_name = detected or self.default_config if config_name is None: available = sorted(self.available_configs) or "" - has_detector = self._detect_config_fn is not None + has_detector = self.detect_config_fn is not None raise ValueError( "Could not determine which config to load for this checkpoint.\n" "\n" @@ -133,8 +155,8 @@ def get_model_config(self, state_dict: dict) -> str: "\n" "To fix this, either:\n" ' - pass `config=""` to `from_single_file(...)` to skip auto-detection, OR\n' - " - update the model's `WeightMappingMetadata` to register a `_detect_config_fn` that returns a " - "name from `_available_configs`, and/or set `_default_config` to a name in `_available_configs`." + " - update the model's `WeightMappingHandler` to set `detect_config_fn` (returns a name from " + "`available_configs`), and/or set `default_config` to a name in `available_configs`." ) if config_name not in self.available_configs: raise ValueError( @@ -151,9 +173,9 @@ def map_to_diffusers(self, state_dict: dict, **kwargs) -> dict: No-op (returns ``state_dict`` unchanged) if no converter callable is registered; callers are expected to use the prefix-only path (via :meth:`normalize_checkpoint_keys`) in that case. """ - if self._map_to_diffusers_fn is None: + if self.map_to_diffusers_fn is None: return state_dict - return self._map_to_diffusers_fn(state_dict, **kwargs) + return self.map_to_diffusers_fn(state_dict, **kwargs) def maybe_convert_state_dict(self, model, state_dict: dict) -> dict: """Bring ``state_dict`` to diffusers naming if it isn't already. Two phases: @@ -175,9 +197,9 @@ def maybe_convert_state_dict(self, model, state_dict: dict) -> dict: def map_from_diffusers(self, state_dict: dict, **kwargs) -> dict: """Convert state_dict from diffusers format to original format.""" - if self._map_from_diffusers_fn is None: - raise NotImplementedError("No `_map_from_diffusers` callable registered for this model.") - return self._map_from_diffusers_fn(state_dict, **kwargs) + if self.map_from_diffusers_fn is None: + raise NotImplementedError("No `map_from_diffusers_fn` callable registered for this model.") + return self.map_from_diffusers_fn(state_dict, **kwargs) # ---- driver for declarative transforms ---- diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 16f84229bb9c..f3ccdd90f59d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -237,71 +237,105 @@ def _skip_init(*args, **kwargs): @dataclass class ModelMetadata: - """ - Capability flags and subsystem registrations for a diffusers model class. - - Two kinds of fields: - - 1. **Scalar capability flags** (``_supports_gradient_checkpointing``, ``_no_split_modules``, ``_cp_plan``, etc.) — - passive declarative data; mirrored directly onto the model class. - 2. **Subsystem handlers** (``_lora``, ``_weight_mapping``, ``_ip_adapter``) — concrete handler instances holding - both data and behavior for foreign-format conversion. The handler is the *runtime* object the model uses; the - model class accesses them as ``cls._lora`` / ``cls._weight_mapping`` / ``cls._ip_adapter``. - - Models declare their full picture in one place via ``@register_metadata(ModelMetadata(...))``. - - Attributes: - _supports_gradient_checkpointing: Whether the model supports gradient checkpointing for - memory-efficient training. - _no_split_modules: Module class names that should NOT be split across devices during model - parallelism. - _keep_in_fp32_modules: Module names to keep in FP32 precision when using lower-precision dtypes. - _skip_layerwise_casting_patterns: Patterns to exclude from layerwise casting. - _supports_group_offloading: Whether the model supports group offloading. - _repeated_blocks: Module class names that repeat throughout the model (used by optimization passes). - _cp_plan: Context-parallel sharding plan mapping input/output tensor names to ``ContextParallelInput`` - / ``ContextParallelOutput`` declarations. Universal — applies to any tensor-sharding work. - _keys_to_ignore_on_load_unexpected: State-dict keys to silently ignore at load time. - _lora: :class:`LoRAHandler` instance owning per-model LoRA format detection and conversion. - _weight_mapping: :class:`WeightMappingHandler` instance owning per-model single-file checkpoint - conversion (prefix-stripping, key remapping, variant detection). - _ip_adapter: :class:`IPAdapterHandler` instance owning per-model IP-Adapter conversion callables. - """ - - _supports_gradient_checkpointing: bool = False - _no_split_modules: list[str] | None = None - _keep_in_fp32_modules: list[str] | None = None - _skip_layerwise_casting_patterns: tuple[str, ...] | None = None - _supports_group_offloading: bool = True - _repeated_blocks: list[str] = field(default_factory=list) - _cp_plan: dict[str, Any] | None = None - _keys_to_ignore_on_load_unexpected: list[str] | None = None - _lora: LoRAHandler = field(default_factory=LoRAHandler) - _ip_adapter: IPAdapterHandler = field(default_factory=IPAdapterHandler) - _weight_mapping: WeightMappingHandler = field(default_factory=WeightMappingHandler) + _supports_gradient_checkpointing: bool = field( + default=False, + metadata={"doc": "Whether the model supports gradient checkpointing for memory-efficient training."}, + ) + _no_split_modules: list[str] | None = field( + default=None, + metadata={"doc": "Block class names that must stay on a single device under `device_map='auto'` sharding."}, + ) + _keep_in_fp32_modules: list[str] | None = field( + default=None, + metadata={"doc": "Submodule name patterns that must remain in fp32 even when the model is cast to fp16/bf16."}, + ) + _skip_layerwise_casting_patterns: tuple[str, ...] | None = field( + default=None, + metadata={"doc": "Parameter name substrings excluded from layerwise dtype casting (e.g. embeddings, norms)."}, + ) + _supports_group_offloading: bool = field( + default=True, + metadata={"doc": "Whether the model can be loaded with `enable_group_offload` for CPU/disk-staged inference."}, + ) + _repeated_blocks: list[str] = field( + default_factory=list, + metadata={ + "doc": "Block class names safe to `torch.compile` once and reuse — enables `compile_repeated_blocks`." + }, + ) + _cp_plan: dict[str, Any] | None = field( + default=None, + metadata={ + "label": "supports_context_parallel", + "doc": "Context-parallel I/O plan: which forward inputs to scatter and outputs to gather across CP ranks.", + }, + ) + _keys_to_ignore_on_load_unexpected: list[str] | None = field( + default=None, + metadata={ + "doc": "State-dict keys silently dropped at load time instead of being surfaced as 'unexpected keys'." + }, + ) + _lora: LoRAHandler = field( + default_factory=LoRAHandler, + metadata={"doc": "Foreign-format LoRA detection + conversion (kohya/xlabs/bfl/...) to diffusers naming."}, + ) + _ip_adapter: IPAdapterHandler = field( + default_factory=IPAdapterHandler, + metadata={"doc": "IP-Adapter weight conversion: attn-processor builders and image-projection construction."}, + ) + _weight_mapping: WeightMappingHandler = field( + default_factory=WeightMappingHandler, + metadata={ + "label": "supported_model_types", + "doc": "Single-file checkpoint loading: prefix stripping, key renaming, config auto-detection.", + }, + ) def _register(self, cls): - """Attach this ``ModelMetadata`` to ``cls``: stash for introspection, install handlers, mirror scalars. - - Two steps: - 1. Attach subsystem handlers directly as ``cls._lora`` / ``cls._weight_mapping`` / ``cls._ip_adapter``. - These are the runtime entry points; internal code reaches them via composition (e.g. - ``cls._weight_mapping.normalize_checkpoint_keys(...)``). - 2. Mirror the scalar capability fields (``_supports_gradient_checkpointing``, ``_no_split_modules``, - ``_cp_plan``, etc.) directly onto ``cls`` — consumed by code via bare attribute access. + """Attach this ``ModelMetadata`` to ``cls``. + + ``cls._metadata`` is the canonical umbrella for all subsystem access — internal callers reach the handlers via + ``cls._metadata._lora`` / ``._weight_mapping`` / ``._ip_adapter``. Scalar capability fields + (``_supports_gradient_checkpointing``, ``_no_split_modules``, ``_cp_plan``, etc.) are *additionally* mirrored + directly onto ``cls`` so existing code paths that do bare attribute access (e.g. ``cls._keep_in_fp32_modules`` + in ``from_pretrained``) keep working unchanged. """ - cls._model_metadata = self - # Subsystem handlers attach directly — they're already the runtime objects. - cls._lora = self._lora - cls._weight_mapping = self._weight_mapping - cls._ip_adapter = self._ip_adapter - # Scalar capability fields mirror to ``cls``. + cls._metadata = self + # Scalar capability fields mirror directly to ``cls``; handlers stay accessible only via ``cls._metadata``. for f in fields(self): if f.name in {"_lora", "_weight_mapping", "_ip_adapter"}: continue setattr(cls, f.name, getattr(self, f.name)) +def _render_metadata_value(value): + """Format a ``ModelMetadata`` field value for the introspection tables. + + Returns ``None`` to mean "this capability isn't present" — callers can skip the row entirely or render an empty + cell depending on whether they want a per-model view (:meth:`ModelMixin.describe_capabilities`) or a full-schema + view (:meth:`ModelMixin.doc`). + """ + if isinstance(value, LoRAHandler): + formats = sorted(value.format_keys) + return ", ".join(formats) if formats else None + if isinstance(value, WeightMappingHandler): + if not value.supports_single_file: + return None + return ", ".join(sorted(value.available_configs)) + if isinstance(value, IPAdapterHandler): + return "yes" if value.supports_ip_adapter else None + if isinstance(value, bool): + return "yes" if value else None + if value is None: + return None + if isinstance(value, dict): + return "yes" if value else None + if isinstance(value, (list, tuple)): + return ", ".join(map(str, value)) if value else None + return str(value) or None + + def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. @@ -369,6 +403,63 @@ def __getattr__(self, name: str) -> Any: # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module return super().__getattr__(name) + @classmethod + def describe_capabilities(cls) -> None: + """Print a two-column summary of the capabilities this model declares in ``cls._metadata``. + + Only present capabilities are shown — fields that are ``False``, ``None``, empty, or whose handler reports it + doesn't support the subsystem are omitted. Useful as a quick "what does this model support?" introspection at + the REPL. + """ + rows = [] + for f in fields(cls._metadata): + rendered = _render_metadata_value(getattr(cls._metadata, f.name)) + if rendered is None: + continue + label = f.metadata.get("label") or f.name.lstrip("_") + rows.append((label, rendered)) + + title = f"{cls.__name__} capabilities" + if not rows: + print(f"\n{title}\n(none declared)\n") + return + + name_w = max(len(n) for n, _ in rows) + width = max(len(title), name_w + max(len(v) for _, v in rows) + 2) + + print() + print(title) + print("─" * width) + for name, val in rows: + print(f"{name:<{name_w}} {val}") + print() + + @classmethod + def doc(cls, verbose: bool = False) -> None: + """Print every :class:`ModelMetadata` field with its current value. + + Columns: ``field name | current value``. The value column is blank for fields this model leaves at the default. + Pass ``verbose=True`` to also print each field's description. + """ + rows = [] + for f in fields(cls._metadata): + rendered = _render_metadata_value(getattr(cls._metadata, f.name)) or "" + label = f.metadata.get("label") or f.name.lstrip("_") + rows.append((label, rendered, f.metadata.get("doc", ""))) + + name_w = max(len(n) for n, _, _ in rows) + val_w = max(len(v) for _, v, _ in rows) + title = f"{cls.__name__} fields" + print() + print(title) + print("─" * len(title)) + for name, val, doc in rows: + if verbose: + print(f"{name:<{name_w}} {val:<{val_w}} {doc}") + else: + print(f"{name:<{name_w}} {val}") + print() + @property def is_gradient_checkpointing(self) -> bool: """ @@ -1532,7 +1623,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None # Convert checkpoint if needed (e.g., original format to diffusers format). For models that haven't # registered weight-mapping metadata this is a no-op via the default handler. - state_dict = cls._weight_mapping.maybe_convert_state_dict(model, state_dict) + state_dict = cls._metadata._weight_mapping.maybe_convert_state_dict(model, state_dict) if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] @@ -1680,18 +1771,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No load_single_file_checkpoint, ) - # ``cls._weight_mapping`` is the composed ``WeightMappingHandler`` (attached by ``@register_metadata``, - # or the no-op default placeholder on ``ModelMixin``). Bind it once and reuse below. Its + # The ``WeightMappingHandler`` is composed under ``cls._metadata._weight_mapping`` (by + # ``@register_metadata``, or the no-op default on ``ModelMixin``). Bind it once and reuse below. Its # ``supports_single_file`` property checks that the model declared enough to (a) bring keys to # diffusers naming (converter or prefix-only path) and (b) resolve which config to load # (``available_configs`` non-empty). - _weight_mapping = cls._weight_mapping + _weight_mapping = cls._metadata._weight_mapping if not _weight_mapping.supports_single_file: raise ValueError( f"`{cls.__name__}.from_single_file` is not supported. " - "The model's `WeightMappingMetadata` must register `_available_configs` (so we know which config " - "to load) plus at least one of: `_map_to_diffusers` (full key conversion) or " - "`_checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " + "The model's `WeightMappingHandler` must register `available_configs` (so we know which config " + "to load) plus at least one of: `map_to_diffusers_fn` (full key conversion) or " + "`checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " "foreign prefix). Use `from_pretrained` if the model is already in diffusers format." ) default_subfolder = _weight_mapping.default_subfolder diff --git a/src/diffusers/models/transformers/flux/ip_adapter.py b/src/diffusers/models/transformers/flux/ip_adapter.py index 1e7866c8e8e1..00012e2700f1 100644 --- a/src/diffusers/models/transformers/flux/ip_adapter.py +++ b/src/diffusers/models/transformers/flux/ip_adapter.py @@ -137,6 +137,6 @@ def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U # Handler assembled into ``ModelMetadata`` by ``flux/model.py``. FLUX_IP_ADAPTER = IPAdapterHandler( - convert_attn_to_diffusers=convert_attn_processors, - convert_image_proj_to_diffusers=convert_image_proj, + convert_attn_to_diffusers_fn=convert_attn_processors, + convert_image_proj_to_diffusers_fn=convert_image_proj, ) diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py index 2d2b7bb3b3ed..2285f60cde47 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/lora.py @@ -436,5 +436,5 @@ def map_lora_to_diffusers(state_dict, **kwargs): # Handler assembled into ``ModelMetadata`` by ``flux/model.py``. FLUX_LORA = LoRAHandler( format_keys=_FLUX_LORA_FORMAT_KEYS, - map_lora_to_diffusers=map_lora_to_diffusers, + map_lora_to_diffusers_fn=map_lora_to_diffusers, ) diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index 2d8b128d98dc..96d7b8b2daf0 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -307,8 +307,8 @@ def detect_config(weight_mapping, state_dict: dict[str, Any]) -> str | None: checkpoint_key_prefixes=_FLUX_CHECKPOINT_KEY_PREFIXES, rename_patterns=FLUX_RENAME_PATTERNS, available_configs=_FLUX_AVAILABLE_CONFIGS, - map_to_diffusers=map_to_diffusers, - map_from_diffusers=map_from_diffusers, + map_to_diffusers_fn=map_to_diffusers, + map_from_diffusers_fn=map_from_diffusers, detect_config_fn=detect_config, # Kicks in only when ``detect_config`` returns ``None`` (e.g. the ``img_in`` / ``x_embedder`` key is # absent so we can't read in_channels). Most Flux checkpoints in the wild are dev-derived, so it's From 68a3e9a5f2c26877a942c5e17e8762e2d3ec7d23 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 00:18:03 +0530 Subject: [PATCH 11/21] update --- src/diffusers/loaders/ip_adapter_model.py | 42 ++- src/diffusers/loaders/lora.py | 34 +- src/diffusers/loaders/weight_mapping.py | 110 +++--- src/diffusers/models/attention.py | 15 + src/diffusers/models/cache_utils.py | 14 + src/diffusers/models/modeling_utils.py | 333 ++++++++---------- .../models/transformers/flux/ip_adapter.py | 2 +- .../models/transformers/flux/lora.py | 2 +- .../models/transformers/flux/model.py | 41 +-- .../transformers/flux/weight_mapping.py | 70 ++-- 10 files changed, 330 insertions(+), 333 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index 13f2fd2835c8..fbc9e6867689 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -15,8 +15,8 @@ Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on :class:`IPAdapterModelMixin`. Per-model conversion lives in an :class:`IPAdapterHandler` declared next to the model -(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), composed into the model's ``ModelMetadata``, and attached as -``cls._metadata._ip_adapter`` by ``@register_metadata``. +(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as +``_ip_adapter = FLUX_IP_ADAPTER``. """ from dataclasses import dataclass @@ -26,9 +26,6 @@ from ..utils import is_torch_version, logging -# Local copy to avoid a circular import with ``models.modeling_utils`` — that module's -# end-of-file ``ModelMetadata()._register(ModelMixin)`` call instantiates the default -# ``IPAdapterHandler`` from here, so we can't import back into it during module load. _LOW_CPU_MEM_USAGE_DEFAULT = is_torch_version(">=", "1.9.0") @@ -39,7 +36,7 @@ class IPAdapterHandler: """Composition-style holder for a model class's IP-Adapter conversion callables. - Attached to ``cls._metadata._ip_adapter`` by :meth:`ModelMetadata._register`. The converter callables receive the + Attached as the ``_ip_adapter`` class attribute on :class:`IPAdapterModelMixin` (overridden per-model). The converter callables receive the model instance because they need to read its config (``attn_processors``, ``inner_dim``, etc.). Attributes: @@ -86,15 +83,30 @@ def convert_image_proj(self, model, image_proj_state_dict, low_cpu_mem_usage: bo class IPAdapterModelMixin: """Generic IP-Adapter loader for diffusers transformer / UNet models. - The per-model conversion callables live on ``self._metadata._ip_adapter`` (an :class:`IPAdapterHandler` composed by - the metadata decorator). This mixin owns only the orchestration: dispatching to the converters, wiring up + The per-model conversion callables live on ``self._ip_adapter`` (an :class:`IPAdapterHandler` assigned as a class + attribute by the model). This mixin owns only the orchestration: dispatching to the converters, wiring up ``MultiIPAdapterImageProjection``, and flipping ``encoder_hid_dim_type``. """ - # ``_ip_adapter: IPAdapterHandler`` is provided universally by ``ModelMixin`` (set via the default - # ``ModelMetadata()._register(ModelMixin)`` call). Models without IP-Adapter metadata inherit a no-op - # handler; calling ``_load_ip_adapter_weights`` on such a model raises ``NotImplementedError`` from inside - # the handler. + # Per-model IP-Adapter conversion config. Defaults to an empty handler; models that support IP-Adapter assign + # ``_ip_adapter = FLUX_IP_ADAPTER`` (etc.) in their class body. Calling ``_load_ip_adapter_weights`` on a + # model that didn't override raises ``NotImplementedError`` from inside the handler. + _ip_adapter: IPAdapterHandler = IPAdapterHandler() + + @classmethod + def _metadata(cls): + """Contribute the ``ip_adapter`` row to :class:`ModelMetadata` when converters are registered.""" + from ..models.modeling_utils import DOCS_BASE + + if not cls._ip_adapter.supports_ip_adapter: + return {} + return { + "ip_adapter": ( + "yes", + "Supports loading IP-Adapter weights (image-conditioning adapters).", + f"{DOCS_BASE}/using-diffusers/ip_adapter", + ) + } def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): """Install IP-Adapter weights on the model. @@ -102,7 +114,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. """ - if not self._metadata._ip_adapter.supports_ip_adapter: + if not self._ip_adapter.supports_ip_adapter: raise NotImplementedError( f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterHandler." ) @@ -112,14 +124,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U self.encoder_hid_proj = None - attn_procs = self._metadata._ip_adapter.convert_attn_processors( + attn_procs = self._ip_adapter.convert_attn_processors( self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage ) self.set_attn_processor(attn_procs) image_projection_layers = [] for state_dict in state_dicts: - image_projection_layer = self._metadata._ip_adapter.convert_image_proj( + image_projection_layer = self._ip_adapter.convert_image_proj( self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage ) image_projection_layers.append(image_projection_layer) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index e7cc6379d91f..b4639a7ccf8a 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -108,7 +108,7 @@ def _normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, class LoRAHandler: """Composition-style holder for a model class's LoRA conversion configuration. - Attached to ``cls._metadata._lora`` by :meth:`ModelMetadata._register`. Holds the per-model foreign-format + Attached as the ``_lora`` class attribute on :class:`LoRAModelMixin` (overridden per-model). Holds the per-model foreign-format conversion data. Public conversion utilities (``normalize_lora_suffixes``, ``detect_lora_format``) live on :class:`LoRAModelMixin` and read from this handler. @@ -398,7 +398,7 @@ class LoRAModelMixin: / hotswap) plus foreign-format conversion (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. Per-model conversion knobs live in a :class:`LoRAHandler` declared in the model's ``lora.py`` (e.g. ``FLUX_LORA``) - and attached to the class via ``@register_metadata(ModelMetadata(_lora=...))``. The default no-op handler just + and assigned to the model class as ``_lora = FLUX_LORA``. The default no-op handler just normalizes ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. Install the latest version of PEFT, and use this mixin to: @@ -409,13 +409,29 @@ class LoRAModelMixin: - Get a list of the active adapters. """ - # Runtime PEFT state — set during adapter load / hotswap setup. Not part of the metadata-driven config. + # Runtime PEFT state — set during adapter load / hotswap setup. _hf_peft_config_loaded = False _lora_hotswap_kwargs: Optional[dict] = None - # ``_lora: LoRAHandler`` is provided universally by ``ModelMixin`` (set via the default - # ``ModelMetadata()._register(ModelMixin)`` call). Models without LoRA conversion metadata inherit a no-op - # handler; ``@register_metadata(ModelMetadata(_lora=...))`` overrides it on the subclass. + # Per-model LoRA conversion config. Defaults to a no-op handler (only suffix normalization, no foreign-format + # conversion). Models override by assigning ``_lora = FLUX_LORA`` (etc.) in their class body. + _lora: LoRAHandler = LoRAHandler() + + @classmethod + def _metadata(cls): + """Contribute the ``lora_formats`` row to :class:`ModelMetadata` when foreign formats are registered.""" + from ..models.modeling_utils import DOCS_BASE + + formats = sorted(cls._lora.format_keys) + if not formats: + return {} + return { + "lora_formats": ( + ", ".join(formats), + "Foreign LoRA formats this model converts to diffusers naming on load.", + f"{DOCS_BASE}/training/lora", + ) + } @staticmethod def normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: @@ -430,9 +446,9 @@ def detect_lora_format(self, state_dict: Dict[str, "torch.Tensor"]) -> Optional[ """Return the foreign LoRA format name (``"kohya"`` / ``"xlabs"`` / ...) matched by ``state_dict``, or ``None`` if no registered format matches (e.g. it's already in diffusers naming). - Reads ``self._metadata._lora.format_keys`` (the per-model registry of identifying key substrings). + Reads ``self._lora.format_keys`` (the per-model registry of identifying key substrings). """ - format_keys = self._metadata._lora.format_keys + format_keys = self._lora.format_keys if not format_keys: return None keys = set(state_dict) @@ -579,7 +595,7 @@ def _load_adapter_from_pretrained( # Universal suffix normalization first (kohya-style ``.lora_down/.lora_up`` → ``.lora_A/.lora_B``), then # run the per-model foreign-format converter (no-op when none is registered). state_dict = self.normalize_lora_suffixes(state_dict) - state_dict = self._metadata._lora.map_to_diffusers(state_dict) + state_dict = self._lora.map_to_diffusers(state_dict) if not state_dict: model_class_name = self.__class__.__name__ logger.warning( diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index 125e1d4c266b..3afc085e7fba 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -15,13 +15,9 @@ """Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. A model declares its mapping in a :class:`WeightMappingHandler` instance (typically in its ``weight_mapping.py`` -module). The ``@register_metadata`` decorator bundles it into the model's ``ModelMetadata``, reachable as -``cls._metadata._weight_mapping``. Internal call sites go through ``cls._metadata._weight_mapping.X`` (e.g. -``cls._metadata._weight_mapping.normalize_checkpoint_keys(state_dict)``) instead of flattening the methods onto the -model class itself. - -The :meth:`WeightMappingHandler.apply_transforms` helper drives the forward direction from a single declarative table — -see ``models/transformers/flux/weight_mapping.py`` for an example. +module) and assigns it to the class as ``_weight_mapping = FLUX_WEIGHT_MAPPING``. Internal call sites go through +``cls._weight_mapping.X`` (e.g. ``cls._weight_mapping.normalize_state_dict_keys(state_dict)``) instead of flattening +the methods onto the model class itself. """ from dataclasses import dataclass, field @@ -33,22 +29,31 @@ logger = logging.get_logger(__name__) +# Foreign key prefixes seen across multiple model families' single-file checkpoints. Stripping these is +# universally safe (no model uses them as native diffusers keys), so the handler defaults to removing them on +# every load. Models with additional, family-specific prefixes can extend or override +# ``prefixes_to_remove`` on their handler. +PREFIXES_TO_REMOVE: list[str] = [ + "model.diffusion_model.", +] + + @dataclass class WeightMappingHandler: """Composition-style holder for a model class's weight-mapping configuration and helpers. - Attached to ``cls._metadata._weight_mapping`` by :meth:`ModelMetadata._register`. Owns all the data (available - configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, normalize) for - single-file checkpoint loading. Internal callers reach it via ``cls._metadata._weight_mapping.X``. + Attached as the ``_weight_mapping`` class attribute on :class:`ModelMixin` (overridden per-model). Owns all + the data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, + detect, normalize) for single-file checkpoint loading. Internal callers reach it via ``cls._weight_mapping.X``. Attributes: - checkpoint_keys: Distinctive keys whose presence indicates the checkpoint is in the original - (pre-diffusers) format. - checkpoint_key_prefixes: Foreign prefixes (e.g. ``["model.diffusion_model."]``) the handler will strip via - :meth:`normalize_checkpoint_keys`. Set this on prefix-only models to skip registering a - ``map_to_diffusers_fn`` callable. - rename_patterns: Default rename patterns shared between forward and reverse conversions (consumed by - :meth:`apply_transforms`). + original_format_keys: Distinctive keys whose presence indicates the state_dict is in the original + (pre-diffusers) format. Used by :meth:`is_original_format` to decide whether key conversion is + needed. + prefixes_to_remove: Foreign prefixes (e.g. ``["model.diffusion_model."]``) the handler will strip via + :meth:`normalize_state_dict_keys`. Defaults to the shared :data:`PREFIXES_TO_REMOVE` list — most models + only need that. Extend it for family-specific wrappers; prefix-only models can rely on the default and skip + registering a ``map_to_diffusers_fn`` callable. available_configs: Map of short config name to hub repo id (e.g. ``{"flux-dev": "black-forest-labs/FLUX.1-dev"}``). default_config: Config name (key into ``available_configs``) used when ``detect_config_fn`` is @@ -61,9 +66,8 @@ class WeightMappingHandler: ``available_configs``, or ``None`` to fall back to ``default_config``. """ - checkpoint_keys: set = field(default_factory=set) - checkpoint_key_prefixes: list = field(default_factory=list) - rename_patterns: dict = field(default_factory=dict) + original_format_keys: set = field(default_factory=set) + prefixes_to_remove: list = field(default_factory=lambda: list(PREFIXES_TO_REMOVE)) available_configs: dict = field(default_factory=dict) default_config: Optional[str] = None default_subfolder: str = "transformer" @@ -94,19 +98,24 @@ def rename_key(key: str, patterns: dict) -> str: return key def is_original_format(self, state_dict: dict) -> bool: - """Check if state_dict is in original (non-diffusers) format by presence of a known foreign key.""" - if not self.checkpoint_keys: + """Check if state_dict is in the original (pre-diffusers) format by presence of a known marker key. + + Returns ``True`` only when a registered ``original_format_keys`` entry is observed in the state_dict. + Returning ``False`` means "no positive evidence of original format" — empty / unrelated / unknown + state_dicts all fall here. Callers treat ``False`` as "proceed with diffusers-native keys." + """ + if not self.original_format_keys: return False - return bool(self.checkpoint_keys & set(state_dict.keys())) + return bool(self.original_format_keys & set(state_dict.keys())) - def normalize_checkpoint_keys(self, state_dict: dict) -> dict: + def normalize_state_dict_keys(self, state_dict: dict) -> dict: """Strip known foreign prefixes (e.g. ``model.diffusion_model.``) from state_dict keys.""" - if not self.checkpoint_key_prefixes: + if not self.prefixes_to_remove: return state_dict result = {} for key, value in state_dict.items(): new_key = key - for prefix in self.checkpoint_key_prefixes: + for prefix in self.prefixes_to_remove: if key.startswith(prefix): new_key = key[len(prefix) :] break @@ -126,18 +135,18 @@ def detect_config(self, state_dict: dict) -> Optional[str]: return self.detect_config_fn(self, state_dict) def get_model_config(self, state_dict: dict) -> str: - """Resolve the hub repo id whose config best matches this checkpoint. + """Resolve the hub repo id whose config best matches this state_dict. Resolution order: 1. Run ``detect_config(state_dict)`` (if a detector is registered). 2. If detection returns ``None``, fall back to ``default_config`` and warn (since the user is now getting a - config that may not match the checkpoint shape). + config that may not match the state_dict shape). 3. Look up the chosen name in ``available_configs`` to get the hub repo id. """ detected = self.detect_config(state_dict) if detected is None and self.default_config is not None and self.detect_config_fn is not None: logger.warning( - f"Could not auto-detect a config for this checkpoint; falling back to default_config=" + f"Could not auto-detect a config for this state_dict; falling back to default_config=" f"'{self.default_config}' ({self.available_configs.get(self.default_config)}). " f"If this is the wrong architecture, pass `config=` to `from_single_file(...)` " f"explicitly. Known configs: {sorted(self.available_configs)}." @@ -147,7 +156,7 @@ def get_model_config(self, state_dict: dict) -> str: available = sorted(self.available_configs) or "" has_detector = self.detect_config_fn is not None raise ValueError( - "Could not determine which config to load for this checkpoint.\n" + "Could not determine which config to load for this state_dict.\n" "\n" f" Detection: {'registered, but returned None for this state_dict' if has_detector else 'no detect_config_fn registered'}\n" " Default config: not set\n" @@ -171,7 +180,7 @@ def map_to_diffusers(self, state_dict: dict, **kwargs) -> dict: """Convert state_dict from original format to diffusers format. No-op (returns ``state_dict`` unchanged) if no converter callable is registered; callers are expected to use - the prefix-only path (via :meth:`normalize_checkpoint_keys`) in that case. + the prefix-only path (via :meth:`normalize_state_dict_keys`) in that case. """ if self.map_to_diffusers_fn is None: return state_dict @@ -180,18 +189,18 @@ def map_to_diffusers(self, state_dict: dict, **kwargs) -> dict: def maybe_convert_state_dict(self, model, state_dict: dict) -> dict: """Bring ``state_dict`` to diffusers naming if it isn't already. Two phases: - 1. :meth:`normalize_checkpoint_keys` — strip known prefixes (idempotent; no-op if none registered). + 1. :meth:`normalize_state_dict_keys` — strip known prefixes (idempotent; no-op if none registered). 2. :meth:`map_to_diffusers` — full key conversion, only invoked if step 1 alone didn't make the keys match the model's. Skipped (no-op) if no converter callable was registered. Idempotent overall: calling twice produces the same result as calling once. """ - state_dict = self.normalize_checkpoint_keys(state_dict) + state_dict = self.normalize_state_dict_keys(state_dict) model_keys = set(model.state_dict().keys()) - ckpt_keys = set(state_dict.keys()) - # If the model's keys are a (strict) subset of the checkpoint's, the rest is extras we'll surface later + state_dict_keys = set(state_dict.keys()) + # If the model's keys are a (strict) subset of the state_dict's, the rest is extras we'll surface later # via the missing/unexpected keys report — but no key-renaming pass is needed. - if model_keys.issubset(ckpt_keys): + if model_keys.issubset(state_dict_keys): return state_dict return self.map_to_diffusers(state_dict) @@ -200,32 +209,3 @@ def map_from_diffusers(self, state_dict: dict, **kwargs) -> dict: if self.map_from_diffusers_fn is None: raise NotImplementedError("No `map_from_diffusers_fn` callable registered for this model.") return self.map_from_diffusers_fn(state_dict, **kwargs) - - # ---- driver for declarative transforms ---- - - @staticmethod - def apply_transforms(state_dict, transforms, rename_patterns, **ctx): - """Drive a forward state-dict conversion from a list of (source, targets, fn) entries. - - Each entry is a tuple ``(source, targets, forward_fn, reverse_fn)``: - - ``source``: substring matched against each key (with surrounding dots, e.g. ``".img_attn.qkv."``); the - first matching entry wins. - - ``targets``: list of substrings substituted for ``source`` to build the output keys. ``len(targets)`` is - the fan-out (1 for a unary transform, >1 for a split). - - ``forward_fn(value, **ctx) -> list[tensor]`` returns one tensor per target. (``reverse_fn`` is reserved for - a future ``apply_reverse_transforms`` driver.) - - Keys that match no transform get their dots renamed via ``rename_patterns``. - """ - out = {} - for key, value in state_dict.items(): - for source, targets, forward_fn, _ in transforms: - if source in key: - tensors = forward_fn(value, **ctx) - for target, tensor in zip(targets, tensors): - new_key = WeightMappingHandler.rename_key(key.replace(source, target), rename_patterns) - out[new_key] = tensor - break - else: - out[WeightMappingHandler.rename_key(key, rename_patterns)] = value - return out diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f4cd1ff6856b..0915b8467400 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -37,6 +37,21 @@ class AttentionMixin: + _supports_attention = True + + @classmethod + def _metadata(cls): + """Contribute the ``attention`` row to :class:`ModelMetadata` when the model inherits :class:`AttentionMixin`.""" + from .modeling_utils import DOCS_BASE + + return { + "attention": ( + "yes", + "Model contains attention modules; supports `set_attention_backend(...)`.", + f"{DOCS_BASE}/optimization/attention_backends", + ) + } + @property def attn_processors(self) -> dict[str, AttentionProcessor]: r""" diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 161fcf426f21..0d8d4d17e84c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -31,6 +31,20 @@ class CacheMixin: """ _cache_config = None + _supports_cache = True + + @classmethod + def _metadata(cls): + """Contribute the ``cache`` row to :class:`ModelMetadata` when the model inherits :class:`CacheMixin`.""" + from .modeling_utils import DOCS_BASE + + return { + "cache": ( + "yes", + "Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.", + f"{DOCS_BASE}/optimization/cache", + ) + } @property def is_cache_enabled(self) -> bool: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f3ccdd90f59d..a494be21c692 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,7 @@ import tempfile from collections import OrderedDict from contextlib import ExitStack, contextmanager, nullcontext -from dataclasses import dataclass, field, fields +from dataclasses import dataclass from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -40,8 +40,7 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..loaders.ip_adapter_model import IPAdapterHandler -from ..loaders.lora import LoRAHandler, LoRAModelMixin +from ..loaders.lora import LoRAModelMixin from ..loaders.weight_mapping import WeightMappingHandler from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod @@ -235,119 +234,63 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -@dataclass -class ModelMetadata: - _supports_gradient_checkpointing: bool = field( - default=False, - metadata={"doc": "Whether the model supports gradient checkpointing for memory-efficient training."}, - ) - _no_split_modules: list[str] | None = field( - default=None, - metadata={"doc": "Block class names that must stay on a single device under `device_map='auto'` sharding."}, - ) - _keep_in_fp32_modules: list[str] | None = field( - default=None, - metadata={"doc": "Submodule name patterns that must remain in fp32 even when the model is cast to fp16/bf16."}, - ) - _skip_layerwise_casting_patterns: tuple[str, ...] | None = field( - default=None, - metadata={"doc": "Parameter name substrings excluded from layerwise dtype casting (e.g. embeddings, norms)."}, - ) - _supports_group_offloading: bool = field( - default=True, - metadata={"doc": "Whether the model can be loaded with `enable_group_offload` for CPU/disk-staged inference."}, - ) - _repeated_blocks: list[str] = field( - default_factory=list, - metadata={ - "doc": "Block class names safe to `torch.compile` once and reuse — enables `compile_repeated_blocks`." - }, - ) - _cp_plan: dict[str, Any] | None = field( - default=None, - metadata={ - "label": "supports_context_parallel", - "doc": "Context-parallel I/O plan: which forward inputs to scatter and outputs to gather across CP ranks.", - }, - ) - _keys_to_ignore_on_load_unexpected: list[str] | None = field( - default=None, - metadata={ - "doc": "State-dict keys silently dropped at load time instead of being surfaced as 'unexpected keys'." - }, - ) - _lora: LoRAHandler = field( - default_factory=LoRAHandler, - metadata={"doc": "Foreign-format LoRA detection + conversion (kohya/xlabs/bfl/...) to diffusers naming."}, - ) - _ip_adapter: IPAdapterHandler = field( - default_factory=IPAdapterHandler, - metadata={"doc": "IP-Adapter weight conversion: attn-processor builders and image-projection construction."}, - ) - _weight_mapping: WeightMappingHandler = field( - default_factory=WeightMappingHandler, - metadata={ - "label": "supported_model_types", - "doc": "Single-file checkpoint loading: prefix stripping, key renaming, config auto-detection.", - }, - ) - - def _register(self, cls): - """Attach this ``ModelMetadata`` to ``cls``. - - ``cls._metadata`` is the canonical umbrella for all subsystem access — internal callers reach the handlers via - ``cls._metadata._lora`` / ``._weight_mapping`` / ``._ip_adapter``. Scalar capability fields - (``_supports_gradient_checkpointing``, ``_no_split_modules``, ``_cp_plan``, etc.) are *additionally* mirrored - directly onto ``cls`` so existing code paths that do bare attribute access (e.g. ``cls._keep_in_fp32_modules`` - in ``from_pretrained``) keep working unchanged. - """ - cls._metadata = self - # Scalar capability fields mirror directly to ``cls``; handlers stay accessible only via ``cls._metadata``. - for f in fields(self): - if f.name in {"_lora", "_weight_mapping", "_ip_adapter"}: - continue - setattr(cls, f.name, getattr(self, f.name)) +# Base URL for the diffusers docs. Used by each mixin's ``_metadata`` classmethod to build per-capability +# docs links. Adjust as the docs layout evolves — links are reader hints, nothing depends on them +# programmatically. +DOCS_BASE = "https://huggingface.co/docs/diffusers/main/en" -def _render_metadata_value(value): - """Format a ``ModelMetadata`` field value for the introspection tables. +@dataclass(frozen=True) +class ModelMetadata: + """Read-only snapshot of a model class's capabilities. + + Constructed by :meth:`ModelMixin.metadata`, which walks ``cls.__mro__`` collecting rows from each mixin's + ``_metadata`` classmethod. Purely a display object — printing it renders a formatted table. + Programmatic handler access (``model._lora``, ``model._weight_mapping``, ``model._ip_adapter``) goes + through the class attributes directly. - Returns ``None`` to mean "this capability isn't present" — callers can skip the row entirely or render an empty - cell depending on whether they want a per-model view (:meth:`ModelMixin.describe_capabilities`) or a full-schema - view (:meth:`ModelMixin.doc`). + ``rows`` maps a capability label to ``(value, description, docs_url)``. ``verbose=True`` (via + ``Model.metadata(verbose=True)``) renders the description and docs link under each row. """ - if isinstance(value, LoRAHandler): - formats = sorted(value.format_keys) - return ", ".join(formats) if formats else None - if isinstance(value, WeightMappingHandler): - if not value.supports_single_file: - return None - return ", ".join(sorted(value.available_configs)) - if isinstance(value, IPAdapterHandler): - return "yes" if value.supports_ip_adapter else None - if isinstance(value, bool): - return "yes" if value else None - if value is None: - return None - if isinstance(value, dict): - return "yes" if value else None - if isinstance(value, (list, tuple)): - return ", ".join(map(str, value)) if value else None - return str(value) or None + + rows: dict[str, tuple[str, str, str]] + verbose: bool = False + + def __repr__(self) -> str: + if not self.rows: + return "ModelMetadata(no capabilities declared)" + + name_w = max(len(n) for n in self.rows) + if not self.verbose: + lines = ["ModelMetadata:"] + for label, (value, _doc, _link) in self.rows.items(): + lines.append(f" {label:<{name_w}} {value}") + return "\n".join(lines) + + lines = ["ModelMetadata:"] + for label, (value, doc, link) in self.rows.items(): + lines.append(f" {label:<{name_w}} {value}") + if doc: + lines.append(f" {'':<{name_w}} {doc}") + if link: + lines.append(f" {'':<{name_w}} docs: {link}") + lines.append("") + return "\n".join(lines).rstrip() def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. - Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own attachment logic. Works for both - model-level metadata (``ModelMetadata``) and block-level metadata (``TransformerBlockMetadata``):: + Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own attachment logic. Currently used + by :class:`~diffusers.hooks._helpers.TransformerBlockMetadata` to register block-level metadata into + :class:`TransformerBlockRegistry`:: - @register_metadata(FLUX_MODEL_METADATA) class FluxTransformer2DModel(...): + @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) + class FluxTransformerBlock(nn.Module): ... - @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) class - FluxTransformerBlock(nn.Module): - ... + Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and the appropriate + subsystem mixins like :class:`LoRAModelMixin`, :class:`IPAdapterModelMixin`) — no decorator needed. """ def wrap(cls): @@ -369,22 +312,107 @@ class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, PushToHubMixin): [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ - # Non-metadata class attrs. Everything that lives on ``ModelMetadata`` (capability flags, partition plans, - # subsystem handlers) is set at the bottom of this module via ``ModelMetadata()._register(ModelMixin)`` so - # ``ModelMetadata``'s dataclass field defaults are the single source of truth. config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _skip_keys = None + _supports_gradient_checkpointing: bool = False + _no_split_modules: list[str] | None = None + _keep_in_fp32_modules: list[str] | None = None + _skip_layerwise_casting_patterns: tuple[str, ...] | list[str] | None = None + _supports_group_offloading: bool = True + _repeated_blocks: list[str] = [] + _cp_plan: dict[str, Any] | None = None + _keys_to_ignore_on_load_unexpected: list[str] | None = None + _weight_mapping: WeightMappingHandler = WeightMappingHandler() + def __init__(self): super().__init__() self._gradient_checkpointing_func = None + @classmethod + def _metadata(cls) -> dict[str, tuple[str, str, str]]: + """Return ``ModelMixin``-level rows for the metadata snapshot. + + Each row is keyed by capability label and maps to ``(value, description, docs_url)``. Only present + capabilities are returned. See :meth:`metadata` for the aggregated view across all mixins in + ``cls.__mro__``. + """ + rows: dict[str, tuple[str, str, str]] = {} + if cls._supports_gradient_checkpointing: + rows["gradient_checkpointing"] = ( + "yes", + "Trades compute for memory by recomputing activations during backward.", + f"{DOCS_BASE}/optimization/memory#gradient-checkpointing", + ) + if cls._supports_group_offloading: + rows["group_offloading"] = ( + "yes", + "Stage parameter groups on CPU/disk and stream them to the accelerator for inference.", + f"{DOCS_BASE}/optimization/memory#group-offloading", + ) + if cls._no_split_modules: + rows["no_split_modules"] = ( + ", ".join(cls._no_split_modules), + "Block class names that must stay on a single device under `device_map='auto'` sharding.", + f"{DOCS_BASE}/training/distributed_inference#device-map", + ) + if cls._keep_in_fp32_modules: + rows["keep_in_fp32_modules"] = ( + ", ".join(cls._keep_in_fp32_modules), + "Submodule name patterns that remain in fp32 even when the model is cast to fp16/bf16.", + f"{DOCS_BASE}/optimization/fp16#mixed-precision", + ) + if cls._skip_layerwise_casting_patterns: + rows["skip_layerwise_casting_patterns"] = ( + ", ".join(cls._skip_layerwise_casting_patterns), + "Parameter name substrings excluded from layerwise dtype casting (embeddings, norms, ...).", + f"{DOCS_BASE}/optimization/memory#layerwise-casting", + ) + if cls._repeated_blocks: + rows["repeated_blocks"] = ( + ", ".join(cls._repeated_blocks), + "Block class names safe to `torch.compile` once and reuse — enables `compile_repeated_blocks`.", + f"{DOCS_BASE}/optimization/torch2.0", + ) + if cls._cp_plan: + rows["context_parallel"] = ( + "yes", + "Forward inputs/outputs are scatter/gathered across context-parallel ranks.", + f"{DOCS_BASE}/training/distributed_inference#context-parallelism", + ) + if cls._weight_mapping.supports_single_file: + rows["supported_model_types"] = ( + ", ".join(sorted(cls._weight_mapping.available_configs)), + "Auto-resolvable configs for `from_single_file(path)` (no `config=` argument required).", + f"{DOCS_BASE}/api/loaders/single_file", + ) + return rows + + @classmethod + def metadata(cls, verbose: bool = False) -> "ModelMetadata": + """Return a read-only snapshot of this class's capabilities. + + Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod + (handled via direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen + wins on label collisions; this puts the model's own overrides ahead of inherited defaults. + + ``print(Model.metadata())`` shows the formatted table. Pass ``verbose=True`` to render each row with + a description and a link to the relevant docs section. + """ + merged: dict[str, tuple[str, str, str]] = {} + for klass in cls.__mro__: + method = klass.__dict__.get("_metadata") + if method is None: + continue + for label, info in method.__func__(cls).items(): + merged.setdefault(label, info) + return ModelMetadata(rows=merged, verbose=verbose) + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite @@ -403,63 +431,6 @@ def __getattr__(self, name: str) -> Any: # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module return super().__getattr__(name) - @classmethod - def describe_capabilities(cls) -> None: - """Print a two-column summary of the capabilities this model declares in ``cls._metadata``. - - Only present capabilities are shown — fields that are ``False``, ``None``, empty, or whose handler reports it - doesn't support the subsystem are omitted. Useful as a quick "what does this model support?" introspection at - the REPL. - """ - rows = [] - for f in fields(cls._metadata): - rendered = _render_metadata_value(getattr(cls._metadata, f.name)) - if rendered is None: - continue - label = f.metadata.get("label") or f.name.lstrip("_") - rows.append((label, rendered)) - - title = f"{cls.__name__} capabilities" - if not rows: - print(f"\n{title}\n(none declared)\n") - return - - name_w = max(len(n) for n, _ in rows) - width = max(len(title), name_w + max(len(v) for _, v in rows) + 2) - - print() - print(title) - print("─" * width) - for name, val in rows: - print(f"{name:<{name_w}} {val}") - print() - - @classmethod - def doc(cls, verbose: bool = False) -> None: - """Print every :class:`ModelMetadata` field with its current value. - - Columns: ``field name | current value``. The value column is blank for fields this model leaves at the default. - Pass ``verbose=True`` to also print each field's description. - """ - rows = [] - for f in fields(cls._metadata): - rendered = _render_metadata_value(getattr(cls._metadata, f.name)) or "" - label = f.metadata.get("label") or f.name.lstrip("_") - rows.append((label, rendered, f.metadata.get("doc", ""))) - - name_w = max(len(n) for n, _, _ in rows) - val_w = max(len(v) for _, v, _ in rows) - title = f"{cls.__name__} fields" - print() - print(title) - print("─" * len(title)) - for name, val, doc in rows: - if verbose: - print(f"{name:<{name_w}} {val:<{val_w}} {doc}") - else: - print(f"{name:<{name_w}} {val}") - print() - @property def is_gradient_checkpointing(self) -> bool: """ @@ -1623,7 +1594,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None # Convert checkpoint if needed (e.g., original format to diffusers format). For models that haven't # registered weight-mapping metadata this is a no-op via the default handler. - state_dict = cls._metadata._weight_mapping.maybe_convert_state_dict(model, state_dict) + state_dict = cls._weight_mapping.maybe_convert_state_dict(model, state_dict) if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] @@ -1771,19 +1742,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No load_single_file_checkpoint, ) - # The ``WeightMappingHandler`` is composed under ``cls._metadata._weight_mapping`` (by - # ``@register_metadata``, or the no-op default on ``ModelMixin``). Bind it once and reuse below. Its - # ``supports_single_file`` property checks that the model declared enough to (a) bring keys to - # diffusers naming (converter or prefix-only path) and (b) resolve which config to load - # (``available_configs`` non-empty). - _weight_mapping = cls._metadata._weight_mapping + # The ``WeightMappingHandler`` is attached as ``cls._weight_mapping`` — either overridden by the model + # (e.g. ``_weight_mapping = FLUX_WEIGHT_MAPPING``) or inherited as the empty default from ``ModelMixin``. + # Its ``supports_single_file`` property checks that the model declared a ``default_config`` so config + # resolution always succeeds with no extra args from the user. + _weight_mapping = cls._weight_mapping if not _weight_mapping.supports_single_file: raise ValueError( f"`{cls.__name__}.from_single_file` is not supported. " - "The model's `WeightMappingHandler` must register `available_configs` (so we know which config " - "to load) plus at least one of: `map_to_diffusers_fn` (full key conversion) or " - "`checkpoint_key_prefixes` (prefix-only conversion for diffusers-format checkpoints with a " - "foreign prefix). Use `from_pretrained` if the model is already in diffusers format." + "The model's `WeightMappingHandler` must declare `default_config` (a key into " + "`available_configs`) so we can resolve which architecture to instantiate when the user " + "doesn't pass `config=` explicitly. Use `from_pretrained` if the model is already in " + "diffusers format." ) default_subfolder = _weight_mapping.default_subfolder @@ -1829,7 +1799,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No ) # Normalize state_dict keys via the weight-mapping handler (strip known prefixes; no-op if none registered). - state_dict = _weight_mapping.normalize_checkpoint_keys(state_dict) + state_dict = _weight_mapping.normalize_state_dict_keys(state_dict) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) @@ -1881,7 +1851,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No else: keep_in_fp32_modules = [] - # ``normalize_checkpoint_keys`` already ran earlier (before model creation) for detection; this call is + # ``normalize_state_dict_keys`` already ran earlier (before model creation) for detection; this call is # idempotent and runs the full converter only if keys still don't match the freshly-built model. state_dict = _weight_mapping.maybe_convert_state_dict(model, state_dict) @@ -2554,13 +2524,6 @@ def recursive_find_attn_block(name, module): return state_dict -# Seed ``ModelMixin`` with the default ``ModelMetadata()``: capability flags, partition plans, and subsystem -# handler placeholders. Every subclass inherits these via MRO; ``@register_metadata(ModelMetadata(...))`` on a -# subclass overrides them with model-specific values. Doing this here means the dataclass field defaults are -# the single source of truth — no parallel copy of the same defaults declared in ``ModelMixin``'s class body. -ModelMetadata()._register(ModelMixin) - - class LegacyModelMixin(ModelMixin): r""" A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more diff --git a/src/diffusers/models/transformers/flux/ip_adapter.py b/src/diffusers/models/transformers/flux/ip_adapter.py index 00012e2700f1..df7a83aa4200 100644 --- a/src/diffusers/models/transformers/flux/ip_adapter.py +++ b/src/diffusers/models/transformers/flux/ip_adapter.py @@ -135,7 +135,7 @@ def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U return attn_procs -# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +# Assigned to ``FluxTransformer2DModel`` as the ``_ip_adapter`` class attribute in ``flux/model.py``. FLUX_IP_ADAPTER = IPAdapterHandler( convert_attn_to_diffusers_fn=convert_attn_processors, convert_image_proj_to_diffusers_fn=convert_image_proj, diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/lora.py index 2285f60cde47..a7018d8d3507 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/lora.py @@ -433,7 +433,7 @@ def map_lora_to_diffusers(state_dict, **kwargs): return state_dict -# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +# Assigned to ``FluxTransformer2DModel`` as the ``_lora`` class attribute in ``flux/model.py``. FLUX_LORA = LoRAHandler( format_keys=_FLUX_LORA_FORMAT_KEYS, map_lora_to_diffusers_fn=map_lora_to_diffusers, diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 321deeccaeb4..764f7e52a609 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -36,7 +36,7 @@ get_1d_rotary_pos_embed, ) from ...modeling_outputs import Transformer2DModelOutput -from ...modeling_utils import ModelMetadata, ModelMixin, register_metadata +from ...modeling_utils import ModelMixin, register_metadata from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from .ip_adapter import FLUX_IP_ADAPTER from .lora import FLUX_LORA @@ -528,27 +528,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -_METADATA = ModelMetadata( - _supports_gradient_checkpointing=True, - _no_split_modules=["FluxTransformerBlock", "FluxSingleTransformerBlock"], - _skip_layerwise_casting_patterns=("pos_embed", "norm"), - _repeated_blocks=["FluxTransformerBlock", "FluxSingleTransformerBlock"], - _cp_plan={ - "": { - "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - }, - "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), - }, - _lora=FLUX_LORA, - _weight_mapping=FLUX_WEIGHT_MAPPING, - _ip_adapter=FLUX_IP_ADAPTER, -) - - -@register_metadata(_METADATA) class FluxTransformer2DModel( ModelMixin, AttentionMixin, @@ -586,6 +565,24 @@ class FluxTransformer2DModel( The dimensions to use for the rotary positional embeddings. """ + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ("pos_embed", "norm") + _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + _lora = FLUX_LORA + _weight_mapping = FLUX_WEIGHT_MAPPING + _ip_adapter = FLUX_IP_ADAPTER + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index 96d7b8b2daf0..cdd998701b89 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -63,23 +63,16 @@ def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor: # -------------------------------------------------------------------------- # Per-key transforms (split + special), unified. # -------------------------------------------------------------------------- -# Single source of truth. Each entry is -# (source_substring, [target_substrings], forward_fn, reverse_fn) -# - source/targets include surrounding dots so they only match at module -# boundaries (e.g. ".img_attn.qkv." matches both "X.img_attn.qkv.weight" -# and "X.img_attn.qkv.bias" with one entry). +# Each entry is ``(source_substring, [target_substrings], forward_fn)``: +# - source/targets include surrounding dots so they only match at module boundaries +# (e.g. ".img_attn.qkv." matches both "X.img_attn.qkv.weight" and "X.img_attn.qkv.bias"). # - len(targets) == 1 -> a unary transform (e.g. AdaLN scale/shift swap). # - len(targets) > 1 -> a split transform (forward chunks the tensor). # - forward_fn(tensor, **ctx) -> list[tensor] of length len(targets). -# - reverse_fn(list[tensor], **ctx) -> tensor. def _swap_to_list(v, **_): return [swap_scale_shift(v)] -def _list_to_swap(vs, **_): - return swap_scale_shift(vs[0]) - - def _make_chunk(n): return lambda v, **_: torch.chunk(v, n, dim=0) @@ -88,20 +81,11 @@ def _qkvmlp_split(v, inner_dim=3072, **_): return torch.split(v, [inner_dim, inner_dim, inner_dim, inner_dim * 4], dim=0) -def _cat0(vs, **_): - return torch.cat(vs, dim=0) - - FLUX_TRANSFORMS = [ - ("final_layer.adaLN_modulation.1.", ["norm_out.linear."], _swap_to_list, _list_to_swap), - (".img_attn.qkv.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v."], _make_chunk(3), _cat0), - ( - ".txt_attn.qkv.", - [".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj."], - _make_chunk(3), - _cat0, - ), - (".linear1.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."], _qkvmlp_split, _cat0), + ("final_layer.adaLN_modulation.1.", ["norm_out.linear."], _swap_to_list), + (".img_attn.qkv.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v."], _make_chunk(3)), + (".txt_attn.qkv.", [".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj."], _make_chunk(3)), + (".linear1.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."], _qkvmlp_split), ] @@ -116,7 +100,7 @@ def _wrap_unary(fwd_fn): FLUX_QKV_SPLIT_PATTERNS: dict[str, list[str]] = {} FLUX_QKVMLP_SPLIT_PATTERN: str = "" FLUX_QKVMLP_TARGETS: list[str] = [] -for _src, _tgts, _fwd, _ in FLUX_TRANSFORMS: +for _src, _tgts, _fwd in FLUX_TRANSFORMS: if len(_tgts) == 1: for _suffix in ("weight", "bias"): FLUX_SPECIAL_KEYS[_src + _suffix] = { @@ -138,18 +122,38 @@ def _get_inner_dim(state_dict: dict[str, torch.Tensor]) -> int: # Total size = 3 * inner_dim + mlp_hidden_dim = 3 * inner_dim + 4 * inner_dim = 7 * inner_dim total = state_dict[key].shape[0] return total // 7 + return 3072 # Default +def _apply_transforms(state_dict, transforms, rename_patterns, **ctx): + """Drive a forward state-dict conversion from a list of ``(source, targets, forward_fn)`` entries. + + For each key in ``state_dict``: scan ``transforms``; the first entry whose ``source`` substring matches + expands the value via ``forward_fn(value, **ctx)`` into one tensor per target, each at a key derived by + ``key.replace(source, target)`` then ``rename_patterns``. Keys that match no transform are just renamed. + """ + out = {} + for key, value in state_dict.items(): + for source, targets, forward_fn in transforms: + if source in key: + tensors = forward_fn(value, **ctx) + for target, tensor in zip(targets, tensors): + new_key = WeightMappingHandler.rename_key(key.replace(source, target), rename_patterns) + out[new_key] = tensor + break + else: + out[WeightMappingHandler.rename_key(key, rename_patterns)] = value + return out + + def map_to_diffusers( state_dict: dict[str, torch.Tensor], **kwargs, ) -> dict[str, torch.Tensor]: """Convert a Flux transformer state_dict from original format to diffusers format.""" inner_dim = _get_inner_dim(state_dict) - return WeightMappingHandler.apply_transforms( - state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim - ) + return _apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) # Build reverse patterns for map_from_diffusers @@ -258,10 +262,8 @@ def map_from_diffusers( return converted_state_dict -_FLUX_CHECKPOINT_KEY_PREFIXES: list[str] = ["model.diffusion_model."] - # Distinctive keys for original format detection (only keys that use simple renaming, not splits) -_FLUX_CHECKPOINT_KEYS: set[str] = { +_FLUX_STATE_DICT_KEYS: set[str] = { "time_in.in_layer.weight", "double_blocks.0.img_mod.lin.weight", } @@ -295,17 +297,15 @@ def detect_config(weight_mapping, state_dict: dict[str, Any]) -> str | None: in_channels = state_dict[x_embedder_key].shape[1] if in_channels == 384: return "flux-fill" - elif in_channels == 128: + if in_channels == 128: return "flux-depth" return "flux-dev" -# Handler assembled into ``ModelMetadata`` by ``flux/model.py``. +# Assigned to ``FluxTransformer2DModel`` as the ``_weight_mapping`` class attribute in ``flux/model.py``. FLUX_WEIGHT_MAPPING = WeightMappingHandler( - checkpoint_keys=_FLUX_CHECKPOINT_KEYS, - checkpoint_key_prefixes=_FLUX_CHECKPOINT_KEY_PREFIXES, - rename_patterns=FLUX_RENAME_PATTERNS, + original_format_keys=_FLUX_STATE_DICT_KEYS, available_configs=_FLUX_AVAILABLE_CONFIGS, map_to_diffusers_fn=map_to_diffusers, map_from_diffusers_fn=map_from_diffusers, From 983103908b0ebaa5eb68a34c7562e1b9f0fe8f9c Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 00:22:12 +0530 Subject: [PATCH 12/21] update --- src/diffusers/loaders/ip_adapter_model.py | 13 +++++----- src/diffusers/loaders/lora.py | 10 ++++---- src/diffusers/loaders/weight_mapping.py | 15 ++++++----- src/diffusers/models/modeling_utils.py | 25 +++++++++---------- .../transformers/flux/weight_mapping.py | 7 +++--- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index fbc9e6867689..4b934164d513 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -15,8 +15,8 @@ Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on :class:`IPAdapterModelMixin`. Per-model conversion lives in an :class:`IPAdapterHandler` declared next to the model -(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as -``_ip_adapter = FLUX_IP_ADAPTER``. +(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as ``_ip_adapter = +FLUX_IP_ADAPTER``. """ from dataclasses import dataclass @@ -36,8 +36,9 @@ class IPAdapterHandler: """Composition-style holder for a model class's IP-Adapter conversion callables. - Attached as the ``_ip_adapter`` class attribute on :class:`IPAdapterModelMixin` (overridden per-model). The converter callables receive the - model instance because they need to read its config (``attn_processors``, ``inner_dim``, etc.). + Attached as the ``_ip_adapter`` class attribute on :class:`IPAdapterModelMixin` (overridden per-model). The + converter callables receive the model instance because they need to read its config (``attn_processors``, + ``inner_dim``, etc.). Attributes: convert_attn_to_diffusers_fn: @@ -124,9 +125,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U self.encoder_hid_proj = None - attn_procs = self._ip_adapter.convert_attn_processors( - self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage - ) + attn_procs = self._ip_adapter.convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) image_projection_layers = [] diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index b4639a7ccf8a..95b3f56dbc27 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -108,9 +108,9 @@ def _normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, class LoRAHandler: """Composition-style holder for a model class's LoRA conversion configuration. - Attached as the ``_lora`` class attribute on :class:`LoRAModelMixin` (overridden per-model). Holds the per-model foreign-format - conversion data. Public conversion utilities (``normalize_lora_suffixes``, ``detect_lora_format``) live on - :class:`LoRAModelMixin` and read from this handler. + Attached as the ``_lora`` class attribute on :class:`LoRAModelMixin` (overridden per-model). Holds the per-model + foreign-format conversion data. Public conversion utilities (``normalize_lora_suffixes``, ``detect_lora_format``) + live on :class:`LoRAModelMixin` and read from this handler. Attributes: format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying key substrings. The first @@ -398,8 +398,8 @@ class LoRAModelMixin: / hotswap) plus foreign-format conversion (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. Per-model conversion knobs live in a :class:`LoRAHandler` declared in the model's ``lora.py`` (e.g. ``FLUX_LORA``) - and assigned to the model class as ``_lora = FLUX_LORA``. The default no-op handler just - normalizes ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. + and assigned to the model class as ``_lora = FLUX_LORA``. The default no-op handler just normalizes + ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. Install the latest version of PEFT, and use this mixin to: diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py index 3afc085e7fba..29e72b2c26ad 100644 --- a/src/diffusers/loaders/weight_mapping.py +++ b/src/diffusers/loaders/weight_mapping.py @@ -42,14 +42,13 @@ class WeightMappingHandler: """Composition-style holder for a model class's weight-mapping configuration and helpers. - Attached as the ``_weight_mapping`` class attribute on :class:`ModelMixin` (overridden per-model). Owns all - the data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, - detect, normalize) for single-file checkpoint loading. Internal callers reach it via ``cls._weight_mapping.X``. + Attached as the ``_weight_mapping`` class attribute on :class:`ModelMixin` (overridden per-model). Owns all the + data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, + normalize) for single-file checkpoint loading. Internal callers reach it via ``cls._weight_mapping.X``. Attributes: original_format_keys: Distinctive keys whose presence indicates the state_dict is in the original - (pre-diffusers) format. Used by :meth:`is_original_format` to decide whether key conversion is - needed. + (pre-diffusers) format. Used by :meth:`is_original_format` to decide whether key conversion is needed. prefixes_to_remove: Foreign prefixes (e.g. ``["model.diffusion_model."]``) the handler will strip via :meth:`normalize_state_dict_keys`. Defaults to the shared :data:`PREFIXES_TO_REMOVE` list — most models only need that. Extend it for family-specific wrappers; prefix-only models can rely on the default and skip @@ -100,9 +99,9 @@ def rename_key(key: str, patterns: dict) -> str: def is_original_format(self, state_dict: dict) -> bool: """Check if state_dict is in the original (pre-diffusers) format by presence of a known marker key. - Returns ``True`` only when a registered ``original_format_keys`` entry is observed in the state_dict. - Returning ``False`` means "no positive evidence of original format" — empty / unrelated / unknown - state_dicts all fall here. Callers treat ``False`` as "proceed with diffusers-native keys." + Returns ``True`` only when a registered ``original_format_keys`` entry is observed in the state_dict. Returning + ``False`` means "no positive evidence of original format" — empty / unrelated / unknown state_dicts all fall + here. Callers treat ``False`` as "proceed with diffusers-native keys." """ if not self.original_format_keys: return False diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a494be21c692..88056fdd4bd1 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -245,9 +245,9 @@ class ModelMetadata: """Read-only snapshot of a model class's capabilities. Constructed by :meth:`ModelMixin.metadata`, which walks ``cls.__mro__`` collecting rows from each mixin's - ``_metadata`` classmethod. Purely a display object — printing it renders a formatted table. - Programmatic handler access (``model._lora``, ``model._weight_mapping``, ``model._ip_adapter``) goes - through the class attributes directly. + ``_metadata`` classmethod. Purely a display object — printing it renders a formatted table. Programmatic handler + access (``model._lora``, ``model._weight_mapping``, ``model._ip_adapter``) goes through the class attributes + directly. ``rows`` maps a capability label to ``(value, description, docs_url)``. ``verbose=True`` (via ``Model.metadata(verbose=True)``) renders the description and docs link under each row. @@ -285,8 +285,8 @@ def register_metadata(metadata): by :class:`~diffusers.hooks._helpers.TransformerBlockMetadata` to register block-level metadata into :class:`TransformerBlockRegistry`:: - @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) - class FluxTransformerBlock(nn.Module): + @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) class + FluxTransformerBlock(nn.Module): ... Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and the appropriate @@ -338,9 +338,8 @@ def __init__(self): def _metadata(cls) -> dict[str, tuple[str, str, str]]: """Return ``ModelMixin``-level rows for the metadata snapshot. - Each row is keyed by capability label and maps to ``(value, description, docs_url)``. Only present - capabilities are returned. See :meth:`metadata` for the aggregated view across all mixins in - ``cls.__mro__``. + Each row is keyed by capability label and maps to ``(value, description, docs_url)``. Only present capabilities + are returned. See :meth:`metadata` for the aggregated view across all mixins in ``cls.__mro__``. """ rows: dict[str, tuple[str, str, str]] = {} if cls._supports_gradient_checkpointing: @@ -397,12 +396,12 @@ def _metadata(cls) -> dict[str, tuple[str, str, str]]: def metadata(cls, verbose: bool = False) -> "ModelMetadata": """Return a read-only snapshot of this class's capabilities. - Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod - (handled via direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen - wins on label collisions; this puts the model's own overrides ahead of inherited defaults. + Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod (handled via + direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen wins on label collisions; + this puts the model's own overrides ahead of inherited defaults. - ``print(Model.metadata())`` shows the formatted table. Pass ``verbose=True`` to render each row with - a description and a link to the relevant docs section. + ``print(Model.metadata())`` shows the formatted table. Pass ``verbose=True`` to render each row with a + description and a link to the relevant docs section. """ merged: dict[str, tuple[str, str, str]] = {} for klass in cls.__mro__: diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/weight_mapping.py index cdd998701b89..b5eb52d4880e 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/weight_mapping.py @@ -129,9 +129,9 @@ def _get_inner_dim(state_dict: dict[str, torch.Tensor]) -> int: def _apply_transforms(state_dict, transforms, rename_patterns, **ctx): """Drive a forward state-dict conversion from a list of ``(source, targets, forward_fn)`` entries. - For each key in ``state_dict``: scan ``transforms``; the first entry whose ``source`` substring matches - expands the value via ``forward_fn(value, **ctx)`` into one tensor per target, each at a key derived by - ``key.replace(source, target)`` then ``rename_patterns``. Keys that match no transform are just renamed. + For each key in ``state_dict``: scan ``transforms``; the first entry whose ``source`` substring matches expands the + value via ``forward_fn(value, **ctx)`` into one tensor per target, each at a key derived by ``key.replace(source, + target)`` then ``rename_patterns``. Keys that match no transform are just renamed. """ out = {} for key, value in state_dict.items(): @@ -144,6 +144,7 @@ def _apply_transforms(state_dict, transforms, rename_patterns, **ctx): break else: out[WeightMappingHandler.rename_key(key, rename_patterns)] = value + return out From dafb81c2da3d8ef664703da988aeb9f045d7411c Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 10:35:57 +0530 Subject: [PATCH 13/21] update --- src/diffusers/loaders/ip_adapter_model.py | 2 +- src/diffusers/models/transformers/flux/__init__.py | 3 --- .../transformers/flux/{ip_adapter.py => _ip_adapter.py} | 0 .../models/transformers/flux/{lora.py => _lora.py} | 2 +- .../flux/{weight_mapping.py => _weight_mapping.py} | 4 ++-- src/diffusers/models/transformers/flux/model.py | 6 +++--- 6 files changed, 7 insertions(+), 10 deletions(-) rename src/diffusers/models/transformers/flux/{ip_adapter.py => _ip_adapter.py} (100%) rename src/diffusers/models/transformers/flux/{lora.py => _lora.py} (99%) rename src/diffusers/models/transformers/flux/{weight_mapping.py => _weight_mapping.py} (98%) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index 4b934164d513..fd87f1145f82 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -15,7 +15,7 @@ Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on :class:`IPAdapterModelMixin`. Per-model conversion lives in an :class:`IPAdapterHandler` declared next to the model -(e.g. ``flux/ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as ``_ip_adapter = +(e.g. ``flux/_ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as ``_ip_adapter = FLUX_IP_ADAPTER``. """ diff --git a/src/diffusers/models/transformers/flux/__init__.py b/src/diffusers/models/transformers/flux/__init__.py index c4c8707dff45..477b552bcbb0 100644 --- a/src/diffusers/models/transformers/flux/__init__.py +++ b/src/diffusers/models/transformers/flux/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .ip_adapter import FLUX_IP_ADAPTER -from .lora import FLUX_LORA from .model import ( FluxAttention, FluxAttnProcessor, @@ -23,4 +21,3 @@ FluxTransformer2DModel, FluxTransformerBlock, ) -from .weight_mapping import FLUX_WEIGHT_MAPPING diff --git a/src/diffusers/models/transformers/flux/ip_adapter.py b/src/diffusers/models/transformers/flux/_ip_adapter.py similarity index 100% rename from src/diffusers/models/transformers/flux/ip_adapter.py rename to src/diffusers/models/transformers/flux/_ip_adapter.py diff --git a/src/diffusers/models/transformers/flux/lora.py b/src/diffusers/models/transformers/flux/_lora.py similarity index 99% rename from src/diffusers/models/transformers/flux/lora.py rename to src/diffusers/models/transformers/flux/_lora.py index a7018d8d3507..9e821be9ed6f 100644 --- a/src/diffusers/models/transformers/flux/lora.py +++ b/src/diffusers/models/transformers/flux/_lora.py @@ -36,7 +36,7 @@ from ....loaders.lora import LoRAHandler from ....utils import logging, state_dict_all_zero -from .weight_mapping import ( +from ._weight_mapping import ( FLUX_QKV_SPLIT_PATTERNS, FLUX_QKVMLP_SPLIT_PATTERN, FLUX_QKVMLP_TARGETS, diff --git a/src/diffusers/models/transformers/flux/weight_mapping.py b/src/diffusers/models/transformers/flux/_weight_mapping.py similarity index 98% rename from src/diffusers/models/transformers/flux/weight_mapping.py rename to src/diffusers/models/transformers/flux/_weight_mapping.py index b5eb52d4880e..26b482ebd1aa 100644 --- a/src/diffusers/models/transformers/flux/weight_mapping.py +++ b/src/diffusers/models/transformers/flux/_weight_mapping.py @@ -126,7 +126,7 @@ def _get_inner_dim(state_dict: dict[str, torch.Tensor]) -> int: return 3072 # Default -def _apply_transforms(state_dict, transforms, rename_patterns, **ctx): +def apply_transforms(state_dict, transforms, rename_patterns, **ctx): """Drive a forward state-dict conversion from a list of ``(source, targets, forward_fn)`` entries. For each key in ``state_dict``: scan ``transforms``; the first entry whose ``source`` substring matches expands the @@ -154,7 +154,7 @@ def map_to_diffusers( ) -> dict[str, torch.Tensor]: """Convert a Flux transformer state_dict from original format to diffusers format.""" inner_dim = _get_inner_dim(state_dict) - return _apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) + return apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) # Build reverse patterns for map_from_diffusers diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 764f7e52a609..6e6de8c0ee5a 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -38,9 +38,9 @@ from ...modeling_outputs import Transformer2DModelOutput from ...modeling_utils import ModelMixin, register_metadata from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from .ip_adapter import FLUX_IP_ADAPTER -from .lora import FLUX_LORA -from .weight_mapping import FLUX_WEIGHT_MAPPING +from ._ip_adapter import FLUX_IP_ADAPTER +from ._lora import FLUX_LORA +from ._weight_mapping import FLUX_WEIGHT_MAPPING logger = logging.get_logger(__name__) From 086b4cfc550e8f5f9fe0f66a13b70b36198bafc8 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 13:08:46 +0530 Subject: [PATCH 14/21] update --- src/diffusers/loaders/ip_adapter_model.py | 7 +- src/diffusers/loaders/lora.py | 7 +- src/diffusers/models/attention.py | 7 +- src/diffusers/models/cache_utils.py | 7 +- src/diffusers/models/modeling_utils.py | 137 +++++++++++----------- 5 files changed, 88 insertions(+), 77 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py index fd87f1145f82..36e1abe49174 100644 --- a/src/diffusers/loaders/ip_adapter_model.py +++ b/src/diffusers/loaders/ip_adapter_model.py @@ -96,14 +96,15 @@ class IPAdapterModelMixin: @classmethod def _metadata(cls): - """Contribute the ``ip_adapter`` row to :class:`ModelMetadata` when converters are registered.""" + """Contribute the ``_ip_adapter`` row to :class:`ModelMetadata` when converters are registered.""" from ..models.modeling_utils import DOCS_BASE if not cls._ip_adapter.supports_ip_adapter: return {} return { - "ip_adapter": ( - "yes", + "_ip_adapter": ( + True, + "True", "Supports loading IP-Adapter weights (image-conditioning adapters).", f"{DOCS_BASE}/using-diffusers/ip_adapter", ) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 95b3f56dbc27..f10faeba73a1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -222,6 +222,7 @@ def _offloading_disabled(model): CPU offload, group offload, etc.). This context saves the hook state, removes the hooks for the duration of the block, and restores them on exit so existing offloading config survives a LoRA load. """ + saved_hf_hook = None is_sequential = False if hasattr(model, "_hf_hook"): @@ -233,6 +234,7 @@ def _offloading_disabled(model): ): saved_hf_hook = hook is_sequential = True + if saved_hf_hook is not None: remove_hook_from_module(model, recurse=is_sequential) @@ -419,14 +421,15 @@ class LoRAModelMixin: @classmethod def _metadata(cls): - """Contribute the ``lora_formats`` row to :class:`ModelMetadata` when foreign formats are registered.""" + """Contribute the ``_lora`` row to :class:`ModelMetadata` when foreign formats are registered.""" from ..models.modeling_utils import DOCS_BASE formats = sorted(cls._lora.format_keys) if not formats: return {} return { - "lora_formats": ( + "_lora": ( + formats, ", ".join(formats), "Foreign LoRA formats this model converts to diffusers naming on load.", f"{DOCS_BASE}/training/lora", diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0915b8467400..9ea264cadb19 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -41,12 +41,13 @@ class AttentionMixin: @classmethod def _metadata(cls): - """Contribute the ``attention`` row to :class:`ModelMetadata` when the model inherits :class:`AttentionMixin`.""" + """Contribute the ``_supports_attention`` row to :class:`ModelMetadata` for models inheriting :class:`AttentionMixin`.""" from .modeling_utils import DOCS_BASE return { - "attention": ( - "yes", + "_supports_attention": ( + True, + "True", "Model contains attention modules; supports `set_attention_backend(...)`.", f"{DOCS_BASE}/optimization/attention_backends", ) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 0d8d4d17e84c..5baadf8f760c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -35,12 +35,13 @@ class CacheMixin: @classmethod def _metadata(cls): - """Contribute the ``cache`` row to :class:`ModelMetadata` when the model inherits :class:`CacheMixin`.""" + """Contribute the ``_supports_cache`` row to :class:`ModelMetadata` for models inheriting :class:`CacheMixin`.""" from .modeling_utils import DOCS_BASE return { - "cache": ( - "yes", + "_supports_cache": ( + True, + "True", "Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.", f"{DOCS_BASE}/optimization/cache", ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 88056fdd4bd1..5d998c30c5ea 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,6 @@ import tempfile from collections import OrderedDict from contextlib import ExitStack, contextmanager, nullcontext -from dataclasses import dataclass from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -240,44 +239,6 @@ def _skip_init(*args, **kwargs): DOCS_BASE = "https://huggingface.co/docs/diffusers/main/en" -@dataclass(frozen=True) -class ModelMetadata: - """Read-only snapshot of a model class's capabilities. - - Constructed by :meth:`ModelMixin.metadata`, which walks ``cls.__mro__`` collecting rows from each mixin's - ``_metadata`` classmethod. Purely a display object — printing it renders a formatted table. Programmatic handler - access (``model._lora``, ``model._weight_mapping``, ``model._ip_adapter``) goes through the class attributes - directly. - - ``rows`` maps a capability label to ``(value, description, docs_url)``. ``verbose=True`` (via - ``Model.metadata(verbose=True)``) renders the description and docs link under each row. - """ - - rows: dict[str, tuple[str, str, str]] - verbose: bool = False - - def __repr__(self) -> str: - if not self.rows: - return "ModelMetadata(no capabilities declared)" - - name_w = max(len(n) for n in self.rows) - if not self.verbose: - lines = ["ModelMetadata:"] - for label, (value, _doc, _link) in self.rows.items(): - lines.append(f" {label:<{name_w}} {value}") - return "\n".join(lines) - - lines = ["ModelMetadata:"] - for label, (value, doc, link) in self.rows.items(): - lines.append(f" {label:<{name_w}} {value}") - if doc: - lines.append(f" {'':<{name_w}} {doc}") - if link: - lines.append(f" {'':<{name_w}} docs: {link}") - lines.append("") - return "\n".join(lines).rstrip() - - def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. @@ -335,82 +296,126 @@ def __init__(self): self._gradient_checkpointing_func = None @classmethod - def _metadata(cls) -> dict[str, tuple[str, str, str]]: + def _metadata(cls) -> dict[str, tuple[Any, str, str, str]]: """Return ``ModelMixin``-level rows for the metadata snapshot. - Each row is keyed by capability label and maps to ``(value, description, docs_url)``. Only present capabilities - are returned. See :meth:`metadata` for the aggregated view across all mixins in ``cls.__mro__``. + Each row is keyed by the **class attribute name** that controls the capability (e.g. + ``"_supports_gradient_checkpointing"``) and maps to ``(value, display, description, docs_url)``. Only present + capabilities are returned. """ - rows: dict[str, tuple[str, str, str]] = {} + rows: dict[str, tuple[Any, str, str, str]] = {} if cls._supports_gradient_checkpointing: - rows["gradient_checkpointing"] = ( - "yes", + rows["_supports_gradient_checkpointing"] = ( + True, + "True", "Trades compute for memory by recomputing activations during backward.", f"{DOCS_BASE}/optimization/memory#gradient-checkpointing", ) if cls._supports_group_offloading: - rows["group_offloading"] = ( - "yes", + rows["_supports_group_offloading"] = ( + True, + "True", "Stage parameter groups on CPU/disk and stream them to the accelerator for inference.", f"{DOCS_BASE}/optimization/memory#group-offloading", ) if cls._no_split_modules: - rows["no_split_modules"] = ( + rows["_no_split_modules"] = ( + list(cls._no_split_modules), ", ".join(cls._no_split_modules), "Block class names that must stay on a single device under `device_map='auto'` sharding.", f"{DOCS_BASE}/training/distributed_inference#device-map", ) if cls._keep_in_fp32_modules: - rows["keep_in_fp32_modules"] = ( + rows["_keep_in_fp32_modules"] = ( + list(cls._keep_in_fp32_modules), ", ".join(cls._keep_in_fp32_modules), "Submodule name patterns that remain in fp32 even when the model is cast to fp16/bf16.", f"{DOCS_BASE}/optimization/fp16#mixed-precision", ) if cls._skip_layerwise_casting_patterns: - rows["skip_layerwise_casting_patterns"] = ( + rows["_skip_layerwise_casting_patterns"] = ( + tuple(cls._skip_layerwise_casting_patterns), ", ".join(cls._skip_layerwise_casting_patterns), "Parameter name substrings excluded from layerwise dtype casting (embeddings, norms, ...).", f"{DOCS_BASE}/optimization/memory#layerwise-casting", ) if cls._repeated_blocks: - rows["repeated_blocks"] = ( + rows["_repeated_blocks"] = ( + list(cls._repeated_blocks), ", ".join(cls._repeated_blocks), "Block class names safe to `torch.compile` once and reuse — enables `compile_repeated_blocks`.", f"{DOCS_BASE}/optimization/torch2.0", ) if cls._cp_plan: - rows["context_parallel"] = ( - "yes", - "Forward inputs/outputs are scatter/gathered across context-parallel ranks.", + rows["_cp_plan"] = ( + True, + "True", + "Support context parallel inference.", f"{DOCS_BASE}/training/distributed_inference#context-parallelism", ) if cls._weight_mapping.supports_single_file: - rows["supported_model_types"] = ( - ", ".join(sorted(cls._weight_mapping.available_configs)), + configs = sorted(cls._weight_mapping.available_configs) + rows["_weight_mapping"] = ( + configs, + ", ".join(configs), "Auto-resolvable configs for `from_single_file(path)` (no `config=` argument required).", f"{DOCS_BASE}/api/loaders/single_file", ) return rows @classmethod - def metadata(cls, verbose: bool = False) -> "ModelMetadata": - """Return a read-only snapshot of this class's capabilities. + def describe(cls, verbose: bool = False) -> None: + """Print this class's feature attributes, keyed by the controlling class attribute name. Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod (handled via direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen wins on label collisions; this puts the model's own overrides ahead of inherited defaults. - ``print(Model.metadata())`` shows the formatted table. Pass ``verbose=True`` to render each row with a - description and a link to the relevant docs section. + Compact form (default): two-column `` ``. With ``verbose=True``, each row is followed by an + indented description and docs link. ANSI color/style is applied when stdout is a TTY and stripped otherwise so + the output stays clean in logs and pipes. """ - merged: dict[str, tuple[str, str, str]] = {} - for klass in cls.__mro__: - method = klass.__dict__.get("_metadata") + import sys + + merged: dict[str, tuple[Any, str, str, str]] = {} + for mixin in cls.__mro__: + method = mixin.__dict__.get("_metadata") if method is None: continue - for label, info in method.__func__(cls).items(): - merged.setdefault(label, info) - return ModelMetadata(rows=merged, verbose=verbose) + for attr, info in method.__func__(cls).items(): + merged.setdefault(attr, info) + + if not merged: + print(f"{cls.__name__}: no feature attributes declared") + return + + is_tty = sys.stdout.isatty() + bold = "\033[1m" if is_tty else "" + dim = "\033[2m" if is_tty else "" + cyan = "\033[36m" if is_tty else "" + underline = "\033[4m" if is_tty else "" + reset = "\033[0m" if is_tty else "" + + attr_w = max(len(attr) for attr in merged) + title = f"{cls.__name__} feature attributes" + rule_width = max(len(title), attr_w + 2 + max(len(row[1]) for row in merged.values())) + lines = [ + f"{bold}{title}{reset}", + f"{dim}{'─' * rule_width}{reset}", + ] + + rows = list(merged.items()) + for i, (attr, (_value, display, doc, link)) in enumerate(rows): + lines.append(f" {bold}{cyan}{attr:<{attr_w}}{reset} {display}") + if verbose: + if doc: + lines.append(f" {dim}{doc}{reset}") + if link: + lines.append(f" {dim}See {underline}{link}{reset}") + if i < len(rows) - 1: + lines.append("") + lines.append("") + print("\n".join(lines)) def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing From c026a687aca6a99447af0b489f0f074618bb16fc Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 17:25:45 +0530 Subject: [PATCH 15/21] update --- src/diffusers/loaders/ip_adapter_model.py | 140 ------------------ src/diffusers/models/modeling_utils.py | 5 +- .../models/transformers/flux/_ip_adapter.py | 69 ++++++--- .../models/transformers/flux/model.py | 6 +- 4 files changed, 56 insertions(+), 164 deletions(-) delete mode 100644 src/diffusers/loaders/ip_adapter_model.py diff --git a/src/diffusers/loaders/ip_adapter_model.py b/src/diffusers/loaders/ip_adapter_model.py deleted file mode 100644 index 36e1abe49174..000000000000 --- a/src/diffusers/loaders/ip_adapter_model.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Model-side IP-Adapter machinery. - -Generic orchestration (set processors, build ``MultiIPAdapterImageProjection``, flip ``encoder_hid_dim_type``) lives on -:class:`IPAdapterModelMixin`. Per-model conversion lives in an :class:`IPAdapterHandler` declared next to the model -(e.g. ``flux/_ip_adapter.py`` exports ``FLUX_IP_ADAPTER``), assigned to the model class as ``_ip_adapter = -FLUX_IP_ADAPTER``. -""" - -from dataclasses import dataclass -from typing import Callable, Optional - -from ..models.embeddings import MultiIPAdapterImageProjection -from ..utils import is_torch_version, logging - - -_LOW_CPU_MEM_USAGE_DEFAULT = is_torch_version(">=", "1.9.0") - - -logger = logging.get_logger(__name__) - - -@dataclass -class IPAdapterHandler: - """Composition-style holder for a model class's IP-Adapter conversion callables. - - Attached as the ``_ip_adapter`` class attribute on :class:`IPAdapterModelMixin` (overridden per-model). The - converter callables receive the model instance because they need to read its config (``attn_processors``, - ``inner_dim``, etc.). - - Attributes: - convert_attn_to_diffusers_fn: - Callable ``(model, state_dicts, low_cpu_mem_usage=False) -> dict[str, AttnProcessor]`` returning the - attn-processor dict ready for ``set_attn_processor``. - convert_image_proj_to_diffusers_fn: Callable - ``(model, image_proj_state_dict, low_cpu_mem_usage=False) -> ImageProjection`` returning the image - projection module. - """ - - convert_attn_to_diffusers_fn: Optional[Callable] = None - convert_image_proj_to_diffusers_fn: Optional[Callable] = None - - @property - def supports_ip_adapter(self) -> bool: - """Whether the model has both converters registered (required to actually load weights).""" - return self.convert_attn_to_diffusers_fn is not None and self.convert_image_proj_to_diffusers_fn is not None - - def convert_attn_processors(self, model, state_dicts, low_cpu_mem_usage: bool = False): - """Build the attention-processor dict for a list of IP-Adapter state dicts. - - Receives the model so the converter can inspect ``model.attn_processors``, ``model.config``, - ``model.inner_dim``, etc. - """ - if self.convert_attn_to_diffusers_fn is None: - raise NotImplementedError( - f"{type(model).__name__} did not register `convert_attn_to_diffusers_fn` in its IPAdapterHandler." - ) - return self.convert_attn_to_diffusers_fn(model, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - - def convert_image_proj(self, model, image_proj_state_dict, low_cpu_mem_usage: bool = False): - """Build the image-projection module from a single IP-Adapter state dict.""" - if self.convert_image_proj_to_diffusers_fn is None: - raise NotImplementedError( - f"{type(model).__name__} did not register `convert_image_proj_to_diffusers_fn` in its " - f"IPAdapterHandler." - ) - return self.convert_image_proj_to_diffusers_fn( - model, image_proj_state_dict, low_cpu_mem_usage=low_cpu_mem_usage - ) - - -class IPAdapterModelMixin: - """Generic IP-Adapter loader for diffusers transformer / UNet models. - - The per-model conversion callables live on ``self._ip_adapter`` (an :class:`IPAdapterHandler` assigned as a class - attribute by the model). This mixin owns only the orchestration: dispatching to the converters, wiring up - ``MultiIPAdapterImageProjection``, and flipping ``encoder_hid_dim_type``. - """ - - # Per-model IP-Adapter conversion config. Defaults to an empty handler; models that support IP-Adapter assign - # ``_ip_adapter = FLUX_IP_ADAPTER`` (etc.) in their class body. Calling ``_load_ip_adapter_weights`` on a - # model that didn't override raises ``NotImplementedError`` from inside the handler. - _ip_adapter: IPAdapterHandler = IPAdapterHandler() - - @classmethod - def _metadata(cls): - """Contribute the ``_ip_adapter`` row to :class:`ModelMetadata` when converters are registered.""" - from ..models.modeling_utils import DOCS_BASE - - if not cls._ip_adapter.supports_ip_adapter: - return {} - return { - "_ip_adapter": ( - True, - "True", - "Supports loading IP-Adapter weights (image-conditioning adapters).", - f"{DOCS_BASE}/using-diffusers/ip_adapter", - ) - } - - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - """Install IP-Adapter weights on the model. - - ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain - ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. - """ - if not self._ip_adapter.supports_ip_adapter: - raise NotImplementedError( - f"{type(self).__name__} did not register IP-Adapter converters in its IPAdapterHandler." - ) - - if not isinstance(state_dicts, list): - state_dicts = [state_dicts] - - self.encoder_hid_proj = None - - attn_procs = self._ip_adapter.convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - self.set_attn_processor(attn_procs) - - image_projection_layers = [] - for state_dict in state_dicts: - image_projection_layer = self._ip_adapter.convert_image_proj( - self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage - ) - image_projection_layers.append(image_projection_layer) - - self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5d998c30c5ea..f4614163e5fd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -250,8 +250,9 @@ def register_metadata(metadata): FluxTransformerBlock(nn.Module): ... - Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and the appropriate - subsystem mixins like :class:`LoRAModelMixin`, :class:`IPAdapterModelMixin`) — no decorator needed. + Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and on subsystem + mixins like :class:`LoRAModelMixin` or model-specific ones like ``FluxIPAdapterMixin``) — no decorator + needed. """ def wrap(cls): diff --git a/src/diffusers/models/transformers/flux/_ip_adapter.py b/src/diffusers/models/transformers/flux/_ip_adapter.py index df7a83aa4200..9a813e82654d 100644 --- a/src/diffusers/models/transformers/flux/_ip_adapter.py +++ b/src/diffusers/models/transformers/flux/_ip_adapter.py @@ -11,24 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Flux IP-Adapter conversion. +"""Flux-specific IP-Adapter loading. -Per-model converters consumed by ``IPAdapterModelMixin`` via ``FLUX_IP_ADAPTER_METADATA``: +IP-Adapter behavior — what's in the state dict, what the attn processors look like, which blocks they bind +to — varies enough across models that a generic mixin can't really capture the orchestration. Flux owns its +own ``_load_ip_adapter_weights`` here, including the loop over blocks, the choice to skip single-stream +blocks, and the projection-dim computation. -- ``convert_image_proj``: rewrites ``proj.weight`` → ``image_embeds.weight`` and builds an ``ImageProjection`` sized - off the source state dict (4 or 16 image-text embeds depending on the ``proj.weight`` row count). -- ``convert_attn_processors``: walks ``model.attn_processors``, skips ``single_transformer_blocks`` (Flux only attaches - IP-Adapter on the double-stream blocks), and builds one ``FluxIPAdapterAttnProcessor`` per remaining block. Reads - ``model.config.joint_attention_dim`` and ``model.inner_dim`` for the projection dimensions and pulls ``to_k_ip`` / - ``to_v_ip`` weights/biases keyed by ``key_id``. +``FluxIPAdapterMixin`` is added to ``FluxTransformer2DModel``'s bases in ``flux/model.py``. Models that don't +support IP-Adapter simply don't inherit anything — there's no opt-in handler default to override. """ from contextlib import nullcontext -from ....loaders.ip_adapter_model import IPAdapterHandler -from ....models.embeddings import ImageProjection +from ....models.embeddings import ImageProjection, MultiIPAdapterImageProjection from ....models.model_loading_utils import load_model_dict_into_meta -from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, DOCS_BASE from ....utils import is_accelerate_available, is_torch_version, logging from ....utils.torch_utils import empty_device_cache @@ -58,7 +56,7 @@ def _resolve_init_context(low_cpu_mem_usage): return nullcontext, False -def convert_image_proj(model, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): +def _convert_image_proj(model, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): """Build a Flux ``ImageProjection`` from an IP-Adapter ``image_proj`` state dict.""" init_context, low_cpu_mem_usage = _resolve_init_context(low_cpu_mem_usage) @@ -88,7 +86,7 @@ def convert_image_proj(model, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_D return image_projection -def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): +def _convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): """Build the IP-Adapter attn-processor dict for a ``FluxTransformer2DModel``. Single-stream blocks keep their existing processor; double-stream blocks get a ``FluxIPAdapterAttnProcessor`` @@ -135,8 +133,43 @@ def convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_U return attn_procs -# Assigned to ``FluxTransformer2DModel`` as the ``_ip_adapter`` class attribute in ``flux/model.py``. -FLUX_IP_ADAPTER = IPAdapterHandler( - convert_attn_to_diffusers_fn=convert_attn_processors, - convert_image_proj_to_diffusers_fn=convert_image_proj, -) +class FluxIPAdapterMixin: + """Flux-specific IP-Adapter loader. Mixed into :class:`FluxTransformer2DModel`.""" + + _supports_ip_adapter = True + + @classmethod + def _metadata(cls): + """Contribute the ``_supports_ip_adapter`` row to the metadata describe() table.""" + return { + "_supports_ip_adapter": ( + True, + "True", + "Supports loading IP-Adapter weights (image-conditioning adapters).", + f"{DOCS_BASE}/using-diffusers/ip_adapter", + ) + } + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Install IP-Adapter weights on the Flux transformer. + + ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain + ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. + """ + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = _convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = _convert_image_proj( + self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 6e6de8c0ee5a..41e30f37d341 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -22,7 +22,6 @@ from ....configuration_utils import register_to_config from ....hooks._helpers import TransformerBlockMetadata -from ....loaders.ip_adapter_model import IPAdapterModelMixin from ....utils import apply_lora_scale, logging from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -38,7 +37,7 @@ from ...modeling_outputs import Transformer2DModelOutput from ...modeling_utils import ModelMixin, register_metadata from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ._ip_adapter import FLUX_IP_ADAPTER +from ._ip_adapter import FluxIPAdapterMixin from ._lora import FLUX_LORA from ._weight_mapping import FLUX_WEIGHT_MAPPING @@ -532,7 +531,7 @@ class FluxTransformer2DModel( ModelMixin, AttentionMixin, CacheMixin, - IPAdapterModelMixin, + FluxIPAdapterMixin, ): """ The Transformer model introduced in Flux. @@ -581,7 +580,6 @@ class FluxTransformer2DModel( _lora = FLUX_LORA _weight_mapping = FLUX_WEIGHT_MAPPING - _ip_adapter = FLUX_IP_ADAPTER @register_to_config def __init__( From ecd307d4b26a44e856638fc4aa9571f615685b2c Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 22:44:59 +0530 Subject: [PATCH 16/21] update --- src/diffusers/loaders/lora_base.py | 235 +++++++--- src/diffusers/loaders/lora_pipeline.py | 569 ++++++++++++------------- 2 files changed, 448 insertions(+), 356 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index aded26bf3624..5b5579664b55 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -22,10 +22,14 @@ import safetensors import torch +import torch.nn as nn +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE -from ..models.modeling_utils import ModelMixin +from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, + _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, delete_adapter_layers, @@ -42,6 +46,8 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.peft_utils import _create_lora_config +from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): @@ -55,62 +61,9 @@ logger = logging.get_logger(__name__) - -def _func_optionally_disable_offloading(_pipeline): - """Optionally remove accelerate offloading hooks before mutating a pipeline's components. - - Walks ``_pipeline.components``, detects accelerate / group-offload hooks, and removes accelerate hooks in-place - (group-offload is reapplied later by the LoRA load path). Returns ``(is_model_cpu_offload, - is_sequential_cpu_offload, is_group_offload)`` so callers know which offloading mode was active and can re-enable - it after loading. - - Used by pipeline-side LoRA loaders (``LoraBaseMixin._optionally_disable_offloading``) and the legacy paths in - ``peft.py`` / ``unet.py``. Model-side loading uses the ``_offloading_disabled`` context manager in ``loaders.lora`` - instead. - """ - from ..hooks.group_offloading import _is_group_offload_enabled - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - is_group_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if not isinstance(component, torch.nn.Module): - continue - is_group_offload = is_group_offload or _is_group_offload_enabled(component) - if not hasattr(component, "_hf_hook"): - continue - is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) - is_sequential_cpu_offload = is_sequential_cpu_offload or ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - if is_sequential_cpu_offload or is_model_cpu_offload: - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous " - "hooks will be first removed. Then the LoRA parameters will be loaded and the hooks " - "will be applied again." - ) - for _, component in _pipeline.components.items(): - if not isinstance(component, torch.nn.Module) or not hasattr(component, "_hf_hook"): - continue - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) - - -# Constants and fetch helpers live in ``loaders.lora`` — re-exported here for back-compat. -from .lora import ( # noqa: E402, F401 (intentional mid-file import: back-compat re-exports) - LORA_ADAPTER_METADATA_KEY, - LORA_WEIGHT_NAME, - LORA_WEIGHT_NAME_SAFE, - _best_guess_weight_name, - _fetch_lora_metadata, - _fetch_state_dict, -) +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): @@ -242,6 +195,124 @@ def _remove_text_encoder_monkey_patch(text_encoder): text_encoder._hf_peft_config_loaded = None +def _fetch_state_dict( + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, + metadata=None, +): + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + metadata = _load_sft_state_dict_metadata(model_file) + + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + metadata = None + pass + + if model_file is None: + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + metadata = None + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict, metadata + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + logger.warning( + f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`." + ) + weight_name = targeted_files[0] + return weight_name + + def _pack_dict_with_prefix(state_dict, prefix): sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} return sd_with_prefix @@ -316,9 +387,7 @@ def _load_lora_into_text_encoder( network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} # create `LoraConfig` - from .lora import _create_lora_config - - lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) + lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False) # adapter_name if adapter_name is None: @@ -363,6 +432,50 @@ def _load_lora_into_text_encoder( ) +def _func_optionally_disable_offloading(_pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. + """ + from ..hooks.group_offloading import _is_group_offload_enabled + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + is_group_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if not isinstance(component, nn.Module): + continue + is_group_offload = is_group_offload or _is_group_offload_enabled(component) + if not hasattr(component, "_hf_hook"): + continue + is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) + is_sequential_cpu_offload = is_sequential_cpu_offload or ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + if is_sequential_cpu_offload or is_model_cpu_offload: + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + for _, component in _pipeline.components.items(): + if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): + continue + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) + + class LoraBaseMixin: """Utility class for handling LoRAs.""" @@ -417,7 +530,7 @@ def unload_lora_weights(self): model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): - model.delete_adapters() + model.unload_lora() elif issubclass(model.__class__, PreTrainedModel): _remove_text_encoder_monkey_patch(model) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2da0cfeca4d0..403e5a87db61 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -35,14 +35,16 @@ LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, - _fetch_lora_metadata, _fetch_state_dict, _load_lora_into_text_encoder, _pack_dict_with_prefix, ) from .lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, + _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux2_lora_to_diffusers, + _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, @@ -53,6 +55,7 @@ _convert_non_diffusers_qwen_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_non_diffusers_z_image_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -220,6 +223,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -231,6 +235,7 @@ def load_lora_weights( else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, + _pipeline=self, metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -297,14 +302,20 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -313,17 +324,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -401,7 +402,7 @@ def load_lora_into_unet( # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. logger.info(f"Loading {cls.unet_name}.") - unet.load_adapter( + unet.load_lora_adapter( state_dict, prefix=cls.unet_name, network_alphas=network_alphas, @@ -649,6 +650,7 @@ def load_lora_weights( unet=self.unet, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -660,6 +662,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -671,6 +674,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -737,14 +741,20 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -753,17 +763,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -842,7 +842,7 @@ def load_lora_into_unet( # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. logger.info(f"Loading {cls.unet_name}.") - unet.load_adapter( + unet.load_lora_adapter( state_dict, prefix=cls.unet_name, network_alphas=network_alphas, @@ -1019,14 +1019,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1035,17 +1041,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -1093,6 +1089,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1104,6 +1101,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1115,6 +1113,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1140,7 +1139,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -1315,14 +1314,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1331,17 +1336,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -1390,6 +1385,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1416,7 +1412,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -1523,14 +1519,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -1539,17 +1541,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -1557,18 +1549,45 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - from ..models.transformers.flux import FluxTransformer2DModel + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) - # Format-specific dispatch lives on the model: detect format (kohya/xlabs/bfl/kontext) - # and convert to diffusers naming. Unknown / diffusers-native state dicts fall - # through to the alpha-extraction path below. - is_recognized_format = FluxTransformer2DModel._detect_lora_format(state_dict) is not None or any( - k.startswith("transformer.") for k in state_dict - ) - if is_recognized_format: - state_dict = FluxTransformer2DModel.map_lora_to_diffusers(state_dict) - # Recognized formats embed alphas in the conversion (kohya scales weights; - # xlabs / bfl / kontext don't use alphas). + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + + is_fal_kontext = any("base_model" in k for k in state_dict) + if is_fal_kontext: + state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) return cls._prepare_outputs( state_dict, metadata=metadata, @@ -1577,8 +1596,8 @@ def lora_state_dict( return_metadata=return_lora_metadata, ) - # Diffusers-native fallback (e.g. https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA): - # alphas ride alongside the weights as separate ``.alpha`` keys. + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) network_alphas = {} for k in keys: @@ -1681,6 +1700,7 @@ def load_lora_weights( transformer=transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1700,6 +1720,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -1726,7 +1747,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, @@ -2270,7 +2291,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, @@ -2423,14 +2444,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2439,17 +2466,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -2497,6 +2514,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2523,7 +2541,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -2622,14 +2640,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2638,17 +2662,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -2697,6 +2711,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2723,7 +2738,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -2824,14 +2839,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -2840,17 +2861,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -2903,6 +2914,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -2929,7 +2941,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3031,14 +3043,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3047,17 +3065,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -3119,6 +3127,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3130,6 +3139,7 @@ def load_lora_weights( else self.connectors, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, prefix=self.connectors_name, @@ -3157,7 +3167,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {prefix}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3260,14 +3270,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3276,17 +3292,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -3335,6 +3341,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3361,7 +3368,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3663,14 +3670,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3679,17 +3692,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -3742,6 +3745,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3768,7 +3772,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -3869,14 +3873,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -3885,17 +3895,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -3949,6 +3949,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -3975,7 +3976,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4077,14 +4078,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4093,17 +4100,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -4152,6 +4149,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4178,7 +4176,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4279,14 +4277,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4295,17 +4299,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) @@ -4417,6 +4411,7 @@ def load_lora_weights( transformer=self.transformer_2, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4428,6 +4423,7 @@ def load_lora_weights( else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4454,7 +4450,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4556,14 +4552,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4572,17 +4574,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) @@ -4696,6 +4688,7 @@ def load_lora_weights( transformer=self.transformer_2, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4707,6 +4700,7 @@ def load_lora_weights( else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4733,7 +4727,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -4835,14 +4829,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -4851,17 +4851,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -4910,6 +4900,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -4936,7 +4927,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5037,14 +5028,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5053,17 +5050,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -5116,6 +5103,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5142,7 +5130,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5243,14 +5231,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5259,17 +5253,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -5325,6 +5309,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5351,7 +5336,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5452,14 +5437,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5468,17 +5459,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -5534,6 +5515,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5560,7 +5542,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, @@ -5661,14 +5643,20 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - kwargs.pop("use_safetensors", None) # safetensors-only; kwarg accepted but ignored + use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, + use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, @@ -5677,17 +5665,7 @@ def lora_state_dict( revision=revision, subfolder=subfolder, user_agent=user_agent, - ) - metadata = _fetch_lora_metadata( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, + allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) @@ -5751,6 +5729,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, + _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -5777,7 +5756,7 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") - transformer.load_adapter( + transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, From d66b366c7c3d10a6e3c431497950355044c81926 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 22:50:04 +0530 Subject: [PATCH 17/21] update --- src/diffusers/models/attention.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9ea264cadb19..f4cd1ff6856b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -37,22 +37,6 @@ class AttentionMixin: - _supports_attention = True - - @classmethod - def _metadata(cls): - """Contribute the ``_supports_attention`` row to :class:`ModelMetadata` for models inheriting :class:`AttentionMixin`.""" - from .modeling_utils import DOCS_BASE - - return { - "_supports_attention": ( - True, - "True", - "Model contains attention modules; supports `set_attention_backend(...)`.", - f"{DOCS_BASE}/optimization/attention_backends", - ) - } - @property def attn_processors(self) -> dict[str, AttentionProcessor]: r""" From 1fb496a893c4df11b792a86c71c4fc4a36f92f51 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 22 May 2026 23:43:58 +0530 Subject: [PATCH 18/21] update --- src/diffusers/models/modeling_utils.py | 137 ++++++++++++------ .../models/transformers/flux/_ip_adapter.py | 12 +- 2 files changed, 101 insertions(+), 48 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f4614163e5fd..1df5cb6d2cd3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -22,6 +22,7 @@ import os import re import shutil +import sys import tempfile from collections import OrderedDict from contextlib import ExitStack, contextmanager, nullcontext @@ -239,6 +240,90 @@ def _skip_init(*args, **kwargs): DOCS_BASE = "https://huggingface.co/docs/diffusers/main/en" +class ModelMetadata: + """Snapshot of a model class's feature attributes. + + Constructed by :meth:`ModelMixin.metadata` — walks ``cls.__mro__`` collecting rows from each mixin's ``_metadata`` + classmethod and exposes the raw values as attributes: + + >>> meta = FluxTransformer2DModel.metadata() >>> meta._supports_ip_adapter True >>> meta._lora ['bfl', 'kohya', + 'kontext', 'xlabs'] >>> '_supports_cache' in meta True + + ``repr(meta)`` (and ``print(meta)``) render a formatted table. Call :meth:`describe` to print the verbose variant + with descriptions and docs links. + """ + + # Internal storage is name-mangled (``self.__rows`` → ``self._ModelMetadata__rows``) so ``dir(meta)`` and + # tab-completion show only the feature attributes + ``describe``, not the snapshot's bookkeeping fields. + def __init__(self, rows: dict[str, tuple[Any, str, str, str]], cls_name: str): + self.__rows = rows + self.__cls_name = cls_name + for attr, (value, _display, _doc, _link) in rows.items(): + setattr(self, attr, value) + + def __iter__(self): + return iter(self.__rows) + + def __contains__(self, key): + return key in self.__rows + + def __len__(self): + return len(self.__rows) + + def __dir__(self): + return list(self.__rows) + ["describe", "keys", "values", "items"] + + def keys(self): + """Names of the feature attributes this snapshot exposes.""" + return self.__rows.keys() + + def values(self): + """Raw values for each feature attribute (same as ``meta.`` access).""" + return (info[0] for info in self.__rows.values()) + + def items(self): + """Pairs of ``(attribute_name, value)`` for each feature attribute.""" + return ((attr, info[0]) for attr, info in self.__rows.items()) + + def __repr__(self) -> str: + return self._render(verbose=False) + + def describe(self, verbose: bool = False) -> None: + """Print the formatted capability table. ``verbose=True`` adds descriptions and docs links per row.""" + print(self._render(verbose=verbose)) + + def _render(self, verbose: bool) -> str: + if not self.__rows: + return f"{self.__cls_name}: no feature attributes declared" + + is_tty = sys.stdout.isatty() + bold = "\033[1m" if is_tty else "" + dim = "\033[2m" if is_tty else "" + cyan = "\033[36m" if is_tty else "" + underline = "\033[4m" if is_tty else "" + reset = "\033[0m" if is_tty else "" + + attr_w = max(len(attr) for attr in self.__rows) + title = f"{self.__cls_name} feature attributes" + rule_width = max(len(title), attr_w + 2 + max(len(row[1]) for row in self.__rows.values())) + lines = [ + f"{bold}{title}{reset}", + f"{dim}{'─' * rule_width}{reset}", + ] + + rows = list(self.__rows.items()) + for i, (attr, (_value, display, doc, link)) in enumerate(rows): + lines.append(f" {bold}{cyan}{attr:<{attr_w}}{reset} {display}") + if verbose: + if doc: + lines.append(f" {dim}{doc}{reset}") + if link: + lines.append(f" {dim}See {underline}{link}{reset}") + if i < len(rows) - 1: + lines.append("") + return "\n".join(lines) + + def register_metadata(metadata): """Generic class decorator that attaches metadata to the decorated class. @@ -250,9 +335,8 @@ def register_metadata(metadata): FluxTransformerBlock(nn.Module): ... - Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and on subsystem - mixins like :class:`LoRAModelMixin` or model-specific ones like ``FluxIPAdapterMixin``) — no decorator - needed. + Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and on subsystem mixins + like :class:`LoRAModelMixin` or model-specific ones like ``FluxIPAdapterMixin``) — no decorator needed. """ def wrap(cls): @@ -365,58 +449,27 @@ def _metadata(cls) -> dict[str, tuple[Any, str, str, str]]: return rows @classmethod - def describe(cls, verbose: bool = False) -> None: - """Print this class's feature attributes, keyed by the controlling class attribute name. + def metadata(cls) -> "ModelMetadata": + """Return a :class:`ModelMetadata` snapshot of this class's feature attributes. Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod (handled via direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen wins on label collisions; - this puts the model's own overrides ahead of inherited defaults. + subclass overrides win over inherited defaults. - Compact form (default): two-column `` ``. With ``verbose=True``, each row is followed by an - indented description and docs link. ANSI color/style is applied when stdout is a TTY and stripped otherwise so - the output stays clean in logs and pipes. + The returned object exposes feature values as attributes (``meta._supports_ip_adapter``, ``meta._lora``, ...), + supports ``hasattr`` / ``in`` for presence checks, and prints as a formatted table via its ``__repr__``. Call + ``meta.describe(verbose=True)`` for the verbose variant with descriptions and docs. """ - import sys - merged: dict[str, tuple[Any, str, str, str]] = {} for mixin in cls.__mro__: method = mixin.__dict__.get("_metadata") if method is None: continue + for attr, info in method.__func__(cls).items(): merged.setdefault(attr, info) - if not merged: - print(f"{cls.__name__}: no feature attributes declared") - return - - is_tty = sys.stdout.isatty() - bold = "\033[1m" if is_tty else "" - dim = "\033[2m" if is_tty else "" - cyan = "\033[36m" if is_tty else "" - underline = "\033[4m" if is_tty else "" - reset = "\033[0m" if is_tty else "" - - attr_w = max(len(attr) for attr in merged) - title = f"{cls.__name__} feature attributes" - rule_width = max(len(title), attr_w + 2 + max(len(row[1]) for row in merged.values())) - lines = [ - f"{bold}{title}{reset}", - f"{dim}{'─' * rule_width}{reset}", - ] - - rows = list(merged.items()) - for i, (attr, (_value, display, doc, link)) in enumerate(rows): - lines.append(f" {bold}{cyan}{attr:<{attr_w}}{reset} {display}") - if verbose: - if doc: - lines.append(f" {dim}{doc}{reset}") - if link: - lines.append(f" {dim}See {underline}{link}{reset}") - if i < len(rows) - 1: - lines.append("") - lines.append("") - print("\n".join(lines)) + return ModelMetadata(merged, cls.__name__) def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing diff --git a/src/diffusers/models/transformers/flux/_ip_adapter.py b/src/diffusers/models/transformers/flux/_ip_adapter.py index 9a813e82654d..eb679d09ec4f 100644 --- a/src/diffusers/models/transformers/flux/_ip_adapter.py +++ b/src/diffusers/models/transformers/flux/_ip_adapter.py @@ -13,13 +13,13 @@ # limitations under the License. """Flux-specific IP-Adapter loading. -IP-Adapter behavior — what's in the state dict, what the attn processors look like, which blocks they bind -to — varies enough across models that a generic mixin can't really capture the orchestration. Flux owns its -own ``_load_ip_adapter_weights`` here, including the loop over blocks, the choice to skip single-stream -blocks, and the projection-dim computation. +IP-Adapter behavior — what's in the state dict, what the attn processors look like, which blocks they bind to — varies +enough across models that a generic mixin can't really capture the orchestration. Flux owns its own +``_load_ip_adapter_weights`` here, including the loop over blocks, the choice to skip single-stream blocks, and the +projection-dim computation. -``FluxIPAdapterMixin`` is added to ``FluxTransformer2DModel``'s bases in ``flux/model.py``. Models that don't -support IP-Adapter simply don't inherit anything — there's no opt-in handler default to override. +``FluxIPAdapterMixin`` is added to ``FluxTransformer2DModel``'s bases in ``flux/model.py``. Models that don't support +IP-Adapter simply don't inherit anything — there's no opt-in handler default to override. """ from contextlib import nullcontext From b4302317d6dcd0fa493c5458fe1b0cafbac5b891 Mon Sep 17 00:00:00 2001 From: DN6 Date: Sat, 23 May 2026 00:31:31 +0530 Subject: [PATCH 19/21] update --- src/diffusers/loaders/unet.py | 98 +++++-------------- src/diffusers/models/modeling_utils.py | 21 +--- .../models/transformers/flux/model.py | 2 + src/diffusers/utils/__init__.py | 1 + 4 files changed, 30 insertions(+), 92 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 3046e2b3cdcf..9dab3bc667ea 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -35,17 +35,16 @@ from ..utils import ( USE_PEFT_BACKEND, _get_model_file, - convert_sai_sd_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, deprecate, get_adapter_name, + get_peft_kwargs, is_accelerate_available, is_peft_version, is_torch_version, logging, ) from ..utils.torch_utils import empty_device_cache -from .lora import _create_lora_config from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -66,76 +65,6 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def _load_adapter_from_pretrained(self, pretrained_model_name_or_path_or_dict, **kwargs): - """UNet override that handles model-specific LoRA formats before delegating to the base loader. - - - Converts old non-PEFT UNet LoRA naming to PEFT shape (when no key carries ``lora_A``). - - Detects SAI Control LoRA (``lora_controlnet`` marker) — that path has its own loader because the LoraConfig - needs post-create overrides the base flow doesn't expose. See https://huggingface.co/stabilityai/control-lora - and https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors. - """ - from ..utils import HUB_KWARGS - from .lora import _fetch_state_dict - - # Resolve to a state_dict up-front so we can inspect / convert before the base loader. - if isinstance(pretrained_model_name_or_path_or_dict, dict): - state_dict = pretrained_model_name_or_path_or_dict - else: - fetch_kwargs = {k: kwargs.get(k, default) for k, default in HUB_KWARGS.items()} - fetch_kwargs["weight_name"] = kwargs.get("weight_name") - state_dict = _fetch_state_dict(pretrained_model_name_or_path_or_dict, **fetch_kwargs) - - if not any("lora_A" in k for k in state_dict): - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if "lora_controlnet" in state_dict: - state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) - return self._load_sai_control_lora(state_dict, **kwargs) - - # Hand the (possibly-converted) state_dict to the base loader. It hits the - # dict-passthrough branch in `_fetch_state_dict` and runs the rest of the flow. - return super()._load_adapter_from_pretrained(state_dict, **kwargs) - - def _load_sai_control_lora(self, state_dict, **kwargs): - """Bespoke loader for SAI Control LoRA: same flow as the base, plus LoraConfig overrides - (``lora_alpha`` follows ``r``, all biases trained, ``exclude_modules`` repurposed).""" - from .lora import _maybe_warn_for_unhandled_keys, _offloading_disabled - - adapter_name = kwargs.get("adapter_name") or get_adapter_name(self) - network_alphas = kwargs.get("network_alphas") - prefix = kwargs.get("prefix", "transformer") - hotswap = kwargs.get("hotswap", False) - low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage", False) - metadata = kwargs.get("metadata") - - if prefix is not None: - state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - - rank = {f"^{key}": val.shape[1] for key, val in state_dict.items() if "lora_B" in key and val.ndim > 1} - - if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas if k.startswith(f"{prefix}.")] - network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) - - # SAI Control LoRA overrides: alpha follows rank; all biases are trained. - lora_config.lora_alpha = lora_config.r - lora_config.alpha_pattern = lora_config.rank_pattern - lora_config.bias = "all" - lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = None - - peft_kwargs = {"low_cpu_mem_usage": low_cpu_mem_usage} - with _offloading_disabled(self): - if hotswap: - self._hotswap_adapter(state_dict, lora_config, adapter_name) - incompatible_keys = None - else: - incompatible_keys = self._inject_adapter(state_dict, lora_config, adapter_name, peft_kwargs) - self._maybe_apply_deferred_hotswap_prep(lora_config) - _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) - @validate_hf_hub_args def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs): r""" @@ -372,7 +301,7 @@ def _process_lora( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") - from peft import inject_adapter_in_model, set_peft_model_state_dict + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -410,7 +339,28 @@ def _process_lora( if "lora_B" in key: rank[key] = val.shape[1] - lora_config = _create_lora_config(state_dict, network_alphas, rank) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1df5cb6d2cd3..1a43797d220b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -40,7 +40,6 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..loaders.lora import LoRAModelMixin from ..loaders.weight_mapping import WeightMappingHandler from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod @@ -352,7 +351,7 @@ def wrap(cls): _ATTENTION_API_DEPRECATION_MSG = "`ModelMixin.{name}` is deprecated. Use `{replacement}` instead." -class ModelMixin(torch.nn.Module, ConfigMixin, LoRAModelMixin, PushToHubMixin): +class ModelMixin(torch.nn.Module, ConfigMixin, PushToHubMixin): r""" Base class for all models. @@ -382,12 +381,7 @@ def __init__(self): @classmethod def _metadata(cls) -> dict[str, tuple[Any, str, str, str]]: - """Return ``ModelMixin``-level rows for the metadata snapshot. - - Each row is keyed by the **class attribute name** that controls the capability (e.g. - ``"_supports_gradient_checkpointing"``) and maps to ``(value, display, description, docs_url)``. Only present - capabilities are returned. - """ + """Return ``ModelMixin``-level rows for the metadata snapshot.""" rows: dict[str, tuple[Any, str, str, str]] = {} if cls._supports_gradient_checkpointing: rows["_supports_gradient_checkpointing"] = ( @@ -450,16 +444,7 @@ def _metadata(cls) -> dict[str, tuple[Any, str, str, str]]: @classmethod def metadata(cls) -> "ModelMetadata": - """Return a :class:`ModelMetadata` snapshot of this class's feature attributes. - - Walks ``cls.__mro__`` and merges rows from each ancestor class's own ``_metadata`` classmethod (handled via - direct ``__dict__`` lookup so the aggregator never recurses into itself). First-seen wins on label collisions; - subclass overrides win over inherited defaults. - - The returned object exposes feature values as attributes (``meta._supports_ip_adapter``, ``meta._lora``, ...), - supports ``hasattr`` / ``in`` for presence checks, and prints as a formatted table via its ``__repr__``. Call - ``meta.describe(verbose=True)`` for the verbose variant with descriptions and docs. - """ + """Return a :class:`ModelMetadata` snapshot of this class's feature attributes.""" merged: dict[str, tuple[Any, str, str, str]] = {} for mixin in cls.__mro__: method = mixin.__dict__.get("_metadata") diff --git a/src/diffusers/models/transformers/flux/model.py b/src/diffusers/models/transformers/flux/model.py index 41e30f37d341..b2201a52413f 100644 --- a/src/diffusers/models/transformers/flux/model.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -22,6 +22,7 @@ from ....configuration_utils import register_to_config from ....hooks._helpers import TransformerBlockMetadata +from ....loaders.lora import LoRAModelMixin from ....utils import apply_lora_scale, logging from ....utils.torch_utils import maybe_allow_in_graph from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -529,6 +530,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class FluxTransformer2DModel( ModelMixin, + LoRAModelMixin, AttentionMixin, CacheMixin, FluxIPAdapterMixin, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3c3511380549..f2ba49878710 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -139,6 +139,7 @@ check_peft_version, delete_adapter_layers, get_adapter_name, + get_peft_kwargs, recurse_remove_peft_layers, scale_lora_layers, set_adapter_layers, From 5eaaa3f0eaa6bc7a104a36625a898b8761bd598c Mon Sep 17 00:00:00 2001 From: DN6 Date: Sat, 23 May 2026 00:55:18 +0530 Subject: [PATCH 20/21] update --- src/diffusers/models/modeling_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1a43797d220b..0c14f4ca1077 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1635,10 +1635,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict) - # Convert checkpoint if needed (e.g., original format to diffusers format). For models that haven't - # registered weight-mapping metadata this is a no-op via the default handler. - state_dict = cls._weight_mapping.maybe_convert_state_dict(model, state_dict) - if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] else: From 774807bb77e6cbfba4235fbc353d48e33b6d0c39 Mon Sep 17 00:00:00 2001 From: DN6 Date: Sat, 23 May 2026 01:12:11 +0530 Subject: [PATCH 21/21] update --- src/diffusers/utils/peft_utils.py | 107 ++++++++++++++++-------------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index f2b5b4ccd822..65bcfe631e97 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -150,6 +150,56 @@ def unscale_lora_layers(model, weight: float | None = None): module.set_scale(adapter_name, 1.0) +def get_peft_kwargs( + rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None +): + rank_pattern = {} + alpha_pattern = {} + r = lora_alpha = list(rank_dict.values())[0] + + if len(set(rank_dict.values())) > 1: + # get the rank occurring the most number of times + r = collections.Counter(rank_dict.values()).most_common()[0][0] + + # for modules with rank different from the most occurring rank, add it to the `rank_pattern` + rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} + + if network_alpha_dict is not None and len(network_alpha_dict) > 0: + if len(set(network_alpha_dict.values())) > 1: + # get the alpha occurring the most number of times + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + + # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + if is_unet: + alpha_pattern = { + ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v + for k, v in alpha_pattern.items() + } + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + else: + lora_alpha = set(network_alpha_dict.values()).pop() + + target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) + # for now we know that the "bias" keys are only associated with `lora_B`. + lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": target_modules, + "use_dora": use_dora, + "lora_bias": lora_bias, + } + + return lora_config_kwargs + + def get_adapter_name(model): from peft.tuners.tuners_utils import BaseTunerLayer @@ -294,55 +344,6 @@ def check_peft_version(min_version: str) -> None: ) -def get_peft_kwargs( - rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None -): - rank_pattern = {} - alpha_pattern = {} - r = lora_alpha = list(rank_dict.values())[0] - - if len(set(rank_dict.values())) > 1: - # get the rank occurring the most number of times - r = collections.Counter(rank_dict.values()).most_common()[0][0] - - # for modules with rank different from the most occurring rank, add it to the `rank_pattern` - rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) - rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} - - if network_alpha_dict is not None and len(network_alpha_dict) > 0: - if len(set(network_alpha_dict.values())) > 1: - # get the alpha occurring the most number of times - lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] - - # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` - alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) - if is_unet: - alpha_pattern = { - ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v - for k, v in alpha_pattern.items() - } - else: - alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} - else: - lora_alpha = set(network_alpha_dict.values()).pop() - - target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) - use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) - lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) - - lora_config_kwargs = { - "r": r, - "lora_alpha": lora_alpha, - "rank_pattern": rank_pattern, - "alpha_pattern": alpha_pattern, - "target_modules": target_modules, - "use_dora": use_dora, - "lora_bias": lora_bias, - } - - return lora_config_kwargs - - def _create_lora_config( state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ): @@ -362,6 +363,7 @@ def _create_lora_config( _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) + # Version checks for DoRA and lora_bias if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") @@ -381,6 +383,11 @@ def _maybe_raise_error_for_ambiguous_keys(config): target_modules = config["target_modules"] for key in list(rank_pattern.keys()): + # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] @@ -394,6 +401,7 @@ def _maybe_raise_error_for_ambiguous_keys(config): def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): warn_msg = "" if incompatible_keys is not None: + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] @@ -403,6 +411,7 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): f" {', '.join(lora_unexpected_keys)}. " ) + # Filter missing keys specific to the current adapter. missing_keys = getattr(incompatible_keys, "missing_keys", None) if missing_keys: lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]