diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 7582a56505f7..848e26630825 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -370,6 +370,8 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
+ - local: api/models/joyai_image_transformer3d
+ title: JoyAIImageTransformer3DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
@@ -466,6 +468,8 @@
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
+ - local: api/models/autoencoder_kl_joyai_image
+ title: JoyAIImageVAE
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
@@ -558,6 +562,8 @@
title: Kandinsky 5.0 Image
- local: api/pipelines/kolors
title: Kolors
+ - local: api/pipelines/joyai_image
+ title: JoyAI-Image
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
diff --git a/docs/source/en/api/models/autoencoder_kl_joyai_image.md b/docs/source/en/api/models/autoencoder_kl_joyai_image.md
new file mode 100644
index 000000000000..e65909092da6
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_joyai_image.md
@@ -0,0 +1,35 @@
+
+
+# JoyAIImageVAE
+
+The 3D variational autoencoder (VAE) model with KL loss used in JoyAI-Image by JDopensource.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import JoyAIImageVAE
+
+vae = JoyAIImageVAE.from_pretrained("path/to/checkpoint", subfolder="vae", torch_dtype=torch.bfloat16)
+```
+
+
+## JoyAIImageVAE
+
+[[autodoc]] JoyAIImageVAE
+ - decode
+ - all
+
+
+## DecoderOutput
+
+[[autodoc]] diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput
\ No newline at end of file
diff --git a/docs/source/en/api/models/joyai_image_transformer3d.md b/docs/source/en/api/models/joyai_image_transformer3d.md
new file mode 100644
index 000000000000..c90ac5e50ea4
--- /dev/null
+++ b/docs/source/en/api/models/joyai_image_transformer3d.md
@@ -0,0 +1,26 @@
+
+
+# JoyAIImageTransformer3DModel
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import JoyAIImageTransformer3DModel
+
+transformer = JoyAIImageTransformer3DModel.from_pretrained("path/to/checkpoint", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+
+## JoyAIImageTransformer3DModel
+
+[[autodoc]] JoyAIImageTransformer3DModel
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/joyai_image.md b/docs/source/en/api/pipelines/joyai_image.md
new file mode 100644
index 000000000000..7ed482a66978
--- /dev/null
+++ b/docs/source/en/api/pipelines/joyai_image.md
@@ -0,0 +1,128 @@
+
+
+# JoyAI-Image
+
+
+
+
+
+JoyAI-Image is a multimodal foundation model specialized in instruction-guided image editing. It enables precise and controllable edits by leveraging strong spatial understanding, including scene parsing, relational grounding, and instruction decomposition, allowing complex modifications to be applied accurately to specified regions.
+
+
+### Key Features
+- 🌟 **Unified Multimodal Understanding and Generation**: Combines powerful image understanding with generation capabilities in a single model.
+- 🌟 **Spatial Editing**: Supports precise spatial editing including object movement, rotation, and camera control.
+- 🌟 **Instruction Following**: Accurately interprets user instructions for image modifications while preserving image quality.
+- 🌟 **Qwen2.5-VL Integration**: Leverages Qwen2.5-VL for enhanced multimodal understanding.
+
+For more details, please refer to the [JoyAI-Image GitHub](https://github.com/jd-opensource/JoyAI-Image).
+
+
+## Usage Example
+
+```py
+import torch
+from diffusers import JoyAIImagePipeline
+
+pipe = JoyAIImagePipeline.from_pretrained("path/to/converted/checkpoint", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "Move the apple into the red box and finally remove the red box."
+image = pipe(
+ prompt,
+ image=input_image,
+ num_inference_steps=30,
+ guidance_scale=5.0,
+).images[0]
+image.save("./output.png")
+```
+
+
+### Supported Prompt Patterns
+
+#### 1. Object Move
+```text
+Move the into the red box and finally remove the red box.
+```
+
+#### 2. Object Rotation
+```text
+Rotate the to show the side view.
+```
+Supported views: front, right, left, rear, front right, front left, rear right, rear left
+
+#### 3. Camera Control
+```text
+Move the camera.
+- Camera rotation: Yaw {y_rotation}°, Pitch {p_rotation}°.
+- Camera zoom: in/out/unchanged.
+- Keep the 3D scene static; only change the viewpoint.
+```
+
+This pipeline was contributed by JDopensource Team. The original codebase can be found [here](https://github.com/jd-opensource/JoyAI-Image).
+
+
+## Available Models
+
+
+
+
+ Models
+ Type
+ Description
+ Download Link
+
+
+
+
+ JoyAI‑Image‑Edit
+ Image Editing
+ Final Release. Specialized model for instruction-guided image editing.
+
+ 🤗 Huggingface
+
+
+
+
+
+
+## Converting Original Checkpoint to Diffusers Format
+
+If you have the original JoyAI checkpoint, you can convert it to diffusers format using the provided conversion script:
+
+```bash
+python scripts/convert_joyai_image_to_diffusers.py \
+ --source_path /path/to/original/JoyAI-Image-Edit \
+ --output_path /path/to/converted/checkpoint \
+ --dtype bf16
+```
+
+After conversion, load the model with:
+
+```py
+from diffusers import JoyAIImagePipeline
+pipe = JoyAIImagePipeline.from_pretrained("/path/to/converted/checkpoint")
+```
+
+
+## JoyAIImagePipeline
+
+[[autodoc]] JoyAIImagePipeline
+ - all
+ - __call__
+
+
+## JoyAIImagePipelineOutput
+
+[[autodoc]] pipelines.joyai_image.pipeline_output.JoyAIImagePipelineOutput
+
diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md
index c3e493c63d6a..bac1a810529e 100644
--- a/docs/source/en/api/pipelines/overview.md
+++ b/docs/source/en/api/pipelines/overview.md
@@ -49,6 +49,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
| [Kandinsky 3](kandinsky3) | text2image, image2image |
| [Kolors](kolors) | text2image |
+| [JoyAI-Image](joyai_image) | image editing |
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
diff --git a/scripts/convert_joyai_image_to_diffusers.py b/scripts/convert_joyai_image_to_diffusers.py
new file mode 100644
index 000000000000..343885a76f6f
--- /dev/null
+++ b/scripts/convert_joyai_image_to_diffusers.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import torch
+from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration
+
+from diffusers import JoyAIImagePipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models.autoencoders.autoencoder_kl_joyai_image import JoyAIImageVAE
+from diffusers.models.transformers.transformer_joyai_image import JoyAIImageTransformer3DModel
+from diffusers.schedulers.scheduling_joyai_flow_match_discrete import JoyAIFlowMatchDiscreteScheduler
+
+
+DTYPE_MAP = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+PRECISION_TO_TYPE = {
+ "fp32": torch.float32,
+ "float32": torch.float32,
+ "fp16": torch.float16,
+ "float16": torch.float16,
+ "bf16": torch.bfloat16,
+ "bfloat16": torch.bfloat16,
+}
+
+
+@dataclass
+class JoyAIImageSourceConfig:
+ source_root: Path
+ dit_precision: str = "bf16"
+ vae_precision: str = "bf16"
+ text_encoder_precision: str = "bf16"
+ text_token_max_length: int = 2048
+ enable_multi_task_training: bool = False
+ dit_arch_config: dict[str, Any] = field(
+ default_factory=lambda: {
+ "hidden_size": 4096,
+ "in_channels": 16,
+ "heads_num": 32,
+ "mm_double_blocks_depth": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "rope_dim_list": [16, 56, 56],
+ "text_states_dim": 4096,
+ "rope_type": "rope",
+ "dit_modulation_type": "wanx",
+ "theta": 10000,
+ "attn_backend": "flash_attn",
+ }
+ )
+ scheduler_arch_config: dict[str, Any] = field(
+ default_factory=lambda: {
+ "num_train_timesteps": 1000,
+ "shift": 4.0,
+ }
+ )
+
+ @property
+ def text_encoder_arch_config(self) -> dict[str, Any]:
+ return {"params": {"text_encoder_ckpt": str(self.source_root / "JoyAI-Image-Und")}}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Convert a raw JoyAI checkpoint directory to standard diffusers format."
+ )
+ parser.add_argument(
+ "--source_path", type=str, required=True, help="Path to the original JoyAI checkpoint directory"
+ )
+ parser.add_argument(
+ "--output_path", type=str, required=True, help="Output path for the converted diffusers checkpoint"
+ )
+ parser.add_argument(
+ "--dtype", type=str, default="bf16", choices=sorted(DTYPE_MAP), help="Component dtype to load and save"
+ )
+ parser.add_argument("--device", type=str, default="cpu", help="Device used while loading the raw JoyAI checkpoint")
+ parser.add_argument(
+ "--safe_serialization",
+ action="store_true",
+ default=True,
+ help="Save diffusers weights with safetensors when supported (default: True)",
+ )
+ return parser.parse_args()
+
+
+def dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]:
+ if torch_dtype is None:
+ return None
+ for name, value in PRECISION_TO_TYPE.items():
+ if value == torch_dtype and name in {"fp32", "fp16", "bf16"}:
+ return name
+ raise ValueError(f"Unsupported torch dtype for JoyAI conversion: {torch_dtype}")
+
+
+def resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]:
+ if manifest_value is None:
+ return None
+ path = Path(manifest_value)
+ if path.parts and path.parts[0] == source_root.name:
+ path = Path(*path.parts[1:])
+ return source_root / path
+
+
+def is_joyai_source_dir(path: Path) -> bool:
+ return (
+ path.is_dir()
+ and (path / "infer_config.py").is_file()
+ and (path / "manifest.json").is_file()
+ and (path / "transformer").is_dir()
+ and (path / "vae").is_dir()
+ )
+
+
+def load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]:
+ state = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
+ if "model" in state:
+ state = state["model"]
+ return state
+
+
+def load_joyai_components(
+ source_root: Union[str, Path],
+ torch_dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[str, torch.device]] = None,
+) -> dict[str, Any]:
+ source_root = Path(source_root)
+ if not is_joyai_source_dir(source_root):
+ raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}")
+
+ precision = dtype_to_precision(torch_dtype)
+ cfg = JoyAIImageSourceConfig(source_root=source_root)
+
+ manifest = json.loads((source_root / "manifest.json").read_text())
+ transformer_ckpt = resolve_manifest_path(source_root, manifest.get("transformer_ckpt"))
+ vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth"
+ text_encoder_ckpt = source_root / "JoyAI-Image-Und"
+
+ if precision is not None:
+ cfg.dit_precision = precision
+ cfg.vae_precision = precision
+ cfg.text_encoder_precision = precision
+
+ load_device = torch.device(device) if device is not None else torch.device("cpu")
+ transformer = JoyAIImageTransformer3DModel(
+ dtype=PRECISION_TO_TYPE[cfg.dit_precision],
+ device=load_device,
+ **cfg.dit_arch_config,
+ )
+ state_dict = load_transformer_state_dict(transformer_ckpt)
+ if "img_in.weight" in state_dict and transformer.img_in.weight.shape != state_dict["img_in.weight"].shape:
+ value = state_dict["img_in.weight"]
+ padded = value.new_zeros(transformer.img_in.weight.shape)
+ padded[:, : value.shape[1], :, :, :] = value
+ state_dict["img_in.weight"] = padded
+ transformer.load_state_dict(state_dict, strict=True)
+ transformer = transformer.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]).eval()
+
+ vae = JoyAIImageVAE(
+ pretrained=str(vae_ckpt),
+ torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision],
+ device=load_device,
+ )
+ vae = vae.to(device=load_device, dtype=PRECISION_TO_TYPE[cfg.vae_precision]).eval()
+ text_encoder = (
+ Qwen3VLForConditionalGeneration.from_pretrained(
+ str(text_encoder_ckpt),
+ dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision],
+ local_files_only=True,
+ trust_remote_code=True,
+ )
+ .to(load_device)
+ .eval()
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ str(text_encoder_ckpt),
+ local_files_only=True,
+ trust_remote_code=True,
+ )
+ processor = AutoProcessor.from_pretrained(
+ str(text_encoder_ckpt),
+ local_files_only=True,
+ trust_remote_code=True,
+ )
+ scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config)
+
+ return {
+ "args": cfg,
+ "processor": processor,
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ "vae": vae,
+ }
+
+
+def _sanitize_config_value(value: Any) -> Any:
+ if isinstance(value, (torch.dtype, torch.device)):
+ raise TypeError("Drop non-JSON torch config values")
+ if isinstance(value, Path):
+ return str(value)
+ if isinstance(value, dict):
+ sanitized = {}
+ for key, item in value.items():
+ try:
+ sanitized[key] = _sanitize_config_value(item)
+ json.dumps(sanitized[key])
+ except TypeError:
+ continue
+ return sanitized
+ if isinstance(value, (list, tuple)):
+ sanitized = []
+ for item in value:
+ try:
+ converted = _sanitize_config_value(item)
+ json.dumps(converted)
+ sanitized.append(converted)
+ except TypeError:
+ continue
+ return sanitized
+ return value
+
+
+def _sanitize_component_config(component: Any) -> None:
+ config = getattr(component, "config", None)
+ if config is None:
+ return
+
+ sanitized_config = {}
+ for key, value in dict(config).items():
+ try:
+ sanitized_value = _sanitize_config_value(value)
+ json.dumps(sanitized_value)
+ sanitized_config[key] = sanitized_value
+ except TypeError:
+ continue
+
+ component._internal_dict = FrozenDict(sanitized_config)
+
+
+def _sanitize_pipeline_for_export(pipeline: JoyAIImagePipeline) -> None:
+ for component_name in ["vae", "transformer", "scheduler"]:
+ _sanitize_component_config(getattr(pipeline, component_name, None))
+
+
+def main():
+ args = parse_args()
+ source_path = Path(args.source_path)
+ output_path = Path(args.output_path)
+
+ if not source_path.exists():
+ raise ValueError(f"Source path does not exist: {source_path}")
+
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ components = load_joyai_components(
+ source_root=source_path,
+ torch_dtype=DTYPE_MAP[args.dtype],
+ device=args.device,
+ )
+ pipeline = JoyAIImagePipeline(
+ vae=components["vae"],
+ text_encoder=components["text_encoder"],
+ tokenizer=components["tokenizer"],
+ transformer=components["transformer"],
+ scheduler=components["scheduler"],
+ processor=components["processor"],
+ args=components["args"],
+ )
+ _sanitize_pipeline_for_export(pipeline)
+ pipeline.save_pretrained(output_path, safe_serialization=args.safe_serialization)
+
+ print(f"Converted JoyAI checkpoint saved to: {output_path}")
+ print(f"Load with: JoyAIImagePipeline.from_pretrained({str(output_path)!r})")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 0f74c0bbcb4a..baeda1625a09 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -209,6 +209,7 @@
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLQwenImage",
+ "JoyAIImageVAE",
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
@@ -245,6 +246,7 @@
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
+ "JoyAIImageTransformer3DModel",
"HunyuanVideo15Transformer3DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
@@ -376,6 +378,7 @@
"FlowMatchLCMScheduler",
"HeliosDMDScheduler",
"HeliosScheduler",
+ "JoyAIFlowMatchDiscreteScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -545,6 +548,7 @@
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
+ "JoyAIImagePipeline",
"HunyuanImagePipeline",
"HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
@@ -1044,6 +1048,8 @@
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
+ JoyAIImageTransformer3DModel,
+ JoyAIImageVAE,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
@@ -1169,6 +1175,7 @@
HeliosScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
+ JoyAIFlowMatchDiscreteScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 7ded56049833..7dfa6c34d58c 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -48,6 +48,7 @@
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
+ _import_structure["autoencoders.autoencoder_kl_joyai_image"] = ["JoyAIImageVAE"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
@@ -110,6 +111,7 @@
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
+ _import_structure["transformers.transformer_joyai_image"] = ["JoyAIImageTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -178,6 +180,7 @@
AutoencoderTiny,
AutoencoderVidTok,
ConsistencyDecoderVAE,
+ JoyAIImageVAE,
VQModel,
)
from .cache_utils import CacheMixin
@@ -228,6 +231,7 @@
HunyuanVideo15Transformer3DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
+ JoyAIImageTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index 609146ec340d..a54a005f6812 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -9,6 +9,7 @@
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
+from .autoencoder_kl_joyai_image import JoyAIImageVAE
from .autoencoder_kl_kvae import AutoencoderKLKVAE
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py
new file mode 100644
index 000000000000..c2d08d20c166
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py
@@ -0,0 +1,812 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This model is adapted from https://github.com/jd-opensource/JoyAI-Image
+
+from contextlib import nullcontext
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__)
+CACHE_T = 2
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanCausalConv3d
+class CausalConv3d(nn.Conv3d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int | tuple[int, int, int],
+ stride: int | tuple[int, int, int] = 1,
+ padding: int | tuple[int, int, int] = 0,
+ ) -> None:
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+ return super().forward(x)
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanRMS_norm
+class RMS_norm(nn.Module):
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
+ t in str(x.dtype) for t in ("float4_", "float8_")
+ )
+ normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
+ x.dtype
+ )
+
+ return normalized * self.scale * self.gamma + self.bias
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanUpsample
+class Upsample(nn.Upsample):
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResample
+class Resample(nn.Module):
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ if upsample_out_dim is None:
+ upsample_out_dim = dim // 2
+
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
+ )
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.resample(x)
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResidualBlock
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ dropout: float = 0.0,
+ non_linearity: str = "silu",
+ ) -> None:
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = RMS_norm(in_dim, images=False)
+ self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1)
+ self.norm2 = RMS_norm(out_dim, images=False)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1)
+ self.conv_shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.conv_shortcut(x)
+
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.dropout(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv2(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv2(x)
+
+ return x + h
+
+
+# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanAttentionBlock
+class AttentionBlock(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ identity = x
+ batch_size, channels, time, height, width = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
+ x = self.norm(x)
+
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
+ x = self.proj(x)
+ x = x.view(batch_size, time, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4)
+
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.ModuleList(downsamples)
+
+ # middle blocks
+ self.middle = nn.ModuleList(
+ [
+ ResidualBlock(out_dim, out_dim, dropout),
+ AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout),
+ ]
+ )
+
+ # output blocks
+ self.head = nn.ModuleList(
+ [
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
+ ]
+ )
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = layer(x)
+
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.ModuleList(
+ [
+ ResidualBlock(dims[0], dims[0], dropout),
+ AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout),
+ ]
+ )
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.ModuleList(upsamples)
+
+ # output blocks
+ self.head = nn.ModuleList(
+ [
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1),
+ ]
+ )
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = layer(x)
+
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = layer(x)
+
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE_(nn.Module):
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(
+ dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ )
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
+
+ @property
+ def quant_conv(self):
+ return self.conv1
+
+ @property
+ def post_quant_conv(self):
+ return self.conv2
+
+ def _encode_frames(self, x):
+ num_frames = x.shape[2]
+ num_chunks = 1 + (num_frames - 1) // 4
+
+ for chunk_idx in range(num_chunks):
+ self._enc_conv_idx = [0]
+ if chunk_idx == 0:
+ encoded = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ encoded_chunk = self.encoder(
+ x[:, :, 1 + 4 * (chunk_idx - 1) : 1 + 4 * chunk_idx, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx,
+ )
+ encoded = torch.cat([encoded, encoded_chunk], dim=2)
+ return encoded
+
+ def _decode_frames(self, x):
+ num_frames = x.shape[2]
+ for frame_idx in range(num_frames):
+ self._conv_idx = [0]
+ decoded_chunk = self.decoder(
+ x[:, :, frame_idx : frame_idx + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx,
+ )
+ if frame_idx == 0:
+ decoded = decoded_chunk
+ else:
+ decoded = torch.cat([decoded, decoded_chunk], dim=2)
+ return decoded
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale=None, return_posterior=False):
+ self.clear_cache()
+ encoded = self._encode_frames(x)
+ mu, log_var = self.quant_conv(encoded).chunk(2, dim=1)
+ if scale is None or return_posterior:
+ return mu, log_var
+
+ mu = self.reparameterize(mu, log_var)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ self.clear_cache()
+ return mu
+
+ def decode(self, z, scale=None):
+ self.clear_cache()
+ if scale is not None:
+ if isinstance(scale[0], torch.Tensor):
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
+ else:
+ z = z / scale[1] + scale[0]
+ decoded = self._decode_frames(self.post_quant_conv(z))
+ self.clear_cache()
+ return decoded
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False, scale=None):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ mu = mu + std * torch.randn_like(std)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ self.clear_cache()
+ return mu
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+def _build_video_vae(z_dim=None, use_meta=False, **kwargs):
+ """Build the JoyAI/Wan-derived VAE backbone without loading external weights."""
+ cfg = {
+ "dim": 96,
+ "z_dim": z_dim,
+ "dim_mult": [1, 2, 4, 4],
+ "num_res_blocks": 2,
+ "attn_scales": [],
+ "temperal_downsample": [False, True, True],
+ "dropout": 0.0,
+ }
+ cfg.update(**kwargs)
+
+ if use_meta:
+ with torch.device("meta"):
+ return WanVAE_(**cfg)
+ return WanVAE_(**cfg)
+
+
+def _remap_joyai_vae_state_dict_keys(pretrained_state_dict):
+ remapped_state_dict = {}
+ for key, value in pretrained_state_dict.items():
+ key = key.replace(".residual.0.gamma", ".norm1.gamma")
+ key = key.replace(".residual.2.weight", ".conv1.weight")
+ key = key.replace(".residual.2.bias", ".conv1.bias")
+ key = key.replace(".residual.3.gamma", ".norm2.gamma")
+ key = key.replace(".residual.6.weight", ".conv2.weight")
+ key = key.replace(".residual.6.bias", ".conv2.bias")
+ key = key.replace(".shortcut.weight", ".conv_shortcut.weight")
+ key = key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ remapped_state_dict[key] = value
+ return remapped_state_dict
+
+
+def _load_pretrained_weights(model, pretrained_path):
+ if not pretrained_path:
+ return model
+
+ logger.info(f"loading {pretrained_path}")
+
+ if pretrained_path.endswith(".safetensors"):
+ from safetensors.torch import load_file
+
+ pretrained_state_dict = load_file(pretrained_path, device="cpu")
+ else:
+ pretrained_state_dict = torch.load(pretrained_path, map_location="cpu")
+
+ pretrained_state_dict = _remap_joyai_vae_state_dict_keys(pretrained_state_dict)
+ model.load_state_dict(pretrained_state_dict, assign=True)
+ return model
+
+
+def _video_vae(pretrained_path=None, z_dim=None, use_meta=False, **kwargs):
+ model = _build_video_vae(z_dim=z_dim, use_meta=use_meta, **kwargs)
+ return _load_pretrained_weights(model, pretrained_path)
+
+
+class JoyAIImageVAE(ModelMixin, ConfigMixin, AutoencoderMixin, FromOriginalModelMixin):
+ def __init__(
+ self,
+ pretrained: str = "",
+ torch_dtype: torch.dtype = torch.float32,
+ device: str | torch.device = "cpu",
+ z_dim: int = 16,
+ latent_channels: int | None = None,
+ dim: int = 96,
+ dim_mult: list[int] | tuple[int, ...] = (1, 2, 4, 4),
+ num_res_blocks: int = 2,
+ attn_scales: list[float] | tuple[float, ...] = (),
+ temperal_downsample: list[bool] | tuple[bool, ...] = (False, True, True),
+ dropout: float = 0.0,
+ latents_mean: list[float] | tuple[float, ...] = (
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ),
+ latents_std: list[float] | tuple[float, ...] = (
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ),
+ spatial_compression_ratio: int = 8,
+ temporal_compression_ratio: int = 4,
+ ):
+ super().__init__()
+
+ if latent_channels is not None:
+ z_dim = latent_channels
+
+ self.register_to_config(
+ pretrained=pretrained,
+ z_dim=z_dim,
+ dim=dim,
+ dim_mult=list(dim_mult),
+ num_res_blocks=num_res_blocks,
+ attn_scales=list(attn_scales),
+ temperal_downsample=list(temperal_downsample),
+ dropout=dropout,
+ latent_channels=z_dim,
+ latents_mean=list(latents_mean),
+ latents_std=list(latents_std),
+ spatial_compression_ratio=spatial_compression_ratio,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+
+ self.register_buffer("mean", torch.tensor(latents_mean, dtype=torch.float32), persistent=True)
+ self.register_buffer("std", torch.tensor(latents_std, dtype=torch.float32), persistent=True)
+
+ self.ffactor_spatial = spatial_compression_ratio
+ self.ffactor_temporal = temporal_compression_ratio
+
+ use_meta = bool(pretrained)
+ self.model = _video_vae(
+ pretrained_path=pretrained,
+ z_dim=z_dim,
+ dim=dim,
+ dim_mult=list(dim_mult),
+ num_res_blocks=num_res_blocks,
+ attn_scales=list(attn_scales),
+ temperal_downsample=list(temperal_downsample),
+ dropout=dropout,
+ use_meta=use_meta,
+ )
+ self.model.eval()
+
+ def _latent_scale_tensors(self, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
+ mean = self.mean.to(device=device, dtype=dtype).view(1, -1, 1, 1, 1)
+ inv_std = self.std.to(device=device, dtype=dtype).reciprocal().view(1, -1, 1, 1, 1)
+ return mean, inv_std
+
+ @apply_forward_hook
+ def encode(self, videos: torch.Tensor, return_dict: bool = True, return_posterior: bool = False, **kwargs):
+ autocast_context = (
+ torch.amp.autocast(device_type="cuda", dtype=torch.float32)
+ if videos.device.type == "cuda"
+ else nullcontext()
+ )
+ with autocast_context:
+ mean, logvar = self.model.encode(videos, scale=None, return_posterior=True)
+ if return_posterior:
+ return mean, logvar
+
+ latent_mean, latent_inv_std = self._latent_scale_tensors(mean.device, mean.dtype)
+ scaled_mean = (mean - latent_mean) * latent_inv_std
+ scaled_logvar = logvar + 2 * torch.log(latent_inv_std)
+ posterior = DiagonalGaussianDistribution(torch.cat([scaled_mean, scaled_logvar], dim=1))
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ @apply_forward_hook
+ def decode(self, zs: torch.Tensor, return_dict: bool = True, **kwargs):
+ autocast_context = (
+ torch.amp.autocast(device_type="cuda", dtype=torch.float32) if zs.device.type == "cuda" else nullcontext()
+ )
+ with autocast_context:
+ mean, inv_std = self._latent_scale_tensors(zs.device, zs.dtype)
+ scale = [mean.view(-1), inv_std.view(-1)]
+ videos = [self.model.decode(z.unsqueeze(0), scale=scale).clamp_(-1, 1).squeeze(0) for z in zs]
+ videos = torch.stack(videos, dim=0)
+
+ if not return_dict:
+ return (videos,)
+ return DecoderOutput(sample=videos)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: torch.Generator | None = None,
+ ):
+ posterior = self.encode(sample).latent_dist
+ latents = posterior.sample(generator=generator) if sample_posterior else posterior.mode()
+ return self.decode(latents, return_dict=return_dict)
+
+
+WanxVAE = JoyAIImageVAE
+
+__all__ = ["JoyAIImageVAE", "WanxVAE"]
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 45157ee91808..ce5a0a7cf5e3 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -34,6 +34,7 @@
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
+ from .transformer_joyai_image import JoyAIImageTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_joyai_image.py b/src/diffusers/models/transformers/transformer_joyai_image.py
new file mode 100644
index 000000000000..33e3cfff866b
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_joyai_image.py
@@ -0,0 +1,612 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This model is adapted from https://github.com/jd-opensource/JoyAI-Image
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import ModelMixin
+from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from diffusers.models.attention_dispatch import AttentionBackendName, dispatch_attention_fn
+from diffusers.models.embeddings import (
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
+from diffusers.models.normalization import RMSNorm
+
+
+def _create_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ if modulate_type == "wanx":
+ return _WanModulation(hidden_size, factor, **factory_kwargs)
+ raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
+
+
+class _WanModulation(nn.Module):
+ """Modulation layer for WanX."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ factor: int,
+ dtype=None,
+ device=None,
+ ):
+ super().__init__()
+ self.factor = factor
+ self.modulate_table = nn.Parameter(
+ torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if len(x.shape) != 3:
+ x = x.unsqueeze(1)
+ return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
+
+
+class JoyAIJointAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __call__(
+ self,
+ attn: "JoyAIJointAttention",
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_kwargs: Optional[dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+ backend = AttentionBackendName.NATIVE if attn.backend == "torch_spda" else AttentionBackendName.FLASH_VARLEN
+
+ try:
+ return dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ attention_kwargs=attention_kwargs,
+ backend=backend,
+ parallel_config=self._parallel_config,
+ )
+ except (RuntimeError, ValueError, TypeError):
+ return dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ attention_kwargs=attention_kwargs,
+ backend=AttentionBackendName.NATIVE,
+ parallel_config=self._parallel_config,
+ )
+
+
+class JoyAIJointAttention(nn.Module, AttentionModuleMixin):
+ _default_processor_cls = JoyAIJointAttnProcessor
+ _available_processors = [JoyAIJointAttnProcessor]
+ _supports_qkv_fusion = False
+
+ def __init__(self, backend: str = "flash_attn", processor=None) -> None:
+ super().__init__()
+ self.backend = backend
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ attention_kwargs: Optional[dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ return self.processor(self, query, key, value, attention_mask, attention_kwargs)
+
+
+class JoyAIImageTransformerBlock(nn.Module):
+ """Joint text-image transformer block."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ mlp_width_ratio: float,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ dit_modulation_type: Optional[str] = "wanx",
+ attn_backend: str = "flash_attn",
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.dit_modulation_type = dit_modulation_type
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
+
+ self.img_mod = _create_modulation(
+ modulate_type=self.dit_modulation_type,
+ hidden_size=hidden_size,
+ factor=6,
+ **factory_kwargs,
+ )
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
+ self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
+ self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
+
+ self.txt_mod = _create_modulation(
+ modulate_type=self.dit_modulation_type,
+ hidden_size=hidden_size,
+ factor=6,
+ **factory_kwargs,
+ )
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
+ self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
+ self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6)
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
+ self.attn = JoyAIJointAttention(attn_backend)
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
+
+ @staticmethod
+ def _modulate(
+ hidden_states: torch.Tensor, shift: torch.Tensor | None = None, scale: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ if scale is None and shift is None:
+ return hidden_states
+ if shift is None:
+ return hidden_states * (1 + scale.unsqueeze(1))
+ if scale is None:
+ return hidden_states + shift.unsqueeze(1)
+ return hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+ @staticmethod
+ def _apply_gate(hidden_states: torch.Tensor, gate: torch.Tensor | None = None, tanh: bool = False) -> torch.Tensor:
+ if gate is None:
+ return hidden_states
+ if tanh:
+ return hidden_states * gate.unsqueeze(1).tanh()
+ return hidden_states * gate.unsqueeze(1)
+
+ def forward(
+ self,
+ img: torch.Tensor,
+ txt: torch.Tensor,
+ vec: torch.Tensor,
+ vis_freqs_cis: tuple = None,
+ txt_freqs_cis: tuple = None,
+ attn_kwargs: Optional[dict] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ (
+ img_mod1_shift,
+ img_mod1_scale,
+ img_mod1_gate,
+ img_mod2_shift,
+ img_mod2_scale,
+ img_mod2_gate,
+ ) = self.img_mod(vec)
+ (
+ txt_mod1_shift,
+ txt_mod1_scale,
+ txt_mod1_gate,
+ txt_mod2_shift,
+ txt_mod2_scale,
+ txt_mod2_gate,
+ ) = self.txt_mod(vec)
+
+ img_modulated = self.img_norm1(img)
+ img_modulated = self._modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
+ img_qkv = self.img_attn_qkv(img_modulated)
+ batch_size, image_sequence_length, _ = img_qkv.shape
+ img_qkv = img_qkv.view(batch_size, image_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4)
+ img_q, img_k, img_v = img_qkv.unbind(0)
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
+
+ if vis_freqs_cis is not None:
+ img_q = apply_rotary_emb(img_q, vis_freqs_cis, sequence_dim=1)
+ img_k = apply_rotary_emb(img_k, vis_freqs_cis, sequence_dim=1)
+
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = self._modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
+ _, text_sequence_length, _ = txt_qkv.shape
+ txt_qkv = txt_qkv.view(batch_size, text_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4)
+ txt_q, txt_k, txt_v = txt_qkv.unbind(0)
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
+
+ if txt_freqs_cis is not None:
+ raise NotImplementedError("RoPE text is not supported for inference")
+
+ attention_output = self.attn(
+ torch.cat((img_q, txt_q), dim=1),
+ torch.cat((img_k, txt_k), dim=1),
+ torch.cat((img_v, txt_v), dim=1),
+ attention_mask=attn_kwargs.get("attention_mask") if attn_kwargs is not None else None,
+ attention_kwargs=attn_kwargs,
+ )
+ attention_output = attention_output.flatten(2, 3)
+ image_attention_output = attention_output[:, : img.shape[1]]
+ text_attention_output = attention_output[:, img.shape[1] :]
+
+ img = img + self._apply_gate(self.img_attn_proj(image_attention_output), gate=img_mod1_gate)
+ img = img + self._apply_gate(
+ self.img_mlp(self._modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
+ gate=img_mod2_gate,
+ )
+
+ txt = txt + self._apply_gate(self.txt_attn_proj(text_attention_output), gate=txt_mod1_gate)
+ txt = txt + self._apply_gate(
+ self.txt_mlp(self._modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
+ gate=txt_mod2_gate,
+ )
+
+ return img, txt
+
+
+class JoyAITimeTextEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ timestep_embedding = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ modulation_states = self.time_proj(self.act_fn(timestep_embedding))
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+
+ return modulation_states, encoder_hidden_states
+
+
+class JoyAIImageTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin):
+ _fsdp_shard_conditions: list = [lambda name, module: isinstance(module, JoyAIImageTransformerBlock)]
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: tuple[int, int, int] = (1, 2, 2),
+ in_channels: int = 4,
+ out_channels: int = None,
+ hidden_size: int = 3072,
+ heads_num: int = 24,
+ text_states_dim: int = 4096,
+ mlp_width_ratio: float = 4.0,
+ mm_double_blocks_depth: int = 20,
+ rope_dim_list: tuple[int, int, int] = (16, 56, 56),
+ rope_type: str = "rope",
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ dit_modulation_type: str = "wanx",
+ attn_backend: str = "flash_attn",
+ theta: int = 256,
+ ):
+ self.out_channels = out_channels or in_channels
+ self.patch_size = patch_size
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ self.rope_dim_list = rope_dim_list
+ self.dit_modulation_type = dit_modulation_type
+ self.rope_type = rope_type
+ self.theta = theta
+
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ if hidden_size % heads_num != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
+
+ self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ self.condition_embedder = JoyAITimeTextEmbedding(
+ dim=hidden_size,
+ time_freq_dim=256,
+ time_proj_dim=hidden_size * 6,
+ text_embed_dim=text_states_dim,
+ )
+
+ self.double_blocks = nn.ModuleList(
+ [
+ JoyAIImageTransformerBlock(
+ self.hidden_size,
+ self.heads_num,
+ mlp_width_ratio=mlp_width_ratio,
+ dit_modulation_type=self.dit_modulation_type,
+ attn_backend=attn_backend,
+ **factory_kwargs,
+ )
+ for _ in range(mm_double_blocks_depth)
+ ]
+ )
+
+ self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(patch_size), **factory_kwargs)
+
+ @staticmethod
+ def _get_meshgrid_nd(start, *args, dim=2):
+ """Build an N-D meshgrid from integer sizes or ranges."""
+
+ def as_tuple(value):
+ if isinstance(value, int):
+ return (value,) * dim
+ if len(value) == dim:
+ return value
+ raise ValueError(f"Expected length {dim} or int, but got {value}")
+
+ if len(args) == 0:
+ num = as_tuple(start)
+ start = (0,) * dim
+ stop = num
+ elif len(args) == 1:
+ start = as_tuple(start)
+ stop = as_tuple(args[0])
+ num = [stop[i] - start[i] for i in range(dim)]
+ elif len(args) == 2:
+ start = as_tuple(start)
+ stop = as_tuple(args[0])
+ num = as_tuple(args[1])
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ axis_grid = []
+ for i in range(dim):
+ a, b, n = start[i], stop[i], num[i]
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
+ axis_grid.append(g)
+ grid = torch.meshgrid(*axis_grid, indexing="ij")
+ grid = torch.stack(grid, dim=0)
+
+ return grid
+
+ @staticmethod
+ def _get_nd_rotary_pos_embed(
+ rope_dim_list,
+ start,
+ *args,
+ theta=10000.0,
+ use_real=False,
+ text_sequence_length=None,
+ ):
+ """Build visual and optional text rotary embeddings."""
+
+ grid = JoyAIImageTransformer3DModel._get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
+
+ embs = []
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ grid[i].reshape(-1),
+ theta=theta,
+ use_real=use_real,
+ )
+ embs.append(emb)
+
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs], dim=1)
+ sin = torch.cat([emb[1] for emb in embs], dim=1)
+ vis_emb = (cos, sin)
+ else:
+ vis_emb = torch.cat(embs, dim=1)
+ if text_sequence_length is not None:
+ embs_txt = []
+ vis_max_ids = grid.view(-1).max().item()
+ text_positions = torch.arange(text_sequence_length) + vis_max_ids + 1
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ text_positions,
+ theta=theta,
+ use_real=use_real,
+ )
+ embs_txt.append(emb)
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs_txt], dim=1)
+ sin = torch.cat([emb[1] for emb in embs_txt], dim=1)
+ txt_emb = (cos, sin)
+ else:
+ txt_emb = torch.cat(embs_txt, dim=1)
+ else:
+ txt_emb = None
+ return vis_emb, txt_emb
+
+ def get_rotary_pos_embed(self, image_grid_size, text_sequence_length=None):
+ target_ndim = 3
+
+ if len(image_grid_size) != target_ndim:
+ image_grid_size = [1] * (target_ndim - len(image_grid_size)) + image_grid_size
+ head_dim = self.hidden_size // self.heads_num
+ rope_dim_list = self.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+ image_rotary_emb, text_rotary_emb = self._get_nd_rotary_pos_embed(
+ rope_dim_list,
+ image_grid_size,
+ text_sequence_length=text_sequence_length,
+ theta=self.theta,
+ use_real=True,
+ )
+ return image_rotary_emb, text_rotary_emb
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ is_multi_item = len(hidden_states.shape) == 6
+ num_items = 0
+ if is_multi_item:
+ num_items = hidden_states.shape[1]
+ if num_items > 1:
+ assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
+ hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
+ batch_size, num_items, channels, frames_per_item, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape(
+ batch_size, channels, num_items * frames_per_item, height, width
+ )
+
+ _, _, output_frames, output_height, output_width = hidden_states.shape
+ latent_frames, latent_height, latent_width = (
+ output_frames // self.patch_size[0],
+ output_height // self.patch_size[1],
+ output_width // self.patch_size[2],
+ )
+ image_hidden_states = self.img_in(hidden_states).flatten(2).transpose(1, 2)
+
+ if encoder_hidden_states_mask is None:
+ encoder_hidden_states_mask = torch.ones(
+ (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
+ dtype=torch.bool,
+ device=image_hidden_states.device,
+ )
+ else:
+ encoder_hidden_states_mask = encoder_hidden_states_mask.to(
+ device=image_hidden_states.device, dtype=torch.bool
+ )
+ modulation_states, text_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
+ if modulation_states.shape[-1] > self.hidden_size:
+ modulation_states = modulation_states.unflatten(1, (6, -1))
+
+ text_seq_len = text_hidden_states.shape[1]
+ image_seq_len = image_hidden_states.shape[1]
+ image_rotary_emb, text_rotary_emb = self.get_rotary_pos_embed(
+ image_grid_size=(latent_frames, latent_height, latent_width),
+ text_sequence_length=text_seq_len if self.rope_type == "mrope" else None,
+ )
+
+ attention_mask = torch.cat(
+ [
+ torch.ones(
+ (encoder_hidden_states_mask.shape[0], image_seq_len),
+ dtype=torch.bool,
+ device=encoder_hidden_states_mask.device,
+ ),
+ encoder_hidden_states_mask.bool(),
+ ],
+ dim=1,
+ )
+ attention_kwargs = {
+ "thw": [latent_frames, latent_height, latent_width],
+ "txt_len": text_seq_len,
+ "attention_mask": attention_mask,
+ }
+
+ for block in self.double_blocks:
+ image_hidden_states, text_hidden_states = block(
+ image_hidden_states,
+ text_hidden_states,
+ modulation_states,
+ image_rotary_emb,
+ text_rotary_emb,
+ attention_kwargs,
+ )
+
+ image_seq_len = image_hidden_states.shape[1]
+ hidden_states = torch.cat((image_hidden_states, text_hidden_states), dim=1)
+ image_hidden_states = hidden_states[:, :image_seq_len, ...]
+ image_hidden_states = self.proj_out(self.norm_out(image_hidden_states))
+ image_hidden_states = self.unpatchify(image_hidden_states, latent_frames, latent_height, latent_width)
+
+ if is_multi_item:
+ batch_size, channels, total_frames, height, width = image_hidden_states.shape
+ image_hidden_states = image_hidden_states.reshape(
+ batch_size, channels, num_items, total_frames // num_items, height, width
+ ).permute(0, 2, 1, 3, 4, 5)
+ if num_items > 1:
+ image_hidden_states = torch.cat(
+ [
+ image_hidden_states[:, 1:],
+ image_hidden_states[:, :1],
+ ],
+ dim=1,
+ )
+
+ if return_dict:
+ return {"sample": image_hidden_states, "encoder_hidden_states": text_hidden_states}
+ return image_hidden_states, text_hidden_states
+
+ def unpatchify(self, hidden_states, latent_frames, latent_height, latent_width):
+ channels = self.out_channels
+ patch_frames, patch_height, patch_width = self.patch_size
+ assert latent_frames * latent_height * latent_width == hidden_states.shape[1]
+
+ hidden_states = hidden_states.reshape(
+ shape=(
+ hidden_states.shape[0],
+ latent_frames,
+ latent_height,
+ latent_width,
+ patch_frames,
+ patch_height,
+ patch_width,
+ channels,
+ )
+ )
+ hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states)
+
+ return hidden_states.reshape(
+ shape=(
+ hidden_states.shape[0],
+ channels,
+ latent_frames * patch_frames,
+ latent_height * patch_height,
+ latent_width * patch_width,
+ )
+ )
+
+
+__all__ = ["JoyAIImageTransformer3DModel"]
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 05aad6e349f6..432fc4034f22 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -263,6 +263,7 @@
_import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
+ _import_structure["joyai_image"] = ["JoyAIImagePipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
@@ -706,6 +707,7 @@
)
from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline
from .hunyuandit import HunyuanDiTPipeline
+ from .joyai_image import JoyAIImagePipeline
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
diff --git a/src/diffusers/pipelines/joyai_image/__init__.py b/src/diffusers/pipelines/joyai_image/__init__.py
new file mode 100644
index 000000000000..cacb9296401a
--- /dev/null
+++ b/src/diffusers/pipelines/joyai_image/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from diffusers.utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from diffusers.utils import dummy_torch_and_transformers_objects
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_joyai_image"] = ["JoyAIImagePipeline"]
+ _import_structure["pipeline_output"] = ["JoyAIImagePipelineOutput"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from diffusers.utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_joyai_image import JoyAIImagePipeline
+ from .pipeline_output import JoyAIImagePipelineOutput
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py
new file mode 100644
index 000000000000..d17ec3ce568f
--- /dev/null
+++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py
@@ -0,0 +1,776 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This pipeline is adapted from https://github.com/jd-opensource/JoyAI-Image
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, PreTrainedTokenizerBase, Qwen3VLForConditionalGeneration
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders.autoencoder_kl_joyai_image import JoyAIImageVAE
+from diffusers.models.transformers.transformer_joyai_image import JoyAIImageTransformer3DModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, empty_device_cache, get_device
+from diffusers.schedulers.scheduling_joyai_flow_match_discrete import JoyAIFlowMatchDiscreteScheduler
+from diffusers.utils import is_accelerate_available, is_accelerate_version, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+from .pipeline_output import JoyAIImagePipelineOutput
+
+
+logger = logging.get_logger(__name__)
+
+PRECISION_TO_TYPE = {
+ "fp32": torch.float32,
+ "float32": torch.float32,
+ "fp16": torch.float16,
+ "float16": torch.float16,
+ "bf16": torch.bfloat16,
+ "bfloat16": torch.bfloat16,
+}
+
+
+PROMPT_TEMPLATE_ENCODE = {
+ "image": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ "multiple_images": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
+ "video": "<|im_start|>system\n \nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+}
+
+PROMPT_TEMPLATE_START_IDX = {
+ "image": 34,
+ "multiple_images": 34,
+ "video": 91,
+}
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class JoyAIImagePipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = ["processor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: JoyAIImageVAE,
+ text_encoder: Qwen3VLForConditionalGeneration,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: JoyAIImageTransformer3DModel,
+ scheduler: JoyAIFlowMatchDiscreteScheduler,
+ processor: Any | None = None,
+ args: Any | None = None,
+ ):
+ super().__init__()
+ self.args = args
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.enable_multi_task = bool(getattr(self.args, "enable_multi_task_training", False))
+ if hasattr(self.vae, "ffactor_spatial"):
+ self.vae_scale_factor = self.vae.ffactor_spatial
+ self.vae_scale_factor_temporal = self.vae.ffactor_temporal
+ else:
+ self.vae_scale_factor = 8
+ self.vae_scale_factor_temporal = 4
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.qwen_processor = processor
+ text_encoder_ckpt = None
+ text_encoder_cfg = getattr(self.args, "text_encoder_arch_config", None)
+ if isinstance(text_encoder_cfg, dict):
+ text_encoder_params = text_encoder_cfg.get("params", {})
+ text_encoder_ckpt = text_encoder_params.get("text_encoder_ckpt")
+ if self.qwen_processor is None and text_encoder_ckpt is not None:
+ self.qwen_processor = AutoProcessor.from_pretrained(
+ text_encoder_ckpt,
+ local_files_only=True,
+ trust_remote_code=True,
+ )
+
+ self.text_token_max_length = int(getattr(self.args, "text_token_max_length", 2048))
+ self.prompt_template_encode = PROMPT_TEMPLATE_ENCODE
+ self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX
+ self._joyai_force_vae_fp32 = True
+
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ return torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ template_type: str = "image",
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._get_runtime_execution_device()
+ dtype = dtype or next(self.text_encoder.parameters()).dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ template = self.prompt_template_encode[template_type]
+ drop_idx = self.prompt_template_encode_start_idx[template_type]
+ formatted_prompts = [template.format(prompt_text) for prompt_text in prompt]
+ txt_tokens = self.tokenizer(
+ formatted_prompts,
+ max_length=self.text_token_max_length + drop_idx,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).to(device)
+ encoder_hidden_states = self._run_text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = min(
+ self.text_token_max_length,
+ max(u.size(0) for u in split_hidden_states),
+ max(u.size(0) for u in attn_mask_list),
+ )
+ prompt_embeds = torch.stack(
+ [
+ torch.cat(
+ [
+ hidden_state,
+ hidden_state.new_zeros(max_seq_len - hidden_state.size(0), hidden_state.size(1)),
+ ]
+ )
+ for hidden_state in split_hidden_states
+ ]
+ )
+ encoder_attention_mask = torch.stack(
+ [
+ torch.cat([attention_mask_row, attention_mask_row.new_zeros(max_seq_len - attention_mask_row.size(0))])
+ for attention_mask_row in attn_mask_list
+ ]
+ )
+ return prompt_embeds.to(dtype=dtype, device=device), encoder_attention_mask
+
+ def encode_prompt_multiple_images(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ images: Optional[List[Any]] = None,
+ template_type: str = "multiple_images",
+ max_sequence_length: Optional[int] = None,
+ drop_vit_feature: bool = False,
+ ):
+ if self.qwen_processor is None:
+ raise ValueError("Qwen processor is required for JoyAI image-edit prompt encoding.")
+ device = device or self._get_runtime_execution_device()
+ template = self.prompt_template_encode[template_type]
+ drop_idx = self.prompt_template_encode_start_idx[template_type]
+ prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt]
+ prompt = [template.format(p) for p in prompt]
+ inputs = self.qwen_processor(text=prompt, images=images, padding=True, return_tensors="pt").to(device)
+ encoder_hidden_states = self._run_text_encoder(**inputs)
+ last_hidden_states = encoder_hidden_states.hidden_states[-1]
+ if drop_vit_feature:
+ input_ids = inputs["input_ids"]
+ vlm_image_end_idx = torch.where(input_ids[0] == 151653)[0][-1]
+ drop_idx = int(vlm_image_end_idx.item()) + 1
+ prompt_embeds = last_hidden_states[:, drop_idx:]
+ prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:]
+ if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
+ prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
+ prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
+ return prompt_embeds, prompt_embeds_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ images: Optional[List[Any]] = None,
+ device: Optional[torch.device] = None,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ template_type: str = "image",
+ drop_vit_feature: bool = False,
+ ):
+ if images is not None:
+ return self.encode_prompt_multiple_images(
+ prompt=prompt,
+ images=images,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ drop_vit_feature=drop_vit_feature,
+ )
+
+ device = device or self._get_runtime_execution_device()
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, template_type, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt: Optional[Union[str, List[str]]],
+ height: int,
+ width: int,
+ images: Optional[List[Any]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ ) -> None:
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
+ if prompt is None and prompt_embeds is None:
+ raise ValueError("Provide either `prompt` or `prompt_embeds`.")
+ if prompt is not None and not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.")
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` must also be passed.")
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` must also be passed."
+ )
+
+ def _vae_compute_dtype(self) -> torch.dtype:
+ if getattr(self, "_joyai_force_vae_fp32", False):
+ return torch.float32
+ if hasattr(self.vae, "model"):
+ return next(self.vae.model.parameters()).dtype
+ return next(self.vae.parameters()).dtype
+
+ def _get_runtime_execution_device(self) -> torch.device:
+ override = getattr(self, "_joyai_execution_device_override", None)
+ if override is not None:
+ return torch.device(override)
+ return self._execution_device
+
+ def _is_sequential_cpu_offload_enabled(self) -> bool:
+ return bool(getattr(self, "_joyai_sequential_cpu_offload_enabled", False))
+
+ def enable_manual_cpu_offload(
+ self,
+ device: torch.device | str,
+ components: Optional[List[str]] = None,
+ ) -> None:
+ """Enable manual CPU offload for selected components."""
+ runtime_device = torch.device(device)
+ component_names = set(components or ["text_encoder", "vae"])
+
+ invalid_components = [name for name in component_names if name not in self.components]
+ if invalid_components:
+ raise ValueError(f"Unknown components for manual cpu offload: {invalid_components}")
+
+ self._joyai_execution_device_override = runtime_device
+ self._joyai_sequential_cpu_offload_enabled = True
+ self._joyai_manual_offload_components = component_names
+
+ for name in component_names:
+ component = getattr(self, name, None)
+ if isinstance(component, torch.nn.Module):
+ component.to("cpu")
+
+ def _uses_manual_sequential_offload(self, component_name: str) -> bool:
+ manual_components = getattr(self, "_joyai_manual_offload_components", set())
+ return self._is_sequential_cpu_offload_enabled() and component_name in manual_components
+
+ def _offload_component_to_cpu(self, component_name: str):
+ component = getattr(self, component_name, None)
+ if component is None:
+ return
+ component.to("cpu")
+ empty_device_cache(getattr(self._get_runtime_execution_device(), "type", "cuda"))
+
+ def _run_text_encoder(self, **inputs):
+ if self._uses_manual_sequential_offload("text_encoder"):
+ self.text_encoder.to(self._get_runtime_execution_device())
+ try:
+ return self.text_encoder(**inputs, output_hidden_states=True)
+ finally:
+ self._offload_component_to_cpu("text_encoder")
+ return self.text_encoder(**inputs, output_hidden_states=True)
+
+ def _get_vae_scale(self, device: torch.device, dtype: torch.dtype):
+ mean = getattr(self.vae, "mean", None)
+ std = getattr(self.vae, "std", None)
+ if mean is None or std is None:
+ return None
+ mean = mean.to(device=device, dtype=dtype)
+ std = std.to(device=device, dtype=dtype)
+ return [mean, 1.0 / std]
+
+ def _encode_with_vae(self, videos: torch.Tensor) -> torch.Tensor:
+ device = self._get_runtime_execution_device()
+ vae_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "vae_precision", "bf16"), videos.dtype)
+ videos = videos.to(device=device, dtype=vae_dtype)
+
+ if self._uses_manual_sequential_offload("vae") and hasattr(self.vae, "model"):
+ scale = self._get_vae_scale(device=device, dtype=vae_dtype)
+ self.vae.model.to(device=device, dtype=vae_dtype)
+ try:
+ return self.vae.model.encode(videos, scale=scale)
+ finally:
+ self.vae.model.to("cpu")
+ empty_device_cache(device.type)
+
+ if hasattr(self.vae, "mean"):
+ self.vae.mean = self.vae.mean.to(device=device, dtype=vae_dtype)
+ if hasattr(self.vae, "std"):
+ self.vae.std = self.vae.std.to(device=device, dtype=vae_dtype)
+ if hasattr(self.vae, "scale"):
+ self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
+ if hasattr(self.vae, "config"):
+ if hasattr(self.vae.config, "latents_mean"):
+ self.vae.config.latents_mean = self.vae.mean
+ if hasattr(self.vae.config, "latents_std"):
+ self.vae.config.latents_std = self.vae.std
+
+ self.vae.to(device=device, dtype=vae_dtype)
+ encoded = self.vae.encode(videos)
+ if hasattr(encoded, "latent_dist"):
+ return encoded.latent_dist.sample()
+ return encoded
+
+ def _decode_with_vae(self, latents: torch.Tensor):
+ device = self._get_runtime_execution_device()
+ vae_dtype = self._vae_compute_dtype()
+ latents = latents.to(device=device, dtype=vae_dtype)
+
+ if self._uses_manual_sequential_offload("vae") and hasattr(self.vae, "model"):
+ scale = self._get_vae_scale(device=device, dtype=vae_dtype)
+ self.vae.model.to(device=device, dtype=vae_dtype)
+ try:
+ videos = [self.vae.model.decode(u.unsqueeze(0), scale=scale).clamp_(-1, 1).squeeze(0) for u in latents]
+ return torch.stack(videos, dim=0)
+ finally:
+ self.vae.model.to("cpu")
+ empty_device_cache(device.type)
+
+ if hasattr(self.vae, "mean"):
+ self.vae.mean = self.vae.mean.to(device=device, dtype=vae_dtype)
+ if hasattr(self.vae, "std"):
+ self.vae.std = self.vae.std.to(device=device, dtype=vae_dtype)
+ if hasattr(self.vae, "scale"):
+ self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
+ if hasattr(self.vae, "config"):
+ if hasattr(self.vae.config, "latents_mean"):
+ self.vae.config.latents_mean = self.vae.mean
+ if hasattr(self.vae.config, "latents_std"):
+ self.vae.config.latents_std = self.vae.std
+
+ self.vae.to(device=device, dtype=vae_dtype)
+ return self.vae.decode(latents, return_dict=False)[0]
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_items,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ reference_images=None,
+ ):
+ shape = (
+ batch_size,
+ num_items,
+ num_channels_latents,
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}."
+ )
+
+ if latents is None:
+ if reference_images is not None and len(reference_images) > 0:
+ ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images]
+ ref_img = torch.stack(ref_img).to(device=device, dtype=dtype)
+ ref_img = ref_img / 127.5 - 1.0
+ ref_img = ref_img.permute(0, 3, 1, 2).unsqueeze(2)
+ ref_vae = self._encode_with_vae(ref_img)
+ ref_vae = ref_vae.reshape(shape[0], num_items - 1, *ref_vae.shape[1:])
+ noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype)
+ latents = torch.cat([ref_vae, noise], dim=1)
+ else:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if not self.enable_multi_task:
+ return latents, None
+ raise NotImplementedError("JoyAI multi-task conditioning is not implemented in the diffusers adaptation yet.")
+
+ def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None):
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ self._maybe_raise_error_if_group_offload_active(raise_error=True)
+ self.remove_all_hooks()
+
+ is_pipeline_device_mapped = self._is_pipeline_device_mapped()
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload()` isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
+ )
+
+ if device is None:
+ device = get_device()
+ if device == "cpu":
+ raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found")
+
+ torch_device = torch.device(device)
+ device_index = torch_device.index
+ if gpu_id is not None and device_index is not None:
+ raise ValueError(
+ f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
+ )
+
+ self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
+ device_type = torch_device.type
+ device = torch.device(f"{device_type}:{self._offload_gpu_id}")
+ self._offload_device = device
+
+ if self.device.type != "cpu":
+ orig_device_type = self.device.type
+ self.to("cpu", silence_dtype_warnings=True)
+ empty_device_cache(orig_device_type)
+
+ self._joyai_manual_offload_components = {"text_encoder", "vae"}
+
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+
+ if name in self._exclude_from_cpu_offload:
+ model.to(device)
+ continue
+
+ if name in self._joyai_manual_offload_components:
+ model.to("cpu")
+ continue
+
+ offload_buffers = len(model._parameters) > 0
+ params = list(model.parameters())
+ on_cpu = len(params) == 0 or all(param.device.type == "cpu" for param in params)
+ state_dict = model.state_dict() if on_cpu else None
+ cpu_offload(model, device, offload_buffers=offload_buffers, state_dict=state_dict)
+
+ self._joyai_sequential_cpu_offload_enabled = True
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def pad_sequence(self, x: torch.Tensor, target_length: int):
+ current_length = x.shape[1]
+ if current_length >= target_length:
+ return x[:, -target_length:]
+ padding_length = target_length - current_length
+ if x.ndim >= 3:
+ padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device)
+ else:
+ padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device)
+ return torch.cat([x, padding], dim=1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ num_frames: int = 1,
+ images: Optional[List[Any]] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 4096,
+ drop_vit_feature: bool = False,
+ **kwargs,
+ ):
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ images=images,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._get_runtime_execution_device()
+ template_type = "image" if num_frames == 1 else "video"
+ num_items = 1 if images is None or len(images) == 0 else 1 + len(images)
+
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ images=images,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ template_type=template_type,
+ drop_vit_feature=drop_vit_feature,
+ )
+
+ if self.do_classifier_free_guidance:
+ if negative_prompt is None and negative_prompt_embeds is None:
+ default_negative_prompt = ""
+ if num_items <= 1:
+ negative_prompt = [f"<|im_start|>user\n{default_negative_prompt}<|im_end|>\n"] * batch_size
+ else:
+ image_tokens = "\n" * (num_items - 1)
+ negative_prompt = [
+ f"<|im_start|>user\n{image_tokens}{default_negative_prompt}<|im_end|>\n"
+ ] * batch_size
+
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ images=images,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ template_type=template_type,
+ )
+
+ max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
+ prompt_embeds = torch.cat(
+ [
+ self.pad_sequence(negative_prompt_embeds, max_seq_len),
+ self.pad_sequence(prompt_embeds, max_seq_len),
+ ]
+ )
+ if prompt_embeds_mask is not None:
+ prompt_embeds_mask = torch.cat(
+ [
+ self.pad_sequence(negative_prompt_embeds_mask, max_seq_len),
+ self.pad_sequence(prompt_embeds_mask, max_seq_len),
+ ]
+ )
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ )
+
+ num_channels_latents = self.vae.config.latent_channels
+ latents, condition = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_items,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ reference_images=images,
+ )
+
+ target_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "dit_precision", "bf16"), prompt_embeds.dtype)
+ autocast_enabled = target_dtype != torch.float32 and device.type == "cuda"
+ vae_dtype = self._vae_compute_dtype()
+ vae_autocast_enabled = vae_dtype != torch.float32 and device.type == "cuda"
+
+ self._num_timesteps = len(timesteps)
+ if num_items > 1:
+ ref_latents = latents[:, : (num_items - 1)].clone()
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+ if num_items > 1:
+ latents[:, : (num_items - 1)] = ref_latents.clone()
+
+ latents_ = torch.cat([latents, condition], dim=2) if condition is not None else latents
+ latent_model_input = torch.cat([latents_] * 2) if self.do_classifier_free_guidance else latents_
+ latent_model_input = latent_model_input.to(device=device, dtype=target_dtype)
+ prompt_embeds_input = prompt_embeds.to(device=device, dtype=target_dtype)
+ t_expand = t.repeat(latent_model_input.shape[0])
+
+ with torch.autocast(device_type=device.type, dtype=target_dtype, enabled=autocast_enabled):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=t_expand,
+ encoder_hidden_states=prompt_embeds_input,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True)
+ noise_norm = torch.norm(noise_pred, dim=2, keepdim=True)
+ noise_pred = noise_pred * (cond_norm / noise_norm)
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.reshape(-1, *latents.shape[2:])
+ with torch.autocast(device_type=device.type, dtype=vae_dtype, enabled=vae_autocast_enabled):
+ decoded = self._decode_with_vae(latents)
+ decoded = decoded.reshape(batch_size, num_items, *decoded.shape[1:])
+ image = decoded[:, -1, :, 0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ self.maybe_free_model_hooks()
+
+ if output_type == "pt":
+ output_image = image.cpu().float()
+ elif output_type == "pil":
+ output_image = self.image_processor.numpy_to_pil(image.cpu().permute(0, 2, 3, 1).float().numpy())
+ else:
+ output_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if not return_dict:
+ return (output_image,)
+ return JoyAIImagePipelineOutput(images=output_image)
+
+
+__all__ = ["JoyAIImagePipeline"]
diff --git a/src/diffusers/pipelines/joyai_image/pipeline_output.py b/src/diffusers/pipelines/joyai_image/pipeline_output.py
new file mode 100644
index 000000000000..131da308bed5
--- /dev/null
+++ b/src/diffusers/pipelines/joyai_image/pipeline_output.py
@@ -0,0 +1,29 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+from PIL import Image
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class JoyAIImagePipelineOutput(BaseOutput):
+ images: Union[Image.Image, np.ndarray]
+
+
+__all__ = ["JoyAIImagePipelineOutput"]
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index b1f75bed7dc5..10f23a0d770b 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -61,6 +61,10 @@
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
+ _import_structure["scheduling_joyai_flow_match_discrete"] = [
+ "JoyAIFlowMatchDiscreteScheduler",
+ "JoyAIFlowMatchDiscreteSchedulerOutput",
+ ]
_import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
_import_structure["scheduling_helios"] = ["HeliosScheduler"]
_import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"]
@@ -172,6 +176,10 @@
from .scheduling_helios_dmd import HeliosDMDScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
+ from .scheduling_joyai_flow_match_discrete import (
+ JoyAIFlowMatchDiscreteScheduler,
+ JoyAIFlowMatchDiscreteSchedulerOutput,
+ )
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_lcm import LCMScheduler
diff --git a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py
new file mode 100644
index 000000000000..b3acaaba10e6
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py
@@ -0,0 +1,83 @@
+# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+import torch
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
+
+
+class JoyAIFlowMatchDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ reverse: bool = True,
+ use_dynamic_shifting: bool = False,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ base_image_seq_len: int = 256,
+ max_image_seq_len: int = 4096,
+ invert_sigmas: bool = False,
+ shift_terminal: float | None = None,
+ use_karras_sigmas: bool = False,
+ use_exponential_sigmas: bool = False,
+ use_beta_sigmas: bool = False,
+ time_shift_type: str = "exponential",
+ stochastic_sampling: bool = False,
+ ):
+ super().__init__(
+ num_train_timesteps=num_train_timesteps,
+ shift=shift,
+ use_dynamic_shifting=use_dynamic_shifting,
+ base_shift=base_shift,
+ max_shift=max_shift,
+ base_image_seq_len=base_image_seq_len,
+ max_image_seq_len=max_image_seq_len,
+ invert_sigmas=invert_sigmas,
+ shift_terminal=shift_terminal,
+ use_karras_sigmas=use_karras_sigmas,
+ use_exponential_sigmas=use_exponential_sigmas,
+ use_beta_sigmas=use_beta_sigmas,
+ time_shift_type=time_shift_type,
+ stochastic_sampling=stochastic_sampling,
+ )
+ self.register_to_config(reverse=reverse)
+
+ def sd3_time_shift(self, timesteps: torch.Tensor) -> torch.Tensor:
+ return (self.config.shift * timesteps) / (1 + (self.config.shift - 1) * timesteps)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device, None] = None,
+ **kwargs,
+ ):
+ self.num_inference_steps = num_inference_steps
+
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
+ sigmas = self.sd3_time_shift(sigmas)
+
+ if not self.config.reverse:
+ sigmas = 1 - sigmas
+
+ self.sigmas = sigmas.to(device=device)
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
+ self._step_index = None
diff --git a/tests/models/autoencoders/test_models_autoencoder_joyai_image.py b/tests/models/autoencoders/test_models_autoencoder_joyai_image.py
new file mode 100644
index 000000000000..a6180c462fb0
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_joyai_image.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright 2026 HuggingFace Inc.
+
+import unittest
+
+from diffusers import JoyAIImageVAE
+
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
+
+
+enable_full_determinism()
+
+
+class JoyAIImageVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
+ model_class = JoyAIImageVAE
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_joyai_image_vae_config(self):
+ return {
+ "dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+ sample = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": sample}
+
+ @property
+ def dummy_input_tiling(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (128, 128)
+ sample = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": sample}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_joyai_image_vae_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def prepare_init_args_and_inputs_for_tiling(self):
+ init_dict = self.get_joyai_image_vae_config()
+ inputs_dict = self.dummy_input_tiling
+ return init_dict, inputs_dict
+
+ @unittest.skip("Gradient checkpointing has not been implemented yet")
+ def test_gradient_checkpointing_is_applied(self):
+ pass
+
+ @unittest.skip("Test not supported")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_training(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_joyai_image.py b/tests/models/transformers/test_models_transformer_joyai_image.py
new file mode 100644
index 000000000000..aab8530b3b16
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_joyai_image.py
@@ -0,0 +1,73 @@
+# coding=utf-8
+
+import torch
+
+from diffusers import JoyAIImageTransformer3DModel
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..testing_utils import AttentionTesterMixin, BaseModelTesterConfig, ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class JoyAIImageTransformerTesterConfig(BaseModelTesterConfig):
+ @property
+ def model_class(self):
+ return JoyAIImageTransformer3DModel
+
+ @property
+ def output_shape(self) -> tuple[int, int, int, int]:
+ return (4, 2, 4, 4)
+
+ @property
+ def main_input_name(self) -> str:
+ return "hidden_states"
+
+ @property
+ def generator(self):
+ return torch.Generator("cpu").manual_seed(0)
+
+ def get_init_dict(self) -> dict[str, int | float | tuple[int, int, int] | str]:
+ return {
+ "patch_size": (1, 2, 2),
+ "in_channels": 4,
+ "out_channels": 4,
+ "hidden_size": 32,
+ "heads_num": 4,
+ "text_states_dim": 24,
+ "mlp_width_ratio": 2.0,
+ "mm_double_blocks_depth": 2,
+ "rope_dim_list": (2, 2, 4),
+ "rope_type": "rope",
+ "attn_backend": "torch_spda",
+ "theta": 1000,
+ }
+
+ def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
+ hidden_states = randn_tensor((batch_size, 4, 2, 4, 4), generator=self.generator, device=torch_device)
+ timestep = torch.tensor([1.0] * batch_size, device=torch_device)
+ encoder_hidden_states = randn_tensor((batch_size, 6, 24), generator=self.generator, device=torch_device)
+ encoder_hidden_states_mask = torch.tensor(
+ [[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]], device=torch_device, dtype=torch.long
+ )
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_mask": encoder_hidden_states_mask,
+ }
+
+
+class TestJoyAIImageTransformer(JoyAIImageTransformerTesterConfig, ModelTesterMixin):
+ pass
+
+
+class TestJoyAIImageTransformerAttention(JoyAIImageTransformerTesterConfig, AttentionTesterMixin):
+ def test_exposes_attention_processors(self):
+ model = self.model_class(**self.get_init_dict()).to(torch_device)
+
+ assert hasattr(model, "attn_processors")
+ assert len(model.attn_processors) == len(model.double_blocks)
diff --git a/tests/pipelines/joyai_image/test_pipeline_joyai_image.py b/tests/pipelines/joyai_image/test_pipeline_joyai_image.py
new file mode 100644
index 000000000000..5adae87b6fbb
--- /dev/null
+++ b/tests/pipelines/joyai_image/test_pipeline_joyai_image.py
@@ -0,0 +1,56 @@
+from unittest.mock import patch
+
+from diffusers import DiffusionPipeline, JoyAIImagePipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.pipelines.joyai_image import pipeline_joyai_image
+
+
+class _DummyModule:
+ pass
+
+
+def test_joyai_pipeline_uses_base_from_pretrained():
+ assert JoyAIImagePipeline.from_pretrained.__func__ is DiffusionPipeline.from_pretrained.__func__
+
+
+def test_joyai_pipeline_does_not_expose_source_loader_api():
+ assert not hasattr(JoyAIImagePipeline, "from_joyai_sources")
+
+
+def test_joyai_pipeline_module_does_not_expose_raw_source_helpers():
+ assert not hasattr(pipeline_joyai_image, "load_joyai_components")
+
+
+def test_joyai_pipeline_keeps_passed_processor_without_reloading():
+ pipe = object.__new__(JoyAIImagePipeline)
+ pipe._internal_dict = FrozenDict({})
+ pipe.args = type("Args", (), {"text_encoder_arch_config": {"params": {"text_encoder_ckpt": "/tmp/raw"}}})()
+ pipe.vae = type("VAE", (), {"ffactor_spatial": 8, "ffactor_temporal": 4})()
+
+ registered = {}
+
+ def fake_register_modules(**kwargs):
+ registered.update(kwargs)
+ for key, value in kwargs.items():
+ setattr(pipe, key, value)
+
+ pipe.register_modules = fake_register_modules
+
+ processor = _DummyModule()
+ with patch(
+ "diffusers.pipelines.joyai_image.pipeline_joyai_image.AutoProcessor.from_pretrained"
+ ) as mock_from_pretrained:
+ JoyAIImagePipeline.__init__(
+ pipe,
+ vae=pipe.vae,
+ text_encoder=_DummyModule(),
+ tokenizer=_DummyModule(),
+ transformer=_DummyModule(),
+ scheduler=_DummyModule(),
+ processor=processor,
+ args=pipe.args,
+ )
+
+ assert pipe.qwen_processor is processor
+ mock_from_pretrained.assert_not_called()
+ assert registered["processor"] is processor
diff --git a/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py b/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py
new file mode 100644
index 000000000000..e24466ea136d
--- /dev/null
+++ b/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py
@@ -0,0 +1,37 @@
+import tempfile
+
+import torch
+
+from diffusers import JoyAIFlowMatchDiscreteScheduler
+from diffusers.utils import logging
+
+from .test_schedulers import CaptureLogger
+
+
+def test_joyai_scheduler_roundtrip_config_has_no_unexpected_warning():
+ scheduler = JoyAIFlowMatchDiscreteScheduler(num_train_timesteps=1000, shift=4.0, reverse=True)
+ logger = logging.get_logger("diffusers.configuration_utils")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ with CaptureLogger(logger) as cap_logger:
+ config = JoyAIFlowMatchDiscreteScheduler.load_config(tmpdirname)
+ reloaded = JoyAIFlowMatchDiscreteScheduler.from_config(config)
+
+ assert isinstance(reloaded, JoyAIFlowMatchDiscreteScheduler)
+ assert cap_logger.out == ""
+
+
+def test_joyai_scheduler_reloaded_instance_supports_step():
+ scheduler = JoyAIFlowMatchDiscreteScheduler(num_train_timesteps=1000, shift=4.0, reverse=True)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_config(tmpdirname)
+ reloaded = JoyAIFlowMatchDiscreteScheduler.from_pretrained(tmpdirname)
+
+ reloaded.set_timesteps(2)
+ sample = torch.zeros(1, 2, 2)
+ model_output = torch.zeros_like(sample)
+ prev_sample = reloaded.step(model_output, reloaded.timesteps[0], sample, return_dict=False)[0]
+
+ assert prev_sample.shape == sample.shape