diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index a5a9bb4279..12649f11d5 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20260228 + image: trinity-rft-unittest:20260310 cap_add: - SYS_PTRACE pull_policy: never @@ -34,7 +34,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20260228 + image: trinity-rft-unittest:20260310 cap_add: - SYS_PTRACE pull_policy: never diff --git a/pyproject.toml b/pyproject.toml index 6cbdf19b3d..ba2a51b191 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,11 +52,12 @@ trinity = "trinity.cli.launcher:main" [project.optional-dependencies] vllm = [ - "vllm>=0.10.2,<=0.16.0,!=0.11.0,!=0.12.0", + "vllm>=0.10.2,<=0.17.1,!=0.11.0,!=0.12.0", # v0.11 has bug when prefix-caching is enabled so we exclude it # v0.12 has a huge performance regression so we exclude it - # v0.10.2 is the most stable version, but we allow up to 0.16.0 for new features - # v0.16.0 is required for transformers>=5.0.0 + # v0.10.2 is the most stable version, but we allow up to 0.17.1 for new features + # For v0.16 and v0.17, the default dependencies require transformers < 5. + # We have patched vLLM to support transformers >= 5.0.0. ] data = [ "py-data-juicer>=1.4.3" @@ -80,10 +81,10 @@ dev = [ "viztracer", ] megatron = [ - "megatron-core[mlm]==0.15.0", + "megatron-core[mlm]>=0.15.0", # if you found "undefined symbol" error in transformer engine # reinstall it with --no-build-isolation and `--no-cache-dir` flag - # "transformer_engine[pytorch]==2.10.0", + "transformer_engine[pytorch]>=2.10.0", # Install mbridge from main branch (unreleased version) # "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612", @@ -109,7 +110,7 @@ mm = [ ] flash_attn = [ - "flash-attn==2.8.1" + "flash-attn>=2.8.1" ] [tool.setuptools.packages.find] diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 86d2e83634..09968b7971 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -63,14 +63,16 @@ def __init__( os.environ["VLLM_CACHE_ROOT"] = os.path.expanduser( f"~/.cache/vllm/{config.bundle_indices}" ) + self.tokenization_kwargs = { + "truncate_prompt_tokens": config.max_prompt_tokens + if config.enable_prompt_truncation + else None + } self.default_sampling_params = vllm.SamplingParams( n=1, temperature=config.temperature, max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, - truncate_prompt_tokens=( - config.max_prompt_tokens if config.enable_prompt_truncation else None - ), skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, @@ -78,6 +80,7 @@ def __init__( top_p=config.top_p, top_k=config.top_k, ignore_eos=config.ignore_eos, + **(self.tokenization_kwargs if self.vllm_version <= parse_version("0.16.0") else {}), ) self.ray_namespace = config.ray_namespace self.request_id = 0 @@ -417,11 +420,17 @@ async def sample( async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1 + generate_kwargs = ( + {"tokenization_kwargs": self.tokenization_kwargs} + if self.vllm_version > parse_version("0.16.0") + else {} + ) stream = self.async_llm.generate( request_id=str(self.request_id), prompt=prompt, sampling_params=self._create_sampling_params(**kwargs), lora_request=lora_request, + **generate_kwargs, ) # Consume the stream until the request is finished. diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index 6fa7b99fe8..ed64fb6c03 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -17,8 +17,22 @@ def vllm_patch(): trf_version = parse_version(transformers.__version__) vllm_version = parse_version(vllm.__version__) - if trf_version >= parse_version("5.0.0") and vllm_version < parse_version("0.16.0"): - raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.") + if trf_version >= parse_version("5.0.0"): + if vllm_version < parse_version("0.16.0"): + raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.") + + from transformers.configuration_utils import PreTrainedConfig + + original_init = PreTrainedConfig.__init__ + + def new_init(self, *args, **kwargs): + if "ignore_keys_at_rope_validation" in kwargs: + kwargs["ignore_keys_at_rope_validation"] = set( + kwargs["ignore_keys_at_rope_validation"] + ) + original_init(self, *args, **kwargs) + + PreTrainedConfig.__init__ = new_init def get_vllm_version(): diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 89b116a954..1496b62154 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -13,10 +13,10 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" version = get_vllm_version() - if version < parse_version("0.10.2") or version > parse_version("0.16.0"): + if version < parse_version("0.10.2") or version >= parse_version("0.18.0"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.16.0." + "This patch requires vllm version >= 0.10.2, < 0.18.0." ) is_v0102 = version == parse_version("0.10.2") diff --git a/trinity/common/patch/qwen3_5.py b/trinity/common/patch/qwen3_5.py new file mode 100644 index 0000000000..80e87913d4 --- /dev/null +++ b/trinity/common/patch/qwen3_5.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass +from functools import wraps +from typing import Optional + +import torch +from transformers.models.qwen3_5.modeling_qwen3_5 import ( + BaseModelOutputWithPast, + Cache, + Qwen3_5CausalLMOutputWithPast, + Qwen3_5DynamicCache, + Qwen3_5ForConditionalGeneration, + Qwen3_5ModelOutputWithPast, + TransformersKwargs, + Unpack, + capture_outputs, + create_causal_mask, + merge_with_config_defaults, +) + + +# TODO: may optimize this function +def ulysses_gated_delta_net_forward_decorator(func): + @wraps(func) + def wrapper( + hidden_states: torch.Tensor, + cache_params: Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + from verl.utils.ulysses import ( + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_world_size, + slice_input_tensor, + ) + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + if ulysses_sp_size > 1: + hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) + + output = func(hidden_states, cache_params, cache_position, attention_mask) + + if ulysses_sp_size > 1: + output = slice_input_tensor(output, dim=1, padding=False) + return output + + return wrapper + + +@merge_with_config_defaults +@capture_outputs +def qwen35_text_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], +) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = ( + linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + ) + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@dataclass +class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_with_torch_backend( + self: Qwen3_5ForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen3_5CausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError( + "To use forward_with_torch_backend, either labels or input_ids must be provided." + ) + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + return Qwen3_5CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_triton_backend( + self: Qwen3_5ForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen3_5CausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError( + "To use forward_with_triton_backend, either labels or input_ids must be provided." + ) + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + return Qwen3_5CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index 56dafd4c49..c95a1c2ac7 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -21,18 +21,28 @@ from collections.abc import Callable from dataclasses import asdict +import megatron import ray import torch import torch.distributed +from megatron.core import dist_checkpointing, mpu from megatron.core.transformer.enums import AttnBackend +from packaging import version from transformers import GenerationConfig from verl.utils.checkpoint.megatron_checkpoint_manager import ( MegatronCheckpointManager as OldMegatronCheckpointManager, ) -from verl.utils.checkpoint.megatron_checkpoint_manager import logger +from verl.utils.checkpoint.megatron_checkpoint_manager import ( + is_non_local, + load_dist_checkpointing, + logger, +) from verl.utils.fs import local_mkdir_safe from verl.utils.logger import log_with_rank -from verl.utils.megatron.dist_checkpointing import save_dist_checkpointing +from verl.utils.megatron.dist_checkpointing import ( + FullyParallelSaveStrategyWrapper, + get_default_save_sharded_strategy, +) from verl.utils.megatron_utils import ( get_dist_checkpoint_path, get_hf_model_checkpoint_path, @@ -43,6 +53,42 @@ from trinity.trainer.verl.verl_trainer import CheckpointMonitor from trinity.utils.log import get_logger +mcore_ge_014 = version.parse(megatron.core.__version__) >= version.parse("0.14.0") +if not mcore_ge_014: + logger.warning( + "Detected megatron.core %s, recommend upgrading to >= 0.14.0 for better checkpoint compatibility", + megatron.core.__version__, + ) + + +# TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 +def save_dist_checkpointing( + sharded_state_dict, + ckpt_path, + async_save=False, + content_metadata=None, +): + validate_sharding_integrity = True + # Get checkpointing strategies + save_strategy = get_default_save_sharded_strategy("torch_dist") + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) + ) + + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.14.0/megatron/core/optimizer/distrib_optimizer.py#L1109-L1123 + mcore_ge_014 = version.parse(megatron.core.__version__) >= version.parse("0.14.0") + # Save model sharded state dicts + save_kwargs = dict( + sharded_strategy=save_strategy, + async_sharded_save=async_save, + validate_access_integrity=validate_sharding_integrity, + ) + if content_metadata is not None: + if mcore_ge_014: + save_kwargs["content_metadata"] = content_metadata + + return dist_checkpointing.save(sharded_state_dict, ckpt_path, **save_kwargs) + class MegatronCheckpointManager(OldMegatronCheckpointManager): """ @@ -71,6 +117,256 @@ def __init__( self.latest_extra_state_save_step = None self.latest_hf_model_save_step = None + # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 + def generate_state_dict( + self, + generate_model: bool = True, + generate_optimizer: bool = True, + generate_extra: bool = True, + is_loading: bool = False, + metadata: dict | None = None, + ): + # For save dist checkpointing + state_dict = {} + base_metadata = metadata or self._build_sharded_state_dict_metadata() + + # Should always generate model state dict + # All ranks Save Model to reduce memory pressure + # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure + for vpp_rank, model in enumerate(self.model): + if len(self.model) > 1: + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + key = f"model{vpp_rank}" if len(self.model) > 1 else "model" + else: + key = "model" + if hasattr(model, "module"): + model = model.module + + # GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group'] + model_metadata = dict(base_metadata) + model_metadata["dp_cp_group"] = mpu.get_data_parallel_group(with_context_parallel=True) + kwargs = {"metadata": model_metadata} + state_dict[key] = model.sharded_state_dict(**kwargs) + + # Optimizer State Dict + if generate_optimizer: + torch.distributed.barrier() + sharded_state_dict_kwargs = {"is_loading": is_loading} + if base_metadata is not None: + # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.14.0/megatron/core/optimizer/distrib_optimizer.py#L1109-L1123 + if mcore_ge_014: + sharded_state_dict_kwargs["metadata"] = base_metadata + optimizer_sharded_states = self.optimizer.sharded_state_dict( + state_dict, **sharded_state_dict_kwargs + ) + state_dict["optimizer"] = optimizer_sharded_states + + if self.lr_scheduler is not None: + lr_state_dict = self.lr_scheduler.state_dict() + state_dict["lr_scheduler"] = lr_state_dict + + if not generate_model: + state_dict.pop("model", None) + + # RNG States State Dict + if generate_extra: + torch.distributed.barrier() + rng_state = self.get_rng_state() + state_dict["rng_state"] = rng_state + + return state_dict + + # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 + def _build_sharded_state_dict_metadata(self) -> dict: + """Builds metadata used for sharded_state_dict versioning. + + + The whole content metadata is passed to ``sharded_state_dict`` model and optimizer methods + and therefore affects only the logic behind sharded_state_dict creation. + The content metadata should be minimalistic, ideally flat (or with a single nesting level) + and with semantically meaningful flag names (e.g. `distrib_optim_sharding_type`). + In particular, a simple integer (or SemVer) versioning flag (e.g. `metadata['version'] = 3.4`) + is discouraged, because the metadata serves for all models and optimizers and it's practically + impossible to enforce a linearly increasing versioning for this whole space. + """ + metadata: dict = {} + + if not mcore_ge_014: + # For backward compatibility with Megatron core < v0.14.0 + if self.use_distributed_optimizer: + metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space" + return metadata + + if self.use_distributed_optimizer: + megatron_config = getattr(self.config, self.role, self.config).megatron + dist_ckpt_optim_fully_reshardable = megatron_config.dist_ckpt_optim_fully_reshardable + distrib_optim_fully_reshardable_mem_efficient = ( + megatron_config.distrib_optim_fully_reshardable_mem_efficient + ) + if dist_ckpt_optim_fully_reshardable: + metadata["distrib_optim_sharding_type"] = "fully_reshardable" + metadata[ + "distrib_optim_fully_reshardable_mem_efficient" + ] = distrib_optim_fully_reshardable_mem_efficient + else: + metadata["distrib_optim_sharding_type"] = "dp_reshardable" + + metadata["singleton_local_shards"] = False + metadata["chained_optim_avoid_prefix"] = True + return metadata + + # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 + def load_checkpoint( # noqa: C901 + self, local_path: str, hdfs_path: str = None, del_local_after_load=False + ): + if local_path is not None: + assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." + + # For load optimizer dist_ckpt + try: + import transformer_engine + + torch.serialization.add_safe_globals([torch.optim.AdamW]) + torch.serialization.add_safe_globals( + [transformer_engine.pytorch.optimizers.fused_adam.FusedAdam] + ) + except Exception: + pass + + dist_checkpoint_path = get_dist_checkpoint_path(local_path) + + load_content_metadata = getattr(dist_checkpointing, "load_content_metadata", None) + if load_content_metadata is None: + # For backward compatibility + sharded_sd_metadata = None + else: + sharded_sd_metadata = load_content_metadata(checkpoint_dir=dist_checkpoint_path) + if sharded_sd_metadata is None: + if self.use_distributed_optimizer: + # Backward-compatibility with old checkpoints which don't have content versioning + # Can be removed after ending support for MLM optimizer checkpoints with MCore < v0.13 + # (for MCore v0.13+ checkpoints `sharded_sd_metadata is not None`) + sharded_sd_metadata = { + "distrib_optim_sharding_type": "fully_sharded_model_space", + } + else: + sharded_sd_metadata = self._build_sharded_state_dict_metadata() + + # Get State Dict for loading + sharded_state_dict = self.generate_state_dict( + self.should_load_model and self.use_dist_checkpointing, + self.should_load_optimizer, + self.should_load_extra, + is_loading=True, + metadata=sharded_sd_metadata, + ) + log_with_rank( + f"Generated state dict for loading: {sharded_state_dict.keys()}", + rank=self.rank, + logger=logger, + ) + + # Load Dist Checkpointing + state_dict = load_dist_checkpointing( + sharded_state_dict=sharded_state_dict, + ckpt_dir=dist_checkpoint_path, + ) + + if self.should_load_model and self.use_dist_checkpointing: + assert "model" in state_dict or any( + f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model)) + ), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + for vpp_rank, model in enumerate(self.model): + if len(self.model) == 1: + model_state_dict = state_dict["model"] + else: + assert ( + f"model{vpp_rank}" in state_dict + ), f"model{vpp_rank} not found in state_dict" + model_state_dict = state_dict[f"model{vpp_rank}"] + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + self.model[vpp_rank].load_state_dict(model_state_dict) + log_with_rank( + f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger + ) + + # Skip HF checkpoint loading if PEFT is used + elif self.should_load_model and self.use_hf_checkpoint and self.peft_cls is None: + hf_model_path = get_hf_model_checkpoint_path(local_path) + if self.vanilla_bridge: + self.bridge.load_weights(self.model, hf_model_path) + else: + self.bridge.load_hf_weights(self.model, hf_model_path) + log_with_rank( + f"Loaded HF model checkpoint from {hf_model_path} with bridge", + rank=self.rank, + logger=logger, + ) + # Load PEFT adapter checkpoint if available + if self.should_load_model and self.peft_cls is not None: + adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint") + if os.path.exists(adapter_ckpt_path): + from verl.utils.megatron_peft_utils import load_adapter_checkpoint + + # TODO: a better format for adapter checkpoint, waiting megatron-bridge support + + load_adapter_checkpoint( + self.model, + adapter_ckpt_path, + ) + log_with_rank( + f"Loaded adapter checkpoint from {adapter_ckpt_path}", + rank=self.rank, + logger=logger, + ) + else: + log_with_rank( + f"PEFT config is set but no adapter checkpoint found at {adapter_ckpt_path}", + rank=self.rank, + logger=logger, + ) + + if self.should_load_optimizer: + assert ( + "optimizer" in state_dict + ), f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + optimizer_state_dict = state_dict["optimizer"] + self.optimizer.load_state_dict(optimizer_state_dict) + log_with_rank( + f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger + ) + if self.use_checkpoint_opt_param_scheduler: + assert "lr_scheduler" in state_dict, ( + f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " + f"{local_path}." + ) + lr_scheduler_state_dict = state_dict["lr_scheduler"] + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + log_with_rank( + f"Loaded LR scheduler checkpoint from {local_path}", + rank=self.rank, + logger=logger, + ) + + if self.should_load_extra: + assert ( + "rng_state" in state_dict + ), f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." + rng_state = state_dict["rng_state"] + self.load_rng_states(rng_state) + log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) + + if del_local_after_load: + try: + os.remove(local_path) if is_non_local(local_path) else None + except Exception as e: + log_with_rank( + f"remove local resume ckpt file after loading failed, exception {e} will be ignored", + rank=self.rank, + logger=logger, + ) + def _save_state_dict(self, local_path, global_step) -> bool: """ Save the model state dict to the specified local path. @@ -93,8 +389,12 @@ def _save_state_dict(self, local_path, global_step) -> bool: # together in a state dict, we save them in one time if self.use_dist_checkpointing: # Generate state dict for saving + sharded_sd_metadata = self._build_sharded_state_dict_metadata() state_dict = self.generate_state_dict( - self.should_save_model, self.should_save_optimizer, self.should_save_extra + self.should_save_model, + self.should_save_optimizer, + self.should_save_extra, + metadata=sharded_sd_metadata, ) # log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) # for vpp_rank, model in enumerate(self.model): @@ -110,6 +410,7 @@ def _save_state_dict(self, local_path, global_step) -> bool: sharded_state_dict=state_dict, ckpt_path=dist_checkpoint_path, async_save=self.checkpoint_config.async_save, + content_metadata=sharded_sd_metadata, ) # Synchronize all async save requests @@ -123,10 +424,12 @@ def _save_state_dict(self, local_path, global_step) -> bool: self.use_hf_checkpoint ), "When not using distributed checkpointing, use_hf_checkpoint should be True." # Generate optimizer and exra state dicts + sharded_sd_metadata = self._build_sharded_state_dict_metadata() state_dict = self.generate_state_dict( generate_model=False, generate_optimizer=self.should_save_optimizer, generate_extra=self.should_save_extra, + metadata=sharded_sd_metadata, ) # Save optimizer and extra states to local path # Start Async save if enabled @@ -134,6 +437,7 @@ def _save_state_dict(self, local_path, global_step) -> bool: sharded_state_dict=state_dict, ckpt_path=dist_checkpoint_path, async_save=self.checkpoint_config.async_save, + content_metadata=sharded_sd_metadata, ) # Synchronize all async save requests diff --git a/trinity/trainer/verl/monkey_patch.py b/trinity/trainer/verl/monkey_patch.py index d0ba880a6c..28105f05c0 100644 --- a/trinity/trainer/verl/monkey_patch.py +++ b/trinity/trainer/verl/monkey_patch.py @@ -1,10 +1,77 @@ +import importlib import sys +from typing import Dict, Optional, Set import torch from transformers.modeling_utils import PreTrainedModel from trinity.utils.log import get_logger +# Map model types to their specific implementation modules. +# To extend support for a new model, simply add an entry here. +MODEL_TYPE_TO_MODULE_MAP: Dict[str, str] = { + "qwen2_5_vl": "verl.models.transformers.qwen2_vl", + "qwen2_vl": "verl.models.transformers.qwen2_vl", + "qwen3_vl": "verl.models.transformers.qwen3_vl", + "qwen3_vl_moe": "verl.models.transformers.qwen3_vl", + "qwen3_5": "trinity.common.patch.qwen3_5", + "qwen3_5_moe": "trinity.common.patch.qwen3_5", + "glm4v": "verl.models.transformers.glm4v", +} + +DEFAULT_MODULE_PATH = "verl.models.transformers.dense_common" +VALID_BACKENDS: Set[str] = {"triton", "torch"} + + +# modified from verl.models.transformers.monkey_patch.patch_forward_with_backends +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: Optional[str] = None, +) -> None: + """ + Monkey-patch the model's forward method with optimized backend implementations. + + Args: + model: The model to patch. + use_fused_kernels: Whether to enable fused kernels. + fused_kernels_backend: The backend to use ('triton' or 'torch'). + """ + logger = get_logger(__name__) + + # 1. Validation & Early Exit + if not use_fused_kernels: + return + + if fused_kernels_backend not in VALID_BACKENDS: + logger.warning( + f"Skipping patch for {model.__class__.__name__}: " + f"Invalid backend '{fused_kernels_backend}'. Choose from {VALID_BACKENDS}." + ) + return + + # 2. Resolve Module Path + model_type: str = getattr(model.config, "model_type", None) + module_path = MODEL_TYPE_TO_MODULE_MAP.get(model_type, DEFAULT_MODULE_PATH) + + # 3. Dynamic Import + try: + backend_module = importlib.import_module(module_path) + except ImportError as e: + logger.error(f"Failed to import {module_path} for {model.__class__.__name__}: {e}") + return + + # 4. Select and Apply Forward Function + func_name = f"forward_with_{fused_kernels_backend}_backend" + patched_forward = getattr(backend_module, func_name, None) + + if patched_forward is None: + logger.error(f"Function '{func_name}' not found in {module_path}") + return + + model.__class__.forward = patched_forward + logger.info(f"Applied {fused_kernels_backend.upper()} backend for {model.__class__.__name__}") + # modified from verl.models.transformers.monkey_patch.apply_monkey_patch def apply_monkey_patch( # noqa: C901 @@ -33,7 +100,6 @@ def apply_monkey_patch( # noqa: C901 """ from verl.models.transformers.monkey_patch import ( _ulysses_flash_attention_forward, - patch_forward_with_backends, patch_vlm_for_ulysses_input_slicing, ) from verl.utils.import_utils import is_trl_available @@ -127,6 +193,53 @@ def state_dict(self, *args, **kwargs): patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel) patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel) + elif model.config.model_type in ["qwen3_5", "qwen3_5_moe"]: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeTextModel, + ) + + # Step 1: bug fix in transformers==5.2.0 + # see https://github.com/huggingface/transformers/pull/44382 + if "Qwen3_5TextDecoderLayer" in model._no_split_modules: + model._no_split_modules.remove("Qwen3_5TextDecoderLayer") + model.model._no_split_modules.remove("Qwen3_5TextDecoderLayer") + if "Qwen3_5MoeTextDecoderLayer" in model._no_split_modules: + model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") + model.model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") + + # see https://github.com/huggingface/transformers/pull/44399 + if is_transformers_version_in_range(max_version="5.3.0"): + from trinity.common.patch.qwen3_5 import qwen35_text_forward + + Qwen3_5TextModel.forward = qwen35_text_forward + Qwen3_5MoeTextModel.forward = qwen35_text_forward + + # Step 2: patch input for multimodal sequence parallelism + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Qwen3_5TextModel) + patch_vlm_for_ulysses_input_slicing(Qwen3_5MoeTextModel) + + from trinity.common.patch.qwen3_5 import ( + ulysses_gated_delta_net_forward_decorator, + ) + + for layer in model.model.language_model.layers: + if layer.layer_type == "linear_attention": + layer.linear_attn.forward = ulysses_gated_delta_net_forward_decorator( + layer.linear_attn.forward + ) + + # Step 3: patch verl.utils.flops_counter + from verl.utils.flops_counter import ESTIMATE_FUNC, _estimate_qwen2_flops + + ESTIMATE_FUNC.update( + { + "qwen3_5": _estimate_qwen2_flops, + "qwen3_5_moe": _estimate_qwen2_flops, + } + ) + elif model.config.model_type == "glm4v": # Step 1: patch model to support image-text mixed data diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index e047c8f347..b0fb804f42 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -6,7 +6,7 @@ import numpy as np import torch -from transformers import ProcessorMixin +from transformers import PreTrainedModel from verl import DataProto from verl.trainer.ppo.metric_utils import _compute_response_info from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path @@ -23,7 +23,7 @@ def to_data_proto( - experiences: List[Experience], pad_token_id: int, processor: ProcessorMixin, logger: Logger + experiences: List[Experience], pad_token_id: int, model: PreTrainedModel, logger: Logger ) -> DataProto: # noqa: C901 """Convert List[Experience] to verl DataProto.""" assert len(experiences) > 0, "No experiences provided." @@ -84,7 +84,8 @@ def to_data_proto( if all(getattr(exp, attr, None) is not None for exp in experiences): batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length) - if processor is not None: + if hasattr(model, "get_rope_index"): + # used for multi-modal model import inspect # Adapted from verl/experimental/agent_loop/agent_loop.py @@ -94,13 +95,23 @@ def to_data_proto( input_ids = batch_dict["input_ids"][idx].unsqueeze(0) attention_mask = batch_dict["attention_mask"][idx].unsqueeze(0) - get_rope_index_sig = inspect.signature(processor.get_rope_index) + get_rope_index_sig = inspect.signature(model.get_rope_index) get_rope_index_kwargs = {} - for key in mm_inputs.keys(): - if key in get_rope_index_sig.parameters: - get_rope_index_kwargs[key] = mm_inputs[key] - - vision_position_ids, _ = processor.get_rope_index( + for key in get_rope_index_sig.parameters: + if key in {"self", "input_ids", "attention_mask", "kwargs"}: + continue + elif key == "mm_token_type_ids": + pad_data = torch.zeros_like(input_ids) + if key in mm_inputs: + data = mm_inputs.pop(key) + start = max_prompt_length - exp.prompt_length + end = start + data.size(1) + pad_data[:, start:end] = data + get_rope_index_kwargs[key] = pad_data + else: + get_rope_index_kwargs[key] = mm_inputs.get(key, None) + + vision_position_ids, _ = model.get_rope_index( input_ids=input_ids, attention_mask=attention_mask, **get_rope_index_kwargs, @@ -253,67 +264,6 @@ def get_latest_hf_checkpoint_path(config: Config): return hf_checkpoint_dir -# modified from verl/utils/tokenizer.py:hf_processor -# bug fix for processor -def hf_processor(name_or_path, **kwargs): - """Create a huggingface processor to process multimodal data. - - Args: - name_or_path (str): The name of the processor. - - Returns: - transformers.ProcessorMixin: The pretrained processor. - """ - import types - import warnings - - from transformers import AutoConfig, AutoProcessor - - try: - processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) - config = AutoConfig.from_pretrained(name_or_path, **kwargs) - - # Bind vlm model's get_rope_index method to processor - processor.config = config - match processor.__class__.__name__: - case "Qwen2VLProcessor": - from transformers.models.qwen2_vl import Qwen2VLModel - - processor.get_rope_index = types.MethodType(Qwen2VLModel.get_rope_index, processor) - case "Qwen2_5_VLProcessor": - from transformers.models.qwen2_5_vl import Qwen2_5_VLModel - - processor.get_rope_index = types.MethodType( - Qwen2_5_VLModel.get_rope_index, processor - ) - case "Qwen3VLProcessor": - from transformers.models.qwen3_vl import Qwen3VLModel - - processor.get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, processor) - case "Glm4vImageProcessor" | "Glm4vProcessor": - from transformers.models.glm4v import Glm4vModel - - processor.get_rope_index = types.MethodType(Glm4vModel.get_rope_index, processor) - case "Glm46VProcessor": - from transformers.models.glm46v import Glm46VModel - - processor.get_rope_index = types.MethodType(Glm46VModel.get_rope_index, processor) - case _: - raise ValueError(f"Unsupported processor type: {processor.__class__.__name__}") - except Exception as e: - processor = None - # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid - # silent failure - warnings.warn( - f"Failed to create processor: {e}. This may affect multimodal processing", stacklevel=1 - ) - # Avoid load tokenizer, see: - # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 - if processor is not None and "Processor" not in processor.__class__.__name__: - processor = None - return processor - - # modified from verl/utils/fsdp_utils.py:apply_fsdp2 # bug fix for transformers v5 def apply_fsdp2(model, fsdp_kwargs, config): diff --git a/trinity/trainer/verl/verl_config.py b/trinity/trainer/verl/verl_config.py index 689b3231b4..040d7e089e 100644 --- a/trinity/trainer/verl/verl_config.py +++ b/trinity/trainer/verl/verl_config.py @@ -129,6 +129,8 @@ class MegatronConfig: use_distributed_optimizer: bool = True use_dist_checkpointing: bool = False dist_checkpointing_path: Optional[str] = None + dist_ckpt_optim_fully_reshardable: bool = False + distrib_optim_fully_reshardable_mem_efficient: bool = False seed: int = 42 override_ddp_config: dict = field(default_factory=dict) override_transformer_config: OverrideTransformerConfig = field( diff --git a/trinity/trainer/verl/verl_trainer.py b/trinity/trainer/verl/verl_trainer.py index 31582f0e34..b22af2715f 100644 --- a/trinity/trainer/verl/verl_trainer.py +++ b/trinity/trainer/verl/verl_trainer.py @@ -11,6 +11,8 @@ import ray import torch +import transformers +from accelerate import init_empty_weights from omegaconf import OmegaConf from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import ( @@ -24,7 +26,7 @@ Role, create_colocated_worker_cls, ) -from verl.utils import hf_tokenizer +from verl.utils import hf_processor, hf_tokenizer from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.debug import marked_timer from verl.utils.fs import copy_local_path_from_hdfs @@ -37,7 +39,7 @@ from trinity.common.constants import SaveStrategy from trinity.common.experience import Experience from trinity.trainer.trainer import TrainEngineWrapper -from trinity.trainer.verl.utils import compute_data_metrics, hf_processor, to_data_proto +from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto from trinity.utils.log import get_logger @@ -198,6 +200,15 @@ def __init__( tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # processor for multimodal LLM, could be None processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + hf_config = transformers.AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code + ) + with init_empty_weights(): + self.empty_model = transformers.AutoModel.from_config( + hf_config, trust_remote_code=trust_remote_code + ) + from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup @@ -447,7 +458,9 @@ async def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 - batch = to_data_proto(batch_exps, self.tokenizer.pad_token_id, self.processor, self.logger) + batch = to_data_proto( + batch_exps, self.tokenizer.pad_token_id, self.empty_model, self.logger + ) metrics = {} self.global_steps += 1 timing_raw = {}