-
Notifications
You must be signed in to change notification settings - Fork 0
[WIP]for full-fine-tune-support #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sft-v5
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -242,13 +242,14 @@ | |||||||||||||||||
| try: | ||||||||||||||||||
| from kt_kernel.sft import ( | ||||||||||||||||||
| get_kt_lora_params, | ||||||||||||||||||
| get_kt_trainable_params, | ||||||||||||||||||
| kt_adapt_peft_lora, | ||||||||||||||||||
| load_kt_moe_from_adapter, | ||||||||||||||||||
| save_kt_moe_to_adapter, | ||||||||||||||||||
| update_kt_lora_pointers, | ||||||||||||||||||
| ) | ||||||||||||||||||
| except ImportError: | ||||||||||||||||||
| get_kt_lora_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None | ||||||||||||||||||
| get_kt_lora_params = get_kt_trainable_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||
|
|
@@ -1518,6 +1519,20 @@ def _inner_training_loop( | |||||||||||||||||
| logger.info(f" Total optimization steps = {max_steps:,}") | ||||||||||||||||||
| logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Include KT wrapper trainable params (gate_proj_buf etc.) which are invisible | ||||||||||||||||||
| # to model.parameters() because BaseSFTMoEWrapper is not nn.Module. | ||||||||||||||||||
| if self.is_kt_enabled and get_kt_trainable_params is not None: | ||||||||||||||||||
| try: | ||||||||||||||||||
| kt_model_for_count = self.accelerator.unwrap_model(model, keep_torch_compile=False) | ||||||||||||||||||
| kt_trainable = get_kt_trainable_params(kt_model_for_count) | ||||||||||||||||||
| kt_trainable_numel = sum(p.numel() for p in kt_trainable) if kt_trainable else 0 | ||||||||||||||||||
| if kt_trainable_numel > 0: | ||||||||||||||||||
| logger.info(f" Number of KT trainable parameters = {kt_trainable_numel:,}") | ||||||||||||||||||
| total_trainable = get_model_param_count(model, trainable_only=True) + kt_trainable_numel | ||||||||||||||||||
| logger.info(f" Total trainable parameters (model + KT) = {total_trainable:,}") | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| pass # Non-critical: best-effort count | ||||||||||||||||||
|
Comment on lines
+1533
to
+1534
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a broad
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| if resume_from_checkpoint is not None: | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| f" Resuming training from checkpoint with epoch {epochs_trained} and global step {self.state.global_step}" | ||||||||||||||||||
|
|
@@ -1688,16 +1703,21 @@ def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpo | |||||||||||||||||
| if kt_model is not None and kt_adapt_peft_lora is not None: | ||||||||||||||||||
| kt_adapt_peft_lora(kt_model) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Inject fused expert LoRA params into existing optimizer's last param group | ||||||||||||||||||
| # Inject KT trainable params (LoRA or full weight grad) into existing optimizer's last param group | ||||||||||||||||||
| # (cannot use add_param_group — lr_scheduler is already created with fixed group count) | ||||||||||||||||||
| if self.optimizer is not None and get_kt_lora_params is not None: | ||||||||||||||||||
| kt_lora_params = get_kt_lora_params(kt_model) | ||||||||||||||||||
| if kt_lora_params: | ||||||||||||||||||
| if self.optimizer is not None and get_kt_trainable_params is not None: | ||||||||||||||||||
| kt_params = get_kt_trainable_params(kt_model) | ||||||||||||||||||
| if kt_params: | ||||||||||||||||||
| existing_ids = {id(p) for group in self.optimizer.param_groups for p in group["params"]} | ||||||||||||||||||
| new_params = [p for p in kt_lora_params if id(p) not in existing_ids] | ||||||||||||||||||
| new_params = [p for p in kt_params if id(p) not in existing_ids] | ||||||||||||||||||
| if new_params: | ||||||||||||||||||
| self.optimizer.param_groups[-1]["params"].extend(new_params) | ||||||||||||||||||
| logger.info(f"Injected {len(new_params)} fused expert LoRA params into optimizer") | ||||||||||||||||||
| has_full_weight_grad = any( | ||||||||||||||||||
| getattr(w, "_full_weight_grad", False) | ||||||||||||||||||
| for w in getattr(kt_model, "_kt_wrappers", []) | ||||||||||||||||||
| ) | ||||||||||||||||||
|
Comment on lines
+1715
to
+1718
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||||||||||||
| param_type = "base weight + LoRA" if has_full_weight_grad else "LoRA" | ||||||||||||||||||
| logger.info(f"Injected {len(new_params)} fused expert {param_type} params into optimizer") | ||||||||||||||||||
|
|
||||||||||||||||||
| # load checkpoint | ||||||||||||||||||
| if resume_from_checkpoint is not None: | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for counting KT trainable parameters might lead to double counting if some parameters returned by
get_kt_trainable_paramsare already present inmodel.parameters(). This is consistent with the check performed during parameter injection in_prepare_for_training. Filtering these out ensures the reported counts are accurate.