Skip to content

Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406

Merged
dg845 merged 6 commits intohuggingface:mainfrom
akshan-main:fix-qwenimage-rope-sync
Apr 10, 2026
Merged

Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406
dg845 merged 6 commits intohuggingface:mainfrom
akshan-main:fix-qwenimage-rope-sync

Conversation

@akshan-main
Copy link
Copy Markdown
Contributor

@akshan-main akshan-main commented Apr 3, 2026

What does this PR do?

Part of #13401

QwenEmbedRope.forward() copies pos_freqs and neg_freqs from CPU to GPU via .to(device) on every transformer forward call. These tensors are fixed at init and never change, so the repeated transfer triggers an unnecessary cudaStreamSynchronize (~76ms each).

Added _get_device_freqs() that caches the GPU copy on first call. Applied to both QwenEmbedRope and QwenEmbedLayer3DRope.

(register_buffer can't be used here because it drops the imaginary part of complex tensors)

Profiling (A100 80GB, eager, 2 steps, 1024x1024)

                                     BEFORE        AFTER
------------------------------ ------------ ------------
Big syncs (>50ms)                         3            0
Big sync total (ms)                   228.7          0.0
Big syncs before: [76.6, 76.4, 75.7]
Big syncs after:  []

Before (76ms cudaStreamSynchronize inside transformer_forward):

before_sync

After (no sync gap):

after_sync

Profiled with the tooling from #13356. Reproduction notebook.

Part of #13401

Before submitting

Who can review?

@sayakpaul @dg845

@akshan-main
Copy link
Copy Markdown
Contributor Author

akshan-main commented Apr 3, 2026

The profiling was done with 2 steps, but this sync happens every transformer forward call, so at 20 inference steps, this eliminates ~1.5s of CPU-GPU sync overhead per run. Under torch.compile the impact is larger since GPU queues are deeper(each sync stalls longer) (80ms vs 76ms in eager).

@akshan-main
Copy link
Copy Markdown
Contributor Author

oh and this fix applies to all QwenImage variants (Edit, EditPlus, Layered) since they share the same transformer

@dg845 dg845 requested review from dg845 and sayakpaul April 8, 2026 05:39
@sayakpaul
Copy link
Copy Markdown
Member

@akshan-main thanks for this! In the second plot, could you tell which one of the blocks the reported duration belongs to?

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a clean fix to me. But I will let @dg845 make the final merge.

@akshan-main
Copy link
Copy Markdown
Contributor Author

the selected slice in after image is the transformer_forward user_annotation itself (~439ms), wrapping the full QwenImageTransformer2DModel.forward.

I am highlighting a specific sub-block showing where the 76ms cudaStreamSynchronize used to sit (in the before screenshot) is gone.

@akshan-main
Copy link
Copy Markdown
Contributor Author

~439ms is for entire transformer_forward block

@akshan-main
Copy link
Copy Markdown
Contributor Author

Friendly ping @dg845, hey! seeking your review/ interpretation

Comment on lines +237 to +241
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
if device not in self._device_freq_cache:
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
return self._device_freq_cache[device]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
if device not in self._device_freq_cache:
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
return self._device_freq_cache[device]
@lru_cache_unless_export(maxsize=None)
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device."""
return self.pos_freqs.to(device), self.neg_freqs.to(device)

I think this might be slightly cleaner since lru_cache_unless_export should handle the cases where we're compiling or exporting the model correctly.


# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}

Follow-up change to #13406 (comment).

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 9, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 9, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions bot added models size/S PR with diff < 50 LOC labels Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left one suggestion about using lru_cache_unless_export instead of caching manually.

@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 9, 2026
@akshan-main akshan-main requested a review from dg845 April 9, 2026 06:40
@akshan-main
Copy link
Copy Markdown
Contributor Author

@dg845 done! switched both QwenEmbedRope and QwenEmbedLayer3DRope to lru_cache_unless_export
ci testing shouldn't be a issue too now

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 9, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 9, 2026

Hi @akshan-main, have you also profiled the QwenImage pipeline when using torch.compile? Are the CPU to GPU syncs also eliminated by this fix in that case?

@akshan-main
Copy link
Copy Markdown
Contributor Author

profiling compile before/after now, will update with numbers

@akshan-main
Copy link
Copy Markdown
Contributor Author

@dg845 profiled compile before/after. torch.compile() already eliminates the big syncs on its own (0 big syncs in both before and after). The fix specifically targets eager mode

@akshan-main
Copy link
Copy Markdown
Contributor Author

@akshan-main akshan-main requested a review from dg845 April 9, 2026 08:09
@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 10, 2026
@akshan-main akshan-main requested a review from sayakpaul April 10, 2026 05:34
@akshan-main
Copy link
Copy Markdown
Contributor Author

@sayakpaul @dg845 don't think the failures are related to my PR

@sayakpaul
Copy link
Copy Markdown
Member

I will let @dg845 take care of the final merging. I am looking into the failing tests (unrelated to your PR).

I also got this script to compare QwenImage outputs on this branch and main branch and they pass:

"""Compare QwenImagePipeline outputs between current branch and main."""

import subprocess
import sys

import torch
from diffusers import DiffusionPipeline


REPO_ID = "Qwen/Qwen-Image"
PROMPT = "A cat holding a sign that says hello world"
NUM_INFERENCE_STEPS = 2
HEIGHT = 256
WIDTH = 256


def get_output(pipe):
    output = pipe(
        PROMPT,
        num_inference_steps=NUM_INFERENCE_STEPS,
        height=HEIGHT,
        width=WIDTH,
        generator=torch.manual_seed(0),
        output_type="np",
    ).images[0]
    return output


def main():
    current_branch = (
        subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode().strip()
    )
    print(f"Current branch: {current_branch}")

    # --- Current branch ---
    print("Loading pipeline on current branch...")
    pipe = DiffusionPipeline.from_pretrained(REPO_ID, torch_dtype=torch.bfloat16).to("cuda")
    print("Computing output on current branch...")
    output_current = get_output(pipe)

    del pipe
    torch.cuda.empty_cache()

    # --- main branch ---
    print("Checking out main branch...")
    subprocess.check_call(["git", "checkout", "main"])

    # Reload diffusers from main
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", ".", "--quiet"])

    print("Loading pipeline on main branch...")
    pipe = DiffusionPipeline.from_pretrained(REPO_ID, torch_dtype=torch.bfloat16).to("cuda")
    print("Computing output on main branch...")
    output_main = get_output(pipe)

    del pipe
    torch.cuda.empty_cache()

    # --- Restore original branch ---
    print(f"Restoring branch: {current_branch}")
    subprocess.check_call(["git", "checkout", current_branch])

    # --- Compare ---
    max_diff = abs(output_current - output_main).max()
    mean_diff = abs(output_current - output_main).mean()

    print(f"\nMax absolute difference:  {max_diff}")
    print(f"Mean absolute difference: {mean_diff}")

    if max_diff < 1e-3:
        print("PASSED: Outputs match.")
    else:
        print("FAILED: Outputs differ significantly.")


if __name__ == "__main__":
    main()

@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 10, 2026
@akshan-main
Copy link
Copy Markdown
Contributor Author

@sayakpaul outputs should match since the fix only changes how the freqs are cached, not the computation itself

@akshan-main
Copy link
Copy Markdown
Contributor Author

friendly ping @dg845 same tests are failing

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 10, 2026

Merging as the CI failures should be unrelated.

@dg845 dg845 merged commit 4548e68 into huggingface:main Apr 10, 2026
10 of 14 checks passed
@akshan-main
Copy link
Copy Markdown
Contributor Author

thanks @sayakpaul @dg845! for the opportunity to contribute

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants