-
Notifications
You must be signed in to change notification settings - Fork 401
[Refactor] refactor packing in RL train controller and train worker #1393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the packing logic in the RL training controller and worker components to improve token balancing and code organization. The key changes introduce a Karmarkar-Karp algorithm for balanced partitioning, extract helper methods for better code maintainability, and restructure how data batches are distributed across workers.
Key Changes
- Introduces sequence-length balanced partitioning using the Karmarkar-Karp differencing algorithm to better distribute workload across devices
- Refactors worker's
fitmethod to accept nested list structurelist[list[WorkerInputItem]]instead of flat list, aligning with the new per-step packing approach - Extracts reusable helper methods (
_resolve_ray_data,_apply_rollout_is_correction,_create_padding_sample,_pack,_balance_split_batch) to reduce code duplication and improve maintainability
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| xtuner/v1/rl/utils.py | Adds Karmarkar-Karp algorithm implementation with get_seqlen_balanced_partitions function for balanced workload distribution across partitions |
| xtuner/v1/rl/base/worker.py | Refactors fit method to handle nested batch structure, extracts ray data resolution and importance sampling logic into separate methods, adds get_worker_cfg accessor method |
| xtuner/v1/rl/base/controller.py | Major refactoring of packing logic with new balanced splitting, padding creation, and improved data distribution across workers with per-step gradient accumulation support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
xtuner/v1/rl/utils.py
Outdated
| # Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/seqlen_balancing.py | ||
| def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): | ||
| # see: https://en.wikipedia.org/wiki/Largest_differencing_method | ||
| class Set: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
xtuner/v1/rl/utils.py
Outdated
| return len(self.items) < len(other.items) | ||
| return self.items < other.items | ||
|
|
||
| class State: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
ce11425 to
62ae9fc
Compare
xtuner/v1/rl/base/controller.py
Outdated
| get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}") | ||
|
|
||
| packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)] | ||
| max_packs_per_card = [0] * optimizer_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to max_packed_batch_num_per_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_packs_per_step 更加准确一些:每步最大的packs数
xtuner/v1/rl/base/worker.py
Outdated
|
|
||
| # old logprobs are inplaced updated in compute_actor_logprobs | ||
| loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) | ||
| loss_ctx_input_list, metrics = self._apply_rollout_is_correction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!原来很长的fit函数变得有层次更易读了
Great PR description! Then I can review easily as the same order above : ) Additionally, when you write the core function calling chain responding to your original design in the "Key Changes", you will find that there are some high-level functions missing in your implementation, just like the The unit test can play the same role sometimes. For example, if you want to write unit test to test the core padding function, then you need to abstract the related code pieces into the function |
ca54108 to
bedd4d4
Compare
xtuner/v1/rl/base/worker.py
Outdated
| del data_batches | ||
|
|
||
| # old logprobs are inplaced updated in compute_actor_logprobs | ||
| loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有一个优化项。self._resolve_ray_data 是一个相对耗时的操作,可以和 self.compute_actor_logprobs overlap 计算,从而掩盖掉跨节点数据读取开销。
具体咋写可能有点麻烦,如果暂时不想改,可以加一个 TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 先写todo了
xtuner/v1/rl/base/controller.py
Outdated
| assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size" | ||
| optimizer_steps = self.worker_cfg.optimizer_steps | ||
|
|
||
| batches_per_dp_group: list[list[WorkerInputItem]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能有些 corner case 没有考虑。比如 optimizer_steps=16,但是数据条数不够 16,代码是否会报错。建议这种可以写严谨的单元测试来覆盖
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
xtuner/v1/rl/base/controller.py
Outdated
| handles.append( | ||
| worker.fit.remote( # type: ignore[attr-defined] | ||
| data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], | ||
| data_batches=packed_data_batches[worker_idx // self.data_replicate_size], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::dp_size 这个逻辑不能去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
99f65a6 to
da657e5
Compare
8eeddcf to
799d76a
Compare
799d76a to
64adcb7
Compare
9cb8826 to
72d7cc3
Compare
72d7cc3 to
c817f1c
Compare
Motivation
Current xtuner data distribution mechanism has a pack allocation issue that leads to unstable training steps and affects training effectiveness.
The data distribution pipeline consists of three stages:
data_batchby token count, creating one pack per 32K tokens, resulting in N packsoptimizer_stepparameterWhen
N/Mis not divisible byoptimizer_step, the actual training steps fail to match the expected value.For example:
Key Changes: Use the RLDataPacker module in TrainingController
Use the
DataBatchPackermodule inTrainingControllerto pack data batches. Note that DataBatchPacker was introduced in #1438. DataBatchPacker support pack strategy in ["balance", "greedy", "natice"]