Add LongRoPE support and fp64 RoPE precompute for Phi-3 / Phi-4 family#19235
Add LongRoPE support and fp64 RoPE precompute for Phi-3 / Phi-4 family#19235
Conversation
Summary:
Adds LongRoPE plumbing and an fp64 cos/sin precompute pass to
hf_precompute_freqs_cis. Together these eliminate Phi-4 Mini decode-time
n-gram repetition under both XNNPACK and Vulkan delegates.
Phi-3 and Phi-4 family models use HF's "longrope" RoPE scaling, which
multiplies cos/sin by an attention_factor (~1.19 for Phi-4 Mini) and
divides inv_freq element-wise by a per-dimension short_factor (when
seq_len <= original_max_position_embeddings) or long_factor. ET's
hf_precompute_freqs_cis was vanilla RoPE -- missing both terms.
At typical export configurations the dominant effect is the missing
attention_factor, which leaves attention scores ~1.42x softer than
the model was trained for. Compounded across 32 layers this pushes
Phi-4 Mini's narrow top-2 logit margins past their tipping point and
triggers greedy-decode n-gram repetition; the same error explains
prior on-device looping observed under both XNNPACK and Vulkan.
Adds LongRoPE plumbing through ModelArgs (short_factor, long_factor,
original_max_position_embeddings, max_position_embeddings,
rope_scaling_attention_factor) and into hf_precompute_freqs_cis,
with attention_factor derived as
sqrt(1 + log(scaling)/log(original_max)) when not explicitly set.
The non-HF precompute_freqs_cis path is left vanilla; longrope models
must set use_hf_rope=True (noted in Rope.__init__).
Also moves the cos/sin precompute to fp64, casting to fp32 once at
the end. After LongRoPE corrects the 19% scale error, fp32 ULP-level
rounding in the cos/sin tables becomes the next-largest contributor
to logit drift -- load-bearing on Vulkan under sampling: with fp32
precompute, 1/2 T=0.5 trajectories collapsed into a 4-gram loop
("avoiding data and data biases") even with LongRoPE applied. fp64
precompute is one-time at construction (microseconds on a few-KB
table); runtime tables remain fp32, so no inference-time cost.
Wires the LongRoPE fields into examples/models/phi_4_mini/config/config.json
sourced from HF's Phi-4 Mini config.
Test Plan:
Validated end-to-end on Samsung Galaxy S25:
- Eager bf16 (host, 12 threads): 3/3 loop-free at T=0 greedy and
T=0.5 sampling x 2 seeds.
- XNNPACK 8da4w-g32 on device: 3/3 loop-free, ~20.7 tok/s decode.
- Vulkan 8da4w-g32 on device: 3/3 loop-free, ~17.1 tok/s decode.
Reproduced across two distinct S25 units to confirm result is not
device-specific. Verified that omitting either fix regresses Vulkan
sampling: LongRoPE alone leaves residual sampling loops; fp64 alone
was previously known insufficient.
Co-Authored-By: Claude <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19235
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 2 Unrelated FailuresAs of commit 6420809 with merge base e84a418 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Summary:
Adds LongRoPE plumbing and an fp64 cos/sin precompute pass to
hf_precompute_freqs_cis. Together these eliminate Phi-4 Mini decode-time
n-gram repetition under both XNNPACK and Vulkan delegates.
Phi-3 and Phi-4 family models use HF's "longrope" RoPE scaling, which
multiplies cos/sin by an attention_factor (~1.19 for Phi-4 Mini) and
divides inv_freq element-wise by a per-dimension short_factor (when
seq_len <= original_max_position_embeddings) or long_factor. ET's
hf_precompute_freqs_cis was vanilla RoPE -- missing both terms.
At typical export configurations the dominant effect is the missing
attention_factor, which leaves attention scores ~1.42x softer than
the model was trained for. Compounded across 32 layers this pushes
Phi-4 Mini's narrow top-2 logit margins past their tipping point and
triggers greedy-decode n-gram repetition; the same error explains
prior on-device looping observed under both XNNPACK and Vulkan.
Adds LongRoPE plumbing through ModelArgs (short_factor, long_factor,
original_max_position_embeddings, max_position_embeddings,
rope_scaling_attention_factor) and into hf_precompute_freqs_cis,
with attention_factor derived as
sqrt(1 + log(scaling)/log(original_max)) when not explicitly set.
The non-HF precompute_freqs_cis path is left vanilla; longrope models
must set use_hf_rope=True (noted in Rope.init).
Also moves the cos/sin precompute to fp64, casting to fp32 once at
the end. After LongRoPE corrects the 19% scale error, fp32 ULP-level
rounding in the cos/sin tables becomes the next-largest contributor
to logit drift -- load-bearing on Vulkan under sampling: with fp32
precompute, 1/2 T=0.5 trajectories collapsed into a 4-gram loop
("avoiding data and data biases") even with LongRoPE applied. fp64
precompute is one-time at construction (microseconds on a few-KB
table); runtime tables remain fp32, so no inference-time cost.
Wires the LongRoPE fields into examples/models/phi_4_mini/config/config.json
sourced from HF's Phi-4 Mini config.
Test Plan:
Validated end-to-end on Samsung Galaxy S25:
T=0.5 sampling x 2 seeds.
Reproduced across two distinct S25 units to confirm result is not
device-specific. Verified that omitting either fix regresses Vulkan
sampling: LongRoPE alone leaves residual sampling loops; fp64 alone
was previously known insufficient.
Co-Authored-By: Claude noreply@anthropic.com