Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final
os.makedirs(val_save_dir)

original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
lambda image_url_or_path: (
load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)
)(args.val_image_url_or_path)

if torch.backends.mps.is_available():
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
logger = logging.get_logger(__name__)

_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
lambda: lambda model_cls, weights: weights,
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
Expand Down
63 changes: 62 additions & 1 deletion src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel


@dataclass
Expand Down Expand Up @@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()


@dataclass
class TensorParallelConfig:
"""
Configuration for tensor parallelism.

Tensor parallelism shards weight matrices (column-wise and row-wise) across devices.
Each device computes a partial result; an AllReduce/AllGather at layer boundaries
reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module``
with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles.

On Neuron, use the ``_pre_shard_and_tp`` workaround from
``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug
on large tensors (>= 5120x5120).

Args:
tp_degree (`int`, defaults to `1`):
Number of devices to shard across. Must be a divisor of the number of
attention heads (and FFN hidden dimensions) of the model being parallelised.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use. If provided, ``tp_degree`` is inferred from
``mesh.size()`` and the argument is ignored. Useful when combining TP with
other parallelism strategies (e.g. CP) that share the same mesh.
"""

tp_degree: int = 1
mesh: torch.distributed.device_mesh.DeviceMesh | None = None

_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None

def __post_init__(self):
if self.tp_degree < 1:
raise ValueError("`tp_degree` must be >= 1.")

def setup(
self,
rank: int,
world_size: int,
device: torch.device,
mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
if mesh is not None:
self._mesh = mesh
elif self.mesh is not None:
self._mesh = self.mesh
else:
from torch.distributed.device_mesh import init_device_mesh

device_type = str(device).split(":")[0]
self._mesh = init_device_mesh(device_type, (self.tp_degree,), mesh_dim_names=("tp",))


@dataclass
class ParallelConfig:
"""
Expand All @@ -150,9 +206,12 @@ class ParallelConfig:
Args:
context_parallel_config (`ContextParallelConfig`, *optional*):
Configuration for context parallelism.
tensor_parallel_config (`TensorParallelConfig`, *optional*):
Configuration for tensor parallelism.
"""

context_parallel_config: ContextParallelConfig | None = None
tensor_parallel_config: TensorParallelConfig | None = None

_rank: int = None
_world_size: int = None
Expand All @@ -173,6 +232,8 @@ def setup(
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, mesh)
if self.tensor_parallel_config is not None:
self.tensor_parallel_config.setup(rank, world_size, device, mesh)


@dataclass(frozen=True)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/transformers/transformer_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
is_neuron = ids.device.type == "neuron"
freqs_dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
for i in range(len(self.axes_dim)):
cos, sin = get_1d_rotary_pos_embed(
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float |
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
is_neuron = sample.device.type == "neuron"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_neuronx_available,
is_torch_xla_available,
logging,
replace_example_docstring,
Expand Down Expand Up @@ -862,7 +863,7 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

# 4. Prepare timesteps
if XLA_AVAILABLE:
if XLA_AVAILABLE or is_torch_neuronx_available():
timestep_device = "cpu"
else:
timestep_device = device
Expand Down Expand Up @@ -914,10 +915,11 @@ def __call__(
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
is_neuron = latent_model_input.device.type == "neuron"
if isinstance(current_timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,11 @@ def __call__(
)

# 4. Prepare timesteps
if XLA_AVAILABLE:
# Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where
# dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep()
# are incompatible with static-graph compilation.
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
if XLA_AVAILABLE or is_neuron_device:
timestep_device = "cpu"
else:
timestep_device = device
Expand Down Expand Up @@ -1195,15 +1199,23 @@ def __call__(
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region.
# index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs.
if is_neuron_device:
latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device)
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support
# int64 ops; the compiled UNet graph requires a float32 timestep input on-device.
t_unet = t.to(torch.float32).to(device) if is_neuron_device else t
noise_pred = self.unet(
latent_model_input,
t,
t_unet,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
Expand All @@ -1222,7 +1234,13 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device.
if is_neuron_device:
latents = self.scheduler.step(
noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False
)[0].to(device)
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
is_timm_available,
is_torch_available,
is_torch_mlu_available,
is_torch_neuronx_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
Expand Down Expand Up @@ -249,6 +250,10 @@ def is_torch_mlu_available():
return _torch_mlu_available


def is_torch_neuronx_available():
return _torch_neuronx_available


def is_flax_available():
return _flax_available

Expand Down Expand Up @@ -579,6 +584,10 @@ def is_av_available():
"""


TORCH_NEURONX_IMPORT_ERROR = """
{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/
"""

BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
Expand Down Expand Up @@ -609,6 +618,7 @@ def is_av_available():
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)),
]
)

Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
is_peft_available,
is_timm_available,
is_torch_available,
is_torch_neuronx_available,
is_torch_version,
is_torchao_available,
is_torchsde_available,
Expand Down Expand Up @@ -113,6 +114,8 @@
torch_device = "cuda"
elif torch.xpu.is_available():
torch_device = "xpu"
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
torch_device = torch.neuron.current_device()
else:
torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
Expand Down
Loading
Loading