Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ class ModelArgs:
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
rope_scale_factor: int = 8
high_freq_factor: int = 4
# LongRoPE (https://arxiv.org/abs/2402.13753) used by Phi-3 / Phi-4 family.
# Mirrors HF's rope_scaling.{short_factor,long_factor,attention_factor}
# plus original_max_position_embeddings / max_position_embeddings.
rope_scaling_short_factor: Optional[list] = None
rope_scaling_long_factor: Optional[list] = None
original_max_position_embeddings: Optional[int] = None
max_position_embeddings: Optional[int] = None
rope_scaling_attention_factor: Optional[float] = None
# Additional Model Metadata needed at runtime
bos_idx: int = 1
eos_idx: int = 3
Expand Down
86 changes: 75 additions & 11 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,32 +136,80 @@ def forward(

# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
# Current only support non-long rope.
# Supports both vanilla HF RoPE and LongRoPE
# (https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
# `_compute_longrope_parameters`), used by Phi-3 / Phi-4 family.
def hf_precompute_freqs_cis(
dim: int,
end: int,
theta: float,
partial_rotary_factor: float = 1.0,
device: Union[str, torch.device] = "cpu",
short_factor: Optional[list] = None,
long_factor: Optional[list] = None,
original_max_pos: Optional[int] = None,
max_position_embeddings: Optional[int] = None,
attention_factor: Optional[float] = None,
):
# Partial rotary embeddings.
dim = int(dim * partial_rotary_factor)

# Short factor scaling.
freqs = 1.0 / (
# Compute the RoPE table in fp64 to minimize ULP-level drift; cast to fp32
# once at the end. Phi-4 Mini's narrow decode-time logit margins make the
# exported model sensitive to 1-ULP differences in freqs_cos / freqs_sin
# under sampling, especially on the Vulkan delegate.
inv_freq = 1.0 / (
theta
** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)
** (
torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(torch.float64)
/ dim
)
)
# TODO: support long factor scaling.

# LongRoPE: divide inv_freq element-wise by short_factor or long_factor.
# Selection mirrors HF: long_factor when seq_len > original_max_position_embeddings.
longrope_active = (short_factor is not None) or (long_factor is not None)
if longrope_active:
chosen = (
long_factor
if (original_max_pos is not None and end > original_max_pos)
else short_factor
)
if chosen is None:
# Fall back to whichever factor was provided.
chosen = short_factor if long_factor is None else long_factor
ext_factors = torch.tensor(chosen, dtype=torch.float64, device=device)
assert ext_factors.numel() == inv_freq.numel(), (
f"LongRoPE factor length {ext_factors.numel()} must equal dim/2 "
f"({inv_freq.numel()})"
)
inv_freq = inv_freq / ext_factors

# Derive attention_factor if not provided (matches HF's
# _compute_longrope_parameters default).
if attention_factor is None and original_max_pos is not None:
ref_max_pos = (
max_position_embeddings if max_position_embeddings is not None else end
)
scaling_factor = ref_max_pos / original_max_pos
if scaling_factor <= 1.0:
attention_factor = 1.0
else:
attention_factor = math.sqrt(
1 + math.log(scaling_factor) / math.log(original_max_pos)
)

# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
freqs # pyre-ignore
)
freqs = torch.outer(t, freqs).float() # pyre-ignore
t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(torch.float64)
freqs = torch.outer(t, inv_freq).to(torch.float64) # pyre-ignore
emb = torch.cat((freqs, freqs), dim=-1)
freqs_cos = torch.cos(emb)
freqs_sin = torch.sin(emb)
cos_tab = torch.cos(emb)
sin_tab = torch.sin(emb)
if attention_factor is not None and attention_factor != 1.0:
cos_tab = cos_tab * attention_factor
sin_tab = sin_tab * attention_factor
freqs_cos = cos_tab.to(torch.float32)
freqs_sin = sin_tab.to(torch.float32)
return freqs_cos, freqs_sin


Expand Down Expand Up @@ -241,9 +289,25 @@ def __init__(self, params: ModelArgs):
hf_precompute_freqs_cis,
partial_rotary_factor=self.params.partial_rotary_factor,
device=getattr(self.params, "device", "cpu"),
short_factor=getattr(self.params, "rope_scaling_short_factor", None),
long_factor=getattr(self.params, "rope_scaling_long_factor", None),
original_max_pos=getattr(
self.params, "original_max_position_embeddings", None
),
max_position_embeddings=getattr(
self.params, "max_position_embeddings", None
),
attention_factor=getattr(
self.params, "rope_scaling_attention_factor", None
),
)
self.apply_rotary_emb = hf_apply_rotary_emb
else:
# NOTE: precompute_freqs_cis (the non-HF path) does not implement
# LongRoPE today. Models using rope_scaling.type == "longrope" must
# set use_hf_rope=True. If a future model needs LongRoPE on the
# vanilla path, mirror the short_factor/long_factor/attention_factor
# plumbing from hf_precompute_freqs_cis.
self.precompute_freqs_cis = partial(
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
Expand Down
6 changes: 5 additions & 1 deletion examples/models/phi_4_mini/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@
"vocab_size": 200064,
"use_hf_rope": true,
"partial_rotary_factor": 0.75,
"attention_qkv_bias": false
"attention_qkv_bias": false,
"original_max_position_embeddings": 4096,
"max_position_embeddings": 131072,
"rope_scaling_short_factor": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"rope_scaling_long_factor": [1.0, 1.118320672, 1.250641126, 1.398617824, 1.564103225, 1.74916897, 1.956131817, 2.187582649, 2.446418898, 2.735880826, 3.059592084, 3.421605075, 3.826451687, 4.279200023, 4.785517845, 5.351743533, 5.984965424, 6.693110555, 7.485043894, 8.370679318, 9.36110372, 10.4687158, 11.70738129, 13.09260651, 14.64173252, 16.37415215, 18.31155283, 20.47818807, 22.90118105, 25.61086418, 28.64115884, 32.03, 32.1, 32.13, 32.23, 32.6, 32.61, 32.64, 32.66, 32.7, 32.71, 32.93, 32.97, 33.28, 33.49, 33.5, 44.16, 47.77]
}
Loading