Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406
Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage#13406dg845 merged 6 commits intohuggingface:mainfrom
Conversation
|
The profiling was done with 2 steps, but this sync happens every transformer forward call, so at 20 inference steps, this eliminates ~1.5s of CPU-GPU sync overhead per run. Under torch.compile the impact is larger since GPU queues are deeper(each sync stalls longer) (80ms vs 76ms in eager). |
|
oh and this fix applies to all QwenImage variants (Edit, EditPlus, Layered) since they share the same transformer |
|
@akshan-main thanks for this! In the second plot, could you tell which one of the blocks the reported duration belongs to? |
|
the selected slice in after image is the transformer_forward user_annotation itself (~439ms), wrapping the full QwenImageTransformer2DModel.forward. I am highlighting a specific sub-block showing where the 76ms cudaStreamSynchronize used to sit (in the before screenshot) is gone. |
|
~439ms is for entire transformer_forward block |
|
Friendly ping @dg845, hey! seeking your review/ interpretation |
| def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Return pos_freqs and neg_freqs on the given device, caching the transfer.""" | ||
| if device not in self._device_freq_cache: | ||
| self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device)) | ||
| return self._device_freq_cache[device] |
There was a problem hiding this comment.
| def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Return pos_freqs and neg_freqs on the given device, caching the transfer.""" | |
| if device not in self._device_freq_cache: | |
| self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device)) | |
| return self._device_freq_cache[device] | |
| @lru_cache_unless_export(maxsize=None) | |
| def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Return pos_freqs and neg_freqs on the given device.""" | |
| return self.pos_freqs.to(device), self.neg_freqs.to(device) |
I think this might be slightly cleaner since lru_cache_unless_export should handle the cases where we're compiling or exporting the model correctly.
|
|
||
| # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART | ||
| self.scale_rope = scale_rope | ||
| self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {} |
There was a problem hiding this comment.
| self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {} |
Follow-up change to #13406 (comment).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left one suggestion about using lru_cache_unless_export instead of caching manually.
|
@dg845 done! switched both QwenEmbedRope and QwenEmbedLayer3DRope to lru_cache_unless_export |
|
Hi @akshan-main, have you also profiled the QwenImage pipeline when using |
|
profiling compile before/after now, will update with numbers |
|
@dg845 profiled compile before/after. torch.compile() already eliminates the big syncs on its own (0 big syncs in both before and after). The fix specifically targets eager mode |
|
@sayakpaul @dg845 don't think the failures are related to my PR |
|
I will let @dg845 take care of the final merging. I am looking into the failing tests (unrelated to your PR). I also got this script to compare QwenImage outputs on this branch and """Compare QwenImagePipeline outputs between current branch and main."""
import subprocess
import sys
import torch
from diffusers import DiffusionPipeline
REPO_ID = "Qwen/Qwen-Image"
PROMPT = "A cat holding a sign that says hello world"
NUM_INFERENCE_STEPS = 2
HEIGHT = 256
WIDTH = 256
def get_output(pipe):
output = pipe(
PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
height=HEIGHT,
width=WIDTH,
generator=torch.manual_seed(0),
output_type="np",
).images[0]
return output
def main():
current_branch = (
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode().strip()
)
print(f"Current branch: {current_branch}")
# --- Current branch ---
print("Loading pipeline on current branch...")
pipe = DiffusionPipeline.from_pretrained(REPO_ID, torch_dtype=torch.bfloat16).to("cuda")
print("Computing output on current branch...")
output_current = get_output(pipe)
del pipe
torch.cuda.empty_cache()
# --- main branch ---
print("Checking out main branch...")
subprocess.check_call(["git", "checkout", "main"])
# Reload diffusers from main
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", ".", "--quiet"])
print("Loading pipeline on main branch...")
pipe = DiffusionPipeline.from_pretrained(REPO_ID, torch_dtype=torch.bfloat16).to("cuda")
print("Computing output on main branch...")
output_main = get_output(pipe)
del pipe
torch.cuda.empty_cache()
# --- Restore original branch ---
print(f"Restoring branch: {current_branch}")
subprocess.check_call(["git", "checkout", current_branch])
# --- Compare ---
max_diff = abs(output_current - output_main).max()
mean_diff = abs(output_current - output_main).mean()
print(f"\nMax absolute difference: {max_diff}")
print(f"Mean absolute difference: {mean_diff}")
if max_diff < 1e-3:
print("PASSED: Outputs match.")
else:
print("FAILED: Outputs differ significantly.")
if __name__ == "__main__":
main() |
|
@sayakpaul outputs should match since the fix only changes how the freqs are cached, not the computation itself |
|
friendly ping @dg845 same tests are failing |
|
Merging as the CI failures should be unrelated. |
|
thanks @sayakpaul @dg845! for the opportunity to contribute |
What does this PR do?
Part of #13401
QwenEmbedRope.forward()copiespos_freqsandneg_freqsfrom CPU to GPU via.to(device)on every transformer forward call. These tensors are fixed at init and never change, so the repeated transfer triggers an unnecessarycudaStreamSynchronize(~76ms each).Added
_get_device_freqs()that caches the GPU copy on first call. Applied to bothQwenEmbedRopeandQwenEmbedLayer3DRope.(
register_buffercan't be used here because it drops the imaginary part of complex tensors)Profiling (A100 80GB, eager, 2 steps, 1024x1024)
Before (76ms cudaStreamSynchronize inside transformer_forward):
After (no sync gap):
Profiled with the tooling from #13356. Reproduction notebook.
Part of #13401
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @dg845