Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 222 additions & 1 deletion agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def _validate(self):
return test_metrics

def _train_step(self, batch_dict: dict) -> dict:
# Check if RAFT mode is enabled
if self.config.algorithm.adv_estimator == "raft":
return self._train_step_raft(batch_dict)

# Isolate in a separate method to automatically recycle the variables before validation.
batch: DataProto = DataProto.from_single_dict(batch_dict)
metrics = {}
Expand Down Expand Up @@ -388,6 +392,223 @@ def _train_step(self, batch_dict: dict) -> dict:

return metrics

def _train_step_raft(self, batch_dict: dict) -> dict:
"""
RAFT training step: Simplified training loop that only trains on r=1 samples.

RAFT (Rejection sampling Adaptive Fine-Tuning) differs from GRPO/PPO by:
1. Rejection sampling: Only keeping samples with reward r=1
2. Simple loss: Using standard cross-entropy (NLL) loss instead of advantage-weighted loss
3. No critic: No value function estimation needed
4. No advantage: No advantage function or GAE computation needed
"""
batch: DataProto = DataProto.from_single_dict(batch_dict)
metrics = {}
timing_raw = {}

with _timer("step", timing_raw):
# When agent mode is enabled, we read the batch as it is.
gen_batch = batch

# Generate rollouts and collect data
with _timer("gen", timing_raw):
self.async_rollout_manager.wake_up()
self.agent_mode_daemon.set_up_data_and_server(
gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
)
self.agent_mode_daemon.run_until_all_finished()
batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
max_prompt_length=self.config.data.max_prompt_length,
max_response_length=self.config.data.max_response_length,
device=gen_batch.batch["fake_ids"].device,
)
metrics.update(agent_metrics)
self.agent_mode_daemon.clear_data_and_server()
self.async_rollout_manager.sleep()

# RAFT Step 1: Rejection Sampling - Filter to keep only r=1 samples
with _timer("rejection_sampling", timing_raw):
# Extract rewards from token_level_scores (sum to get sequence-level reward)
# The reward is stored at the last token position in token_level_scores
sequence_rewards = batch.batch["token_level_scores"].sum(dim=-1) # (batch_size,)

# Binary reward: 1.0 for success, 0.0 for failure
# In RAFT, we only keep samples with reward == 1.0
is_positive_reward = (sequence_rewards == 1.0)
positive_indices = is_positive_reward.nonzero(as_tuple=True)[0]

# Log rejection sampling statistics
n_total = len(batch)
n_positive = len(positive_indices)
n_rejected = n_total - n_positive
metrics["raft/n_total_samples"] = n_total
metrics["raft/n_positive_samples"] = n_positive
metrics["raft/n_rejected_samples"] = n_rejected
metrics["raft/rejection_rate"] = n_rejected / n_total if n_total > 0 else 0.0
metrics["raft/positive_rate"] = n_positive / n_total if n_total > 0 else 0.0

# If no positive samples, skip this training step
if n_positive == 0:
metrics["raft/loss"] = 0.0
metrics["raft/skipped_no_positive_samples"] = 1
return metrics

# Filter batch to keep only positive samples
positive_batch = batch[positive_indices.cpu().tolist()]

# RAFT Step 2: Compute response mask for the filtered batch
positive_batch.batch["response_mask"] = compute_response_mask(positive_batch)

# Set uid (required by update_actor, similar to GRPO)
# uid is used for algorithm like GRPO, should be aligned to data id
if "data_id_list" in positive_batch.non_tensor_batch:
positive_batch.non_tensor_batch["uid"] = positive_batch.non_tensor_batch["data_id_list"]

# Drop samples with prompts that are too long
keep_indices = (~positive_batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0]
metrics["raft/n_triplets_prompt_too_long"] = (
positive_batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0]
)
if len(keep_indices) == 0:
metrics["raft/loss"] = 0.0
metrics["raft/skipped_all_dropped"] = 1
return metrics
positive_batch = positive_batch[keep_indices]

# Round to mini batch size for efficient training
mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
n_transition = len(positive_batch)
random_indices = list(range(n_transition))
random.shuffle(random_indices)
positive_batch.reorder(torch.tensor(random_indices).type(torch.int32))
n_remained_transition = n_transition // mini_batch_size * mini_batch_size
positive_batch = positive_batch[list(range(n_remained_transition))]
metrics["raft/n_triplets_dropped_remainder"] = n_transition - n_remained_transition

# Balance batch if enabled
if self.config.trainer.balance_batch:
self._balance_batch(positive_batch, metrics=metrics)

# RAFT Step 3: Prepare batch for RAFT loss computation
# Remove advantage-related fields since RAFT doesn't use them
raft_batch = positive_batch
max_response_length = raft_batch.batch["responses"].shape[-1]

# RAFT Step 4: Prepare batch for actor update
# Need to compute old_log_probs and set required meta_info fields
with _timer("prepare_raft_batch", timing_raw):
# Ensure uid is set (may have been lost during filtering)
if "data_id_list" in raft_batch.non_tensor_batch:
raft_batch.non_tensor_batch["uid"] = raft_batch.non_tensor_batch["data_id_list"]

# Compute global_token_num (required by update_actor)
raft_batch.meta_info["global_token_num"] = torch.sum(raft_batch.batch["attention_mask"], dim=-1).tolist()

# Pad batch for distributed training before computing log_probs
raft_batch, pad_size_prep = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size)

# Compute old_log_probs (required by update_actor, similar to GRPO)
# This is needed even for RAFT because update_actor expects this field
old_log_prob = self.actor_rollout_wg.compute_log_prob(raft_batch)
entropys = old_log_prob.batch["entropys"]
response_masks = raft_batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
raft_batch = raft_batch.union(old_log_prob)

# Set required meta_info fields (similar to GRPO)
raft_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
# Temperature is required by update_actor (from config or default 0.7)
raft_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.get("temperature", 0.7)

# Unpad before setting advantages
raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep)

# RAFT Step 5: Pure SFT update
# Use standard cross-entropy loss via PPO with advantages=1.0 and disabled clipping
# Note: PPO loss with advantages=1.0 and no clipping becomes equivalent to SFT
with _timer("update_actor_sft", timing_raw):
# Set advantages to 1.0 (no advantage weighting, pure SFT)
# This makes the PPO loss equivalent to standard cross-entropy when clipping is disabled
raft_batch.batch["advantages"] = torch.ones(
(len(raft_batch), max_response_length),
device=raft_batch.batch["input_ids"].device,
dtype=torch.float32
)
raft_batch.batch["returns"] = raft_batch.batch["advantages"].clone()

# Remove any existing values field (no critic in RAFT)
if "values" in raft_batch.batch:
raft_batch.batch.pop("values")

# Pad again for distributed training before update_actor
raft_batch, pad_size_actor = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size)

# Temporarily disable PPO clipping for pure SFT
original_clip_low = self.config.actor_rollout_ref.actor.get("clip_ratio_low", 0.2)
original_clip_high = self.config.actor_rollout_ref.actor.get("clip_ratio_high", 0.3)

# Disable clipping: set both ratios to a very large value (effectively no clipping)
# Using 1000.0 ensures clip(ratio, 1-1, 1+1000) = clip(ratio, 0, 1001)
# which doesn't restrict ratio values in [0, +∞) range
self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1
self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1000

try:
# Update actor with pure SFT loss
# With advantages=1.0 and clipping disabled, this becomes standard cross-entropy
# This mimics SFTTrainer.compute_loss() behavior
actor_output = self.actor_rollout_wg.update_actor(raft_batch)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the called update_actor function is still the RL one, not the one in "SFTTrainer"? not sure if verl has SFTTrainer...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verl does not have a SFTrainer. SFTrainer inherits from transformers.Trainer and requires the complete HuggingFace Trainer infrastructure, whereas VERL uses Ray distributed training and a custom worker group. Directly adopting SFTrainer would disrupt VERL's existing architecture. Additionally, SFTrainer and VERL use different data formats.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Then why SFT can be implemented in this way? Will update_actor actually compute SFT loss as we expected?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Then why SFT can be implemented in this way? Will update_actor actually compute SFT loss as we expected?

Here is the derivation of SFT from PPO:

The standard PPO objective (with clipping) is defined as:
L = - E(min( rA, clip( r, 1-ϵ,1+ϵ )A ))

To adapt the PPO framework for Supervised Fine-Tuning (SFT) mode, we set the following neutral conditions:

Set advantages: A=1
Disable clipping: ϵ = 1.0
clip(r, 1-1, 1+0) = clip(r, 0, 2) ≈ r (r = exp(log_prob - old_log_prob), which is usually ranges in [0, 2])
When A=1 and clipping is disabled, the loss simplifies to:
L = - E(min( r, clip( r, 1-ϵ,1+ϵ ))) = -E(r)
logr = logπ(a|s)-logπold(a|s)

Therefore:
L = - E( exp(logπ(a|s)-logπold(a|s)))

In SFT, we directly optimize the current policy without relying on importance sampling. When the old policy is equal to (or very close to) the current policy, the PPO objective is replaced with the Negative Log-Likelihood loss, which is what we want to minimize:
L = - E( logπ(a|s))
For language models, this is the standard Cross-Entropy Loss.

Summary, by:

Setting A=1.0$to remove advantage weighting.
Disabling clipping to remove the PPO clipping mechanism.
The final loss effectively degenerates (or is replaced by) the standard Cross-Entropy Loss (Negative Log-Likelihood), which is the SFT loss.
Thus, the PPO framework, under these specific conditions, becomes equivalent to standard supervised learning (SFT).

actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

# Extract and log the SFT loss
# Use actor loss from update_actor output (same as GRPO)
# Note: update_actor returns "actor/pg_loss" not "actor/loss"
if "actor/pg_loss" in actor_output_metrics:
metrics["raft/loss"] = actor_output_metrics["actor/pg_loss"]
elif "actor/loss" in actor_output_metrics:
metrics["raft/loss"] = actor_output_metrics["actor/loss"]
else:
# Fallback: use a default value if loss not found
metrics["raft/loss"] = 0.0
finally:
# Restore original clipping ratios
self.config.actor_rollout_ref.actor["clip_ratio_low"] = original_clip_low
self.config.actor_rollout_ref.actor["clip_ratio_high"] = original_clip_high

# Log that we're using pure SFT update (like SFTTrainer)
metrics["raft/pure_sft_update"] = 1.0

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with _timer("dump_rollout_generations", timing_raw):
# Unpad for logging
log_batch = unpad_dataproto(raft_batch, pad_size_actor)
inputs = self.tokenizer.batch_decode(log_batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(log_batch.batch["responses"], skip_special_tokens=True)
# Get scores from the filtered batch
log_scores = log_batch.batch["token_level_scores"].sum(dim=-1).cpu().tolist()
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=log_scores,
reward_extra_infos_dict={},
dump_path=rollout_data_dir,
)

# Compute training metrics
# Note: We skip critic metrics for RAFT since there's no critic
metrics.update(compute_timing_metrics(batch=raft_batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=raft_batch, timing_raw=timing_raw, n_gpus=n_gpus))

return metrics

def fit(self):
logger = Tracking(
project_name=self.config.trainer.project_name,
Expand Down Expand Up @@ -496,4 +717,4 @@ def fit(self):
return

progress_bar.update(1)
self.global_steps += 1
self.global_steps += 1