diff --git a/src/transformers/integrations/kt.py b/src/transformers/integrations/kt.py index 2018f635c540..09419c3e49fe 100644 --- a/src/transformers/integrations/kt.py +++ b/src/transformers/integrations/kt.py @@ -48,6 +48,8 @@ class HfTrainerKTConfig: "kt_lora_alpha": ("ACCELERATE_KT_LORA_ALPHA", float), "kt_model_max_length": ("ACCELERATE_KT_MODEL_MAX_LENGTH", int), "kt_skip_expert_loading": ("ACCELERATE_KT_SKIP_EXPERT_LOADING", bool), + "kt_train_mode": ("ACCELERATE_KT_TRAIN_MODE", str), + "kt_full_weight_grad": ("ACCELERATE_KT_FULL_WEIGHT_GRAD", bool), } def __init__(self, kt_config_dict: Any | None): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8f2968b75487..f3403dc862d2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -242,13 +242,14 @@ try: from kt_kernel.sft import ( get_kt_lora_params, + get_kt_trainable_params, kt_adapt_peft_lora, load_kt_moe_from_adapter, save_kt_moe_to_adapter, update_kt_lora_pointers, ) except ImportError: - get_kt_lora_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None + get_kt_lora_params = get_kt_trainable_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None if TYPE_CHECKING: @@ -1518,6 +1519,20 @@ def _inner_training_loop( logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + # Include KT wrapper trainable params (gate_proj_buf etc.) which are invisible + # to model.parameters() because BaseSFTMoEWrapper is not nn.Module. + if self.is_kt_enabled and get_kt_trainable_params is not None: + try: + kt_model_for_count = self.accelerator.unwrap_model(model, keep_torch_compile=False) + kt_trainable = get_kt_trainable_params(kt_model_for_count) + kt_trainable_numel = sum(p.numel() for p in kt_trainable) if kt_trainable else 0 + if kt_trainable_numel > 0: + logger.info(f" Number of KT trainable parameters = {kt_trainable_numel:,}") + total_trainable = get_model_param_count(model, trainable_only=True) + kt_trainable_numel + logger.info(f" Total trainable parameters (model + KT) = {total_trainable:,}") + except Exception: + pass # Non-critical: best-effort count + if resume_from_checkpoint is not None: logger.info( f" Resuming training from checkpoint with epoch {epochs_trained} and global step {self.state.global_step}" @@ -1688,16 +1703,21 @@ def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpo if kt_model is not None and kt_adapt_peft_lora is not None: kt_adapt_peft_lora(kt_model) - # Inject fused expert LoRA params into existing optimizer's last param group + # Inject KT trainable params (LoRA or full weight grad) into existing optimizer's last param group # (cannot use add_param_group — lr_scheduler is already created with fixed group count) - if self.optimizer is not None and get_kt_lora_params is not None: - kt_lora_params = get_kt_lora_params(kt_model) - if kt_lora_params: + if self.optimizer is not None and get_kt_trainable_params is not None: + kt_params = get_kt_trainable_params(kt_model) + if kt_params: existing_ids = {id(p) for group in self.optimizer.param_groups for p in group["params"]} - new_params = [p for p in kt_lora_params if id(p) not in existing_ids] + new_params = [p for p in kt_params if id(p) not in existing_ids] if new_params: self.optimizer.param_groups[-1]["params"].extend(new_params) - logger.info(f"Injected {len(new_params)} fused expert LoRA params into optimizer") + has_full_weight_grad = any( + getattr(w, "_full_weight_grad", False) + for w in getattr(kt_model, "_kt_wrappers", []) + ) + param_type = "base weight + LoRA" if has_full_weight_grad else "LoRA" + logger.info(f"Injected {len(new_params)} fused expert {param_type} params into optimizer") # load checkpoint if resume_from_checkpoint is not None: