fix: correct qwen image tile mask weights#257
Conversation
There was a problem hiding this comment.
Code Review
This pull request corrects the indexing and weight calculations for the bottom and right borders in the _build_tile_mask method of the Qwen image autoencoder. The reviewer suggests vectorizing the loops for both border calculations using torch.arange and broadcasting to eliminate Python loop overhead and improve PyTorch performance.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for y in range(border_h): | ||
| mask[:, :, :, h - 1 - y, :] *= y / border_h | ||
| mask[:, :, :, h - border_h + y, :] *= 1 - y / border_h |
There was a problem hiding this comment.
Using a Python loop to perform element-wise multiplication on PyTorch tensors is inefficient because it incurs Python loop overhead and launches multiple small CUDA kernels. We can vectorize this operation using torch.arange and broadcasting, which is much faster and more idiomatic.
Note: The is_top and is_left blocks can also be vectorized in a similar manner.
| for y in range(border_h): | |
| mask[:, :, :, h - 1 - y, :] *= y / border_h | |
| mask[:, :, :, h - border_h + y, :] *= 1 - y / border_h | |
| weights = 1.0 - torch.arange(border_h, device=mask.device, dtype=mask.dtype) / border_h | |
| mask[:, :, :, h - border_h :, :] *= weights.view(1, 1, 1, border_h, 1) |
| for x in range(border_w): | ||
| mask[:, :, :, :, w - 1 - x] *= x / border_w | ||
| mask[:, :, :, :, w - border_w + x] *= 1 - x / border_w |
There was a problem hiding this comment.
Similarly, this loop can be vectorized to avoid Python loop overhead and multiple CUDA kernel launches, improving performance.
| for x in range(border_w): | |
| mask[:, :, :, :, w - 1 - x] *= x / border_w | |
| mask[:, :, :, :, w - border_w + x] *= 1 - x / border_w | |
| weights = 1.0 - torch.arange(border_w, device=mask.device, dtype=mask.dtype) / border_w | |
| mask[:, :, :, :, w - border_w :] *= weights.view(1, 1, 1, 1, border_w) |
31c6d5a to
b9e4259
Compare
No description provided.