-
Notifications
You must be signed in to change notification settings - Fork 55
Support Qwen3.5 #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Support Qwen3.5 #515
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5f6d359
Support Qwen3.5
chenyushuo c09b62f
Upgrade the dependency versions of vllm, transformers, and megatron
chenyushuo d296f8b
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/sup…
chenyushuo 5267498
fix pre commit
chenyushuo 4a15384
upgrade vllm
chenyushuo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| from dataclasses import dataclass | ||
| from functools import wraps | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from transformers.models.qwen3_5.modeling_qwen3_5 import ( | ||
| BaseModelOutputWithPast, | ||
| Cache, | ||
| Qwen3_5CausalLMOutputWithPast, | ||
| Qwen3_5DynamicCache, | ||
| Qwen3_5ForConditionalGeneration, | ||
| Qwen3_5ModelOutputWithPast, | ||
| TransformersKwargs, | ||
| Unpack, | ||
| capture_outputs, | ||
| create_causal_mask, | ||
| merge_with_config_defaults, | ||
| ) | ||
|
|
||
|
|
||
| # TODO: may optimize this function | ||
| def ulysses_gated_delta_net_forward_decorator(func): | ||
| @wraps(func) | ||
| def wrapper( | ||
| hidden_states: torch.Tensor, | ||
| cache_params: Qwen3_5DynamicCache | None = None, | ||
| cache_position: torch.LongTensor | None = None, | ||
| attention_mask: torch.Tensor | None = None, | ||
| ): | ||
| from verl.utils.ulysses import ( | ||
| gather_outputs_and_unpad, | ||
| get_ulysses_sequence_parallel_world_size, | ||
| slice_input_tensor, | ||
| ) | ||
|
|
||
| ulysses_sp_size = get_ulysses_sequence_parallel_world_size() | ||
| if ulysses_sp_size > 1: | ||
| hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) | ||
|
|
||
| output = func(hidden_states, cache_params, cache_position, attention_mask) | ||
|
|
||
| if ulysses_sp_size > 1: | ||
| output = slice_input_tensor(output, dim=1, padding=False) | ||
| return output | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| @merge_with_config_defaults | ||
| @capture_outputs | ||
| def qwen35_text_forward( | ||
| self, | ||
| input_ids: torch.LongTensor | None = None, | ||
| attention_mask: torch.Tensor | None = None, | ||
| position_ids: torch.LongTensor | None = None, | ||
| past_key_values: Cache | None = None, | ||
| inputs_embeds: torch.FloatTensor | None = None, | ||
| use_cache: bool | None = None, | ||
| cache_position: torch.LongTensor | None = None, | ||
| **kwargs: Unpack[TransformersKwargs], | ||
| ) -> BaseModelOutputWithPast: | ||
| if (input_ids is None) ^ (inputs_embeds is not None): | ||
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | ||
|
|
||
| if inputs_embeds is None: | ||
| inputs_embeds = self.embed_tokens(input_ids) | ||
|
|
||
| if use_cache and past_key_values is None: | ||
| past_key_values = Qwen3_5DynamicCache(config=self.config) | ||
|
|
||
| if cache_position is None: | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| cache_position = torch.arange( | ||
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||
| ) | ||
|
|
||
| # mrope: the hard coded `3` is for temporal, height and width. | ||
| if position_ids is None: | ||
| position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) | ||
| elif position_ids.ndim == 2: | ||
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | ||
|
|
||
| if position_ids.ndim == 3 and position_ids.shape[0] == 4: | ||
| text_position_ids = position_ids[0] | ||
| position_ids = position_ids[1:] | ||
| else: | ||
| text_position_ids = position_ids[0] | ||
|
|
||
| causal_mask = create_causal_mask( | ||
| config=self.config, | ||
| inputs_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| cache_position=cache_position, | ||
| past_key_values=past_key_values, | ||
| position_ids=text_position_ids, | ||
| ) | ||
| linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) | ||
|
|
||
| hidden_states = inputs_embeds | ||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||
|
|
||
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): | ||
| layer_mask = ( | ||
| linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask | ||
| ) | ||
|
|
||
| hidden_states = decoder_layer( | ||
| hidden_states, | ||
| position_embeddings=position_embeddings, | ||
| attention_mask=layer_mask, | ||
| position_ids=text_position_ids, | ||
| past_key_values=past_key_values, | ||
| use_cache=use_cache, | ||
| cache_position=cache_position, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| hidden_states = self.norm(hidden_states) | ||
|
|
||
| return Qwen3_5ModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=past_key_values, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast): | ||
| log_probs: Optional[torch.FloatTensor] = None | ||
| entropy: Optional[torch.FloatTensor] = None | ||
|
|
||
|
|
||
| def forward_with_torch_backend( | ||
| self: Qwen3_5ForConditionalGeneration, | ||
| input_ids: torch.LongTensor = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| temperature: float = 1.0, | ||
| **kwargs, | ||
| ) -> tuple | Qwen3_5CausalLMOutputForPPO: | ||
| from verl.utils.experimental.torch_functional import FusedLinearForPPO | ||
|
|
||
| outputs = self.model(input_ids=input_ids, **kwargs) | ||
| hidden_states = outputs[0] | ||
|
|
||
| # Loss calculations | ||
| if labels is not None: | ||
| rolled_labels = torch.roll(labels, shifts=-1, dims=-1) | ||
| elif input_ids is not None: | ||
| rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) | ||
| else: | ||
| raise RuntimeError( | ||
| "To use forward_with_torch_backend, either labels or input_ids must be provided." | ||
| ) | ||
|
|
||
| fused_linear_for_ppo = FusedLinearForPPO() | ||
| log_probs, entropy = fused_linear_for_ppo.forward( | ||
| hidden_states=hidden_states, | ||
| vocab_weights=self.lm_head.weight, | ||
| input_ids=rolled_labels, | ||
| temperature=temperature, | ||
| ) | ||
| return Qwen3_5CausalLMOutputForPPO( | ||
| log_probs=log_probs, | ||
| entropy=entropy, | ||
| hidden_states=outputs.hidden_states, | ||
| ) | ||
|
|
||
|
|
||
| def forward_with_triton_backend( | ||
| self: Qwen3_5ForConditionalGeneration, | ||
| input_ids: torch.LongTensor = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| temperature: float = 1.0, | ||
| **kwargs, | ||
| ) -> tuple | Qwen3_5CausalLMOutputForPPO: | ||
| from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy | ||
|
|
||
| outputs = self.model(input_ids=input_ids, **kwargs) | ||
| hidden_states = outputs[0] | ||
|
|
||
| # Loss calculations | ||
| if labels is not None: | ||
| rolled_labels = torch.roll(labels, shifts=-1, dims=-1) | ||
| elif input_ids is not None: | ||
| rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) | ||
| else: | ||
| raise RuntimeError( | ||
| "To use forward_with_triton_backend, either labels or input_ids must be provided." | ||
| ) | ||
|
|
||
| log_probs, entropy = linear_cross_entropy( | ||
| hidden_states, | ||
| self.lm_head.weight, | ||
| rolled_labels, | ||
| temperature, | ||
| "none", | ||
| ) | ||
| return Qwen3_5CausalLMOutputForPPO( | ||
| log_probs=log_probs, | ||
| entropy=entropy, | ||
| hidden_states=outputs.hidden_states, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.