Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class FieldDescriptions:
t5_encoder = "T5 tokenizer and text encoder"
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
qwen3_encoder = "Qwen3 tokenizer and text encoder"
mistral_encoder = "Mistral tokenizer/processor and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
Expand All @@ -171,6 +172,7 @@ class FieldDescriptions:
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
z_image_model = "Z-Image model (Transformer) to load"
flux2_dev_model = "FLUX.2 [dev] model (Transformer) to load"
qwen_image_model = "Qwen Image Edit model (Transformer) to load"
qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
Expand Down
176 changes: 176 additions & 0 deletions invokeai/app/invocations/flux2_dev_lora_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""FLUX.2 [dev] LoRA loader invocations.

Mirror of the Klein LoRA loader, but routes encoder LoRAs to the Mistral text
encoder rather than the Qwen3 encoder.
"""

from typing import Optional

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
LoRAField,
MistralEncoderField,
ModelIdentifierField,
TransformerField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, Flux2VariantType, ModelType


@invocation_output("flux2_dev_lora_loader_output")
class Flux2DevLoRALoaderOutput(BaseInvocationOutput):
"""FLUX.2 [dev] LoRA loader output."""

transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Transformer"
)
mistral_encoder: Optional[MistralEncoderField] = OutputField(
default=None, description=FieldDescriptions.mistral_encoder, title="Mistral Encoder"
)


@invocation(
"flux2_dev_lora_loader",
title="Apply LoRA - FLUX.2 [dev]",
tags=["lora", "model", "flux", "flux2", "dev"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2DevLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA to a FLUX.2 [dev] transformer and/or its Mistral text encoder."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
mistral_encoder: MistralEncoderField | None = InputField(
default=None,
title="Mistral Encoder",
description=FieldDescriptions.mistral_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> Flux2DevLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")

lora_config = context.models.get_config(lora_key)
lora_variant = getattr(lora_config, "variant", None)

# Warn if LoRA variant doesn't match transformer variant. A Klein LoRA on a
# dev transformer is virtually guaranteed to produce shape errors.
if lora_variant and self.transformer is not None:
transformer_config = context.models.get_config(self.transformer.transformer.key)
transformer_variant = getattr(transformer_config, "variant", None)
if transformer_variant and lora_variant != transformer_variant:
context.logger.warning(
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
f"but transformer is {transformer_variant.value}. This may cause shape errors."
)
if lora_variant != Flux2VariantType.Dev:
context.logger.warning(
f"LoRA '{lora_config.name}' is a {lora_variant.value} LoRA but is being applied "
"via the FLUX.2 [dev] loader. Use the Klein loader for Klein LoRAs."
)

# Check for duplicate keys.
if self.transformer and any(existing.lora.key == lora_key for existing in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.mistral_encoder and any(existing.lora.key == lora_key for existing in self.mistral_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to Mistral encoder.')

output = Flux2DevLoRALoaderOutput()
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(LoRAField(lora=self.lora, weight=self.weight))
if self.mistral_encoder is not None:
output.mistral_encoder = self.mistral_encoder.model_copy(deep=True)
output.mistral_encoder.loras.append(LoRAField(lora=self.lora, weight=self.weight))
return output


@invocation(
"flux2_dev_lora_collection_loader",
title="Apply LoRA Collection - FLUX.2 [dev]",
tags=["lora", "model", "flux", "flux2", "dev"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2DevLoRACollectionLoader(BaseInvocation):
"""Apply a collection of LoRAs to a FLUX.2 [dev] transformer and/or Mistral encoder."""

loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None,
description="LoRA models and weights. May be a single LoRA or collection.",
title="LoRAs",
)
transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
mistral_encoder: MistralEncoderField | None = InputField(
default=None,
title="Mistral Encoder",
description=FieldDescriptions.mistral_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> Flux2DevLoRALoaderOutput:
output = Flux2DevLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.mistral_encoder is not None:
output.mistral_encoder = self.mistral_encoder.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)

lora_config = context.models.get_config(lora.lora.key)
lora_variant = getattr(lora_config, "variant", None)
if lora_variant and self.transformer is not None:
transformer_config = context.models.get_config(self.transformer.transformer.key)
transformer_variant = getattr(transformer_config, "variant", None)
if transformer_variant and lora_variant != transformer_variant:
context.logger.warning(
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
f"but transformer is {transformer_variant.value}. This may cause shape errors."
)

added_loras.append(lora.lora.key)

if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)
if self.mistral_encoder is not None and output.mistral_encoder is not None:
output.mistral_encoder.loras.append(lora)

return output
179 changes: 179 additions & 0 deletions invokeai/app/invocations/flux2_dev_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""FLUX.2 [dev] model loader invocation.

Loads a FLUX.2 [dev] transformer with its Mistral Small 3.1 text encoder and the
shared FLUX.2 32-channel VAE.
"""

from typing import Literal, Optional

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
MistralEncoderField,
ModelIdentifierField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
Flux2VariantType,
ModelFormat,
ModelType,
SubModelType,
)


@invocation_output("flux2_dev_model_loader_output")
class Flux2DevModelLoaderOutput(BaseInvocationOutput):
"""FLUX.2 [dev] model loader output."""

transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
mistral_encoder: MistralEncoderField = OutputField(
description=FieldDescriptions.mistral_encoder, title="Mistral Encoder"
)
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="Max sequence length for the Mistral encoder.",
title="Max Seq Length",
)


@invocation(
"flux2_dev_model_loader",
title="Main Model - FLUX.2 [dev]",
tags=["model", "flux", "flux2", "dev", "mistral"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2DevModelLoaderInvocation(BaseInvocation):
"""Load a FLUX.2 [dev] transformer plus its Mistral text encoder and VAE.

FLUX.2 [dev] is a 32B guidance-distilled rectified flow transformer that uses
Mistral Small 3.1 (24B) as its sole text encoder, sharing the 32-channel
AutoencoderKLFlux2 VAE with FLUX.2 Klein.

When the transformer is a Diffusers-format checkpoint, both VAE and Mistral
encoder can be extracted directly from the main model. For single-file
safetensors or GGUF transformers, you must supply standalone VAE and
Mistral encoder models, or point at a Diffusers FLUX.2 [dev] checkout for
sub-model extraction.
"""

model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux2_dev_model,
input=Input.Direct,
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.Main,
title="Transformer",
)

vae_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone FLUX.2 VAE (AutoencoderKLFlux2). "
"If not provided, the VAE is extracted from the Diffusers source model.",
input=Input.Direct,
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.VAE,
title="VAE",
)

mistral_encoder_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone Mistral text encoder. Required when the transformer is "
"a single-file safetensors or GGUF without a sibling Diffusers source.",
input=Input.Direct,
ui_model_type=ModelType.MistralEncoder,
title="Mistral Encoder",
)

mistral_source_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers FLUX.2 [dev] model to extract VAE and/or Mistral encoder from. "
"Use this if you don't have separate VAE / Mistral encoder models. "
"Ignored if both are provided separately.",
input=Input.Direct,
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.Main,
ui_model_format=ModelFormat.Diffusers,
title="Mistral Source (Diffusers)",
)

max_seq_len: Literal[256, 512] = InputField(
default=512,
description="Max sequence length for the Mistral encoder. FLUX.2 [dev] uses 512 by default.",
title="Max Seq Length",
)

def invoke(self, context: InvocationContext) -> Flux2DevModelLoaderOutput:
# Validate the selected main model is FLUX.2 [dev], not Klein.
main_config = context.models.get_config(self.model)
variant = getattr(main_config, "variant", None)
if variant is not None and variant != Flux2VariantType.Dev:
raise ValueError(
f"FLUX.2 [dev] loader requires a FLUX.2 [dev] transformer, "
f"but the selected model is variant '{variant.value}'. "
"Use the FLUX.2 Klein loader for Klein variants."
)

transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
main_is_diffusers = main_config.format == ModelFormat.Diffusers

# Resolve VAE.
if self.vae_model is not None:
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
elif main_is_diffusers:
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
elif self.mistral_source_model is not None:
self._validate_diffusers_format(context, self.mistral_source_model, "Mistral Source")
vae = self.mistral_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
else:
raise ValueError(
"No VAE source provided. Single-file / GGUF transformers require a separate VAE. "
"Options:\n"
" 1. Set 'VAE' to a standalone FLUX.2 VAE model\n"
" 2. Set 'Mistral Source' to a Diffusers FLUX.2 [dev] model to extract the VAE from"
)

# Resolve Mistral encoder.
if self.mistral_encoder_model is not None:
tokenizer = self.mistral_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.mistral_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif main_is_diffusers:
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif self.mistral_source_model is not None:
self._validate_diffusers_format(context, self.mistral_source_model, "Mistral Source")
tokenizer = self.mistral_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.mistral_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
else:
raise ValueError(
"No Mistral encoder source provided. Single-file / GGUF transformers require a separate "
"text encoder. Options:\n"
" 1. Set 'Mistral Encoder' to a standalone Mistral Small 3.1 text encoder model\n"
" 2. Set 'Mistral Source' to a Diffusers FLUX.2 [dev] model to extract the encoder from"
)

return Flux2DevModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
mistral_encoder=MistralEncoderField(tokenizer=tokenizer, text_encoder=text_encoder),
vae=VAEField(vae=vae),
max_seq_len=self.max_seq_len,
)

def _validate_diffusers_format(
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
) -> None:
config = context.models.get_config(model)
if config.format != ModelFormat.Diffusers:
raise ValueError(
f"The {model_name} model must be a Diffusers format model. "
f"The selected model '{config.name}' is in {config.format.value} format."
)
Loading
Loading