Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/integrations/kt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class HfTrainerKTConfig:
"kt_lora_alpha": ("ACCELERATE_KT_LORA_ALPHA", float),
"kt_model_max_length": ("ACCELERATE_KT_MODEL_MAX_LENGTH", int),
"kt_skip_expert_loading": ("ACCELERATE_KT_SKIP_EXPERT_LOADING", bool),
"kt_train_mode": ("ACCELERATE_KT_TRAIN_MODE", str),
"kt_full_weight_grad": ("ACCELERATE_KT_FULL_WEIGHT_GRAD", bool),
}

def __init__(self, kt_config_dict: Any | None):
Expand Down
34 changes: 27 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,14 @@
try:
from kt_kernel.sft import (
get_kt_lora_params,
get_kt_trainable_params,
kt_adapt_peft_lora,
load_kt_moe_from_adapter,
save_kt_moe_to_adapter,
update_kt_lora_pointers,
)
except ImportError:
get_kt_lora_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None
get_kt_lora_params = get_kt_trainable_params = kt_adapt_peft_lora = load_kt_moe_from_adapter = save_kt_moe_to_adapter = update_kt_lora_pointers = None


if TYPE_CHECKING:
Expand Down Expand Up @@ -1518,6 +1519,20 @@ def _inner_training_loop(
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

# Include KT wrapper trainable params (gate_proj_buf etc.) which are invisible
# to model.parameters() because BaseSFTMoEWrapper is not nn.Module.
if self.is_kt_enabled and get_kt_trainable_params is not None:
try:
kt_model_for_count = self.accelerator.unwrap_model(model, keep_torch_compile=False)
kt_trainable = get_kt_trainable_params(kt_model_for_count)
kt_trainable_numel = sum(p.numel() for p in kt_trainable) if kt_trainable else 0
Comment on lines +1527 to +1528

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for counting KT trainable parameters might lead to double counting if some parameters returned by get_kt_trainable_params are already present in model.parameters(). This is consistent with the check performed during parameter injection in _prepare_for_training. Filtering these out ensures the reported counts are accurate.

                kt_trainable = get_kt_trainable_params(kt_model_for_count)
                model_param_ids = {id(p) for p in model.parameters()}
                kt_trainable = [p for p in kt_trainable if id(p) not in model_param_ids] if kt_trainable else []
                kt_trainable_numel = sum(p.numel() for p in kt_trainable)

if kt_trainable_numel > 0:
logger.info(f" Number of KT trainable parameters = {kt_trainable_numel:,}")
total_trainable = get_model_param_count(model, trainable_only=True) + kt_trainable_numel
logger.info(f" Total trainable parameters (model + KT) = {total_trainable:,}")
except Exception:
pass # Non-critical: best-effort count
Comment on lines +1533 to +1534

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception and silently passing can make it difficult to diagnose issues if the parameter counting logic fails unexpectedly. It is recommended to at least log the exception at a debug level to aid in troubleshooting.

Suggested change
except Exception:
pass # Non-critical: best-effort count
except Exception as e:
logger.debug(f"Could not count KT trainable parameters: {e}")


if resume_from_checkpoint is not None:
logger.info(
f" Resuming training from checkpoint with epoch {epochs_trained} and global step {self.state.global_step}"
Expand Down Expand Up @@ -1688,16 +1703,21 @@ def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpo
if kt_model is not None and kt_adapt_peft_lora is not None:
kt_adapt_peft_lora(kt_model)

# Inject fused expert LoRA params into existing optimizer's last param group
# Inject KT trainable params (LoRA or full weight grad) into existing optimizer's last param group
# (cannot use add_param_group — lr_scheduler is already created with fixed group count)
if self.optimizer is not None and get_kt_lora_params is not None:
kt_lora_params = get_kt_lora_params(kt_model)
if kt_lora_params:
if self.optimizer is not None and get_kt_trainable_params is not None:
kt_params = get_kt_trainable_params(kt_model)
if kt_params:
existing_ids = {id(p) for group in self.optimizer.param_groups for p in group["params"]}
new_params = [p for p in kt_lora_params if id(p) not in existing_ids]
new_params = [p for p in kt_params if id(p) not in existing_ids]
if new_params:
self.optimizer.param_groups[-1]["params"].extend(new_params)
logger.info(f"Injected {len(new_params)} fused expert LoRA params into optimizer")
has_full_weight_grad = any(
getattr(w, "_full_weight_grad", False)
for w in getattr(kt_model, "_kt_wrappers", [])
)
Comment on lines +1715 to +1718

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If kt_model is a PeftModel, the _kt_wrappers attribute might be located on the base model rather than the wrapper itself. Using unwrap_peft_model ensures that the check for _full_weight_grad is performed on the correct module level where these attributes are typically attached.

Suggested change
has_full_weight_grad = any(
getattr(w, "_full_weight_grad", False)
for w in getattr(kt_model, "_kt_wrappers", [])
)
has_full_weight_grad = any(
getattr(w, "_full_weight_grad", False)
for w in getattr(unwrap_peft_model(kt_model), "_kt_wrappers", [])
)

param_type = "base weight + LoRA" if has_full_weight_grad else "LoRA"
logger.info(f"Injected {len(new_params)} fused expert {param_type} params into optimizer")

# load checkpoint
if resume_from_checkpoint is not None:
Expand Down