From c80eddf4e2c9f55a1ceb3616490e381927d4a742 Mon Sep 17 00:00:00 2001 From: SurbhiJainUSC Date: Thu, 30 Apr 2026 00:58:57 +0000 Subject: [PATCH] Exclude RNG state when updating vllm with checkpoint for pre-rl evaluation --- src/maxtext/trainers/post_train/rl/train_rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index fda6d1f933..5fa22e366d 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -596,8 +596,8 @@ def rl_train(argv: Sequence[str], kwargs: dict): # Run evaluation before training if trainer_config.num_test_batches > 0: - # Update vllm with model parameters from checkpoint - rl_cluster.rollout.update_params(nnx.state(actor_model)) + # Update vllm with model parameters from checkpoint, excluding RNG state + rl_cluster.rollout.update_params(nnx.state(actor_model, nnx.Not(nnx.RngState))) (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( trainer_config,