diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index fda6d1f933..5fa22e366d 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -596,8 +596,8 @@ def rl_train(argv: Sequence[str], kwargs: dict): # Run evaluation before training if trainer_config.num_test_batches > 0: - # Update vllm with model parameters from checkpoint - rl_cluster.rollout.update_params(nnx.state(actor_model)) + # Update vllm with model parameters from checkpoint, excluding RNG state + rl_cluster.rollout.update_params(nnx.state(actor_model, nnx.Not(nnx.RngState))) (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( trainer_config,