Skip to content

Add ernie image#13432

Open
HsiaWinter wants to merge 13 commits intohuggingface:mainfrom
HsiaWinter:add-ernie-image
Open

Add ernie image#13432
HsiaWinter wants to merge 13 commits intohuggingface:mainfrom
HsiaWinter:add-ernie-image

Conversation

@HsiaWinter
Copy link
Copy Markdown

What does this PR do?

We have introduced a new text-to-image model called ERNIE-Image, which will soon be open-sourced to the community. This PR includes the model architecture definition, the pipeline, as well as the related documentation and test files.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot added documentation Improvements or additions to documentation models tests utils pipelines size/L PR with diff > 200 LOC labels Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
i left some feedbacks

@yiyixuxu yiyixuxu requested a review from dg845 April 8, 2026 09:02
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!
i left a few more comments

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! left two small comments
let's merge this soon

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Apr 9, 2026

@claude can you do a review here also? please keep these 3 note in mind as well during your review

  1. compare the Ernie model/pipeline to others like Qwen/Flux —let us know if there is any significant inconsistencies you found.
  2. if you see any unused code paths, let us know
  3. Look over the PR comments I made and check if the same patterns we caught/fixed still exist elsewhere in the code.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

Claude Code is working…

I'll analyze this and get back to you.

View job run

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 10, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Copy Markdown
Collaborator

can you add the new doc pages to https://github.com/huggingface/diffusers/actions/runs/24222913924/job/70733036127?pr=13432#step:16:80

and also run make fix-copies?

Comment on lines +31 to +35
from ...utils import BaseOutput
from ..normalization import RMSNorm
from ..attention_processor import Attention
from ..attention_dispatch import dispatch_attention_fn
from ..attention import AttentionMixin, AttentionModuleMixin
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...utils import BaseOutput
from ..normalization import RMSNorm
from ..attention_processor import Attention
from ..attention_dispatch import dispatch_attention_fn
from ..attention import AttentionMixin, AttentionModuleMixin
from ...utils import BaseOutput, logging
from ..normalization import RMSNorm
from ..attention_processor import Attention
from ..attention_dispatch import dispatch_attention_fn
from ..attention import AttentionMixin, AttentionModuleMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

logger is used in line 216 below:

https://github.com/HsiaWinter/diffusers/blob/c482b0d953ef8704bde3319d723c900619fb1fe5/src/diffusers/models/transformers/transformer_ernie_image.py#L216-L218

but is not currently defined.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

torch.backends.cuda.matmul.allow_tf32 = False


class ErnieImageTransformerTests(ModelTesterMixin, unittest.TestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually can you write test in the new format, using BaseModelTesterConfig,
see this PR as reference https://github.com/huggingface/diffusers/pull/13344/changes

width=1024,
num_inference_steps=50,
guidance_scale=5.0,
generator=generator,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
generator=generator,
generator=torch.Generator("cuda").manual_seed(42),

generator is used here but not defined in the example.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

width=1024,
num_inference_steps=8,
guidance_scale=5.0,
generator=generator,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
generator=generator,
generator=torch.Generator("cuda").manual_seed(42),

Same suggestion as in #13432 (comment).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix


pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# 如果显存不足,可以开启offload
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 如果显存不足,可以开启offload
# If you are running low on GPU VRAM, you can enable offloading

nit: use English translation of comment since this file is in the en docs

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix


pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# 如果显存不足,可以开启offload
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 如果显存不足,可以开启offload
# If you are running low on GPU VRAM, you can enable offloading

Same as #13432 (comment).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)

def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None):
def forward(self, x, rotary_pos_emb, temb: tuple[torch.Tensor, ...], attention_mask: torch.Tensor | None = None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb

nit: I think it would be a little cleaner if we put all the modulation parameters as a tuple in a temb argument and then unpacked it inside forward.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

pe=pe,
pe_tokenizer=pe_tokenizer,
)
self.vae_scale_factor = 16 # VAE downsample factor
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.vae_scale_factor = 16 # VAE downsample factor
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 # VAE downsample factor

nit: I think it would be better to try to get the VAE scale factor from the VAE config if possible so that it's easier to use different VAEs if necessary (not sure if the suggestion is exactly right).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

# Latent dimensions
latent_h = height // self.vae_scale_factor
latent_w = width // self.vae_scale_factor
latent_channels = 128 # After patchify
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
latent_channels = 128 # After patchify
latent_channels = self.transformer.config.in_channels # 128 after patchify

nit: get latent_channels from the transformer config so that the pipeline code is more robust to different transformers

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left a few small comments.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants