diff --git a/.gitignore b/.gitignore index 657c99fb8..28f4b3792 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,9 @@ env **.pyc **.txt *.log +*.npy weights/ +slurm_outputs/ # SSIM test outputs fastvideo/tests/ssim/generated_videos/ @@ -82,4 +84,4 @@ docs/distillation/examples/ !assets/videos/**/*.mp4 dmd_t2v_output/ -preprocess_output_text/ +preprocess_output_text/ \ No newline at end of file diff --git a/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml b/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml new file mode 100644 index 000000000..7f3a0daf9 --- /dev/null +++ b/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml @@ -0,0 +1,91 @@ +# DMD2 distillation: Wan 2.1 T2V 1.3B (teacher 50-step -> student 8-step). +# +# - Teacher: frozen pretrained Wan 2.1 T2V 1.3B +# - Student: trainable, initialized from the same pretrained weights +# - Critic: trainable, initialized from the same pretrained weights +# - Validation: 8-step SDE sampling + +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method + rollout_mode: simulate + generator_update_interval: 5 + real_score_guidance_scale: 4.5 + dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125] + + # Critic optimizer (required — no fallback to training.optimizer) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 1 + hsdp_shard_dim: 8 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 2.0e-6 + betas: [0.0, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wan2.1_dmd2_8steps + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wan + run_name: wan2.1_dmd2_8steps_cfg4.5 + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + max_grad_norm: 1.0 + validation: + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 50 + sampling_steps: [8] + sampler_kind: sde + sampling_timesteps: [1000, 850, 700, 550, 350, 275, 200, 125] + guidance_scale: 6.0 + +pipeline: + flow_shift: 8 diff --git a/examples/train/example.yaml b/examples/train/example.yaml new file mode 100644 index 000000000..f025d47b1 --- /dev/null +++ b/examples/train/example.yaml @@ -0,0 +1,208 @@ +# ============================================================================== +# Full configuration reference for fastvideo.train +# +# Legend: +# [TYPED] — parsed into a typed dataclass; fields are validated with +# defaults. Unknown keys are silently ignored. +# [FREE] — free-form dict passed as-is to the target class / method. +# Keys depend on the _target_ class constructor / method_config. +# [RESOLVED] — parsed by PipelineConfig.from_kwargs(); auto-populated from +# the model's config files. Only scalar overrides are useful. +# ============================================================================== + +# ------------------------------------------------------------------------------ +# models: [FREE] +# +# Each role is instantiated via _target_(*, training_config=..., **kwargs). +# Keys here are constructor kwargs of the _target_ class (e.g. WanModel). +# You can define any role name (student, teacher, critic, etc.). +# ------------------------------------------------------------------------------ +models: + student: + _target_: fastvideo.train.models.wan.WanModel # required + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # required: HF repo or local path + trainable: true # default: true + disable_custom_init_weights: false # default: false + flow_shift: 3.0 # default: 3.0 + enable_gradient_checkpointing_type: null # default: null (falls back to training.model) + + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +# ------------------------------------------------------------------------------ +# method: [FREE] +# +# Instantiated via _target_(*, cfg=RunConfig, role_models=...). +# All keys besides _target_ are available in self.method_config (a plain dict). +# Keys depend entirely on the method class. +# ------------------------------------------------------------------------------ +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method # required + + # --- DMD2-specific keys (read from self.method_config) --- + rollout_mode: simulate # required: "simulate" or "data_latent" + generator_update_interval: 5 # default: 1 + dmd_denoising_steps: [1000, 750, 500, 250] # SDE timestep schedule + + # Critic optimizer (all required — no fallback) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + + # CFG conditioning policy (optional) + # cfg_uncond: + # on_missing: error # "error" or "ignore" + # text: keep # "keep", "zero", "drop", "negative_prompt" + # image: keep # "keep", "zero", "drop" + # action: keep # "keep", "zero", "drop" + + # --- FineTuneMethod keys (if using finetune instead) --- + # _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod + # attn_kind: vsa # "dense" or "vsa" + # use_ema: false + +# ------------------------------------------------------------------------------ +# training: [TYPED] -> TrainingConfig +# +# Every field below has a typed default. Unknown keys are ignored. +# ------------------------------------------------------------------------------ +training: + + # --- training.distributed [TYPED] -> DistributedConfig --- + distributed: + num_gpus: 8 # default: 1 + tp_size: 1 # default: 1 + sp_size: 1 # default: 1 (defaults to num_gpus in loader) + hsdp_replicate_dim: 1 # default: 1 + hsdp_shard_dim: 8 # default: -1 (defaults to num_gpus in loader) + pin_cpu_memory: false # default: false + + # --- training.data [TYPED] -> DataConfig --- + data: + data_path: data/my_dataset # default: "" + train_batch_size: 1 # default: 1 + dataloader_num_workers: 4 # default: 0 + training_cfg_rate: 0.1 # default: 0.0 + seed: 1000 # default: 0 + num_height: 448 # default: 0 + num_width: 832 # default: 0 + num_latent_t: 20 # default: 0 + num_frames: 77 # default: 0 + + # --- training.optimizer [TYPED] -> OptimizerConfig --- + # Note: only for the student optimizer. Critic optimizer is in method config. + optimizer: + learning_rate: 2.0e-6 # default: 0.0 + betas: [0.9, 0.999] # default: [0.9, 0.999] + weight_decay: 0.01 # default: 0.0 + lr_scheduler: constant # default: "constant" + lr_warmup_steps: 0 # default: 0 + lr_num_cycles: 0 # default: 0 + lr_power: 0.0 # default: 0.0 + min_lr_ratio: 0.5 # default: 0.5 + + # --- training.loop [TYPED] -> TrainingLoopConfig --- + loop: + max_train_steps: 10000 # default: 0 + gradient_accumulation_steps: 1 # default: 1 + + # --- training.checkpoint [TYPED] -> CheckpointConfig --- + checkpoint: + output_dir: outputs/my_run # default: "" + resume_from_checkpoint: "" # default: "" (or use --resume-from-checkpoint CLI) + training_state_checkpointing_steps: 1000 # default: 0 (disabled) + checkpoints_total_limit: 3 # default: 0 (keep all) + + # --- training.tracker [TYPED] -> TrackerConfig --- + tracker: + trackers: [] # default: [] (auto-adds "wandb" if project_name is set) + project_name: my_project # default: "fastvideo" + run_name: my_run # default: "" + + # --- training.vsa [TYPED] -> VSAConfig --- + vsa: + sparsity: 0.0 # default: 0.0 (0.0 = disabled) + decay_rate: 0.0 # default: 0.0 + decay_interval_steps: 0 # default: 0 + + # --- training.model [TYPED] -> ModelTrainingConfig --- + model: + weighting_scheme: uniform # default: "uniform" + logit_mean: 0.0 # default: 0.0 + logit_std: 1.0 # default: 1.0 + mode_scale: 1.0 # default: 1.0 + precondition_outputs: false # default: false + moba_config: {} # default: {} + enable_gradient_checkpointing_type: full # default: null ("full" or null) + + # --- training top-level [TYPED] --- + dit_precision: fp32 # default: "fp32" (master weight precision) + # model_path: ... # default: "" (auto-derived from models.student.init_from) + +# ------------------------------------------------------------------------------ +# callbacks: [FREE] +# +# Each callback is instantiated via _target_(*, **kwargs). +# The callback name (e.g. "grad_clip") is arbitrary — only _target_ matters. +# training_config is injected automatically (not from YAML). +# ------------------------------------------------------------------------------ +callbacks: + + # --- GradNormClipCallback --- + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback # optional if using default registry + max_grad_norm: 1.0 # default: 0.0 (0.0 = disabled) + log_grad_norms: false # default: false + + # --- EMACallback --- + # ema: + # _target_: fastvideo.train.callbacks.ema.EMACallback + # type: constant # default: "constant" ("constant", "power", "halflife") + # beta: 0.9999 # default: 0.9999 (for constant type) + # gamma: 16.97 # default: 16.97 (for power type) + # ema_halflife_kimg: 500.0 # default: 500.0 (for halflife type) + # ema_rampup_ratio: 0.05 # default: 0.05 (for halflife type) + # start_iter: 0 # default: 0 + # batch_size: 1 # default: 1 + + # --- ValidationCallback --- + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback # optional if using default registry + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline # required + dataset_file: path/to/validation.json # required + every_steps: 100 # default: 100 + sampling_steps: [4] # default: [40] + sampler_kind: sde # default: "ode" (use "sde" for few-step distilled models) + scheduler_target: null # default: null (_target_ for scheduler class, e.g. + # fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler + # fastvideo.models.schedulers.scheduling_flow_unipc_multistep.FlowUniPCMultistepScheduler) + guidance_scale: 5.0 # default: null (uses model default) + num_frames: null # default: null (derived from training.data) + output_dir: null # default: null (falls back to training.checkpoint.output_dir) + sampling_timesteps: null # default: null (explicit timestep list for SDE) + rollout_mode: parallel # default: "parallel" ("parallel" or "streaming") + +# ------------------------------------------------------------------------------ +# pipeline: [RESOLVED] -> PipelineConfig +# +# Parsed by PipelineConfig.from_kwargs(). Most fields are auto-populated from +# the model's config files (vae_config, dit_config, text_encoder_configs, etc.). +# Only scalar overrides are typically needed here. +# ------------------------------------------------------------------------------ +pipeline: + flow_shift: 3 # default: null (model-specific) + # flow_shift_sr: null # default: null (super-resolution shift) + # embedded_cfg_scale: 6.0 # default: 6.0 + # is_causal: false # default: false + # vae_tiling: true # default: true + # vae_sp: true # default: true + # disable_autocast: false # default: false diff --git a/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml b/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml new file mode 100644 index 000000000..f4ac89599 --- /dev/null +++ b/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml @@ -0,0 +1,82 @@ +# V3 config: Wan 2.1 T2V 1.3B finetune with VSA (phase 3.4, 0.9 sparsity). +# +# Uses _target_-based instantiation — each model role is an independent +# class instance; the method class is resolved directly from the YAML. + +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + +method: + _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod + attn_kind: vsa + use_ema: true + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 8 + hsdp_shard_dim: 1 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.1 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 1.0e-6 + betas: [0.9, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/phase3.4_wan2.1_finetune_vsa_0.9_v3 + training_state_checkpointing_steps: 1000 + weight_only_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wan_r + run_name: phase3.4_wan_finetune_vsa_0.9_v3 + + model: + enable_gradient_checkpointing_type: full + + vsa: + sparsity: 0.9 + decay_rate: 0.03 + decay_interval_steps: 1 + +callbacks: + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback + max_grad_norm: 1.0 + ema: + _target_: fastvideo.train.callbacks.ema.EMACallback + beta: 0.9999 + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 50 + sampling_steps: [50] + guidance_scale: 5.0 + +pipeline: + flow_shift: 3 + sampler_kind: ode diff --git a/examples/train/rfc.md b/examples/train/rfc.md new file mode 100644 index 000000000..35ed587c8 --- /dev/null +++ b/examples/train/rfc.md @@ -0,0 +1,139 @@ + + +## 1) File Structure + +fastvideo/train/ + trainer.py # Training loop; calls method.train_one_step() + models/ + base.py # BaseModel ABC: predict_x0, add_noise, backward, ... + wan/ + wan.py # Wan model loader + methods/ + base.py # DistillMethod base; methods provide train_one_step + distribution_matching/ + dmd2.py # DMD2 distillation (student/teacher/critic) + self_forcing.py # Self-forcing distillation + fine_tuning/ + finetune.py # SFT finetuning (student only) + dfsft.py # Distribution-free SFT + knowledge_distillation/ + consistency_model/ + callbacks/ + callback.py # CallbackDict registry + grad_clip.py # Gradient clipping + optional per-module norm logging + validation.py # Periodic validation via inference pipeline + ema.py # EMA weight averaging + entrypoint/ + train.py # YAML-only CLI entrypoint (torchrun -m fastvideo.train.entrypoint.train) + dcp_to_diffusers.py # Checkpoint conversion + utils/ + config.py # YAML parser -> RunConfig + builder.py # build_from_config: instantiate models, method, dataloader + instantiate.py # _target_ based instantiation + training_config.py # TrainingConfig dataclass (all training settings with defaults) + dataloader.py # Dataset / dataloader construction + moduleloader.py # Dynamic module import + module_state.py # apply_trainable(): requires_grad + train/eval + optimizer.py # Optimizer construction + tracking.py # W&B tracker (owned by trainer) + checkpoint.py # Save/resume with DCP + validation.py # Validation helpers + +By this design, we only need a YAML config to train different models using different methods. +Models declare `_target_` to select the model class; methods declare `_target_` to select the method class. +Current code: https://github.com/FoundationResearch/FastVideo/tree/distill1/fastvideo/train + +DMD2 Distillation, Self-Forcing, SFT, and DFSFT are tested on Wan. + +Current supported models: Wan. +Current supported methods: DMD2, Self-Forcing, SFT, DFSFT. + +Feedbacks are highly welcome! + + +## 2) Example YAML (DMD2 8-step) + +```yaml +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method + rollout_mode: simulate + generator_update_interval: 5 + real_score_guidance_scale: 3.5 + dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125] + + # Critic optimizer (required) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 2.0e-6 + betas: [0.0, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wan2.1_dmd2_8steps + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wan + run_name: wan2.1_dmd2_8steps + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + max_grad_norm: 1.0 + validation: + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 100 + sampling_steps: [8] + sampler_kind: sde + sampling_timesteps: [1000, 850, 700, 550, 350, 275, 200, 125] + guidance_scale: 6.0 + +pipeline: + flow_shift: 8 +``` diff --git a/examples/train/run.sh b/examples/train/run.sh new file mode 100755 index 000000000..da809abb9 --- /dev/null +++ b/examples/train/run.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Launch distillation training from a v3 YAML config. +# +# Usage: +# bash examples/distillation/refactor/run.sh [extra flags] +# +# Examples: +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/self_forcing_wangame_causal_v3.yaml +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/dfsft_wangame_causal_v3.yaml --dry-run +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/dfsft_wangame_causal_v3.yaml \ +# --override-output-dir outputs/my_run +# +# Logs are written to logs/_.log (and also printed to stdout). + +set -euo pipefail + +CONFIG="${1:?Usage: $0 [extra flags...]}" +shift + +# ── GPU / node settings ────────────────────────────────────────────── +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +NUM_GPUS="${NUM_GPUS:-1}" +NNODES="${NNODES:-1}" +NODE_RANK="${NODE_RANK:-0}" +MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +MASTER_PORT="${MASTER_PORT:-29501}" + +# ── W&B ────────────────────────────────────────────────────────────── +export WANDB_API_KEY="${WANDB_API_KEY:-}" +export WANDB_MODE="${WANDB_MODE:-online}" + +# ── Log file ───────────────────────────────────────────────────────── +CONFIG_NAME="$(basename "${CONFIG}" .yaml)" +TIMESTAMP="$(date +%Y%m%d_%H%M%S)" +LOG_DIR="${LOG_DIR:-examples/distillation/refactor}" +mkdir -p "${LOG_DIR}" +LOG_FILE="${LOG_DIR}/${CONFIG_NAME}_${TIMESTAMP}.log" + +source ~/conda/miniconda/bin/activate +conda activate alexfv + +echo "=== Distillation Training ===" +echo "Config: ${CONFIG}" +echo "Num GPUs: ${NUM_GPUS}" +echo "Num Nodes: ${NNODES}" +echo "Node Rank: ${NODE_RANK}" +echo "Master: ${MASTER_ADDR}:${MASTER_PORT}" +echo "Extra args: $*" +echo "Log file: ${LOG_FILE}" +echo "==============================" + +torchrun \ + --nnodes "${NNODES}" \ + --node_rank "${NODE_RANK}" \ + --nproc_per_node "${NUM_GPUS}" \ + --master_addr "${MASTER_ADDR}" \ + --master_port "${MASTER_PORT}" \ + fastvideo/train/entrypoint/train.py \ + --config "${CONFIG}" \ + "$@" \ + 2>&1 | tee "${LOG_FILE}" diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 7ea6d1284..d30d4e96d 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -3,13 +3,15 @@ from fastvideo.configs.models.dits.hunyuangamecraft import HunyuanGameCraftConfig from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig from fastvideo.configs.models.dits.hunyuanvideo15 import HunyuanVideo15Config +from fastvideo.configs.models.dits.lingbotworld import LingBotWorldVideoConfig from fastvideo.configs.models.dits.longcat import LongCatVideoConfig from fastvideo.configs.models.dits.ltx2 import LTX2VideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig from fastvideo.configs.models.dits.hyworld import HYWorldConfig __all__ = [ - "HunyuanVideoConfig", "HunyuanVideo15Config", "HunyuanGameCraftConfig", - "WanVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig", - "LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig" + "HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig", + "StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig", + "LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig", + "LingBotWorldVideoConfig", "HunyuanGameCraftConfig", "WanVideoConfig" ] diff --git a/fastvideo/configs/pipelines/__init__.py b/fastvideo/configs/pipelines/__init__.py index 2ea5882d9..4d7838638 100644 --- a/fastvideo/configs/pipelines/__init__.py +++ b/fastvideo/configs/pipelines/__init__.py @@ -5,6 +5,7 @@ from fastvideo.configs.pipelines.hunyuan15 import Hunyuan15T2V480PConfig, Hunyuan15T2V720PConfig from fastvideo.configs.pipelines.hunyuangamecraft import HunyuanGameCraftPipelineConfig from fastvideo.configs.pipelines.hyworld import HYWorldConfig +from fastvideo.configs.pipelines.lingbotworld import LingBotWorldI2V480PConfig from fastvideo.configs.pipelines.ltx2 import LTX2T2VConfig from fastvideo.registry import get_pipeline_config_cls_from_name from fastvideo.configs.pipelines.wan import (SelfForcingWanT2V480PConfig, @@ -15,7 +16,8 @@ "HunyuanConfig", "FastHunyuanConfig", "HunyuanGameCraftPipelineConfig", "PipelineConfig", "Hunyuan15T2V480PConfig", "Hunyuan15T2V720PConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", - "WanI2V720PConfig", "SelfForcingWanT2V480PConfig", "CosmosConfig", - "Cosmos25Config", "LTX2T2VConfig", "HYWorldConfig", + "WanI2V720PConfig", "StepVideoT2VConfig", "SelfForcingWanT2V480PConfig", + "CosmosConfig", "Cosmos25Config", "LTX2T2VConfig", "HYWorldConfig", + "SD35Config", "LingBotWorldI2V480PConfig", "get_pipeline_config_cls_from_name" ] diff --git a/fastvideo/configs/pipelines/base.py b/fastvideo/configs/pipelines/base.py index 83df65f8d..f60fb600a 100644 --- a/fastvideo/configs/pipelines/base.py +++ b/fastvideo/configs/pipelines/base.py @@ -69,6 +69,16 @@ class PipelineConfig: # DMD parameters dmd_denoising_steps: list[int] | None = field(default=None) + # Sampler kind (controls the denoising loop semantics). + # - "ode": deterministic solver-style loop (default) + # - "sde": stochastic loop with noise injection + sampler_kind: str = "ode" + + # ODE solver selection when `sampler_kind="ode"`. + # - "unipc": FlowUniPCMultistepScheduler (default) + # - "euler": FlowMatchEulerDiscreteScheduler + ode_solver: str = "unipc" + # Wan2.2 TI2V parameters ti2v_task: bool = False boundary_ratio: float | None = None @@ -175,6 +185,14 @@ def add_cli_args(parser: FlexibleArgumentParser, help= "Comma-separated list of denoising steps (e.g., '1000,757,522')", ) + parser.add_argument( + f"--{prefix_with_dot}sampler-kind", + type=str, + choices=["ode", "sde"], + dest=f"{prefix_with_dot.replace('-', '_')}sampler_kind", + default=PipelineConfig.sampler_kind, + help="Sampling loop kind: ode (default) or sde.", + ) # Add VAE configuration arguments from fastvideo.configs.models.vaes.base import VAEConfig diff --git a/fastvideo/configs/sample/wan.py b/fastvideo/configs/sample/wan.py index a96cf0d29..b6eaf6989 100644 --- a/fastvideo/configs/sample/wan.py +++ b/fastvideo/configs/sample/wan.py @@ -63,13 +63,13 @@ class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam): @dataclass class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam): """Sampling parameters for Wan2.1 Fun 1.3B InP model.""" - height: int = 480 - width: int = 832 - num_frames: int = 81 - fps: int = 16 + height: int = 352 + width: int = 640 + num_frames: int = 77 + fps: int = 25 negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" - guidance_scale: float = 6.0 - num_inference_steps: int = 50 + guidance_scale: float = 1.0 + num_inference_steps: int = 40 @dataclass diff --git a/fastvideo/dataset/parquet_dataset_map_style.py b/fastvideo/dataset/parquet_dataset_map_style.py index dac622497..d1ce92a9f 100644 --- a/fastvideo/dataset/parquet_dataset_map_style.py +++ b/fastvideo/dataset/parquet_dataset_map_style.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import hashlib import os import pickle import random @@ -36,6 +37,7 @@ def __init__( global_rank: int, drop_last: bool = True, drop_first_row: bool = False, + reshuffle_each_epoch: bool = True, seed: int = 0, ): self.batch_size = batch_size @@ -45,34 +47,40 @@ def __init__( self.num_sp_groups = num_sp_groups self.global_rank = global_rank self.sp_world_size = sp_world_size + self.drop_first_row = drop_first_row + self.reshuffle_each_epoch = reshuffle_each_epoch + self._build_indices(0) # build indices for epoch 0 to initialize the sampler + + def _build_indices(self, epoch: int) -> None: # ── epoch-level RNG ──────────────────────────────────────────────── - rng = torch.Generator().manual_seed(self.seed) + rng = torch.Generator().manual_seed(self.seed + epoch) # Create a random permutation of all indices global_indices = torch.randperm(self.dataset_size, generator=rng) - if drop_first_row: + dataset_size = self.dataset_size + if self.drop_first_row: # drop 0 in global_indices global_indices = global_indices[global_indices != 0] - self.dataset_size = self.dataset_size - 1 + dataset_size = dataset_size - 1 if self.drop_last: # For drop_last=True, we: # 1. Ensure total samples is divisible by (batch_size * num_sp_groups) # 2. This guarantees each SP group gets same number of complete batches # 3. Prevents uneven batch sizes across SP groups at end of epoch - num_batches = self.dataset_size // self.batch_size + num_batches = dataset_size // self.batch_size num_global_batches = num_batches // self.num_sp_groups global_indices = global_indices[:num_global_batches * self.num_sp_groups * self.batch_size] else: - if self.dataset_size % (self.num_sp_groups * self.batch_size) != 0: + if dataset_size % (self.num_sp_groups * self.batch_size) != 0: # add more indices to make it divisible by (batch_size * num_sp_groups) padding_size = self.num_sp_groups * self.batch_size - ( - self.dataset_size % (self.num_sp_groups * self.batch_size)) + dataset_size % (self.num_sp_groups * self.batch_size)) logger.info("Padding the dataset from %d to %d", - self.dataset_size, self.dataset_size + padding_size) + dataset_size, dataset_size + padding_size) global_indices = torch.cat( [global_indices, global_indices[:padding_size]]) @@ -84,6 +92,11 @@ def __init__( logger.info("Dataset size for each sp group: %d", len(sp_group_local_indices)) + def set_epoch(self, epoch: int) -> None: + if not self.reshuffle_each_epoch: + return + self._build_indices(epoch) + def __iter__(self): indices = self.sp_group_local_indices for i in range(0, len(indices), self.batch_size): @@ -94,11 +107,90 @@ def __len__(self): return len(self.sp_group_local_indices) // self.batch_size +def _parse_data_path_specs(path: str) -> list[tuple[str, int]]: + """ + Parse data_path into a list of (directory, repeat_count). + Syntax: comma-separated entries; each entry is "path" (default 1) or "path:N" (N = repeat count). + N=0 means skip that path (convenience to disable without removing). If no ":" present, default is 1. + Example: "/dir1:2,/dir2,/dir3:0" -> dir1 2x, dir2 1x, dir3 skipped. + """ + specs: list[tuple[str, int]] = [] + for part in path.split(","): + part = part.strip() + if not part: + continue + if ":" in part: + p, _, count_str = part.rpartition(":") + p = p.strip() + try: + count = int(count_str.strip()) + except ValueError: + raise ValueError( + f"data_path repeat count must be an integer, got {count_str!r}" + ) from None + if count < 0: + raise ValueError( + f"data_path repeat count must be >= 0, got {count}" + ) + specs.append((p, count)) + else: + specs.append((part, 1)) + return specs + + +def _scan_parquet_files_for_path(p: str) -> tuple[list[str], list[int]]: + """Return (file_paths, row_lengths) for a single directory.""" + file_names: list[str] = [] + for root, _, files in os.walk(p): + for file in sorted(files): + if file.endswith(".parquet"): + file_names.append(os.path.join(root, file)) + lengths = [] + for file_path in tqdm.tqdm( + file_names, desc="Reading parquet files to get lengths"): + lengths.append(pq.ParquetFile(file_path).metadata.num_rows) + logger.info("Found %d parquet files with %d total rows", len(file_names), sum(lengths)) + return file_names, lengths + + def get_parquet_files_and_length(path: str): - dataset_root = os.path.realpath(os.path.expanduser(path)) - # Check if cached info exists - cache_dir = os.path.join(dataset_root, "map_style_cache") - cache_file = os.path.join(cache_dir, "file_info.pkl") + """ + Collect parquet file paths and row lengths from one or more directories. + path: single directory, or comma-separated "path" or "path:N" (N = repeat count). + E.g. "/dir1:2,/dir2:1" -> dir1's files appear 2x (oversampled), dir2 once. + """ + path_specs = _parse_data_path_specs(path) + if not path_specs: + raise ValueError( + "data_path must be a non-empty path or comma-separated path specs" + ) + # Use first path with count > 0 for cache_dir (single-path case only) + first_path = next( + (p for p, c in path_specs if c > 0), + path_specs[0][0], + ) + is_single_no_repeat = ( + len(path_specs) == 1 and path_specs[0][1] == 1 + ) + effective_path = path.strip() + # Single path, no repeat: cache under that path (backward compatible). + # Multi-path or repeat: cache in a neutral dir keyed by hash of full path spec, + # so we never reuse "first path's" cache and the cached list is the merged list. + if is_single_no_repeat: + cache_dir = os.path.join(first_path, "map_style_cache") + cache_suffix = "file_info.pkl" + else: + neutral_root = os.environ.get( + "FASTVIDEO_MAP_STYLE_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache", "fastvideo", "map_style_cache"), + ) + cache_dir = neutral_root + cache_suffix = ( + "file_info_" + + hashlib.md5(effective_path.encode()).hexdigest()[:16] + + ".pkl" + ) + cache_file = os.path.join(cache_dir, cache_suffix) # Only rank 0 checks for cache and scans files if needed if get_world_rank() == 0: @@ -152,30 +244,31 @@ def get_parquet_files_and_length(path: str): # If cache not loaded (either doesn't exist or failed to load), scan files if not cache_loaded: - logger.info("Scanning parquet files to get lengths") - lengths = [] - file_names = [] - for root, _, files in os.walk(dataset_root): - for file in sorted(files): - if file.endswith('.parquet'): - file_path = os.path.realpath(os.path.join(root, file)) - file_names.append(file_path) - if len(file_names) == 0: - raise FileNotFoundError( - "No parquet files found under dataset path: " - f"{path}. " - "Please verify this path points to preprocessed parquet " - "data.") - for file_path in tqdm.tqdm( - file_names, desc="Reading parquet files to get lengths"): - num_rows = pq.ParquetFile(file_path).metadata.num_rows - lengths.append(num_rows) - # sort according to file name to ensure all rank has the same order - file_names_sorted, lengths_sorted = zip(*sorted(zip(file_names, - lengths, - strict=True), - key=lambda x: x[0]), - strict=True) + logger.info( + "Scanning parquet files (path specs: %s)", + [(p, c) for p, c in path_specs], + ) + # Build list with repeats; use (path, length, sort_index) for stable order + # Skip paths with count 0 (no I/O for disabled paths) + combined: list[tuple[str, int, int]] = [] + sort_index = 0 + for p, count in path_specs: + if count == 0: + continue + fnames, lens = _scan_parquet_files_for_path(p) + for _ in range(count): + for f, ln in zip(fnames, lens, strict=True): + combined.append((f, ln, sort_index)) + sort_index += 1 + if not combined: + raise ValueError( + "No parquet files found in the dataset (paths: %s)" + % [p for p, _ in path_specs] + ) + combined.sort(key=lambda x: (x[0], x[2])) + file_names_sorted = tuple(x[0] for x in combined) + lengths_sorted = tuple(x[1] for x in combined) + # Save the cache os.makedirs(cache_dir, exist_ok=True) with open(cache_file, "wb") as f: @@ -275,6 +368,7 @@ def __init__( seed: int = 42, drop_last: bool = True, drop_first_row: bool = False, + reshuffle_each_epoch: bool = False, text_padding_length: int = 512, ): super().__init__() @@ -297,6 +391,7 @@ def __init__( global_rank=get_world_rank(), drop_last=drop_last, drop_first_row=drop_first_row, + reshuffle_each_epoch=reshuffle_each_epoch, seed=seed, ) logger.info("Dataset initialized with %d parquet files and %d rows", @@ -369,6 +464,7 @@ def build_parquet_map_style_dataloader( cfg_rate=0.0, drop_last=True, drop_first_row=False, + reshuffle_each_epoch=False, text_padding_length=512, seed=42) -> tuple[LatentsParquetMapStyleDataset, StatefulDataLoader]: dataset = LatentsParquetMapStyleDataset( @@ -377,6 +473,7 @@ def build_parquet_map_style_dataloader( cfg_rate=cfg_rate, drop_last=drop_last, drop_first_row=drop_first_row, + reshuffle_each_epoch=reshuffle_each_epoch, text_padding_length=text_padding_length, parquet_schema=parquet_schema, seed=seed) diff --git a/fastvideo/dataset/validation_dataset.py b/fastvideo/dataset/validation_dataset.py index 5ab467d75..cf97e8bc0 100644 --- a/fastvideo/dataset/validation_dataset.py +++ b/fastvideo/dataset/validation_dataset.py @@ -4,6 +4,7 @@ import pathlib import datasets +import numpy as np from torch.utils.data import IterableDataset from fastvideo.distributed import (get_sp_world_size, get_world_rank, @@ -16,8 +17,9 @@ class ValidationDataset(IterableDataset): - def __init__(self, filename: str): + def __init__(self, filename: str, num_samples: int | None = None): super().__init__() + self.num_samples = num_samples self.filename = pathlib.Path(filename) # get directory of filename @@ -58,6 +60,12 @@ def __init__(self, filename: str): # Convert to list to get total samples self.all_samples = list(data) + + # Limit number of samples if specified + if self.num_samples is not None and self.num_samples < len(self.all_samples): + self.all_samples = self.all_samples[:self.num_samples] + logger.info("Limiting validation samples to %s", self.num_samples) + self.original_total_samples = len(self.all_samples) # Extend samples to be a multiple of DP degree (num_sp_groups) @@ -160,5 +168,25 @@ def __iter__(self): else: sample["control_video"] = load_video(control_video_path) + if sample.get("action_path", None) is not None: + action_path = sample["action_path"] + action_path = os.path.join(self.dir, action_path) + sample["action_path"] = action_path + if not pathlib.Path(action_path).is_file(): + logger.warning("Action file %s does not exist.", action_path) + else: + try: + action_data = np.load(action_path, allow_pickle=True) + num_frames = sample["num_frames"] + if action_data.dtype == object: action_data = action_data.item() + if isinstance(action_data, dict): + sample["keyboard_cond"] = action_data["keyboard"][:num_frames] + sample["mouse_cond"] = action_data["mouse"][:num_frames] + else: + sample["keyboard_cond"] = action_data[:num_frames] + except Exception as e: + logger.error("Error loading action file %s: %s", + action_path, e) + sample = {k: v for k, v in sample.items() if v is not None} yield sample diff --git a/fastvideo/fastvideo_args.py b/fastvideo/fastvideo_args.py index 52d511624..f8bf83523 100644 --- a/fastvideo/fastvideo_args.py +++ b/fastvideo/fastvideo_args.py @@ -802,6 +802,7 @@ class TrainingArgs(FastVideoArgs): """ data_path: str = "" dataloader_num_workers: int = 0 + reshuffle_each_epoch: bool = True num_height: int = 0 num_width: int = 0 num_frames: int = 0 @@ -830,6 +831,7 @@ class TrainingArgs(FastVideoArgs): validation_sampling_steps: str = "" validation_guidance_scale: str = "" validation_steps: float = 0.0 + validation_num_samples: int | None = None # Limit number of validation samples (None = use all) log_validation: bool = False trackers: list[str] = dataclasses.field(default_factory=list) tracker_project_name: str = "" @@ -891,6 +893,19 @@ class TrainingArgs(FastVideoArgs): lora_training: bool = False ltx2_first_frame_conditioning_p: float = 0.1 + # Action-only training: freeze base DiT, only train action modules + train_action_only: bool = False + + # Which action modules to train (only effective when train_action_only=True): + # "both" – action_embedder + prope_proj (default) + # "action_mlp" – action_embedder only + # "prope" – prope_proj only + action_train_target: str = "both" + + # Action warmup: keep action modules (action_embedder, to_out_prope) at zero + # for this many steps to let the base model stabilize first, then enable them. + action_warmup_steps: int = 0 + # distillation args generator_update_interval: int = 5 dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic @@ -902,6 +917,7 @@ class TrainingArgs(FastVideoArgs): fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2" training_state_checkpointing_steps: int = 0 # for resuming training weight_only_checkpointing_steps: int = 0 # for inference + best_checkpoint_start_step: int = 0 # save best checkpoint (by mf_angle_err_mean) after this step; 0 = disabled log_visualization: bool = False visualization_steps: int = 0 # simulate generator forward to match inference @@ -967,14 +983,22 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument("--data-path", - type=str, - required=True, - help="Path to parquet files") + parser.add_argument( + "--data-path", + type=str, + required=True, + help= + "Path to parquet files (comma-separated for multiple; path:N for repeat count)" + ) parser.add_argument("--dataloader-num-workers", type=int, required=True, help="Number of workers for dataloader") + parser.add_argument( + "--reshuffle-each-epoch", + action=StoreBoolean, + default=TrainingArgs.reshuffle_each_epoch, + help="Whether to reshuffle dataset order each epoch") parser.add_argument("--num-height", type=int, required=True, @@ -1064,6 +1088,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--validation-steps", type=float, help="Number of validation steps") + parser.add_argument( + "--validation-num-samples", + type=int, + help="Limit number of validation samples (default: use all)") parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results") @@ -1098,6 +1126,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--weight-only-checkpointing-steps", type=int, help="Steps between weight-only checkpoints (for inference)") + parser.add_argument( + "--best-checkpoint-start-step", + type=int, + help="Save best checkpoint (by mf_angle_err_mean) after this " + "step; 0 = disabled") parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from") @@ -1252,6 +1285,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Probability of conditioning on the first frame during LTX-2 training", ) + # Action-only training (freeze base model, only train action params) + parser.add_argument( + "--train-action-only", + action=StoreBoolean, + help="Whether to only train action-related parameters " + "(action_embedder and to_out_prope) while freezing base model") + + # Action warmup: keep action modules frozen for N steps + parser.add_argument("--action-warmup-steps", + type=int, + default=0, + help="Number of steps to keep action modules " + "(action_embedder, to_out_prope) frozen to let " + "the base model stabilize first") + # V-MoBA parameters parser.add_argument( "--moba-config-path", @@ -1348,6 +1396,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=TrainingArgs.context_noise, help="Context noise level for cache updates") + parser.add_argument( + "--action-train-target", + type=str, + default=TrainingArgs.action_train_target, + choices=["both", "action_mlp", "prope"], + help="Which action modules to train while freezing the base model", + ) return parser diff --git a/fastvideo/models/dits/hyworld/pose.py b/fastvideo/models/dits/hyworld/pose.py index b1b3f5df3..99d535308 100644 --- a/fastvideo/models/dits/hyworld/pose.py +++ b/fastvideo/models/dits/hyworld/pose.py @@ -10,13 +10,15 @@ """ import json +import logging import numpy as np import torch from scipy.spatial.transform import Rotation as R from typing import Union, Optional -from .trajectory import generate_camera_trajectory_local +from fastvideo.models.dits.hyworld.trajectory import generate_camera_trajectory_local +logger = logging.getLogger(__name__) # Mapping from one-hot action encoding to single label mapping = { @@ -411,3 +413,271 @@ def compute_num_frames(latent_num: int) -> int: Number of video frames """ return (latent_num - 1) * 4 + 1 + +def reformat_keyboard_and_mouse_tensors(keyboard_tensor, mouse_tensor): + """ + Reformat the keyboard and mouse tensors to the format compatible with HyWorld. + """ + num_frames = keyboard_tensor.shape[0] + assert (num_frames - 1) % 4 == 0, "num_frames must be a multiple of 4" + assert mouse_tensor.shape[0] == num_frames, "mouse_tensor must have the same number of frames as keyboard_tensor" + keyboard_tensor = keyboard_tensor[1:, :] + mouse_tensor = mouse_tensor[1:, :] + groups = keyboard_tensor.view(-1, 4, keyboard_tensor.shape[1]) + if not (groups == groups[:, 0:1]).all(dim=1).all(): + logger.warning(f"keyboard_tensor has different values for each group: {groups}") + groups = mouse_tensor.view(-1, 4, mouse_tensor.shape[1]) + if not (groups == groups[:, 0:1]).all(dim=1).all(): + logger.warning(f"mouse_tensor has different values for each group: {groups}") + + return keyboard_tensor[::4], mouse_tensor[::4] + +def process_custom_actions(keyboard_tensor, mouse_tensor, forward_speed=DEFAULT_FORWARD_SPEED): + """ + Process custom keyboard and mouse tensors into model inputs (viewmats, intrinsics, action_labels). + Assumes inputs correspond to each LATENT frame. + """ + if keyboard_tensor.ndim == 3: + keyboard_tensor = keyboard_tensor.squeeze(0) + if mouse_tensor.ndim == 3: + mouse_tensor = mouse_tensor.squeeze(0) + keyboard_tensor, mouse_tensor = reformat_keyboard_and_mouse_tensors(keyboard_tensor, mouse_tensor) + + motions = [] + + # 1. Translate tensors to motions for trajectory generation + for t in range(keyboard_tensor.shape[0]): + frame_motion = {} + + # --- Translation --- + # MatrixGame convention: 0:W, 1:S, 2:A, 3:D + fwd = 0.0 + if keyboard_tensor[t, 0] > 0.5: fwd += forward_speed # W + if keyboard_tensor[t, 1] > 0.5: fwd -= forward_speed # S + if fwd != 0: frame_motion["forward"] = fwd + + rgt = 0.0 + if keyboard_tensor[t, 2] > 0.5: rgt -= forward_speed # A (Left is negative Right) + if keyboard_tensor[t, 3] > 0.5: rgt += forward_speed # D (Right) + if rgt != 0: frame_motion["right"] = rgt + + # --- Rotation --- + # MatrixGame convention: mouse is [Pitch, Yaw] (or Y, X) + # Apply scaling (e.g. to match HyWorld distribution) + pitch = mouse_tensor[t, 0].item() + yaw = mouse_tensor[t, 1].item() + + if abs(pitch) > 1e-4: frame_motion["pitch"] = pitch + if abs(yaw) > 1e-4: frame_motion["yaw"] = yaw + + motions.append(frame_motion) + + # 2. Generate Camera Trajectory + # generate_camera_trajectory_local returns T+1 poses (starting at Identity) + # We take the first T poses to match the latent count. + # Pose 0 is Identity. Pose 1 is Identity + Motion[0]. + poses = generate_camera_trajectory_local(motions) + # poses = np.array(poses[:T]) + + # 3. Compute Viewmats (w2c) and Intrinsics + w2c_list = [] + intrinsic_list = [] + + # Setup default intrinsic (normalized) + K = np.array(DEFAULT_INTRINSIC) + K[0, 0] /= K[0, 2] * 2 + K[1, 1] /= K[1, 2] * 2 + K[0, 2] = 0.5 + K[1, 2] = 0.5 + + for i in range(len(poses)): + c2w = np.array(poses[i]) + w2c = np.linalg.inv(c2w) + w2c_list.append(w2c) + intrinsic_list.append(K) + + viewmats = torch.as_tensor(np.array(w2c_list)) + intrinsics = torch.as_tensor(np.array(intrinsic_list)) + + # 4. Generate Action Labels by analyzing the generated trajectory + # This ensures consistency with complex simultaneous movements, exactly as pose_to_input does. + + # Calculate relative camera-to-world transforms + # c2ws = inverse(viewmats) + c2ws = np.linalg.inv(np.array(w2c_list)) + + # Calculate relative movement between frames + # relative_c2w[i] = inv(c2ws[i-1]) @ c2ws[i] + C_inv = np.linalg.inv(c2ws[:-1]) + relative_c2w = np.zeros_like(c2ws) + relative_c2w[0, ...] = c2ws[0, ...] # First is anchor + relative_c2w[1:, ...] = C_inv @ c2ws[1:, ...] + + # Initialize one-hot action encodings + trans_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) + rotate_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) + + move_norm_valid = 0.0001 + + # Skip index 0 (anchor/identity) + for i in range(1, relative_c2w.shape[0]): + move_dirs = relative_c2w[i, :3, 3] # direction vector + move_norms = np.linalg.norm(move_dirs) + + if move_norms > move_norm_valid: # threshold for movement + move_norm_dirs = move_dirs / move_norms + angles_rad = np.arccos(move_norm_dirs.clip(-1.0, 1.0)) + trans_angles_deg = angles_rad * (180.0 / np.pi) # convert to degrees + else: + trans_angles_deg = np.zeros(3) + + R_rel = relative_c2w[i, :3, :3] + r = R.from_matrix(R_rel) + rot_angles_deg = r.as_euler("xyz", degrees=True) + + # Determine movement actions based on trajectory + # Note: HyWorld logic checks if rotation is small before assigning translation labels + # to avoid ambiguity in TPS mode, but here we generally want to capture the dominant movement. + tps = False # Default assumption, can be made an arg if needed + + if move_norms > move_norm_valid: + if (not tps) or ( + tps and abs(rot_angles_deg[1]) < 5e-2 and abs(rot_angles_deg[0]) < 5e-2 + ): + # Z-axis (Forward/Back) + if trans_angles_deg[2] < 60: + trans_one_hot[i, 0] = 1 # forward + elif trans_angles_deg[2] > 120: + trans_one_hot[i, 1] = 1 # backward + + # X-axis (Right/Left) + if trans_angles_deg[0] < 60: + trans_one_hot[i, 2] = 1 # right + elif trans_angles_deg[0] > 120: + trans_one_hot[i, 3] = 1 # left + + # Determine rotation actions + # Y-axis (Yaw) + if rot_angles_deg[1] > 5e-2: + rotate_one_hot[i, 0] = 1 # right + elif rot_angles_deg[1] < -5e-2: + rotate_one_hot[i, 1] = 1 # left + + # X-axis (Pitch) + if rot_angles_deg[0] > 5e-2: + rotate_one_hot[i, 2] = 1 # up + elif rot_angles_deg[0] < -5e-2: + rotate_one_hot[i, 3] = 1 # down + + trans_one_hot = torch.tensor(trans_one_hot) + rotate_one_hot = torch.tensor(rotate_one_hot) + + # Convert to single labels + trans_label = one_hot_to_one_dimension(trans_one_hot) + rotate_label = one_hot_to_one_dimension(rotate_one_hot) + action_labels = trans_label * 9 + rotate_label + + return viewmats, intrinsics, action_labels + +if __name__ == "__main__": + print("Running comparison test between process_custom_actions and pose_to_input...") + + def test_process_custom_actions(pose_string: str, keyboard: torch.Tensor, mouse: torch.Tensor, latent_num: int): + # Run process_custom_actions + # Note: We need to pass float tensors + print("Running process_custom_actions...") + viewmats_1, intrinsics_1, labels_1 = process_custom_actions( + keyboard, mouse + ) + + print(f"Running pose_to_input with string: '{pose_string}'...") + viewmats_2, intrinsics_2, labels_2 = pose_to_input( + pose_string, latent_num=latent_num + ) + + # print(f"Viewmats: {viewmats_1} vs \n {viewmats_2}") + # print(f"Intrinsics: {intrinsics_1} vs \n {intrinsics_2}") + # print(f"Labels: {labels_1} vs \n {labels_2}") + # 3. Compare Results + print("\nComparison Results:") + + # Check Shapes + print(f"Shapes (Viewmats): {viewmats_1.shape} vs {viewmats_2.shape}") + assert viewmats_1.shape == viewmats_2.shape, "Shape mismatch for viewmats" + + # Check Values + # Viewmats + diff_viewmats = (viewmats_1 - viewmats_2).abs().max().item() + print(f"Max difference in Viewmats: {diff_viewmats}") + if diff_viewmats < 1e-5: + print("✅ Viewmats match.") + else: + print("❌ Viewmats mismatch.") + + # Check intrinsics + diff_intrinsics = (intrinsics_1 - intrinsics_2).abs().max().item() + print(f"Max difference in Intrinsics: {diff_intrinsics}") + if diff_intrinsics < 1e-5: + print("✅ Intrinsics match.") + else: + print("❌ Intrinsics mismatch.") + + # Check labels + diff_labels = (labels_1 - labels_2).abs().max().item() + print(f"Max difference in Labels: {diff_labels}") + if diff_labels < 1e-5: + print("✅ Labels match.") + else: + print("❌ Labels mismatch.") + + print("All checks passed.") + + # Define shared parameters + + latent_num = 13 + pose_string = "w-2, a-3, s-1, d-6" + + num_frames = 4 * (latent_num - 1) + 1 + keyboard = torch.zeros((num_frames, 6)) + mouse = torch.zeros((num_frames, 2)) + + # Frame 0 is ignored/start + # Frames 1-8: Press W (index 0) + keyboard[1:9, 0] = 1.0 + # Frames 9-20: Press A (index 2) + keyboard[9:21, 2] = 1.0 + # Frames 21-24: Press S (index 1) + keyboard[21:25, 1] = 1.0 + # Frames 25-48: Press D (index 3) + keyboard[25:49, 3] = 1.0 + + test_process_custom_actions(pose_string, keyboard, mouse, latent_num) + + # Test keyboard AND mouse + latent_num = 25 + pose_string = "w-2, up-2, a-3, down-4, s-1, left-2, d-6, right-4" + + num_frames = 4 * (latent_num - 1) + 1 + keyboard = torch.zeros((num_frames, 6)) + mouse = torch.zeros((num_frames, 2)) + + # Frame 0 is ignored/start + # Frames 1-8: Press W (index 0) + keyboard[1:9, 0] = 1.0 + # Frames 17-28: Press A (index 2) + keyboard[17:29, 2] = 1.0 + # Frames 45-48: Press S (index 1) + keyboard[45:49, 1] = 1.0 + # Frames 57-80: Press D (index 3) + keyboard[57:81, 3] = 1.0 + + # Frames 9-16: Press Up (index 4) + mouse[9:17, 0] = DEFAULT_PITCH_SPEED + # Frames 25-32: Press Down (index 5) + mouse[29:45, 0] = -DEFAULT_PITCH_SPEED + # Frames 41-48: Press Left (index 6) + mouse[49:57, 1] = -DEFAULT_YAW_SPEED + # Frames 57-64: Press Right (index 7) + mouse[81:, 1] = DEFAULT_YAW_SPEED + + test_process_custom_actions(pose_string, keyboard, mouse, latent_num) diff --git a/fastvideo/models/dits/matrixgame/utils.py b/fastvideo/models/dits/matrixgame/utils.py index 4dd937699..c7dfd6743 100644 --- a/fastvideo/models/dits/matrixgame/utils.py +++ b/fastvideo/models/dits/matrixgame/utils.py @@ -301,119 +301,238 @@ def parse_config(config, mode="universal"): # NOTE: drawing functions are commented out to avoid cv2/libGL dependency. # -# def draw_rounded_rectangle(image, top_left, bottom_right, color, radius=10, alpha=0.5): -# overlay = image.copy() -# x1, y1 = top_left -# x2, y2 = bottom_right -# -# cv2.rectangle(overlay, (x1 + radius, y1), (x2 - radius, y2), color, -1) -# cv2.rectangle(overlay, (x1, y1 + radius), (x2, y2 - radius), color, -1) -# cv2.ellipse(overlay, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1) -# cv2.ellipse(overlay, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1) -# cv2.ellipse(overlay, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1) -# cv2.ellipse(overlay, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1) -# cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) -# -# def draw_keys_on_frame(frame, keys, key_size=(80, 50), spacing=20, bottom_margin=30, mode='universal'): -# h, w, _ = frame.shape -# horison_shift = 90 -# vertical_shift = -20 -# horizon_shift_all = 50 -# key_positions = { -# "W": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] * 2 + vertical_shift - 20), -# "A": (w // 2 - key_size[0] * 2 + 5 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "S": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "D": (w // 2 + key_size[0] - 5 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# } -# key_icon = {"W": "W", "A": "A", "S": "S", "D": "D", "left": "left", "right": "right"} -# if mode == 'templerun': -# key_positions.update({ -# "left": (w // 2 + key_size[0] * 2 + spacing * 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "right": (w // 2 + key_size[0] * 3 + spacing * 7 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift) -# }) -# -# for key, (x, y) in key_positions.items(): -# is_pressed = keys.get(key, False) -# top_left = (x, y) -# if key in ["left", "right"]: -# bottom_right = (x + key_size[0] + 40, y + key_size[1]) -# else: -# bottom_right = (x + key_size[0], y + key_size[1]) -# -# color = (0, 255, 0) if is_pressed else (200, 200, 200) -# alpha = 0.8 if is_pressed else 0.5 -# draw_rounded_rectangle(frame, top_left, bottom_right, color, radius=10, alpha=alpha) -# -# text_size = cv2.getTextSize(key, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0] -# if key in ["left", "right"]: -# text_x = x + (key_size[0] + 40 - text_size[0]) // 2 -# else: -# text_x = x + (key_size[0] - text_size[0]) // 2 -# text_y = y + (key_size[1] + text_size[1]) // 2 -# cv2.putText(frame, key_icon[key], (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) -# -# def overlay_icon(frame, icon, position, scale=1.0, rotation=0): -# x, y = position -# h, w, _ = icon.shape -# -# scaled_width = int(w * scale) -# scaled_height = int(h * scale) -# icon_resized = cv2.resize(icon, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA) -# -# center = (scaled_width // 2, scaled_height // 2) -# rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0) -# icon_rotated = cv2.warpAffine( -# icon_resized, rotation_matrix, (scaled_width, scaled_height), -# flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0) -# ) -# -# h, w, _ = icon_rotated.shape -# frame_h, frame_w, _ = frame.shape -# -# top_left_x = max(0, int(x - w // 2)) -# top_left_y = max(0, int(y - h // 2)) -# bottom_right_x = min(frame_w, int(x + w // 2)) -# bottom_right_y = min(frame_h, int(y + h // 2)) -# -# icon_x_start = max(0, int(-x + w // 2)) -# icon_y_start = max(0, int(-y + h // 2)) -# icon_x_end = icon_x_start + (bottom_right_x - top_left_x) -# icon_y_end = icon_y_start + (bottom_right_y - top_left_y) -# -# icon_region = icon_rotated[icon_y_start:icon_y_end, icon_x_start:icon_x_end] -# alpha = icon_region[:, :, 3] / 255.0 -# icon_rgb = icon_region[:, :, :3] -# -# frame_region = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] -# for c in range(3): -# frame_region[:, :, c] = (1 - alpha) * frame_region[:, :, c] + alpha * icon_rgb[:, :, c] -# frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = frame_region -# -# def process_video(input_video, output_video, config, mouse_icon_path, -# mouse_scale=1.0, mouse_rotation=0, process_icon=True, mode='universal'): -# key_data, mouse_data = parse_config(config, mode=mode) -# fps = 12 -# -# mouse_icon = cv2.imread(mouse_icon_path, cv2.IMREAD_UNCHANGED) -# -# out_video = [] -# for frame_idx, frame in enumerate(input_video): -# frame = np.ascontiguousarray(frame) -# if process_icon: -# keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) -# draw_keys_on_frame(frame, keys, key_size=(50, 50), spacing=10, bottom_margin=20, mode=mode) -# if mode == 'universal': -# frame_width = frame.shape[1] -# frame_height = frame.shape[0] -# mouse_position = mouse_data.get(frame_idx, (frame_width // 2, frame_height // 2)) -# overlay_icon(frame, mouse_icon, mouse_position, scale=mouse_scale, rotation=mouse_rotation) -# out_video.append(frame / 255) -# -# export_to_video(out_video, output_video, fps=fps) -# logger.info(f"Video saved to {output_video}") +import cv2 +import numpy as np +from diffusers.utils import export_to_video + +def draw_rounded_rectangle(image, top_left, bottom_right, color, radius=10, alpha=0.5): + overlay = image.copy() + x1, y1 = top_left + x2, y2 = bottom_right + + cv2.rectangle(overlay, (x1 + radius, y1), (x2 - radius, y2), color, -1) + cv2.rectangle(overlay, (x1, y1 + radius), (x2, y2 - radius), color, -1) + cv2.ellipse(overlay, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1) + cv2.ellipse(overlay, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1) + cv2.ellipse(overlay, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1) + cv2.ellipse(overlay, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1) + cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) + +def draw_keys_on_frame(frame, keys, key_size=(30, 30), spacing=5, top_margin=15, mode='universal'): + """Draw WASD keys on the left top of the frame.""" + h, w, _ = frame.shape + + # Left top positioning + left_margin = 15 + gap = 3 # Gap between keys + + key_positions = { + "W": (left_margin + key_size[0] + gap, + top_margin), + "A": (left_margin, + top_margin + key_size[1] + gap), + "S": (left_margin + key_size[0] + gap, + top_margin + key_size[1] + gap), + "D": (left_margin + (key_size[0] + gap) * 2, + top_margin + key_size[1] + gap), + } + key_icon = {"W": "W", "A": "A", "S": "S", "D": "D", "left": "L", "right": "R"} + if mode == 'templerun': + key_positions.update({ + "left": (left_margin + (key_size[0] + gap) * 3 + 10, + top_margin + key_size[1] + gap), + "right": (left_margin + (key_size[0] + gap) * 4 + 15, + top_margin + key_size[1] + gap) + }) + + for key, (x, y) in key_positions.items(): + is_pressed = keys.get(key, False) + top_left = (x, y) + bottom_right = (x + key_size[0], y + key_size[1]) + + color = (0, 255, 0) if is_pressed else (200, 200, 200) + alpha = 0.8 if is_pressed else 0.5 + draw_rounded_rectangle(frame, top_left, bottom_right, color, radius=5, alpha=alpha) + + text_size = cv2.getTextSize(key_icon[key], cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + text_x = x + (key_size[0] - text_size[0]) // 2 + text_y = y + (key_size[1] + text_size[1]) // 2 + cv2.putText(frame, key_icon[key], (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) + +def overlay_icon(frame, icon, position, scale=1.0, rotation=0): + x, y = position + h, w, _ = icon.shape + + scaled_width = int(w * scale) + scaled_height = int(h * scale) + icon_resized = cv2.resize(icon, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA) + + center = (scaled_width // 2, scaled_height // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0) + icon_rotated = cv2.warpAffine( + icon_resized, rotation_matrix, (scaled_width, scaled_height), + flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0) + ) + + h, w, _ = icon_rotated.shape + frame_h, frame_w, _ = frame.shape + + top_left_x = max(0, int(x - w // 2)) + top_left_y = max(0, int(y - h // 2)) + bottom_right_x = min(frame_w, int(x + w // 2)) + bottom_right_y = min(frame_h, int(y + h // 2)) + + icon_x_start = max(0, int(-x + w // 2)) + icon_y_start = max(0, int(-y + h // 2)) + icon_x_end = icon_x_start + (bottom_right_x - top_left_x) + icon_y_end = icon_y_start + (bottom_right_y - top_left_y) + + icon_region = icon_rotated[icon_y_start:icon_y_end, icon_x_start:icon_x_end] + alpha = icon_region[:, :, 3] / 255.0 + icon_rgb = icon_region[:, :, :3] + + frame_region = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] + for c in range(3): + frame_region[:, :, c] = (1 - alpha) * frame_region[:, :, c] + alpha * icon_rgb[:, :, c] + frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = frame_region + +def process_video(input_video, output_video, config, mouse_icon_path, + mouse_scale=1.0, mouse_rotation=0, process_icon=True, mode='universal'): + key_data, mouse_data = parse_config(config, mode=mode) + fps = 12 + + mouse_icon = cv2.imread(mouse_icon_path, cv2.IMREAD_UNCHANGED) + + out_video = [] + for frame_idx, frame in enumerate(input_video): + frame = np.ascontiguousarray(frame) + if process_icon: + keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) + draw_keys_on_frame(frame, keys, key_size=(50, 50), spacing=10, bottom_margin=20, mode=mode) + if mode == 'universal': + frame_width = frame.shape[1] + frame_height = frame.shape[0] + mouse_position = mouse_data.get(frame_idx, (frame_width // 2, frame_height // 2)) + overlay_icon(frame, mouse_icon, mouse_position, scale=mouse_scale, rotation=mouse_rotation) + out_video.append(frame / 255) + + export_to_video(out_video, output_video, fps=fps) + logger.info(f"Video saved to {output_video}") + + +def parse_npy_action(action_path): + """Convert npy action file to key_data and mouse_data dict format.""" + action_data = np.load(action_path, allow_pickle=True).item() + keyboard_data = action_data['keyboard'] # shape: (num_frames, 6) -> [W, S, A, D, left, right] + mouse_data = action_data.get('mouse', None) # shape: (num_frames, 2) -> [Pitch, Yaw] + + # MatrixGame convention: 0:W, 1:S, 2:A, 3:D, 4:left, 5:right + key_names = ["W", "S", "A", "D", "left", "right"] + key_data = {} + for frame_idx, keys in enumerate(keyboard_data): + key_data[frame_idx] = {key_names[i]: bool(keys[i]) for i in range(len(key_names))} + + # MatrixGame convention: mouse is [Pitch, Yaw] + mouse_dict = {} + if mouse_data is not None: + for frame_idx, (pitch, yaw) in enumerate(mouse_data): + mouse_dict[frame_idx] = {"pitch": float(pitch), "yaw": float(yaw)} + + return key_data, mouse_dict + + +def draw_mouse_on_frame(frame, pitch, yaw, top_margin=15): + """Draw crosshair with direction arrow on the right top of the frame.""" + h, w, _ = frame.shape + + # Right top positioning + right_margin = 15 + crosshair_radius = 25 + + # Position crosshair on the right top + crosshair_x = w - right_margin - crosshair_radius + crosshair_y = top_margin + crosshair_radius + + # Yaw affects horizontal direction, pitch affects vertical + dx = int(yaw * crosshair_radius * 8) # Scale for visibility + dy = int(-pitch * crosshair_radius * 8) # Negative because y increases downward + + # Clamp arrow length + max_arrow = crosshair_radius - 5 + dx = max(-max_arrow, min(max_arrow, dx)) + dy = max(-max_arrow, min(max_arrow, dy)) + + # Draw crosshair background + cv2.circle(frame, (crosshair_x, crosshair_y), crosshair_radius, (50, 50, 50), -1) + cv2.circle(frame, (crosshair_x, crosshair_y), crosshair_radius, (200, 200, 200), 1) + cv2.line(frame, (crosshair_x - crosshair_radius + 5, crosshair_y), + (crosshair_x + crosshair_radius - 5, crosshair_y), (100, 100, 100), 1) + cv2.line(frame, (crosshair_x, crosshair_y - crosshair_radius + 5), + (crosshair_x, crosshair_y + crosshair_radius - 5), (100, 100, 100), 1) + + # Draw direction arrow + if abs(dx) > 1 or abs(dy) > 1: + cv2.arrowedLine(frame, (crosshair_x, crosshair_y), (crosshair_x + dx, crosshair_y + dy), + (0, 255, 0), 2, tipLength=0.3) + + +def process_video_with_npy(input_video, output_video, action_path, fps=12, mode='universal'): + """Process video with overlay using npy action file. + + Uses existing draw_keys_on_frame function. + """ + key_data, mouse_data = parse_npy_action(action_path) + + out_video = [] + for frame_idx, frame in enumerate(input_video): + frame = np.ascontiguousarray(frame) + keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) + draw_keys_on_frame(frame, keys, mode=mode) + + # Draw pitch and yaw + mouse = mouse_data.get(frame_idx, {"pitch": 0.0, "yaw": 0.0}) + draw_mouse_on_frame(frame, mouse["pitch"], mouse["yaw"]) + + out_video.append(frame / 255.0) + + export_to_video(out_video, output_video, fps=fps) + logger.info(f"Video saved to {output_video}") + + +if __name__ == "__main__": + import argparse + import cv2 + + parser = argparse.ArgumentParser(description="Overlay keyboard actions on video") + parser.add_argument("--video", type=str, required=True, help="Path to input video (.mp4)") + parser.add_argument("--action", type=str, required=True, help="Path to action file (.npy)") + parser.add_argument("--output", type=str, default=None, help="Path to output video (default: input_with_overlay.mp4)") + parser.add_argument("--fps", type=int, default=12, help="Output video FPS") + args = parser.parse_args() + + # Load video frames using cv2 + cap = cv2.VideoCapture(args.video) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {args.video}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + + print(f"Loaded {len(frames)} frames from video") + + # Set output path + if args.output is None: + base_name = args.video.rsplit('.', 1)[0] + output_path = f"{base_name}_with_overlay.mp4" + else: + output_path = args.output + + # Process video with overlay using existing functions + process_video_with_npy(frames, output_path, args.action, fps=args.fps) + print(f"Video with overlay saved to: {output_path}") diff --git a/fastvideo/models/loader/component_loader.py b/fastvideo/models/loader/component_loader.py index 6ee1b28c3..832d0d0d3 100644 --- a/fastvideo/models/loader/component_loader.py +++ b/fastvideo/models/loader/component_loader.py @@ -23,7 +23,6 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.layers.quantization import get_quantization_config from fastvideo.logger import init_logger -from fastvideo.models.encoders.base import TextEncoder from fastvideo.models.hf_transformer_utils import get_diffusers_config from fastvideo.models.loader.fsdp_load import maybe_load_fsdp_model, shard_model from fastvideo.models.loader.utils import set_default_torch_dtype @@ -268,11 +267,12 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): gemma_path = candidate gemma_path_from_candidate = True model_config["gemma_model_path"] = gemma_path - if gemma_path and not gemma_path_from_candidate: - if not os.path.isabs(gemma_path): - model_config["gemma_model_path"] = os.path.normpath( - os.path.join(repo_root, gemma_path) - ) + if gemma_path and not gemma_path_from_candidate and not os.path.isabs( + gemma_path + ): + model_config["gemma_model_path"] = os.path.normpath( + os.path.join(repo_root, gemma_path) + ) transformer_config_path = os.path.join( repo_root, "transformer", "config.json" ) @@ -280,12 +280,11 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): try: with open(transformer_config_path, encoding="utf-8") as f: transformer_config = json.load(f) - if ( + if (( "connector_double_precision_rope" not in model_config or not model_config["connector_double_precision_rope"] - ): - if transformer_config.get("double_precision_rope") is True: - model_config["connector_double_precision_rope"] = True + ) and transformer_config.get("double_precision_rope") is True): + model_config["connector_double_precision_rope"] = True if "connector_rope_type" not in model_config: rope_type = transformer_config.get("rope_type") if rope_type is not None: @@ -539,7 +538,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): tokenizer_cfg_path = os.path.join(resolved_model_path, "config.json") if os.path.exists(tokenizer_cfg_path): try: - with open(tokenizer_cfg_path, "r") as f: + with open(tokenizer_cfg_path) as f: tokenizer_cfg = json.load(f) if isinstance(tokenizer_cfg, dict) and ( tokenizer_cfg.get("_class_name") == "AutoProcessor" @@ -928,7 +927,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): try: upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[0]) upsampler_cfg.update_model_config(config_dict) - except Exception as e: + except Exception: upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[1]) upsampler_cfg.update_model_config(config_dict) diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index 9ba60320a..d9a3b6150 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -138,7 +138,7 @@ def maybe_load_fsdp_model( weight_iterator = safetensors_weights_iterator(weight_dir_list) param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) - load_model_from_full_model_state_dict( + incompatible_keys, unexpected_keys = load_model_from_full_model_state_dict( model, weight_iterator, device, @@ -147,6 +147,9 @@ def maybe_load_fsdp_model( cpu_offload=cpu_offload, param_names_mapping=param_names_mapping_fn, ) + if incompatible_keys or unexpected_keys: + logger.warning("Incompatible keys: %s", incompatible_keys) + logger.warning("Unexpected keys: %s", unexpected_keys) for n, p in chain(model.named_parameters(), model.named_buffers()): if p.is_meta: raise RuntimeError( @@ -339,8 +342,19 @@ def load_model_from_full_model_state_dict( logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys) - # List of allowed parameter name patterns - ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress", "proj_l"] # Can be extended as needed + # List of allowed parameter name patterns (whitelist for new params not in checkpoint) + ALLOWED_NEW_PARAM_PATTERNS = [ + "gate_compress", + "proj_l", + "to_out_prope", + "action_embedder", + "patch_embedding_wancamctrl", + "cam_conditioner", + ] # Can be extended as needed + + # Patterns for params that need kaiming_uniform init (input projections need non-zero for gradient flow) + KAIMING_INIT_PATTERNS = ["fc_in.weight"] + for new_param_name in unused_keys: if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): @@ -350,17 +364,31 @@ def load_model_from_full_model_state_dict( f"New parameter '{new_param_name}' is not supported. " f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed." ) + + # Check if this param needs kaiming init (non-zero) for gradient flow + use_kaiming = any(pattern in new_param_name for pattern in KAIMING_INIT_PATTERNS) + meta_sharded_param = meta_sd.get(new_param_name) if not hasattr(meta_sharded_param, "device_mesh"): - # Initialize with zeros - sharded_tensor = torch.zeros_like(meta_sharded_param, - device=device, - dtype=param_dtype) + # Non-sharded tensor + if use_kaiming: + import math + sharded_tensor = torch.empty_like(meta_sharded_param, device=device, dtype=param_dtype) + nn.init.kaiming_uniform_(sharded_tensor, a=math.sqrt(5)) + logger.info(f"Initialized {new_param_name} with kaiming_uniform_") + else: + # Initialize with zeros (output projections for residual behavior) + sharded_tensor = torch.zeros_like(meta_sharded_param, device=device, dtype=param_dtype) else: - # Initialize with zeros and distribute - full_tensor = torch.zeros_like(meta_sharded_param, - device=device, - dtype=param_dtype) + # Sharded tensor (DTensor) + if use_kaiming: + import math + full_tensor = torch.empty_like(meta_sharded_param, device=device, dtype=param_dtype) + nn.init.kaiming_uniform_(full_tensor, a=math.sqrt(5)) + logger.info(f"Initialized {new_param_name} with kaiming_uniform_") + else: + # Initialize with zeros and distribute + full_tensor = torch.zeros_like(meta_sharded_param, device=device, dtype=param_dtype) sharded_tensor = distribute_tensor( full_tensor, meta_sharded_param.device_mesh, diff --git a/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py b/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py index 7cffcb0f3..7a5106126 100644 --- a/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py @@ -1,74 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 """ -Wan video diffusion pipeline implementation. +Legacy Wan DMD pipeline entrypoint. -This module contains an implementation of the Wan video diffusion pipeline -using the modular pipeline architecture. +Historically FastVideo exposed a dedicated `WanDMDPipeline` class that wired a +stochastic (SDE-style) denoising loop. Phase 3.2 makes sampling loop selection +explicit via `pipeline_config.sampler_kind`, so this file becomes a thin +compatibility wrapper around `WanPipeline`. """ from fastvideo.fastvideo_args import FastVideoArgs -from fastvideo.logger import init_logger -from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( - FlowMatchEulerDiscreteScheduler) -from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline +from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline -# isort: off -from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, - DmdDenoisingStage, InputValidationStage, - LatentPreparationStage, - TextEncodingStage, - TimestepPreparationStage) -# isort: on -logger = init_logger(__name__) - - -class WanDMDPipeline(LoRAPipeline, ComposedPipelineBase): - """ - Wan video diffusion pipeline with LoRA support. - """ - - _required_config_modules = [ - "text_encoder", "tokenizer", "vae", "transformer", "scheduler" - ] +class WanDMDPipeline(WanPipeline): + """Compatibility wrapper for SDE sampling on Wan.""" def initialize_pipeline(self, fastvideo_args: FastVideoArgs): - - self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( - shift=fastvideo_args.pipeline_config.flow_shift) + fastvideo_args.pipeline_config.sampler_kind = "sde" + return super().initialize_pipeline(fastvideo_args) def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: - """Set up pipeline stages with proper dependency injection.""" - - self.add_stage(stage_name="input_validation_stage", - stage=InputValidationStage()) - - self.add_stage(stage_name="prompt_encoding_stage", - stage=TextEncodingStage( - text_encoders=[self.get_module("text_encoder")], - tokenizers=[self.get_module("tokenizer")], - )) - - self.add_stage(stage_name="conditioning_stage", - stage=ConditioningStage()) - - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) - - self.add_stage(stage_name="latent_preparation_stage", - stage=LatentPreparationStage( - scheduler=self.get_module("scheduler"), - transformer=self.get_module("transformer", None), - use_btchw_layout=True)) - - self.add_stage(stage_name="denoising_stage", - stage=DmdDenoisingStage( - transformer=self.get_module("transformer"), - scheduler=self.get_module("scheduler"))) - - self.add_stage(stage_name="decoding_stage", - stage=DecodingStage(vae=self.get_module("vae"))) + fastvideo_args.pipeline_config.sampler_kind = "sde" + return super().create_pipeline_stages(fastvideo_args) EntryClass = WanDMDPipeline diff --git a/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py b/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py index ed4d870c6..dd3ff1538 100644 --- a/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py @@ -12,10 +12,12 @@ from fastvideo.pipelines.lora_pipeline import LoRAPipeline # isort: off -from fastvideo.pipelines.stages import ( - ImageEncodingStage, ConditioningStage, DecodingStage, DmdDenoisingStage, - ImageVAEEncodingStage, InputValidationStage, LatentPreparationStage, - TextEncodingStage, TimestepPreparationStage) +from fastvideo.pipelines.stages import (ImageEncodingStage, ConditioningStage, + DecodingStage, DmdDenoisingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage) # isort: on from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler) @@ -55,10 +57,6 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) - self.add_stage(stage_name="latent_preparation_stage", stage=LatentPreparationStage( scheduler=self.get_module("scheduler"), diff --git a/fastvideo/pipelines/basic/wan/wan_pipeline.py b/fastvideo/pipelines/basic/wan/wan_pipeline.py index 64c4a0685..78c2e4dad 100644 --- a/fastvideo/pipelines/basic/wan/wan_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_pipeline.py @@ -8,13 +8,16 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.logger import init_logger -from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( - FlowUniPCMultistepScheduler) +from fastvideo.pipelines.samplers.wan import ( + build_wan_scheduler, + get_wan_sampler_kind, + wan_use_btchw_layout, +) from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage, LatentPreparationStage, - TextEncodingStage, + SdeDenoisingStage, TextEncodingStage, TimestepPreparationStage) logger = init_logger(__name__) @@ -30,12 +33,14 @@ class WanPipeline(LoRAPipeline, ComposedPipelineBase): ] def initialize_pipeline(self, fastvideo_args: FastVideoArgs): - # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers. - self.modules["scheduler"] = FlowUniPCMultistepScheduler( - shift=fastvideo_args.pipeline_config.flow_shift) + sampler_kind = get_wan_sampler_kind(fastvideo_args) + self.modules["scheduler"] = build_wan_scheduler(fastvideo_args, + sampler_kind) def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: """Set up pipeline stages with proper dependency injection.""" + sampler_kind = get_wan_sampler_kind(fastvideo_args) + use_btchw_layout = wan_use_btchw_layout(sampler_kind) self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage()) @@ -49,22 +54,32 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) + if sampler_kind == "ode": + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) self.add_stage(stage_name="latent_preparation_stage", stage=LatentPreparationStage( scheduler=self.get_module("scheduler"), - transformer=self.get_module("transformer", None))) - - self.add_stage(stage_name="denoising_stage", - stage=DenoisingStage( - transformer=self.get_module("transformer"), - transformer_2=self.get_module("transformer_2", None), - scheduler=self.get_module("scheduler"), - vae=self.get_module("vae"), - pipeline=self)) + transformer=self.get_module("transformer", None), + use_btchw_layout=use_btchw_layout)) + + if sampler_kind == "sde": + self.add_stage(stage_name="denoising_stage", + stage=SdeDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + )) + else: + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module( + "transformer_2", None), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"), + pipeline=self)) self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), diff --git a/fastvideo/pipelines/pipeline_batch_info.py b/fastvideo/pipelines/pipeline_batch_info.py index 768816cfa..9d433da4d 100644 --- a/fastvideo/pipelines/pipeline_batch_info.py +++ b/fastvideo/pipelines/pipeline_batch_info.py @@ -161,6 +161,10 @@ class ForwardBatch: # Timesteps timesteps: torch.Tensor | None = None + # Optional explicit denoising-loop timesteps (sampler-specific). + # When set, some samplers (e.g. SDE-style rollout) will iterate this list + # instead of `timesteps` produced by `TimestepPreparationStage`. + sampling_timesteps: torch.Tensor | None = None timestep: torch.Tensor | float | int | None = None step_index: int | None = None boundary_ratio: float | None = None diff --git a/fastvideo/pipelines/preprocess/v1_preprocess.py b/fastvideo/pipelines/preprocess/v1_preprocess.py index 18455d70f..6e734877a 100644 --- a/fastvideo/pipelines/preprocess/v1_preprocess.py +++ b/fastvideo/pipelines/preprocess/v1_preprocess.py @@ -67,7 +67,8 @@ def main(args) -> None: else: raise ValueError( f"Invalid preprocess task: {args.preprocess_task}. " - f"Valid options: t2v, i2v, ode_trajectory, text_only, matrixgame") + f"Valid options: t2v, i2v, ode_trajectory, text_only, matrixgame" + ) logger.info("Preprocess task: %s using %s", args.preprocess_task, PreprocessPipeline.__name__) @@ -111,12 +112,14 @@ def main(args) -> None: parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO parser.add_argument("--flow_shift", type=float, default=None) - parser.add_argument( - "--preprocess_task", - type=str, - default="t2v", - choices=["t2v", "i2v", "text_only", "ode_trajectory", "matrixgame"], - help="Type of preprocessing task to run") + parser.add_argument("--preprocess_task", + type=str, + default="t2v", + choices=[ + "t2v", "i2v", "text_only", "ode_trajectory", + "matrixgame" + ], + help="Type of preprocessing task to run") parser.add_argument("--train_fps", type=int, default=30) parser.add_argument("--use_image_num", type=int, default=0) parser.add_argument("--text_max_length", type=int, default=256) diff --git a/fastvideo/pipelines/samplers/__init__.py b/fastvideo/pipelines/samplers/__init__.py new file mode 100644 index 000000000..638a9e532 --- /dev/null +++ b/fastvideo/pipelines/samplers/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.pipelines.samplers.base import SamplerKind + +__all__ = [ + "SamplerKind", +] diff --git a/fastvideo/pipelines/samplers/base.py b/fastvideo/pipelines/samplers/base.py new file mode 100644 index 000000000..da4ef502c --- /dev/null +++ b/fastvideo/pipelines/samplers/base.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Literal + +SamplerKind = Literal["ode", "sde"] + + +def normalize_sampler_kind( + raw: str | None, + *, + where: str, + default: SamplerKind = "ode", +) -> SamplerKind: + if raw is None: + return default + + kind = str(raw).strip().lower() + if kind == "ode": + return "ode" + if kind == "sde": + return "sde" + + raise ValueError( + f"Unknown sampler kind at {where}: {raw!r} (expected ode|sde)") diff --git a/fastvideo/pipelines/samplers/wan.py b/fastvideo/pipelines/samplers/wan.py new file mode 100644 index 000000000..22fd9d3f3 --- /dev/null +++ b/fastvideo/pipelines/samplers/wan.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, ) +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, ) +from fastvideo.pipelines.samplers.base import SamplerKind, normalize_sampler_kind + + +def get_wan_sampler_kind(fastvideo_args: FastVideoArgs) -> SamplerKind: + raw = getattr(fastvideo_args.pipeline_config, "sampler_kind", None) + return normalize_sampler_kind(raw, where="pipeline_config.sampler_kind") + + +def build_wan_scheduler(fastvideo_args: FastVideoArgs, kind: SamplerKind): + shift = fastvideo_args.pipeline_config.flow_shift + if kind == "sde": + return FlowMatchEulerDiscreteScheduler(shift=shift) + + ode_solver_raw = getattr(fastvideo_args.pipeline_config, "ode_solver", + "unipc") + ode_solver = str(ode_solver_raw).strip().lower( + ) if ode_solver_raw is not None else "unipc" + if ode_solver in {"unipc", "unipc_multistep", "multistep"}: + return FlowUniPCMultistepScheduler(shift=shift) + if ode_solver in {"euler", "flowmatch", "flowmatch_euler"}: + return FlowMatchEulerDiscreteScheduler(shift=shift) + + raise ValueError("Unknown pipeline_config.ode_solver for wan pipelines: " + f"{ode_solver_raw!r} (expected 'unipc' or 'euler').") + + +def wan_use_btchw_layout(kind: SamplerKind) -> bool: + return kind == "sde" diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index 9896539c0..e59bd2a4d 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -13,7 +13,7 @@ from fastvideo.pipelines.stages.denoising import ( Cosmos25AutoDenoisingStage, Cosmos25DenoisingStage, Cosmos25V2WDenoisingStage, Cosmos25T2WDenoisingStage, CosmosDenoisingStage, - DenoisingStage, DmdDenoisingStage) + DenoisingStage, DmdDenoisingStage, SdeDenoisingStage) from fastvideo.pipelines.stages.sr_denoising import SRDenoisingStage from fastvideo.pipelines.stages.encoding import EncodingStage from fastvideo.pipelines.stages.image_encoding import ( @@ -33,7 +33,9 @@ LTX2LatentPreparationStage) from fastvideo.pipelines.stages.ltx2_text_encoding import LTX2TextEncodingStage from fastvideo.pipelines.stages.matrixgame_denoising import ( - MatrixGameCausalDenoisingStage) + MatrixGameCausalDenoisingStage, + MatrixGameCausalOdeDenoisingStage, +) from fastvideo.pipelines.stages.hyworld_denoising import HYWorldDenoisingStage from fastvideo.pipelines.stages.gamecraft_denoising import GameCraftDenoisingStage from fastvideo.pipelines.stages.text_encoding import (Cosmos25TextEncodingStage, @@ -61,9 +63,11 @@ "LTX2AudioDecodingStage", "ConditioningStage", "DenoisingStage", + "SdeDenoisingStage", "DmdDenoisingStage", "CausalDMDDenosingStage", "MatrixGameCausalDenoisingStage", + "MatrixGameCausalOdeDenoisingStage", "HYWorldDenoisingStage", "GameCraftDenoisingStage", "CosmosDenoisingStage", diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 1b11478ce..49c40a7d5 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -17,8 +17,6 @@ from fastvideo.forward_context import set_forward_context from fastvideo.logger import init_logger from fastvideo.models.loader.component_loader import TransformerLoader -from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( - FlowMatchEulerDiscreteScheduler) from fastvideo.models.utils import pred_noise_to_pred_video from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage @@ -159,6 +157,35 @@ def forward( }, ) + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + viewmats, intrinsics, action_labels = process_custom_actions( + batch.keyboard_cond, batch.mouse_cond) + camera_action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "viewmats": + viewmats.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "Ks": + intrinsics.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "action": + action_labels.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + }, + ) + # Legacy action-conditioning helper removed. + # num_frames = batch.num_frames + # latent_height = batch.height // 8 + # latent_width = batch.width // 8 + # c2ws_plucker_emb = process_lingbot_actions( + # num_frames, batch.keyboard_cond, batch.mouse_cond, + # latent_height=latent_height, latent_width=latent_width + # ).to(get_local_torch_device(), dtype=target_dtype) + else: + camera_action_kwargs = {} + action_kwargs = self.prepare_extra_func_kwargs( self.transformer.forward, { @@ -419,8 +446,8 @@ def forward( **image_kwargs, **pos_cond_kwargs, **action_kwargs, - **camera_kwargs, **timesteps_r_kwarg, + **camera_action_kwargs, ) if batch.do_classifier_free_guidance: @@ -438,8 +465,7 @@ def forward( **image_kwargs, **neg_cond_kwargs, **action_kwargs, - **camera_kwargs, - **timesteps_r_kwarg, + **camera_action_kwargs, ) noise_pred_text = noise_pred @@ -527,10 +553,19 @@ def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]: Returns: The prepared kwargs. """ - extra_step_kwargs = {} + signature = inspect.signature(func) + if any(p.kind == inspect.Parameter.VAR_KEYWORD + for p in signature.parameters.values()): + # If the callee accepts `**kwargs`, do not filter by signature. + # This is important for models that route parameters internally via + # `forward(*args, **kwargs)` (e.g. causal Wangame), where filtering + # would incorrectly drop conditioning kwargs like `action`. + return dict(kwargs) + + accepted = set(signature.parameters.keys()) + extra_step_kwargs: dict[str, Any] = {} for k, v in kwargs.items(): - accepts = k in set(inspect.signature(func).parameters.keys()) - if accepts: + if k in accepted: extra_step_kwargs[k] = v return extra_step_kwargs @@ -1171,14 +1206,16 @@ def verify_output(self, batch: ForwardBatch, return self._t2w.verify_output(batch, fastvideo_args) -class DmdDenoisingStage(DenoisingStage): - """ - Denoising stage for DMD. +class SdeDenoisingStage(DenoisingStage): + """Denoising stage for SDE-style sampling. + + This stage runs a stochastic rollout loop: + - predict x0 at timestep t + - inject fresh noise to reach the next timestep """ def __init__(self, transformer, scheduler) -> None: super().__init__(transformer, scheduler) - self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0) def forward( self, @@ -1202,16 +1239,6 @@ def forward( autocast_enabled = (target_dtype != torch.float32 ) and not fastvideo_args.disable_autocast - # Get timesteps and calculate warmup steps - timesteps = batch.timesteps - - # TODO(will): remove this once we add input/output validation for stages - if timesteps is None: - raise ValueError("Timesteps must be provided") - num_inference_steps = batch.num_inference_steps - num_warmup_steps = len( - timesteps) - num_inference_steps * self.scheduler.order - # Prepare image latents and embeddings for I2V generation image_embeds = batch.image_embeds if len(image_embeds) > 0: @@ -1237,6 +1264,37 @@ def forward( }, ) + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats, intrinsics, action_labels = process_custom_actions( + batch.keyboard_cond, batch.mouse_cond) + camera_action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "viewmats": + viewmats.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "Ks": + intrinsics.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "action": + action_labels.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + }, + ) + else: + camera_action_kwargs = {} + + action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "mouse_cond": batch.mouse_cond, + "keyboard_cond": batch.keyboard_cond, + "c2ws_plucker_emb": batch.c2ws_plucker_emb, + }, + ) + # Get latents and embeddings assert batch.latents is not None, "latents must be provided" latents = batch.latents @@ -1245,14 +1303,29 @@ def forward( prompt_embeds = batch.prompt_embeds assert not torch.isnan( prompt_embeds[0]).any(), "prompt_embeds contains nan" - timesteps = torch.tensor( - fastvideo_args.pipeline_config.dmd_denoising_steps, - dtype=torch.long, - device=get_local_torch_device()) + loop_timesteps = batch.sampling_timesteps + if loop_timesteps is None: + legacy = getattr(fastvideo_args.pipeline_config, + "dmd_denoising_steps", None) + if legacy is not None: + loop_timesteps = torch.tensor(legacy, dtype=torch.long) + else: + loop_timesteps = batch.timesteps + + if loop_timesteps is None: + raise ValueError( + "SDE sampling requires `batch.sampling_timesteps` (preferred) " + "or `pipeline_config.dmd_denoising_steps`.") + if not isinstance(loop_timesteps, torch.Tensor): + loop_timesteps = torch.tensor(loop_timesteps, dtype=torch.long) + if loop_timesteps.ndim != 1: + raise ValueError("Expected 1D `sampling_timesteps`, got shape " + f"{tuple(loop_timesteps.shape)}") + loop_timesteps = loop_timesteps.to(get_local_torch_device()) # Run denoising loop - with self.progress_bar(total=len(timesteps)) as progress_bar: - for i, t in enumerate(timesteps): + with self.progress_bar(total=len(loop_timesteps)) as progress_bar: + for i, t in enumerate(loop_timesteps): # Skip if interrupted if hasattr(self, 'interrupt') and self.interrupt: continue @@ -1326,6 +1399,8 @@ def forward( guidance=guidance_expand, **image_kwargs, **pos_cond_kwargs, + **action_kwargs, + **camera_action_kwargs, ).permute(0, 2, 1, 3, 4) pred_video = pred_noise_to_pred_video( @@ -1335,13 +1410,15 @@ def forward( scheduler=self.scheduler).unflatten( 0, pred_noise.shape[:2]) - if i < len(timesteps) - 1: - next_timestep = timesteps[i + 1] * torch.ones( + if i < len(loop_timesteps) - 1: + next_timestep = loop_timesteps[i + 1] * torch.ones( [1], dtype=torch.long, device=pred_video.device) - noise = torch.randn(video_raw_latent_shape, - dtype=pred_video.dtype, - generator=batch.generator[0]).to( - self.device) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video.dtype, + generator=batch.generator[0] if isinstance( + batch.generator, list) else batch.generator).to( + self.device) latents = self.scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) @@ -1349,11 +1426,7 @@ def forward( latents = pred_video # Update progress bar - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and - (i + 1) % self.scheduler.order == 0 - and progress_bar is not None): - progress_bar.update() + progress_bar.update() # Gather results if using sequence parallelism latents = latents.permute(0, 2, 1, 3, 4) @@ -1361,3 +1434,7 @@ def forward( batch.latents = latents return batch + + +# Backwards-compatible alias (legacy pipelines still import this symbol). +DmdDenoisingStage = SdeDenoisingStage diff --git a/fastvideo/pipelines/stages/matrixgame_denoising.py b/fastvideo/pipelines/stages/matrixgame_denoising.py index f6373fc79..5ae7bf7b7 100644 --- a/fastvideo/pipelines/stages/matrixgame_denoising.py +++ b/fastvideo/pipelines/stages/matrixgame_denoising.py @@ -54,6 +54,9 @@ class BlockProcessingContext: image_kwargs: dict[str, Any] pos_cond_kwargs: dict[str, Any] + viewmats_full: torch.Tensor | None = None + intrinsics_full: torch.Tensor | None = None + action_full: torch.Tensor | None = None def get_kv_cache(self, timestep_val: float) -> list[dict[Any, Any]]: if self.boundary_timestep is not None: @@ -97,10 +100,12 @@ def __init__(self, -1) except Exception: self.local_attn_size = -1 + try: + self.local_attn_size = getattr(self.transformer.model, + "local_attn_size", -1) + except Exception: + self.local_attn_size = -1 - assert self.local_attn_size != -1, ( - f"local_attn_size must be set for Matrix-Game causal inference, " - f"got {self.local_attn_size}. Check MatrixGameWanVideoArchConfig.") assert self.num_frame_per_block > 0, ( f"num_frame_per_block must be positive, got {self.num_frame_per_block}" ) @@ -126,7 +131,10 @@ def forward( ) and not fastvideo_args.disable_autocast latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] - patch_size = self.transformer.patch_size + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size patch_ratio = patch_size[-1] * patch_size[-2] self.frame_seq_length = latent_seq_length // patch_ratio @@ -166,6 +174,31 @@ def forward( prompt_embeds = batch.prompt_embeds assert torch.isnan(prompt_embeds[0]).sum() == 0 + viewmats_full = None + intrinsics_full = None + action_full = None + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats_list = [] + intrinsics_list = [] + action_list = [] + for bi in range(b): + vm, ks, action = process_custom_actions(batch.keyboard_cond[bi], + batch.mouse_cond[bi]) + viewmats_list.append(vm) + intrinsics_list.append(ks) + action_list.append(action) + viewmats_full = torch.stack(viewmats_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + intrinsics_full = torch.stack(intrinsics_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + action_full = torch.stack(action_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + kv_cache1 = self._initialize_kv_cache(batch_size=latents.shape[0], dtype=target_dtype, device=latents.device) @@ -225,6 +258,9 @@ def forward( "context_noise", 0), image_kwargs=image_kwargs, pos_cond_kwargs=pos_cond_kwargs, + viewmats_full=viewmats_full, + intrinsics_full=intrinsics_full, + action_full=action_full, ) context_noise = getattr(fastvideo_args.pipeline_config, "context_noise", @@ -240,6 +276,8 @@ def forward( action_kwargs = self._prepare_action_kwargs( batch, start_index, current_num_frames) + camera_action_kwargs = self._prepare_camera_action_kwargs( + ctx, start_index, current_num_frames) current_latents = self._process_single_block( current_latents=current_latents, @@ -249,6 +287,7 @@ def forward( timesteps=timesteps, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, progress_bar=progress_bar, ) @@ -263,6 +302,7 @@ def forward( current_num_frames=current_num_frames, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, context_noise=context_noise, ) @@ -324,9 +364,9 @@ def _initialize_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), }) return kv_cache @@ -362,9 +402,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), }) kv_cache_mouse.append({ "k": @@ -382,9 +422,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - 0, + torch.zeros((), dtype=torch.long, device=device), }) return kv_cache_mouse, kv_cache_keyboard @@ -418,6 +458,18 @@ def _initialize_crossattn_cache(self, batch_size: int, max_text_len: int, }) return crossattn_cache + def _prepare_camera_action_kwargs( + self, ctx: BlockProcessingContext, start_index: int, + current_num_frames: int) -> dict[str, Any]: + if ctx.action_full is None or ctx.viewmats_full is None or ctx.intrinsics_full is None: + return {} + end_index = start_index + current_num_frames + return { + "viewmats": ctx.viewmats_full[:, start_index:end_index], + "Ks": ctx.intrinsics_full[:, start_index:end_index], + "action": ctx.action_full[:, start_index:end_index], + } + def _process_single_block( self, current_latents: torch.Tensor, @@ -427,6 +479,7 @@ def _process_single_block( timesteps: torch.Tensor, ctx: BlockProcessingContext, action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], noise_generator: Callable[[tuple, torch.dtype, int], torch.Tensor] | None = None, progress_bar: Any | None = None, @@ -445,7 +498,16 @@ def _process_single_block( independent_first_frame = getattr(self.transformer, 'independent_first_frame', False) - if batch.image_latent is not None and independent_first_frame and start_index == 0: + if batch.image_latent is not None and not independent_first_frame: + image_latent_chunk = batch.image_latent[:, :, start_index: + start_index + + current_num_frames, :, :] + latent_model_input = torch.cat([ + latent_model_input, + image_latent_chunk.to(ctx.target_dtype) + ], + dim=1) + elif batch.image_latent is not None and independent_first_frame and start_index == 0: latent_model_input = torch.cat([ latent_model_input, batch.image_latent.to(ctx.target_dtype) @@ -495,6 +557,7 @@ def _process_single_block( "crossattn_cache": ctx.crossattn_cache, "current_start": start_index * self.frame_seq_length, "start_frame": start_index, + "is_cache": False, } if self.use_action_module and current_model == self.transformer: @@ -510,6 +573,7 @@ def _process_single_block( latent_model_input, prompt_embeds, t_expanded_noise, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, **model_kwargs, @@ -582,6 +646,7 @@ def _update_context_cache( current_num_frames: int, ctx: BlockProcessingContext, action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], context_noise: float, ) -> None: prompt_embeds = batch.prompt_embeds @@ -592,6 +657,17 @@ def _update_context_cache( device=latents_device, dtype=torch.long) * int(context_noise) context_bcthw = current_latents.to(ctx.target_dtype) + context_input = context_bcthw + independent_first_frame = getattr(self.transformer, + "independent_first_frame", False) + if batch.image_latent is not None and not independent_first_frame: + image_context_chunk = batch.image_latent[:, :, + start_index:start_index + + current_num_frames, :, :] + context_input = torch.cat( + [context_input, + image_context_chunk.to(ctx.target_dtype)], + dim=1) with torch.autocast(device_type="cuda", dtype=ctx.target_dtype, @@ -605,6 +681,7 @@ def _update_context_cache( "crossattn_cache": ctx.crossattn_cache, "current_start": start_index * self.frame_seq_length, "start_frame": start_index, + "is_cache": True, } if self.use_action_module: @@ -617,26 +694,398 @@ def _update_context_cache( context_model_kwargs.update(action_kwargs) if ctx.boundary_timestep is not None and self.transformer_2 is not None: - self.transformer_2( - context_bcthw, + cache_update_ret_2 = self.transformer_2( + context_input, prompt_embeds, t_context, kv_cache=ctx.kv_cache2, crossattn_cache=ctx.crossattn_cache, current_start=start_index * self.frame_seq_length, start_frame=start_index, + is_cache=True, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, ) + if isinstance(cache_update_ret_2, + list) and len(cache_update_ret_2) > 0: + ctx.kv_cache2 = cache_update_ret_2 - self.transformer( - context_bcthw, + cache_update_ret = self.transformer( + context_input, prompt_embeds, t_context, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, **context_model_kwargs, ) + if isinstance(cache_update_ret, list) and len(cache_update_ret) > 0: + ctx.kv_cache1 = cache_update_ret + + +class MatrixGameCausalOdeDenoisingStage(MatrixGameCausalDenoisingStage): + """Causal ODE denoising for MatrixGame. + + This is the deterministic counterpart of `MatrixGameCausalDenoisingStage`. + It performs block-by-block causal rollout, but uses the scheduler's ODE-style + `step()` update (no re-noising between steps). + """ + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + timesteps = batch.timesteps + if timesteps is None: + raise ValueError( + "MatrixGameCausalOdeDenoisingStage requires batch.timesteps. " + "Make sure TimestepPreparationStage runs before this stage.") + + target_dtype = torch.bfloat16 + autocast_enabled = (target_dtype != torch.float32 + ) and not fastvideo_args.disable_autocast + + latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size + patch_ratio = patch_size[-1] * patch_size[-2] + self.frame_seq_length = latent_seq_length // patch_ratio + + timesteps = timesteps.to(get_local_torch_device()) + + boundary_ratio = getattr(fastvideo_args.pipeline_config.dit_config, + "boundary_ratio", None) + if boundary_ratio is not None: + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + else: + boundary_timestep = None + + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + assert torch.isnan(image_embeds[0]).sum() == 0 + image_embeds = [ + image_embed.to(target_dtype) for image_embed in image_embeds + ] + + # directly set the kwarg. + image_kwargs = {"encoder_hidden_states_image": image_embeds} + pos_cond_kwargs: dict[str, Any] = {} + + assert batch.latents is not None, "latents must be provided" + latents = batch.latents + b, c, t, h, w = latents.shape + + prompt_embeds = batch.prompt_embeds + assert torch.isnan(prompt_embeds[0]).sum() == 0 + + viewmats_full = None + intrinsics_full = None + action_full = None + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats_list = [] + intrinsics_list = [] + action_list = [] + for bi in range(b): + vm, ks, action = process_custom_actions(batch.keyboard_cond[bi], + batch.mouse_cond[bi]) + viewmats_list.append(vm) + intrinsics_list.append(ks) + action_list.append(action) + viewmats_full = torch.stack(viewmats_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + intrinsics_full = torch.stack(intrinsics_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + action_full = torch.stack(action_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + + kv_cache1 = self._initialize_kv_cache(batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + kv_cache2 = None + if boundary_timestep is not None: + kv_cache2 = self._initialize_kv_cache(batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + + kv_cache_mouse = None + kv_cache_keyboard = None + if self.use_action_module: + kv_cache_mouse, kv_cache_keyboard = self._initialize_action_kv_cache( + batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + + crossattn_cache = self._initialize_crossattn_cache( + batch_size=latents.shape[0], + max_text_len=257, # 1 CLS + 256 patch tokens + dtype=target_dtype, + device=latents.device) + + if t % self.num_frame_per_block != 0: + raise ValueError( + "num_frames must be divisible by num_frame_per_block for causal denoising" + ) + num_blocks = t // self.num_frame_per_block + block_sizes = [self.num_frame_per_block] * num_blocks + start_index = 0 + + if boundary_timestep is not None: + block_sizes[0] = 1 + + ctx = BlockProcessingContext( + batch=batch, + block_idx=0, + start_index=0, + kv_cache1=kv_cache1, + kv_cache2=kv_cache2, + kv_cache_mouse=kv_cache_mouse, + kv_cache_keyboard=kv_cache_keyboard, + crossattn_cache=crossattn_cache, + timesteps=timesteps, + block_sizes=block_sizes, + noise_pool=None, + fastvideo_args=fastvideo_args, + target_dtype=target_dtype, + autocast_enabled=autocast_enabled, + boundary_timestep=boundary_timestep, + high_noise_timesteps=None, + context_noise=getattr(fastvideo_args.pipeline_config, + "context_noise", 0), + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + viewmats_full=viewmats_full, + intrinsics_full=intrinsics_full, + action_full=action_full, + ) + + context_noise = getattr(fastvideo_args.pipeline_config, "context_noise", + 0) + + with self.progress_bar(total=len(block_sizes) * + len(timesteps)) as progress_bar: + for block_idx, current_num_frames in enumerate(block_sizes): + ctx.block_idx = block_idx + ctx.start_index = start_index + current_latents = latents[:, :, start_index:start_index + + current_num_frames, :, :] + + # The scheduler maintains an internal `step_index` (and potentially + # additional multistep state, e.g. UniPC). Since causal streaming runs + # a full denoising trajectory *per block*, reset that state before + # each block rollout. + self._reset_scheduler_state_for_new_rollout() + + action_kwargs = self._prepare_action_kwargs( + batch, start_index, current_num_frames) + camera_action_kwargs = self._prepare_camera_action_kwargs( + ctx, start_index, current_num_frames) + + current_latents = self._process_single_block_ode( + current_latents=current_latents, + batch=batch, + start_index=start_index, + current_num_frames=current_num_frames, + timesteps=timesteps, + ctx=ctx, + action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, + progress_bar=progress_bar, + ) + + latents[:, :, start_index:start_index + + current_num_frames, :, :] = current_latents + + # Update KV caches with clean context + self._update_context_cache( + current_latents=current_latents, + batch=batch, + start_index=start_index, + current_num_frames=current_num_frames, + ctx=ctx, + action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, + context_noise=context_noise, + ) + + start_index += current_num_frames + + if boundary_timestep is not None: + num_frames_to_remove = self.num_frame_per_block - 1 + if num_frames_to_remove > 0: + latents = latents[:, :, :-num_frames_to_remove, :, :] + + batch.latents = latents + return batch + + def _reset_scheduler_state_for_new_rollout(self) -> None: + scheduler = self.scheduler + + # Common diffusers-like state. + if hasattr(scheduler, "_step_index"): + scheduler._step_index = None # type: ignore[attr-defined] + if hasattr(scheduler, "_begin_index"): + scheduler._begin_index = None # type: ignore[attr-defined] + + # UniPC multistep state (FlowUniPCMultistepScheduler) needs additional reset + # between independent trajectories. + if hasattr(scheduler, "model_outputs") and hasattr(scheduler, "config"): + try: + solver_order = int( + getattr(scheduler.config, "solver_order", 0) or 0) + except Exception: + solver_order = 0 + if solver_order > 0: + scheduler.model_outputs = [ + None + ] * solver_order # type: ignore[attr-defined] + if hasattr(scheduler, "timestep_list") and hasattr(scheduler, "config"): + try: + solver_order = int( + getattr(scheduler.config, "solver_order", 0) or 0) + except Exception: + solver_order = 0 + if solver_order > 0: + scheduler.timestep_list = [ + None + ] * solver_order # type: ignore[attr-defined] + if hasattr(scheduler, "lower_order_nums"): + scheduler.lower_order_nums = 0 # type: ignore[attr-defined] + if hasattr(scheduler, "last_sample"): + scheduler.last_sample = None # type: ignore[attr-defined] + + def _process_single_block_ode( + self, + *, + current_latents: torch.Tensor, + batch: ForwardBatch, + start_index: int, + current_num_frames: int, + timesteps: torch.Tensor, + ctx: BlockProcessingContext, + action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], + progress_bar: Any | None = None, + ) -> torch.Tensor: + prompt_embeds = batch.prompt_embeds + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + { + "generator": batch.generator, + "eta": batch.eta, + }, + ) + + for i, t_cur in enumerate(timesteps): + if ctx.boundary_timestep is not None and t_cur < ctx.boundary_timestep: + current_model = self.transformer_2 if self.transformer_2 is not None else self.transformer + else: + current_model = self.transformer + + latent_model_input = current_latents.to(ctx.target_dtype) + + independent_first_frame = getattr(self.transformer, + "independent_first_frame", False) + if batch.image_latent is not None and not independent_first_frame: + image_latent_chunk = batch.image_latent[:, :, start_index: + start_index + + current_num_frames, :, :] + latent_model_input = torch.cat([ + latent_model_input, + image_latent_chunk.to(ctx.target_dtype) + ], + dim=1) + elif (batch.image_latent is not None and independent_first_frame + and start_index == 0): + latent_model_input = torch.cat([ + latent_model_input, + batch.image_latent.to(ctx.target_dtype) + ], + dim=2) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_cur) + + # Build attention metadata if VSA is available + if vsa_available and self.attn_backend == VideoSparseAttentionBackend: + self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls( + ) + if self.attn_metadata_builder_cls is not None: + self.attn_metadata_builder = self.attn_metadata_builder_cls( + ) + h, w = current_latents.shape[-2:] + attn_metadata = self.attn_metadata_builder.build( + current_timestep=i, + raw_latent_shape=(current_num_frames, h, w), + patch_size=ctx.fastvideo_args.pipeline_config. + dit_config.patch_size, + VSA_sparsity=ctx.fastvideo_args.VSA_sparsity, + device=get_local_torch_device(), + ) + assert attn_metadata is not None, "attn_metadata cannot be None" + else: + attn_metadata = None + else: + attn_metadata = None + + with torch.autocast(device_type="cuda", + dtype=ctx.target_dtype, + enabled=ctx.autocast_enabled), \ + set_forward_context(current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch): + t_expanded = t_cur * torch.ones( + (latent_model_input.shape[0], current_num_frames), + device=latent_model_input.device, + dtype=t_cur.dtype) + + model_kwargs = { + "kv_cache": ctx.get_kv_cache(t_cur), + "crossattn_cache": ctx.crossattn_cache, + "current_start": start_index * self.frame_seq_length, + "start_frame": start_index, + "is_cache": False, + } + + if self.use_action_module and current_model == self.transformer: + model_kwargs.update({ + "kv_cache_mouse": + ctx.kv_cache_mouse, + "kv_cache_keyboard": + ctx.kv_cache_keyboard, + }) + model_kwargs.update(action_kwargs) + + noise_pred = current_model( + latent_model_input, + prompt_embeds, + t_expanded, + **camera_action_kwargs, + **ctx.image_kwargs, + **ctx.pos_cond_kwargs, + **model_kwargs, + ) + + current_latents = self.scheduler.step( + noise_pred, + t_cur, + current_latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + if progress_bar is not None: + progress_bar.update() + + return current_latents def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch: @@ -645,7 +1094,10 @@ def streaming_reset(self, batch: ForwardBatch, ) and not fastvideo_args.disable_autocast latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] - patch_size = self.transformer.patch_size + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size patch_ratio = patch_size[-1] * patch_size[-2] self.frame_seq_length = latent_seq_length // patch_ratio @@ -821,6 +1273,7 @@ def streaming_noise_generator(shape: tuple, dtype: torch.dtype, timesteps=ctx.timesteps, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs={}, noise_generator=streaming_noise_generator, ) @@ -835,6 +1288,7 @@ def streaming_noise_generator(shape: tuple, dtype: torch.dtype, current_num_frames=current_num_frames, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs={}, context_noise=ctx.context_noise, ) diff --git a/fastvideo/registry.py b/fastvideo/registry.py index 230d878fa..99a3aad7f 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -32,20 +32,11 @@ TurboDiffusionT2V_1_3B_Config, ) from fastvideo.configs.pipelines.wan import ( - FastWan2_1_T2V_480P_Config, - FastWan2_2_TI2V_5B_Config, - MatrixGameI2V480PConfig, - SelfForcingWan2_2_T2V480PConfig, - SelfForcingWanT2V480PConfig, - WANV2VConfig, - Wan2_2_I2V_A14B_Config, - Wan2_2_T2V_A14B_Config, - Wan2_2_TI2V_5B_Config, - WanI2V480PConfig, - WanI2V720PConfig, - WanT2V480PConfig, - WanT2V720PConfig, -) + FastWan2_1_T2V_480P_Config, FastWan2_2_TI2V_5B_Config, + MatrixGameI2V480PConfig, SelfForcingWan2_2_T2V480PConfig, + SelfForcingWanT2V480PConfig, WANV2VConfig, Wan2_2_I2V_A14B_Config, + Wan2_2_T2V_A14B_Config, Wan2_2_TI2V_5B_Config, WanI2V480PConfig, + WanI2V720PConfig, WanT2V480PConfig, WanT2V720PConfig) from fastvideo.configs.pipelines.sd35 import SD35Config from fastvideo.configs.sample.base import SamplingParam from fastvideo.configs.sample.cosmos import ( @@ -558,7 +549,6 @@ def _register_configs() -> None: "FastVideo/SFWan2.2-I2V-A14B-Preview-Diffusers", ], ) - # SD3.5 register_configs( sampling_param_cls=SD35SamplingParam, diff --git a/fastvideo/tests/distillation/test_optimizer_scheduler_alignment.py b/fastvideo/tests/distillation/test_optimizer_scheduler_alignment.py new file mode 100644 index 000000000..98b2174b9 --- /dev/null +++ b/fastvideo/tests/distillation/test_optimizer_scheduler_alignment.py @@ -0,0 +1,97 @@ +import torch + +from fastvideo.train.methods.base import TrainingMethod + + +class _FakeScheduler: + def __init__(self) -> None: + self.step_calls = 0 + + def step(self) -> None: + self.step_calls += 1 + + +class _FakeOptimizer(torch.optim.Optimizer): + def __init__(self) -> None: + super().__init__([torch.zeros((), requires_grad=True)], {}) + self.step_calls = 0 + self.zero_grad_calls = 0 + + def step(self, closure=None): # noqa: ANN001, ANN201 + self.step_calls += 1 + if closure is not None: + closure() + + def zero_grad(self, *args, **kwargs): # noqa: ANN002, ANN003, ANN201 + self.zero_grad_calls += 1 + + +class _FakeModel: + transformer = None + + def on_train_start(self) -> None: + pass + + def get_rng_generators(self) -> dict: + return {} + + +class _FakeCfg: + class training: + pass + + method: dict = {} + validation: dict = {} + + +class _ScheduleMethod(TrainingMethod): + def __init__(self, interval: int) -> None: + self.student_opt = _FakeOptimizer() + self.critic_opt = _FakeOptimizer() + self.student_sched = _FakeScheduler() + self.critic_sched = _FakeScheduler() + cfg = _FakeCfg() + cfg.method = {} + cfg.validation = {} + role_models = {"student": _FakeModel()} # type: ignore[dict-item] + super().__init__(cfg=cfg, role_models=role_models) + self.interval = interval + + @property + def _optimizer_dict(self): # noqa: ANN201 + return {"student": self.student_opt, "critic": self.critic_opt} + + @property + def _lr_scheduler_dict(self): # noqa: ANN201 + return {"student": self.student_sched, "critic": self.critic_sched} + + def _update_student(self, iteration: int) -> bool: + return iteration % self.interval == 0 + + def single_train_step(self, batch, iteration, *, current_vsa_sparsity=0.0): # noqa: ANN001, ANN201 + loss = torch.zeros((), requires_grad=True) + return {"total_loss": loss}, {}, {} + + def get_optimizers(self, iteration): # noqa: ANN001, ANN201 + optimizers = [self.critic_opt] + if self._update_student(iteration): + optimizers.append(self.student_opt) + return optimizers + + def get_lr_schedulers(self, iteration): # noqa: ANN001, ANN201 + schedulers = [self.critic_sched] + if self._update_student(iteration): + schedulers.append(self.student_sched) + return schedulers + + +def test_optimizer_scheduler_alignment() -> None: + method = _ScheduleMethod(interval=5) + + for step in range(1, 11): + method.optimizers_schedulers_step(step) + + assert method.critic_opt.step_calls == 10 + assert method.critic_sched.step_calls == 10 + assert method.student_opt.step_calls == 2 + assert method.student_sched.step_calls == 2 diff --git a/fastvideo/train/.style.yapf b/fastvideo/train/.style.yapf new file mode 100644 index 000000000..c9a88d5a6 --- /dev/null +++ b/fastvideo/train/.style.yapf @@ -0,0 +1,3 @@ +[style] +based_on_style = pep8 +column_limit = 120 diff --git a/fastvideo/train/__init__.py b/fastvideo/train/__init__.py new file mode 100644 index 000000000..fed6b183c --- /dev/null +++ b/fastvideo/train/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.trainer import Trainer + +__all__ = [ + "Trainer", +] diff --git a/fastvideo/train/callbacks/__init__.py b/fastvideo/train/callbacks/__init__.py new file mode 100644 index 000000000..23334280a --- /dev/null +++ b/fastvideo/train/callbacks/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.callbacks.callback import ( + Callback, + CallbackDict, +) +from fastvideo.train.callbacks.ema import ( + EMACallback, +) +from fastvideo.train.callbacks.grad_clip import ( + GradNormClipCallback, +) +from fastvideo.train.callbacks.validation import ( + ValidationCallback, +) + +__all__ = [ + "Callback", + "CallbackDict", + "EMACallback", + "GradNormClipCallback", + "ValidationCallback", +] diff --git a/fastvideo/train/callbacks/callback.py b/fastvideo/train/callbacks/callback.py new file mode 100644 index 000000000..b44c7dc07 --- /dev/null +++ b/fastvideo/train/callbacks/callback.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Callback base class and CallbackDict manager. + +Adapted from FastGen's callback pattern to FastVideo's types. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TYPE_CHECKING + +from fastvideo.logger import init_logger +from fastvideo.train.utils.instantiate import instantiate + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +logger = init_logger(__name__) + +# Well-known callback names that don't need ``_target_`` in YAML. +_BUILTIN_CALLBACKS: dict[str, str] = { + "grad_clip": "fastvideo.train.callbacks.grad_clip.GradNormClipCallback", + "validation": "fastvideo.train.callbacks.validation.ValidationCallback", + "ema": "fastvideo.train.callbacks.ema.EMACallback", +} + + +class Callback: + """Base callback with no-op hooks. + + Subclasses override whichever hooks they need. The + ``training_config`` and ``method`` attributes are set by + ``CallbackDict`` after instantiation. + """ + + training_config: TrainingConfig + method: TrainingMethod + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_training_step_end( + self, + method: TrainingMethod, + loss_dict: dict[str, Any], + iteration: int = 0, + ) -> None: + pass + + def on_before_optimizer_step( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_validation_begin( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_validation_end( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_train_end( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def state_dict(self) -> dict[str, Any]: + return {} + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + pass + + +class CallbackDict: + """Manages a collection of named callbacks. + + Instantiates each callback from its ``_target_`` config and + dispatches hook calls to all registered callbacks. + """ + + def __init__( + self, + callback_configs: dict[str, dict[str, Any]], + training_config: TrainingConfig, + ) -> None: + self._callbacks: dict[str, Callback] = {} + if not callback_configs: + return + for name, cb_cfg in callback_configs.items(): + cb_cfg = dict(cb_cfg) + if "_target_" not in cb_cfg: + if name in _BUILTIN_CALLBACKS: + cb_cfg["_target_"] = ( + _BUILTIN_CALLBACKS[name] + ) + else: + logger.warning( + "Callback %r is missing " + "'_target_', skipping: %s", + name, + cb_cfg, + ) + continue + logger.info( + "Instantiating callback %r: %s", + name, + cb_cfg, + ) + cb = instantiate(cb_cfg) + if not isinstance(cb, Callback): + raise TypeError( + f"Callback {name!r} resolved to " + f"{type(cb).__name__}, expected a " + f"Callback subclass." + ) + cb.training_config = training_config + self._callbacks[name] = cb + + def __getattr__( + self, method_name: str, + ) -> Callable[..., Any]: + if method_name.startswith("_"): + raise AttributeError(method_name) + + if method_name == "state_dict": + + def _state_dict() -> dict[str, Any]: + return { + n: cb.state_dict() + for n, cb in self._callbacks.items() + } + + return _state_dict + + if method_name == "load_state_dict": + + def _load_state_dict( + state_dict: dict[str, Any], + ) -> None: + for n, cb in self._callbacks.items(): + if n in state_dict: + cb.load_state_dict(state_dict[n]) + else: + logger.warning( + "Callback %r not found in " + "checkpoint.", + n, + ) + + return _load_state_dict + + def _dispatch(*args: Any, **kwargs: Any) -> None: + for cb in self._callbacks.values(): + fn = getattr(cb, method_name, None) + if fn is None: + continue + if not callable(fn): + continue + fn(*args, **kwargs) + + return _dispatch diff --git a/fastvideo/train/callbacks/ema.py b/fastvideo/train/callbacks/ema.py new file mode 100644 index 000000000..328e39a76 --- /dev/null +++ b/fastvideo/train/callbacks/ema.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +"""EMA (Exponential Moving Average) callback. + +Updates EMA shadow weights after each training step. The model owns +the EMA network (created by ``ModelBase._setup_ema``); this callback +only performs the ``lerp_`` update. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + +from fastvideo.logger import init_logger +from fastvideo.train.callbacks.callback import Callback + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + +logger = init_logger(__name__) + + +class EMACallback(Callback): + """Update EMA parameters after each optimizer step. + + The EMA network lives on the method (``method.ema``). + If the method was created with ``use_ema: false``, the callback + detects this at train start and disables itself gracefully. + + Supports three beta strategies: + - ``constant``: fixed ``beta`` every step. + - ``power``: ``(1 - 1/t)^(gamma+1)``. + - ``halflife``: half-life in k-images with optional ramp-up. + """ + + def __init__( + self, + *, + type: str = "constant", + beta: float = 0.9999, + gamma: float = 16.97, + ema_halflife_kimg: float = 500.0, + ema_rampup_ratio: float | None = 0.05, + start_iter: int = 0, + batch_size: int = 1, + ) -> None: + self._type = str(type) + self._beta = float(beta) + self._gamma = float(gamma) + self._ema_halflife_kimg = float(ema_halflife_kimg) + self._ema_rampup_ratio = ( + float(ema_rampup_ratio) + if ema_rampup_ratio is not None + else None + ) + self._start_iter = int(start_iter) + self._batch_size = int(batch_size) + self._enabled = True + + # ---------------------------------------------------------- + # Hooks + # ---------------------------------------------------------- + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + ema = getattr(method, "ema", None) + if ema is None: + self._enabled = False + logger.info( + "EMA not found on method; " + "EMA callback disabled.", + ) + return + + assert not ema.training, ( + "EMA should be in eval mode" + ) + for name, p in ema.named_parameters(): + assert not p.requires_grad, ( + f"EMA parameter {name} should not " + f"require gradients" + ) + + def on_training_step_end( + self, + method: TrainingMethod, + loss_dict: dict[str, Any], + iteration: int = 0, + ) -> None: + if not self._enabled: + return + + if iteration < self._start_iter: + return + if iteration == self._start_iter: + logger.info( + "Starting EMA %r updates at iteration %d.", + "ema", + iteration, + ) + + beta = self._compute_beta(iteration) + ema = method.ema + ema_state = ema.state_dict() + + with torch.no_grad(): + for name, p_net in ( + method.student.transformer.named_parameters() + ): + full = self._gather_full(p_net) + ema_key = name.replace( + "_checkpoint_wrapped_module.", "", + ) + if ema_key not in ema_state: + if iteration == self._start_iter: + logger.warning( + "EMA param %r not found, " + "skipping.", + ema_key, + ) + continue + ema_p = ema_state[ema_key] + val = full.to( + device=ema_p.device, + dtype=ema_p.dtype, + ) + if iteration == self._start_iter: + ema_p.copy_(val) + else: + ema_p.lerp_(val, 1.0 - beta) + + for name, buf in ( + method.student.transformer.named_buffers() + ): + if name in ema_state: + ema_state[name].copy_( + buf.to( + device=ema_state[name].device, + dtype=ema_state[name].dtype, + ) + ) + + tracker = getattr(method, "tracker", None) + if tracker is not None: + tracker.log( + {"ema/beta": beta}, + iteration, + ) + + # ---------------------------------------------------------- + # Beta strategies + # ---------------------------------------------------------- + + def _compute_beta(self, iteration: int) -> float: + if self._type == "constant": + return self._beta + if self._type == "power": + it = max(iteration, 1) + return (1.0 - 1.0 / it) ** (self._gamma + 1) + if self._type == "halflife": + return self._halflife_beta(iteration) + raise ValueError( + f"Invalid EMA type: {self._type!r}" + ) + + def _halflife_beta(self, iteration: int) -> float: + hl_nimg = self._ema_halflife_kimg * 1000.0 + cur_nimg = iteration * self._batch_size + if self._ema_rampup_ratio is not None: + hl_nimg = min( + hl_nimg, + cur_nimg * self._ema_rampup_ratio, + ) + return 0.5 ** ( + self._batch_size / max(hl_nimg, 1e-8) + ) + + # ---------------------------------------------------------- + # FSDP helper + # ---------------------------------------------------------- + + @staticmethod + def _gather_full( + param: torch.Tensor, + ) -> torch.Tensor: + if hasattr(param, "full_tensor"): + if param.device.type == "cpu": + return param.to("cuda").full_tensor() + return param.full_tensor() + return param diff --git a/fastvideo/train/callbacks/grad_clip.py b/fastvideo/train/callbacks/grad_clip.py new file mode 100644 index 000000000..f8d445422 --- /dev/null +++ b/fastvideo/train/callbacks/grad_clip.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Gradient norm clipping callback. + +Clips gradients on modules returned by +``method.get_grad_clip_targets()`` before the optimizer step. +Optionally logs per-module grad norms to the tracker. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from fastvideo.logger import init_logger +from fastvideo.train.callbacks.callback import Callback +from fastvideo.train.utils.optimizer import ( + clip_grad_norm_if_needed, +) + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + +logger = init_logger(__name__) + + +class GradNormClipCallback(Callback): + """Clip gradient norms before the optimizer step. + + ``max_grad_norm`` must be set explicitly in the callback + config (``callbacks.grad_clip.max_grad_norm``). + """ + + def __init__( + self, + *, + max_grad_norm: float = 0.0, + log_grad_norms: bool = False, + ) -> None: + self._max_grad_norm = float(max_grad_norm) + self._log_grad_norms = bool(log_grad_norms) + + def on_before_optimizer_step( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + max_norm = self._max_grad_norm + if max_norm <= 0.0: + return + + targets = method.get_grad_clip_targets(iteration) + tracker = getattr(method, "tracker", None) + + for name, module in targets.items(): + grad_norm = clip_grad_norm_if_needed( + module, max_norm, + ) + if ( + self._log_grad_norms + and tracker is not None + and grad_norm > 0.0 + ): + tracker.log( + {f"grad_norm/{name}": grad_norm}, + iteration, + ) diff --git a/fastvideo/train/callbacks/validation.py b/fastvideo/train/callbacks/validation.py new file mode 100644 index 000000000..29f2edc9e --- /dev/null +++ b/fastvideo/train/callbacks/validation.py @@ -0,0 +1,767 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Validation callback for inference-pipeline driven validation. + +All configuration is read from the YAML ``callbacks.validation`` +section. The pipeline class is resolved from +``pipeline_target``. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING + +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from torch.utils.data import DataLoader + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.validation_dataset import ( + ValidationDataset, ) +from fastvideo.distributed import ( + get_sp_group, + get_world_group, +) +from fastvideo.logger import init_logger +from fastvideo.pipelines import ForwardBatch +from fastvideo.train.callbacks.callback import Callback +from fastvideo.train.utils.instantiate import resolve_target +from fastvideo.train.utils.moduleloader import ( + make_inference_args, ) +from fastvideo.training.trackers import DummyTracker +from fastvideo.utils import shallow_asdict + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +logger = init_logger(__name__) + + +@dataclass(slots=True) +class _ValidationStepResult: + videos: list[list[np.ndarray]] + captions: list[str] + + +class ValidationCallback(Callback): + """Generic validation callback driven entirely by YAML + config. + + Works with any pipeline that follows the + ``PipelineCls.from_pretrained(...)`` + ``pipeline.forward()`` + contract. + """ + + def __init__( + self, + *, + pipeline_target: str, + dataset_file: str, + every_steps: int = 100, + sampling_steps: list[int] | None = None, + sampler_kind: str = "ode", + scheduler_target: str | None = None, + guidance_scale: float | None = None, + num_frames: int | None = None, + output_dir: str | None = None, + sampling_timesteps: list[int] | None = None, + rollout_mode: str = "parallel", + **pipeline_kwargs: Any, + ) -> None: + self.pipeline_target = str(pipeline_target) + self.dataset_file = str(dataset_file) + self.every_steps = int(every_steps) + self.sampling_steps = ( + [int(s) for s in sampling_steps] + if sampling_steps + else [40] + ) + self.sampler_kind = str(sampler_kind) + self.scheduler_target = ( + str(scheduler_target) + if scheduler_target is not None + else None + ) + self.guidance_scale = ( + float(guidance_scale) + if guidance_scale is not None + else None + ) + self.num_frames = ( + int(num_frames) if num_frames is not None + else None + ) + self.output_dir = ( + str(output_dir) if output_dir is not None + else None + ) + self.sampling_timesteps = ( + [int(s) for s in sampling_timesteps] + if sampling_timesteps is not None + else None + ) + self.rollout_mode = str(rollout_mode) + self.pipeline_kwargs = dict(pipeline_kwargs) + + # Set after on_train_start. + self._pipeline: Any | None = None + self._pipeline_key: tuple[Any, ...] | None = None + self._sampling_param: SamplingParam | None = None + self.tracker: Any = DummyTracker() + self.validation_random_generator: ( + torch.Generator | None + ) = None + self.seed: int = 0 + + # ---------------------------------------------------------- + # Callback hooks + # ---------------------------------------------------------- + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + self.method = method + tc = self.training_config + + self.world_group = get_world_group() + self.sp_group = get_sp_group() + self.global_rank = self.world_group.rank + self.rank_in_sp_group = ( + self.sp_group.rank_in_group + ) + self.sp_world_size = self.sp_group.world_size + + seed = tc.data.seed + if seed is None: + raise ValueError( + "training.data.seed must be set " + "for validation" + ) + self.seed = int(seed) + self.validation_random_generator = ( + torch.Generator(device="cpu").manual_seed( + self.seed + ) + ) + + tracker = getattr(method, "tracker", None) + if tracker is not None: + self.tracker = tracker + + def on_validation_begin( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + if self.every_steps <= 0: + return + if iteration % self.every_steps != 0: + return + + self._run_validation(method, iteration) + + # ---------------------------------------------------------- + # Core validation logic + # ---------------------------------------------------------- + + def _run_validation( + self, + method: TrainingMethod, + step: int, + ) -> None: + tc = self.training_config + # Use EMA transformer for validation when available. + transformer = method.transformer_inference + was_training = bool( + getattr(transformer, "training", False) + ) + + output_dir = ( + self.output_dir + or tc.checkpoint.output_dir + ) + + # For streaming SDE pipelines we may need to + # temporarily set dmd_denoising_steps on + # pipeline_config. + old_dmd_denoising_steps = getattr( + tc.pipeline_config, + "dmd_denoising_steps", + None, + ) + try: + transformer.eval() + num_sp_groups = ( + self.world_group.world_size + // self.sp_group.world_size + ) + + for num_inference_steps in self.sampling_steps: + self._maybe_set_dmd_denoising_steps( + tc, + num_inference_steps, + ) + + result = self._run_validation_for_steps( + num_inference_steps, + transformer=transformer, + ) + + if self.rank_in_sp_group != 0: + continue + + if self.global_rank == 0: + all_videos = list(result.videos) + all_captions = list(result.captions) + for sp_idx in range( + 1, num_sp_groups + ): + src = ( + sp_idx * self.sp_world_size + ) + recv_v = ( + self.world_group.recv_object( + src=src + ) + ) + recv_c = ( + self.world_group.recv_object( + src=src + ) + ) + all_videos.extend(recv_v) + all_captions.extend(recv_c) + + os.makedirs( + output_dir, exist_ok=True, + ) + video_filenames: list[str] = [] + sp = self._get_sampling_param() + for i, video in enumerate(all_videos): + fname = os.path.join( + output_dir, + f"validation_step_{step}" + f"_inference_steps_" + f"{num_inference_steps}" + f"_video_{i}.mp4", + ) + imageio.mimsave( + fname, + video, + fps=sp.fps, + ) + video_filenames.append(fname) + + video_logs = [] + for fname, cap in zip( + video_filenames, + all_captions, + strict=True, + ): + art = self.tracker.video( + fname, caption=cap, + ) + if art is not None: + video_logs.append(art) + if video_logs: + logs = { + f"validation_videos_" + f"{num_inference_steps}" + f"_steps": video_logs + } + self.tracker.log_artifacts( + logs, step, + ) + else: + self.world_group.send_object( + result.videos, dst=0, + ) + self.world_group.send_object( + result.captions, dst=0, + ) + finally: + if hasattr(tc.pipeline_config, "dmd_denoising_steps"): + tc.pipeline_config.dmd_denoising_steps = ( + old_dmd_denoising_steps + ) + if was_training: + transformer.train() + + def _maybe_set_dmd_denoising_steps( + self, + tc: TrainingConfig, + num_inference_steps: int, + ) -> None: + """Set dmd_denoising_steps on pipeline_config for + streaming SDE validation.""" + if self.rollout_mode != "streaming": + return + if self.sampler_kind != "sde": + return + if self.sampling_timesteps is not None: + tc.pipeline_config.dmd_denoising_steps = ( # type: ignore[union-attr] + list(self.sampling_timesteps) + ) + else: + timesteps = np.linspace( + 1000, 0, int(num_inference_steps), + ) + tc.pipeline_config.dmd_denoising_steps = [ # type: ignore[union-attr] + int(max(0, min(1000, round(t)))) + for t in timesteps + ] + + # Also set any pipeline-specific kwargs from + # YAML (e.g. dmd_denoising_steps override). + pk = self.pipeline_kwargs + if "dmd_denoising_steps" in pk: + tc.pipeline_config.dmd_denoising_steps = [ # type: ignore[union-attr] + int(s) + for s in pk["dmd_denoising_steps"] + ] + + # ---------------------------------------------------------- + # Pipeline management + # ---------------------------------------------------------- + + def _get_sampling_param(self) -> SamplingParam: + if self._sampling_param is None: + self._sampling_param = ( + SamplingParam.from_pretrained( + self.training_config.model_path + ) + ) + return self._sampling_param + + def _get_pipeline( + self, + *, + transformer: torch.nn.Module, + ) -> Any: + key = ( + id(transformer), + self.rollout_mode, + self.sampler_kind, + self.scheduler_target, + ) + if ( + self._pipeline is not None + and self._pipeline_key == key + ): + return self._pipeline + + tc = self.training_config + PipelineCls = resolve_target(self.pipeline_target) + flow_shift = getattr( + tc.pipeline_config, "flow_shift", None, + ) + + kwargs: dict[str, Any] = { + "inference_mode": True, + "sampler_kind": self.sampler_kind, + "loaded_modules": { + "transformer": transformer, + }, + "tp_size": tc.distributed.tp_size, + "sp_size": tc.distributed.sp_size, + "num_gpus": tc.distributed.num_gpus, + "pin_cpu_memory": ( + tc.distributed.pin_cpu_memory + ), + "dit_cpu_offload": True, + } + if flow_shift is not None: + kwargs["flow_shift"] = float(flow_shift) + + # Build and inject a scheduler if target is set. + scheduler = self._build_scheduler(flow_shift) + if scheduler is not None: + kwargs["loaded_modules"]["scheduler"] = ( + scheduler + ) + + self._pipeline = PipelineCls.from_pretrained( + tc.model_path, **kwargs, + ) + self._pipeline_key = key + return self._pipeline + + def _build_scheduler( + self, flow_shift: float | None, + ) -> Any | None: + """Build scheduler from ``scheduler_target``.""" + if self.scheduler_target is None: + return None + if flow_shift is None: + return None + + SchedulerCls = resolve_target( + self.scheduler_target + ) + return SchedulerCls(shift=float(flow_shift)) + + # ---------------------------------------------------------- + # Batch preparation + # ---------------------------------------------------------- + + def _prepare_validation_batch( + self, + sampling_param: SamplingParam, + validation_batch: dict[str, Any], + num_inference_steps: int, + ) -> ForwardBatch: + tc = self.training_config + + sampling_param.prompt = validation_batch["prompt"] + sampling_param.height = tc.data.num_height + sampling_param.width = tc.data.num_width + sampling_param.num_inference_steps = int( + num_inference_steps + ) + sampling_param.data_type = "video" + if self.guidance_scale is not None: + sampling_param.guidance_scale = float( + self.guidance_scale + ) + sampling_param.seed = self.seed + + # image_path for I2V pipelines. + img_path = ( + validation_batch.get("image_path") + or validation_batch.get("video_path") + ) + if img_path is not None and ( + img_path.startswith("http") + or os.path.isfile(img_path) + ): + sampling_param.image_path = img_path + + temporal_compression_factor = int( + tc.pipeline_config.vae_config.arch_config.temporal_compression_ratio # type: ignore[union-attr] + ) + default_num_frames = ( + (tc.data.num_latent_t - 1) + * temporal_compression_factor + + 1 + ) + if self.num_frames is not None: + sampling_param.num_frames = int( + self.num_frames + ) + else: + sampling_param.num_frames = int( + default_num_frames + ) + + latents_size = [ + (sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, + sampling_param.width // 8, + ] + n_tokens = ( + latents_size[0] + * latents_size[1] + * latents_size[2] + ) + + sampling_timesteps_tensor = ( + torch.tensor( + [int(s) for s in self.sampling_timesteps], + dtype=torch.long, + ) + if self.sampling_timesteps is not None + else None + ) + + inference_args = make_inference_args( + tc, model_path=tc.model_path, + ) + + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=self.validation_random_generator, + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=tc.vsa.sparsity, + timesteps=sampling_timesteps_tensor, + sampling_timesteps=sampling_timesteps_tensor, + ) + batch._inference_args = inference_args # type: ignore[attr-defined] + + # Conditionally set I2V / action-conditioning fields. + if ( + "image" in validation_batch + and validation_batch["image"] is not None + ): + batch.pil_image = validation_batch["image"] + + self._maybe_set_action_conds( + batch, validation_batch, sampling_param, + ) + return batch + + def _maybe_set_action_conds( + self, + batch: ForwardBatch, + validation_batch: dict[str, Any], + sampling_param: SamplingParam, + ) -> None: + """Set keyboard_cond / mouse_cond on the batch if + present in the dataset.""" + target_len = int(sampling_param.num_frames) + + if ( + "keyboard_cond" in validation_batch + and validation_batch["keyboard_cond"] + is not None + ): + kb = torch.as_tensor( + validation_batch["keyboard_cond"] + ).to(dtype=torch.bfloat16) + if kb.ndim == 3 and kb.shape[0] == 1: + kb = kb.squeeze(0) + if kb.ndim != 2: + raise ValueError( + "validation keyboard_cond must have" + " shape (T, K), got " + f"{tuple(kb.shape)}" + ) + if kb.shape[0] > target_len: + kb = kb[:target_len] + elif kb.shape[0] < target_len: + pad = torch.zeros( + ( + target_len - kb.shape[0], + kb.shape[1], + ), + dtype=kb.dtype, + device=kb.device, + ) + kb = torch.cat([kb, pad], dim=0) + batch.keyboard_cond = kb.unsqueeze(0) + + if ( + "mouse_cond" in validation_batch + and validation_batch["mouse_cond"] + is not None + ): + mc = torch.as_tensor( + validation_batch["mouse_cond"] + ).to(dtype=torch.bfloat16) + if mc.ndim == 3 and mc.shape[0] == 1: + mc = mc.squeeze(0) + if mc.ndim != 2: + raise ValueError( + "validation mouse_cond must have " + "shape (T, 2), got " + f"{tuple(mc.shape)}" + ) + if mc.shape[0] > target_len: + mc = mc[:target_len] + elif mc.shape[0] < target_len: + pad = torch.zeros( + ( + target_len - mc.shape[0], + mc.shape[1], + ), + dtype=mc.dtype, + device=mc.device, + ) + mc = torch.cat([mc, pad], dim=0) + batch.mouse_cond = mc.unsqueeze(0) + + # ---------------------------------------------------------- + # Post-processing + # ---------------------------------------------------------- + + def _post_process_validation_frames( + self, + frames: list[np.ndarray], + batch: ForwardBatch, + ) -> list[np.ndarray]: + """Overlay action indicators if conditions present.""" + keyboard_cond = getattr(batch, "keyboard_cond", None) + mouse_cond = getattr(batch, "mouse_cond", None) + if keyboard_cond is None and mouse_cond is None: + return frames + + try: + from fastvideo.models.dits.matrixgame.utils import ( + draw_keys_on_frame, + draw_mouse_on_frame, + ) + except Exception as e: + logger.warning( + "Action overlay unavailable: %s", e, + ) + return frames + + if ( + keyboard_cond is not None + and torch.is_tensor(keyboard_cond) + ): + keyboard_np = ( + keyboard_cond.squeeze(0) + .detach() + .cpu() + .float() + .numpy() + ) + else: + keyboard_np = None + + if ( + mouse_cond is not None + and torch.is_tensor(mouse_cond) + ): + mouse_np = ( + mouse_cond.squeeze(0) + .detach() + .cpu() + .float() + .numpy() + ) + else: + mouse_np = None + + key_names = ["W", "S", "A", "D", "left", "right"] + processed: list[np.ndarray] = [] + for fi, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + if ( + keyboard_np is not None + and fi < len(keyboard_np) + ): + keys = { + key_names[i]: bool( + keyboard_np[fi, i] + ) + for i in range( + min( + len(key_names), + int(keyboard_np.shape[1]), + ) + ) + } + draw_keys_on_frame( + frame, keys, mode="universal", + ) + if ( + mouse_np is not None + and fi < len(mouse_np) + ): + pitch = float(mouse_np[fi, 0]) + yaw = float(mouse_np[fi, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + processed.append(frame) + return processed + + # ---------------------------------------------------------- + # Validation loop + # ---------------------------------------------------------- + + def _run_validation_for_steps( + self, + num_inference_steps: int, + *, + transformer: torch.nn.Module, + ) -> _ValidationStepResult: + tc = self.training_config + pipeline = self._get_pipeline( + transformer=transformer, + ) + sampling_param = self._get_sampling_param() + + dataset = ValidationDataset(self.dataset_file) + dataloader = DataLoader( + dataset, batch_size=None, num_workers=0, + ) + + inference_args = make_inference_args( + tc, model_path=tc.model_path, + ) + + videos: list[list[np.ndarray]] = [] + captions: list[str] = [] + + for validation_batch in dataloader: + batch = self._prepare_validation_batch( + sampling_param, + validation_batch, + num_inference_steps, + ) + + assert ( + batch.prompt is not None + and isinstance(batch.prompt, str) + ) + captions.append(batch.prompt) + + with torch.no_grad(): + output_batch = pipeline.forward( + batch, inference_args, + ) + + samples = output_batch.output.cpu() + if self.rank_in_sp_group != 0: + continue + + video = rearrange( + samples, "b c t h w -> t b c h w", + ) + frames: list[np.ndarray] = [] + for x in video: + x = torchvision.utils.make_grid( + x, nrow=6, + ) + x = ( + x.transpose(0, 1) + .transpose(1, 2) + .squeeze(-1) + ) + frames.append( + (x * 255).numpy().astype(np.uint8) + ) + frames = ( + self._post_process_validation_frames( + frames, batch, + ) + ) + videos.append(frames) + + return _ValidationStepResult( + videos=videos, captions=captions, + ) + + # ---------------------------------------------------------- + # State management + # ---------------------------------------------------------- + + def state_dict(self) -> dict[str, Any]: + state: dict[str, Any] = {} + if self.validation_random_generator is not None: + state["validation_rng"] = ( + self.validation_random_generator.get_state() + ) + return state + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + rng_state = state_dict.get("validation_rng") + if ( + rng_state is not None + and self.validation_random_generator is not None + ): + self.validation_random_generator.set_state( + rng_state + ) diff --git a/fastvideo/train/entrypoint/__init__.py b/fastvideo/train/entrypoint/__init__.py new file mode 100644 index 000000000..988131360 --- /dev/null +++ b/fastvideo/train/entrypoint/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/fastvideo/train/entrypoint/dcp_to_diffusers.py b/fastvideo/train/entrypoint/dcp_to_diffusers.py new file mode 100644 index 000000000..a62dde639 --- /dev/null +++ b/fastvideo/train/entrypoint/dcp_to_diffusers.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Convert a DCP training checkpoint to a diffusers-style model directory. + +Works on a single GPU regardless of how many GPUs were used for training +(DCP handles resharding automatically). + +Usage (no torchrun needed):: + + python -m fastvideo.train.entrypoint.dcp_to_diffusers \ + --checkpoint /path/to/checkpoint-1000 \ + --output-dir /path/to/diffusers_output + +Or with torchrun (also fine):: + + torchrun --nproc_per_node=1 \ + -m fastvideo.train.entrypoint.dcp_to_diffusers \ + --checkpoint ... --output-dir ... + +The checkpoint must contain ``metadata.json`` (written by +``CheckpointManager``). If the checkpoint predates metadata +support, pass ``--config`` explicitly to provide the training +YAML. +""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Any + +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +def _ensure_distributed() -> None: + """Set up a single-process distributed env if needed. + + When running under ``torchrun`` the env vars are already set. + For plain ``python`` we fill in the minimum required vars so + that ``init_process_group`` succeeds with world_size=1. + """ + for key, default in [ + ("RANK", "0"), + ("LOCAL_RANK", "0"), + ("WORLD_SIZE", "1"), + ("MASTER_ADDR", "127.0.0.1"), + ("MASTER_PORT", "29500"), + ]: + os.environ.setdefault(key, default) + + +def _save_role_pretrained( + *, + role: str, + base_model_path: str, + output_dir: str, + module_names: list[str] | None = None, + overwrite: bool = False, + model: Any, +) -> str: + """Export a role's modules into a diffusers-style model dir. + + Produces a ``model_path`` loadable by + ``PipelineComponentLoader`` (``model_index.json``, + ``transformer/``, ``vae/``, etc. copied from + ``base_model_path``). + """ + import shutil + from pathlib import Path + + import torch + import torch.distributed as dist + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) + + from fastvideo.utils import maybe_download_model + + def _rank() -> int: + if dist.is_available() and dist.is_initialized(): + return int(dist.get_rank()) + return 0 + + def _barrier() -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + local_base = Path( + maybe_download_model(str(base_model_path)) + ).resolve() + dst = Path( + os.path.expanduser(str(output_dir)) + ).resolve() + + if _rank() == 0: + if dst.exists(): + if overwrite: + shutil.rmtree(dst, ignore_errors=True) + else: + raise FileExistsError( + f"Refusing to overwrite existing " + f"directory: {dst}. " + "Pass --overwrite to replace it." + ) + + def _copy_or_link(src: str, dest: str) -> None: + try: + os.link(src, dest) + except OSError: + shutil.copy2(src, dest) + + logger.info( + "Creating pretrained export dir at %s " + "(base=%s)", dst, local_base, + ) + shutil.copytree( + local_base, dst, symlinks=True, + copy_function=_copy_or_link, + ) + + _barrier() + + modules: dict[str, torch.nn.Module] = {} + if model.transformer is not None: + modules["transformer"] = model.transformer + + if module_names is None: + module_names = sorted(modules.keys()) + + for module_name in module_names: + if module_name not in modules: + raise KeyError( + f"Role {role!r} does not have module " + f"{module_name!r}. " + f"Available: {sorted(modules.keys())}" + ) + + module_dir = dst / module_name + if not module_dir.is_dir(): + raise FileNotFoundError( + f"Export directory missing component " + f"dir {module_name!r}: {module_dir}" + ) + + options = StateDictOptions( + full_state_dict=True, cpu_offload=True, + ) + state_dict = get_model_state_dict( + modules[module_name], options=options, + ) + + if _rank() == 0: + for path in module_dir.glob("*.safetensors"): + path.unlink(missing_ok=True) + + tensor_state: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + raise TypeError( + f"Expected tensor in state_dict " + f"for {module_name}.{key}, " + f"got {type(value).__name__}" + ) + tensor_state[key] = value.detach().cpu() + + from safetensors.torch import save_file + + out_path = module_dir / "model.safetensors" + logger.info( + "Saving %s weights to %s (%s tensors)", + module_name, out_path, + len(tensor_state), + ) + save_file(tensor_state, str(out_path)) + + _barrier() + + return str(dst) + + +def convert( + *, + checkpoint_dir: str, + output_dir: str, + config_path: str | None = None, + role: str = "student", + overwrite: bool = False, +) -> str: + """Load a DCP checkpoint and export as a diffusers model. + + Returns the path to the exported model directory. + """ + _ensure_distributed() + + from fastvideo.distributed import ( + maybe_init_distributed_environment_and_model_parallel, + ) + from fastvideo.train.utils.builder import build_from_config + from fastvideo.train.utils.checkpoint import ( + CheckpointManager, + _resolve_resume_checkpoint, + ) + from fastvideo.train.utils.config import ( + RunConfig, + load_run_config, + ) + + import torch.distributed.checkpoint as dcp + + # -- Resolve checkpoint directory -- + resolved = _resolve_resume_checkpoint( + checkpoint_dir, output_dir=checkpoint_dir, + ) + dcp_dir = resolved / "dcp" + if not dcp_dir.is_dir(): + raise FileNotFoundError( + f"Missing dcp/ under {resolved}" + ) + + # -- Obtain config -- + cfg: RunConfig + if config_path is not None: + cfg = load_run_config(config_path) + else: + metadata = CheckpointManager.load_metadata( + resolved + ) + raw_config = metadata.get("config") + if raw_config is None: + raise ValueError( + "Checkpoint metadata.json does not " + "contain 'config'. Pass --config " + "explicitly." + ) + cfg = _run_config_from_raw(raw_config) + + tc = cfg.training + + # -- Init distributed (1 GPU is enough; DCP reshards) -- + maybe_init_distributed_environment_and_model_parallel( + tp_size=1, sp_size=1, + ) + + # Override distributed config so model loading uses 1 GPU. + tc.distributed.tp_size = 1 + tc.distributed.sp_size = 1 + tc.distributed.num_gpus = 1 + tc.distributed.hsdp_replicate_dim = 1 + tc.distributed.hsdp_shard_dim = 1 + + # -- Build model (loads pretrained weights + FSDP) -- + _, method, _, _ = build_from_config(cfg) + + # -- Load DCP weights into the model -- + states = method.checkpoint_state() + logger.info( + "Loading DCP checkpoint from %s", resolved, + ) + dcp.load(states, checkpoint_id=str(dcp_dir)) + + # -- Export to diffusers format -- + model = method._role_models[role] + base_model_path = str(tc.model_path) + if not base_model_path: + raise ValueError( + "Cannot determine base_model_path from " + "config. Ensure models.student.init_from " + "is set." + ) + + logger.info( + "Exporting role=%s to %s (base=%s)", + role, + output_dir, + base_model_path, + ) + result = _save_role_pretrained( + role=role, + base_model_path=base_model_path, + output_dir=output_dir, + overwrite=overwrite, + model=model, + ) + logger.info("Export complete: %s", result) + return result + + +def _run_config_from_raw( + raw: dict[str, Any], +) -> Any: + """Reconstruct a RunConfig from a raw config dict. + + This mirrors ``load_run_config`` but operates on an + already-parsed dict (from metadata.json) instead of + reading from a YAML file. + """ + from fastvideo.train.utils.config import ( + RunConfig, + _build_training_config, + _parse_pipeline_config, + _require_mapping, + _require_str, + ) + + models_raw = _require_mapping( + raw.get("models"), where="models", + ) + models: dict[str, dict[str, Any]] = {} + for role_key, model_cfg_raw in models_raw.items(): + role_str = _require_str( + role_key, where="models.", + ) + model_cfg = _require_mapping( + model_cfg_raw, + where=f"models.{role_str}", + ) + models[role_str] = dict(model_cfg) + + method_raw = _require_mapping( + raw.get("method"), where="method", + ) + method = dict(method_raw) + + callbacks_raw = raw.get("callbacks", None) + callbacks: dict[str, dict[str, Any]] = ( + _require_mapping( + callbacks_raw, where="callbacks", + ) + if callbacks_raw is not None + else {} + ) + + pipeline_config = _parse_pipeline_config( + raw, models=models, + ) + + training_raw = _require_mapping( + raw.get("training"), where="training", + ) + t = dict(training_raw) + training = _build_training_config( + t, + models=models, + pipeline_config=pipeline_config, + ) + + return RunConfig( + models=models, + method=method, + training=training, + callbacks=callbacks, + raw=raw, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Convert a DCP training checkpoint to a " + "diffusers-style model directory. " + "Only 1 GPU needed (DCP reshards " + "automatically)." + ), + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help=( + "Path to checkpoint- dir, its dcp/ " + "subdir, or an output_dir (auto-picks " + "latest)." + ), + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Destination for the diffusers model.", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help=( + "Training YAML config. If omitted, read " + "from checkpoint metadata.json." + ), + ) + parser.add_argument( + "--role", + type=str, + default="student", + help="Role to export (default: student).", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output-dir if it exists.", + ) + args = parser.parse_args(sys.argv[1:]) + + convert( + checkpoint_dir=args.checkpoint, + output_dir=args.output_dir, + config_path=args.config, + role=args.role, + overwrite=args.overwrite, + ) + + +if __name__ == "__main__": + main() diff --git a/fastvideo/train/entrypoint/train.py b/fastvideo/train/entrypoint/train.py new file mode 100644 index 000000000..1c253004e --- /dev/null +++ b/fastvideo/train/entrypoint/train.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +"""YAML-only training entrypoint. + +Usage:: + + torchrun --nproc_per_node= -m fastvideo.train.entrypoint.train \ + --config path/to/run.yaml +""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Any + +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +def run_training_from_config( + config_path: str, + *, + dry_run: bool = False, + resume_from_checkpoint: str | None = None, + override_output_dir: str | None = None, +) -> None: + """YAML-only training entrypoint (schema v2).""" + + from fastvideo.distributed import ( + maybe_init_distributed_environment_and_model_parallel, + ) + from fastvideo.train import Trainer + from fastvideo.train.utils.checkpoint import ( + CheckpointConfig, + CheckpointManager, + ) + from fastvideo.train.utils.builder import build_from_config + from fastvideo.train.utils.config import load_run_config + + cfg = load_run_config(config_path) + tc = cfg.training + + if resume_from_checkpoint is not None: + tc.checkpoint.resume_from_checkpoint = str( + resume_from_checkpoint + ) + if override_output_dir is not None: + tc.checkpoint.output_dir = str(override_output_dir) + + maybe_init_distributed_environment_and_model_parallel( + tc.distributed.tp_size, + tc.distributed.sp_size, + ) + + _, method, dataloader, start_step = build_from_config( + cfg + ) + + if dry_run: + logger.info( + "Dry-run: config parsed and " + "build_from_config succeeded." + ) + return + + trainer = Trainer( + tc, + config=cfg.resolved_config(), + callback_configs=cfg.callbacks, + ) + + # Attach the exact YAML used for this run to the + # tracker (e.g., W&B Files). + trainer.tracker.log_file( + os.path.abspath(os.path.expanduser(config_path)), + name="run.yaml", + ) + + ckpt_config = CheckpointConfig( + save_steps=int( + tc.checkpoint.training_state_checkpointing_steps + or 0 + ), + keep_last=int( + tc.checkpoint.checkpoints_total_limit or 0 + ), + ) + + checkpoint_manager = CheckpointManager( + method=method, + dataloader=dataloader, + output_dir=tc.checkpoint.output_dir, + config=ckpt_config, + callbacks=trainer.callbacks, + raw_config=cfg.raw, + ) + + trainer.run( + method, + dataloader=dataloader, + max_steps=tc.loop.max_train_steps, + start_step=start_step, + checkpoint_manager=checkpoint_manager, + ) + + +def main(args: Any) -> None: + config_path = str(args.config) + dry_run = bool(args.dry_run) + resume_from_checkpoint = getattr( + args, "resume_from_checkpoint", None + ) + override_output_dir = getattr( + args, "override_output_dir", None + ) + logger.info( + "Starting training from config=%s", + config_path, + ) + run_training_from_config( + config_path, + dry_run=dry_run, + resume_from_checkpoint=resume_from_checkpoint, + override_output_dir=override_output_dir, + ) + logger.info("Training completed") + + +if __name__ == "__main__": + argv = sys.argv + parser = argparse.ArgumentParser( + description="YAML-only training entrypoint.", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help=( + "Path to training YAML config (schema v2)." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Parse config and build runtime, " + "but do not start training." + ), + ) + parser.add_argument( + "--resume-from-checkpoint", + type=str, + default=None, + help=( + "Path to a checkpoint directory " + "(checkpoint-), its 'dcp/' subdir, " + "or an output_dir containing checkpoints " + "(auto-picks latest)." + ), + ) + parser.add_argument( + "--override-output-dir", + type=str, + default=None, + help=( + "Override training.output_dir from YAML " + "(useful for repeated runs)." + ), + ) + args = parser.parse_args(argv[1:]) + main(args) diff --git a/fastvideo/train/methods/__init__.py b/fastvideo/train/methods/__init__.py new file mode 100644 index 000000000..61fd6ef2e --- /dev/null +++ b/fastvideo/train/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.methods.base import TrainingMethod + +__all__ = [ + "TrainingMethod", + "DMD2Method", + "FineTuneMethod", + "SelfForcingMethod", + "DiffusionForcingSFTMethod", +] + + +def __getattr__(name: str) -> object: + if name == "DMD2Method": + from fastvideo.train.methods.distribution_matching.dmd2 import DMD2Method + return DMD2Method + if name == "FineTuneMethod": + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + return FineTuneMethod + if name == "SelfForcingMethod": + from fastvideo.train.methods.distribution_matching.self_forcing import SelfForcingMethod + return SelfForcingMethod + if name == "DiffusionForcingSFTMethod": + from fastvideo.train.methods.fine_tuning.dfsft import DiffusionForcingSFTMethod + return DiffusionForcingSFTMethod + raise AttributeError(name) diff --git a/fastvideo/train/methods/base.py b/fastvideo/train/methods/base.py new file mode 100644 index 000000000..06b63a3d7 --- /dev/null +++ b/fastvideo/train/methods/base.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Literal, cast + +import torch + +from fastvideo.logger import init_logger +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.checkpoint import _RoleModuleContainer +from fastvideo.training.checkpointing_utils import ( + ModelWrapper, + OptimizerWrapper, + RandomStateWrapper, + SchedulerWrapper, +) + +logger = init_logger(__name__) + +LogScalar = float | int | torch.Tensor + + +class TrainingMethod(torch.nn.Module, ABC): + """Base training method (algorithm layer). + + Subclasses own their role models (student, teacher, critic, …) as + plain attributes and manage optimizers directly — no ``RoleManager`` + or ``RoleHandle``. + + The constructor receives *role_models* (a ``dict[str, ModelBase]``) + and a *cfg* object. It calls ``init_preprocessors`` on the student + and builds ``self.role_modules`` for FSDP wrapping. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__() + self.tracker: Any | None = None + self._role_models: dict[str, ModelBase] = dict(role_models) + + self.student = role_models["student"] + self.training_config = cfg.training + self.method_config: dict[str, Any] = dict(cfg.method) + self.validation_config: dict[str, Any] = dict( + getattr(cfg, "validation", {}) or {} + ) + self._use_ema: bool = bool( + self.method_config.get("use_ema", False) + ) + + # Build nn.ModuleDict for FSDP / checkpoint visibility. + self.role_modules = torch.nn.ModuleDict() + for role, model in role_models.items(): + mods: dict[str, torch.nn.Module] = {} + transformer = getattr(model, "transformer", None) + if isinstance(transformer, torch.nn.Module): + mods["transformer"] = transformer + if mods: + self.role_modules[role] = torch.nn.ModuleDict(mods) + + self._setup_ema() + + # ------------------------------------------------------------------ + # EMA + # ------------------------------------------------------------------ + + def _setup_ema(self) -> None: + """Create EMA copy of student transformer. + + Called at the end of ``__init__``, before FSDP wrapping. + Only acts when ``use_ema: true`` is set in method config. + """ + if not self._use_ema: + return + logger.info( + "Initializing EMA from student transformer", + ) + ema = copy.deepcopy(self.student.transformer) + ema.eval().requires_grad_(False) + self.ema = ema + # Register in role_modules for FSDP / checkpoint. + if "student" not in self.role_modules: + self.role_modules["student"] = ( + torch.nn.ModuleDict() + ) + self.role_modules["student"]["ema"] = ema # type: ignore[index] + + @property + def transformer_inference(self) -> torch.nn.Module: + """Return EMA transformer for inference if available.""" + if self._use_ema: + ema = getattr(self, "ema", None) + if ema is not None: + return ema + return self.student.transformer + + # ------------------------------------------------------------------ + + def set_tracker(self, tracker: Any) -> None: + self.tracker = tracker + + @abstractmethod + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + raise NotImplementedError + + @abstractmethod + def get_optimizers( + self, iteration: int, + ) -> Sequence[torch.optim.Optimizer]: + raise NotImplementedError + + @abstractmethod + def get_lr_schedulers( + self, iteration: int, + ) -> Sequence[Any]: + raise NotImplementedError + + @property + @abstractmethod + def _optimizer_dict(self) -> dict[str, Any]: + ... + + @property + @abstractmethod + def _lr_scheduler_dict(self) -> dict[str, Any]: + ... + + def checkpoint_state(self) -> dict[str, Any]: + """Return DCP-ready checkpoint state for all trainable roles. + + Keys follow the convention: + ``roles..``, ``optimizers.``, + ``schedulers.``, ``random_state.*``. + """ + states: dict[str, Any] = {} + + for role, model in self._role_models.items(): + if not getattr(model, "_trainable", False): + continue + + modules: dict[str, torch.nn.Module] = {} + if model.transformer is not None: + modules["transformer"] = model.transformer + ema = getattr(self, "ema", None) + if role == "student" and ema is not None: + modules["ema"] = ema + + container = _RoleModuleContainer(modules) + + for module_name, module in modules.items(): + states[ + f"roles.{role}.{module_name}" + ] = ModelWrapper(module) + + opt = self._optimizer_dict.get(role) + if opt is not None: + states[ + f"optimizers.{role}" + ] = OptimizerWrapper(container, opt) + + sched = self._lr_scheduler_dict.get(role) + if sched is not None: + states[ + f"schedulers.{role}" + ] = SchedulerWrapper(sched) + + # RNG states. + states["random_state"] = RandomStateWrapper(None) + for name, gen in ( + self.get_rng_generators() or {} + ).items(): + if gen is not None: + states[ + f"random_state.{name}" + ] = RandomStateWrapper(gen) + + return states + + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + del outputs + grad_accum_rounds = max(1, int(grad_accum_rounds)) + (loss_map["total_loss"] / grad_accum_rounds).backward() + + def optimizers_schedulers_step( + self, iteration: int, + ) -> None: + for optimizer in self.get_optimizers(iteration): + optimizer.step() + for scheduler in self.get_lr_schedulers(iteration): + scheduler.step() + + def optimizers_zero_grad( + self, iteration: int, + ) -> None: + for optimizer in self.get_optimizers(iteration): + try: + optimizer.zero_grad(set_to_none=True) + except TypeError: + optimizer.zero_grad() + + # -- Shared hooks (override in subclasses as needed) -- + + def get_grad_clip_targets( + self, iteration: int, + ) -> dict[str, torch.nn.Module]: + """Return modules whose gradients should be clipped. + + Override in subclasses to add/conditionally include + modules (e.g. critic, conditionally student). + Default: student transformer. + """ + return {"student": self.student.transformer} + + def on_train_start(self) -> None: + self.student.on_train_start() + + def get_rng_generators( + self, + ) -> dict[str, torch.Generator]: + generators: dict[str, torch.Generator] = {} + + student_gens = self.student.get_rng_generators() + generators.update(student_gens) + + return generators + + @staticmethod + def _parse_attn_kind( + raw: Any, + ) -> Literal["dense", "vsa"]: + if raw in (None, ""): + return "dense" + kind = str(raw).strip().lower() + if kind not in {"dense", "vsa"}: + raise ValueError( + "method_config.attn_kind must be one of " + f"{{'dense', 'vsa'}}, got {raw!r}." + ) + return cast(Literal["dense", "vsa"], kind) diff --git a/fastvideo/train/methods/consistency_model/__init__.py b/fastvideo/train/methods/consistency_model/__init__.py new file mode 100644 index 000000000..324710b84 --- /dev/null +++ b/fastvideo/train/methods/consistency_model/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/fastvideo/train/methods/distribution_matching/__init__.py b/fastvideo/train/methods/distribution_matching/__init__.py new file mode 100644 index 000000000..4edb43cf7 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.methods.distribution_matching.dmd2 import DMD2Method +from fastvideo.train.methods.distribution_matching.self_forcing import ( + SelfForcingMethod, ) + +__all__ = [ + "DMD2Method", + "SelfForcingMethod", +] diff --git a/fastvideo/train/methods/distribution_matching/dmd2.py b/fastvideo/train/methods/distribution_matching/dmd2.py new file mode 100644 index 000000000..9c2e07ef9 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/dmd2.py @@ -0,0 +1,745 @@ +# SPDX-License-Identifier: Apache-2.0 +"""DMD2 distillation method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) +from fastvideo.train.utils.config import ( + get_optional_float, + get_optional_int, + parse_betas, +) + + +class DMD2Method(TrainingMethod): + """DMD2 distillation algorithm (method layer). + + Owns role model instances directly: + - ``self.student`` — trainable student :class:`ModelBase` + - ``self.teacher`` — frozen teacher :class:`ModelBase` + - ``self.critic`` — trainable critic :class:`ModelBase` + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError( + "DMD2Method requires role 'student'" + ) + if "teacher" not in role_models: + raise ValueError( + "DMD2Method requires role 'teacher'" + ) + if "critic" not in role_models: + raise ValueError( + "DMD2Method requires role 'critic'" + ) + + self.teacher = role_models["teacher"] + self.critic = role_models["critic"] + + if not self.student._trainable: + raise ValueError( + "DMD2Method requires student to be trainable" + ) + if self.teacher._trainable: + raise ValueError( + "DMD2Method requires teacher to be " + "non-trainable" + ) + if not self.critic._trainable: + raise ValueError( + "DMD2Method requires critic to be trainable" + ) + self._cfg_uncond = self._parse_cfg_uncond() + self._rollout_mode = self._parse_rollout_mode() + self._denoising_step_list: torch.Tensor | None = ( + None + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict( + self, + ) -> dict[str, torch.optim.Optimizer]: + return { + "student": self._student_optimizer, + "critic": self._critic_optimizer, + } + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return { + "student": self._student_lr_scheduler, + "critic": self._critic_lr_scheduler, + } + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + latents_source: Literal["data", "zeros"] = "data" + if self._rollout_mode == "simulate": + latents_source = "zeros" + + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source=latents_source, + ) + + update_student = self._should_update_student( + iteration + ) + + generator_loss = torch.zeros( + (), + device=training_batch.latents.device, + dtype=training_batch.latents.dtype, + ) + student_ctx = None + if update_student: + generator_pred_x0 = self._student_rollout( + training_batch, with_grad=True + ) + student_ctx = ( + training_batch.timesteps, + training_batch.attn_metadata_vsa, + ) + generator_loss = self._dmd_loss( + generator_pred_x0, training_batch + ) + + ( + fake_score_loss, + critic_ctx, + critic_outputs, + ) = self._critic_flow_matching_loss(training_batch) + + total_loss = generator_loss + fake_score_loss + loss_map = { + "total_loss": total_loss, + "generator_loss": generator_loss, + "fake_score_loss": fake_score_loss, + } + + outputs: dict[str, Any] = dict(critic_outputs) + outputs["_fv_backward"] = { + "update_student": update_student, + "student_ctx": student_ctx, + "critic_ctx": critic_ctx, + } + metrics: dict[str, LogScalar] = { + "update_student": float(update_student) + } + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + backward_ctx = outputs.get("_fv_backward") + if not isinstance(backward_ctx, dict): + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + + update_student = bool( + backward_ctx.get("update_student", False) + ) + if update_student: + student_ctx = backward_ctx.get("student_ctx") + if student_ctx is None: + raise RuntimeError( + "Missing student backward context" + ) + self.student.backward( + loss_map["generator_loss"], + student_ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + critic_ctx = backward_ctx.get("critic_ctx") + if critic_ctx is None: + raise RuntimeError( + "Missing critic backward context" + ) + self.critic.backward( + loss_map["fake_score_loss"], + critic_ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + optimizers: list[torch.optim.Optimizer] = [] + optimizers.append(self._critic_optimizer) + if self._should_update_student(iteration): + optimizers.append(self._student_optimizer) + return optimizers + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + schedulers: list[Any] = [] + schedulers.append(self._critic_lr_scheduler) + if self._should_update_student(iteration): + schedulers.append(self._student_lr_scheduler) + return schedulers + + # TrainingMethod override: get_grad_clip_targets + def get_grad_clip_targets( + self, iteration: int, + ) -> dict[str, torch.nn.Module]: + targets: dict[str, torch.nn.Module] = {} + if self._should_update_student(iteration): + targets["student"] = ( + self.student.transformer + ) + targets["critic"] = self.critic.transformer + return targets + + def _parse_rollout_mode( + self, + ) -> Literal["simulate", "data_latent"]: + raw = self.method_config.get( + "rollout_mode", None + ) + if raw is None: + raise ValueError( + "method_config.rollout_mode must be set " + "for DMD2" + ) + if not isinstance(raw, str): + raise ValueError( + "method_config.rollout_mode must be a " + "string, " + f"got {type(raw).__name__}" + ) + mode = raw.strip().lower() + if mode in ("simulate", "sim"): + return "simulate" + if mode in ("data_latent", "data", "vae_latent"): + return "data_latent" + raise ValueError( + "method_config.rollout_mode must be one of " + "{simulate, data_latent}, got " + f"{raw!r}" + ) + + def _parse_cfg_uncond( + self, + ) -> dict[str, Any] | None: + raw = self.method_config.get("cfg_uncond", None) + if raw is None: + return None + if not isinstance(raw, dict): + raise ValueError( + "method_config.cfg_uncond must be a dict " + f"when set, got {type(raw).__name__}" + ) + + cfg: dict[str, Any] = dict(raw) + + on_missing_raw = cfg.get("on_missing", "error") + if on_missing_raw is None: + on_missing_raw = "error" + if not isinstance(on_missing_raw, str): + raise ValueError( + "method_config.cfg_uncond.on_missing must " + "be a string, got " + f"{type(on_missing_raw).__name__}" + ) + on_missing = on_missing_raw.strip().lower() + if on_missing not in {"error", "ignore"}: + raise ValueError( + "method_config.cfg_uncond.on_missing must " + "be one of {error, ignore}, got " + f"{on_missing_raw!r}" + ) + cfg["on_missing"] = on_missing + + for channel, policy_raw in list(cfg.items()): + if channel == "on_missing": + continue + if policy_raw is None: + continue + if not isinstance(policy_raw, str): + raise ValueError( + "method_config.cfg_uncond values must " + "be strings, got " + f"{channel}=" + f"{type(policy_raw).__name__}" + ) + policy = policy_raw.strip().lower() + allowed = {"keep", "zero", "drop"} + if channel == "text": + allowed = {*allowed, "negative_prompt"} + if policy not in allowed: + raise ValueError( + "method_config.cfg_uncond values must " + "be one of " + f"{sorted(allowed)}, got " + f"{channel}={policy_raw!r}" + ) + cfg[channel] = policy + + return cfg + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + + # Student optimizer/scheduler. + student_lr = float(tc.optimizer.learning_rate) + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) + + # Critic optimizer/scheduler — must be set in + # method config. + critic_lr_raw = get_optional_float( + self.method_config, + "fake_score_learning_rate", + where="method.fake_score_learning_rate", + ) + if critic_lr_raw is None or critic_lr_raw == 0.0: + raise ValueError( + "method.fake_score_learning_rate must " + "be set to a positive value" + ) + critic_lr = float(critic_lr_raw) + + critic_betas_raw = self.method_config.get( + "fake_score_betas", None + ) + if critic_betas_raw is None: + raise ValueError( + "method.fake_score_betas must be set " + "(e.g. [0.0, 0.999])" + ) + critic_betas = parse_betas( + critic_betas_raw, + where="method.fake_score_betas", + ) + + critic_sched_raw = self.method_config.get( + "fake_score_lr_scheduler", None + ) + if critic_sched_raw is None: + raise ValueError( + "method.fake_score_lr_scheduler must " + "be set (e.g. 'constant')" + ) + critic_sched = str(critic_sched_raw) + critic_params = [ + p + for p in self.critic.transformer.parameters() + if p.requires_grad + ] + ( + self._critic_optimizer, + self._critic_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=critic_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=critic_lr, + betas=critic_betas, + scheduler_name=critic_sched, + ) + + def _should_update_student( + self, iteration: int, + ) -> bool: + interval = get_optional_int( + self.method_config, + "generator_update_interval", + where="method.generator_update_interval", + ) + if interval is None: + interval = 1 + if interval <= 0: + return True + return iteration % interval == 0 + + def _get_denoising_step_list( + self, device: torch.device, + ) -> torch.Tensor: + if ( + self._denoising_step_list is not None + and self._denoising_step_list.device == device + ): + return self._denoising_step_list + + raw = self.method_config.get( + "dmd_denoising_steps", None + ) + if not isinstance(raw, list) or not raw: + raise ValueError( + "method_config.dmd_denoising_steps must " + "be set for DMD2 distillation" + ) + + steps = torch.tensor( + [int(s) for s in raw], + dtype=torch.long, + device=device, + ) + + warp = self.method_config.get( + "warp_denoising_step", None + ) + if warp is None: + warp = False + if bool(warp): + timesteps = torch.cat(( + self.student.noise_scheduler.timesteps.to( + "cpu" + ), + torch.tensor( + [0], dtype=torch.float32 + ), + )).to(device) + steps = timesteps[1000 - steps] + + self._denoising_step_list = steps + return steps + + def _sample_rollout_timestep( + self, device: torch.device, + ) -> torch.Tensor: + step_list = self._get_denoising_step_list(device) + index = torch.randint( + 0, + len(step_list), + [1], + device=device, + dtype=torch.long, + ) + return step_list[index] + + def _student_rollout( + self, batch: Any, *, with_grad: bool, + ) -> torch.Tensor: + latents = batch.latents + device = latents.device + dtype = latents.dtype + step_list = self._get_denoising_step_list(device) + + if self._rollout_mode != "simulate": + timestep = self._sample_rollout_timestep( + device + ) + noise = torch.randn( + latents.shape, device=device, dtype=dtype + ) + noisy_latents = self.student.add_noise( + latents, noise, timestep + ) + pred_x0 = self.student.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + batch.dmd_latent_vis_dict[ + "generator_timestep" + ] = timestep + return pred_x0 + + target_timestep_idx = torch.randint( + 0, + len(step_list), + [1], + device=device, + dtype=torch.long, + ) + target_timestep_idx_int = int( + target_timestep_idx.item() + ) + target_timestep = step_list[target_timestep_idx] + + current_noise_latents = torch.randn( + latents.shape, device=device, dtype=dtype + ) + current_noise_latents_copy = ( + current_noise_latents.clone() + ) + + max_target_idx = len(step_list) - 1 + noise_latents: list[torch.Tensor] = [] + noise_latent_index = target_timestep_idx_int - 1 + + if max_target_idx > 0: + with torch.no_grad(): + for step_idx in range(max_target_idx): + current_timestep = step_list[step_idx] + current_timestep_tensor = ( + current_timestep + * torch.ones( + 1, + device=device, + dtype=torch.long, + ) + ) + + pred_clean = self.student.predict_x0( + current_noise_latents, + current_timestep_tensor, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + next_timestep = step_list[step_idx + 1] + next_timestep_tensor = ( + next_timestep + * torch.ones( + 1, + device=device, + dtype=torch.long, + ) + ) + noise = torch.randn( + latents.shape, + device=device, + dtype=pred_clean.dtype, + ) + current_noise_latents = ( + self.student.add_noise( + pred_clean, + noise, + next_timestep_tensor, + ) + ) + noise_latents.append( + current_noise_latents.clone() + ) + + if noise_latent_index >= 0: + if noise_latent_index >= len(noise_latents): + raise RuntimeError( + "noise_latent_index is out of bounds" + ) + noisy_input = noise_latents[noise_latent_index] + else: + noisy_input = current_noise_latents_copy + + if with_grad: + pred_x0 = self.student.predict_x0( + noisy_input, + target_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + else: + with torch.no_grad(): + pred_x0 = self.student.predict_x0( + noisy_input, + target_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + batch.dmd_latent_vis_dict[ + "generator_timestep" + ] = target_timestep.float().detach() + return pred_x0 + + def _critic_flow_matching_loss( + self, batch: Any, + ) -> tuple[torch.Tensor, Any, dict[str, Any]]: + with torch.no_grad(): + generator_pred_x0 = self._student_rollout( + batch, with_grad=False + ) + + device = generator_pred_x0.device + fake_score_timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + fake_score_timestep = ( + self.student.shift_and_clamp_timestep( + fake_score_timestep + ) + ) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_x0 = self.student.add_noise( + generator_pred_x0, noise, fake_score_timestep + ) + + pred_noise = self.critic.predict_noise( + noisy_x0, + fake_score_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + target = noise - generator_pred_x0 + flow_matching_loss = torch.mean( + (pred_noise - target)**2 + ) + + batch.fake_score_latent_vis_dict = { + "generator_pred_video": generator_pred_x0, + "fake_score_timestep": fake_score_timestep, + } + outputs = { + "fake_score_latent_vis_dict": ( + batch.fake_score_latent_vis_dict + ) + } + return ( + flow_matching_loss, + (batch.timesteps, batch.attn_metadata), + outputs, + ) + + def _dmd_loss( + self, + generator_pred_x0: torch.Tensor, + batch: Any, + ) -> torch.Tensor: + guidance_scale = get_optional_float( + self.method_config, + "real_score_guidance_scale", + where="method.real_score_guidance_scale", + ) + if guidance_scale is None: + guidance_scale = 1.0 + device = generator_pred_x0.device + + with torch.no_grad(): + timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + timestep = ( + self.student.shift_and_clamp_timestep( + timestep + ) + ) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_latents = self.student.add_noise( + generator_pred_x0, noise, timestep + ) + + faker_x0 = self.critic.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_cond_x0 = self.teacher.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_uncond_x0 = self.teacher.predict_x0( + noisy_latents, + timestep, + batch, + conditional=False, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_cfg_x0 = real_uncond_x0 + ( + real_cond_x0 - real_uncond_x0 + ) * guidance_scale + + denom = torch.abs( + generator_pred_x0 - real_cfg_x0 + ).mean() + grad = (faker_x0 - real_cfg_x0) / denom + grad = torch.nan_to_num(grad) + + loss = 0.5 * F.mse_loss( + generator_pred_x0.float(), + ( + generator_pred_x0.float() - grad.float() + ).detach(), + ) + return loss diff --git a/fastvideo/train/methods/distribution_matching/self_forcing.py b/fastvideo/train/methods/distribution_matching/self_forcing.py new file mode 100644 index 000000000..ae2547bd5 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/self_forcing.py @@ -0,0 +1,571 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Self-Forcing distillation method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal, TYPE_CHECKING + +import torch +import torch.distributed as dist + +from fastvideo.train.models.base import ( + CausalModelBase, + ModelBase, +) +from fastvideo.train.methods.distribution_matching.dmd2 import ( + DMD2Method, ) +from fastvideo.train.utils.config import ( + get_optional_float, + get_optional_int, +) +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler, ) +from fastvideo.models.utils import pred_noise_to_pred_video + +if TYPE_CHECKING: + from fastvideo.pipelines import TrainingBatch + + +def _require_bool(raw: Any, *, where: str) -> bool: + if isinstance(raw, bool): + return raw + raise ValueError(f"Expected bool at {where}, got {type(raw).__name__}") + + +def _require_str(raw: Any, *, where: str) -> str: + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"Expected non-empty string at {where}") + return raw + + +class SelfForcingMethod(DMD2Method): + """Self-Forcing DMD2 (distribution matching) method. + + Requires a causal student implementing ``CausalModelBase``. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__( + cfg=cfg, + role_models=role_models, + ) + + # Validate causal student. + if not isinstance(self.student, CausalModelBase): + raise ValueError("SelfForcingMethod requires a causal student " + "implementing CausalModelBase.") + + if self._rollout_mode != "simulate": + raise ValueError("SelfForcingMethod only supports " + "method_config.rollout_mode='simulate'") + + mcfg = self.method_config + + chunk_size = get_optional_int( + mcfg, + "chunk_size", + where="method_config.chunk_size", + ) + if chunk_size is None: + chunk_size = 3 + if chunk_size <= 0: + raise ValueError("method_config.chunk_size must be a positive " + f"integer, got {chunk_size}") + self._chunk_size = int(chunk_size) + + sample_type_raw = mcfg.get("student_sample_type", "sde") + sample_type = _require_str( + sample_type_raw, + where="method_config.student_sample_type", + ) + sample_type = sample_type.strip().lower() + if sample_type not in {"sde", "ode"}: + raise ValueError("method_config.student_sample_type must be one " + f"of {{sde, ode}}, got {sample_type_raw!r}") + self._student_sample_type: Literal["sde", "ode"] = ( + sample_type # type: ignore[assignment] + ) + + same_step_raw = mcfg.get("same_step_across_blocks", False) + if same_step_raw is None: + same_step_raw = False + self._same_step_across_blocks = _require_bool( + same_step_raw, + where="method_config.same_step_across_blocks", + ) + + last_step_raw = mcfg.get("last_step_only", False) + if last_step_raw is None: + last_step_raw = False + self._last_step_only = _require_bool( + last_step_raw, + where="method_config.last_step_only", + ) + + context_noise = get_optional_float( + mcfg, + "context_noise", + where="method_config.context_noise", + ) + if context_noise is None: + context_noise = 0.0 + if context_noise < 0.0: + raise ValueError("method_config.context_noise must be >= 0, " + f"got {context_noise}") + self._context_noise = float(context_noise) + + enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True) + if enable_grad_raw is None: + enable_grad_raw = True + self._enable_gradient_in_rollout = _require_bool( + enable_grad_raw, + where="method_config.enable_gradient_in_rollout", + ) + + start_grad_frame = get_optional_int( + mcfg, + "start_gradient_frame", + where="method_config.start_gradient_frame", + ) + if start_grad_frame is None: + start_grad_frame = 0 + if start_grad_frame < 0: + raise ValueError("method_config.start_gradient_frame must be " + f">= 0, got {start_grad_frame}") + self._start_gradient_frame = int(start_grad_frame) + + shift = float(getattr( + self.training_config.pipeline_config, + "flow_shift", + 0.0, + ) or 0.0) + self._sf_scheduler = SelfForcingFlowMatchScheduler( + num_inference_steps=1000, + num_train_timesteps=int(self.student.num_train_timesteps), + shift=shift, + sigma_min=0.0, + extra_one_step=True, + training=True, + ) + + self._sf_denoising_step_list: torch.Tensor | None = None + + def _get_denoising_step_list(self, device: torch.device) -> torch.Tensor: + if (self._sf_denoising_step_list is not None and self._sf_denoising_step_list.device == device): + return self._sf_denoising_step_list + + raw = self.method_config.get("dmd_denoising_steps", None) + if not isinstance(raw, list) or not raw: + raise ValueError("method_config.dmd_denoising_steps must be set " + "for self_forcing") + steps = torch.tensor( + [int(s) for s in raw], + dtype=torch.long, + device=device, + ) + + warp = self.method_config.get("warp_denoising_step", None) + if warp is None: + warp = False + if bool(warp): + timesteps = torch.cat(( + self._sf_scheduler.timesteps.to("cpu"), + torch.tensor([0], dtype=torch.float32), + )).to(device) + steps = timesteps[int(self.student.num_train_timesteps) - steps] + + self._sf_denoising_step_list = steps + return steps + + def _predict_x0_with_scheduler( + self, + model: ModelBase, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + attn_kind: Literal["dense", "vsa"], + ) -> torch.Tensor: + pred_noise = model.predict_noise( + noisy_latents, + timestep, + batch, + conditional=conditional, + cfg_uncond=self._cfg_uncond, + attn_kind=attn_kind, + ) + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=timestep, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + def _sf_add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + b, t = clean_latents.shape[:2] + noisy = self._sf_scheduler.add_noise( + clean_latents.flatten(0, 1), + noise.flatten(0, 1), + timestep, + ).unflatten(0, (b, t)) + return noisy + + def _timestep_to_sigma(self, timestep: torch.Tensor) -> torch.Tensor: + sigmas = self._sf_scheduler.sigmas.to(device=timestep.device, dtype=torch.float32) + timesteps = self._sf_scheduler.timesteps.to(device=timestep.device, dtype=torch.float32) + t = timestep.to(device=timestep.device, dtype=torch.float32) + if t.ndim == 2: + t = t.flatten(0, 1) + elif t.ndim == 1 and t.numel() == 1: + t = t.expand(1) + elif t.ndim != 1: + raise ValueError("Invalid timestep shape: " + f"{tuple(timestep.shape)}") + idx = torch.argmin( + (timesteps.unsqueeze(0) - t.unsqueeze(1)).abs(), + dim=1, + ) + return sigmas[idx] + + def _sample_exit_indices( + self, + *, + num_blocks: int, + num_steps: int, + device: torch.device, + ) -> list[int]: + if num_blocks <= 0: + return [] + if num_steps <= 0: + raise ValueError("num_steps must be positive") + + shape = ((1, ) if self._same_step_across_blocks else (num_blocks, )) + + if not dist.is_initialized() or dist.get_rank() == 0: + if self._last_step_only: + indices = torch.full( + shape, + num_steps - 1, + dtype=torch.long, + device=device, + ) + else: + indices = torch.randint( + low=0, + high=num_steps, + size=shape, + device=device, + ) + else: + indices = torch.empty(shape, dtype=torch.long, device=device) + + if dist.is_initialized(): + dist.broadcast(indices, src=0) + + if self._same_step_across_blocks: + return [int(indices.item()) for _ in range(num_blocks)] + return [int(i) for i in indices.tolist()] + + def _student_rollout(self, batch: Any, *, with_grad: bool) -> torch.Tensor: + if not isinstance(self.student, CausalModelBase): + raise ValueError("SelfForcingMethod requires a causal student " + "implementing CausalModelBase.") + return self._student_rollout_streaming(batch, with_grad=with_grad) + + def _student_rollout_streaming(self, batch: Any, *, with_grad: bool) -> torch.Tensor: + assert isinstance(self.student, CausalModelBase) + latents = batch.latents + if latents is None: + raise RuntimeError("TrainingBatch.latents is required for " + "self-forcing rollout") + if latents.ndim != 5: + raise ValueError("TrainingBatch.latents must be [B, T, C, H, W]" + f", got shape={tuple(latents.shape)}") + + device = latents.device + dtype = latents.dtype + batch_size = int(latents.shape[0]) + num_frames = int(latents.shape[1]) + + denoising_steps = self._get_denoising_step_list(device) + num_steps = int(denoising_steps.numel()) + + noise_full = torch.randn_like(latents, device=device, dtype=dtype) + + chunk = int(self._chunk_size) + if chunk <= 0: + raise ValueError("chunk_size must be positive") + + remaining = num_frames % chunk + num_blocks = num_frames // chunk + if num_blocks == 0: + num_blocks = 1 + remaining = num_frames + + exit_indices = self._sample_exit_indices( + num_blocks=num_blocks, + num_steps=num_steps, + device=device, + ) + + denoised_blocks: list[torch.Tensor] = [] + + cache_tag = "pos" + self.student.clear_caches(cache_tag=cache_tag) + + for block_idx in range(num_blocks): + if block_idx == 0: + start = 0 + end = remaining + chunk if remaining else chunk + else: + start = remaining + block_idx * chunk + end = remaining + (block_idx + 1) * chunk + start = int(start) + end = int(min(end, num_frames)) + if start >= end: + break + + noisy_block = noise_full[:, start:end] + exit_idx = int(exit_indices[block_idx]) + + for step_idx, current_timestep in enumerate(denoising_steps): + exit_flag = step_idx == exit_idx + + timestep_block = (current_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + )) + + enable_grad = (bool(with_grad) and bool(self._enable_gradient_in_rollout) and torch.is_grad_enabled() + and start >= int(self._start_gradient_frame)) + + if not exit_flag: + with torch.no_grad(): + pred_noise = (self.student.predict_noise_streaming( + noisy_block, + timestep_block, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=False, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + )) + if pred_noise is None: + raise RuntimeError("predict_noise_streaming " + "returned None " + "(store_kv=False)") + pred_x0_chunk = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=(noisy_block.flatten(0, 1)), + timestep=timestep_block, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + + if step_idx + 1 >= num_steps: + break + next_timestep = denoising_steps[step_idx + 1] + if self._student_sample_type == "sde": + noisy_block = self._sf_add_noise( + pred_x0_chunk, + torch.randn_like(pred_x0_chunk), + next_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ), + ) + else: + sigma_cur = self._timestep_to_sigma(timestep_block).view(batch_size, end - start, 1, 1, 1) + sigma_next = self._timestep_to_sigma(next_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + )).view(batch_size, end - start, 1, 1, 1) + eps = (noisy_block - (1 - sigma_cur) * pred_x0_chunk) / sigma_cur.clamp_min(1e-8) + noisy_block = ((1 - sigma_next) * pred_x0_chunk + sigma_next * eps) + continue + + with torch.set_grad_enabled(enable_grad): + pred_noise = (self.student.predict_noise_streaming( + noisy_block, + timestep_block, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=False, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + )) + if pred_noise is None: + raise RuntimeError("predict_noise_streaming returned " + "None (store_kv=False)") + pred_x0_chunk = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=(noisy_block.flatten(0, 1)), + timestep=timestep_block, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + break + + denoised_blocks.append(pred_x0_chunk) + + with torch.no_grad(): + if self._context_noise > 0.0: + context_timestep = torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ) * float(self._context_noise) + context_latents = self._sf_add_noise( + pred_x0_chunk.detach(), + torch.randn_like(pred_x0_chunk), + context_timestep, + ) + else: + context_timestep = torch.zeros( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ) + context_latents = pred_x0_chunk.detach() + + _ = self.student.predict_noise_streaming( + context_latents, + context_timestep, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=True, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + if not denoised_blocks: + raise RuntimeError("Self-forcing rollout produced no blocks") + + self.student.clear_caches(cache_tag=cache_tag) + return torch.cat(denoised_blocks, dim=1) + + def _critic_flow_matching_loss(self, batch: Any) -> tuple[torch.Tensor, Any, dict[str, Any]]: + with torch.no_grad(): + generator_pred_x0 = self._student_rollout(batch, with_grad=False) + + device = generator_pred_x0.device + fake_score_timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + fake_score_timestep = (self.student.shift_and_clamp_timestep(fake_score_timestep)) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_x0 = self._sf_add_noise(generator_pred_x0, noise, fake_score_timestep) + + pred_noise = self.critic.predict_noise( + noisy_x0, + fake_score_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + target = noise - generator_pred_x0 + flow_matching_loss = torch.mean((pred_noise - target)**2) + + batch.fake_score_latent_vis_dict = { + "generator_pred_video": generator_pred_x0, + "fake_score_timestep": fake_score_timestep, + } + outputs = {"fake_score_latent_vis_dict": (batch.fake_score_latent_vis_dict)} + return ( + flow_matching_loss, + (batch.timesteps, batch.attn_metadata), + outputs, + ) + + def _dmd_loss( + self, + generator_pred_x0: torch.Tensor, + batch: Any, + ) -> torch.Tensor: + guidance_scale = get_optional_float( + self.method_config, + "real_score_guidance_scale", + where="method.real_score_guidance_scale", + ) + if guidance_scale is None: + guidance_scale = 1.0 + device = generator_pred_x0.device + + with torch.no_grad(): + timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + timestep = self.student.shift_and_clamp_timestep(timestep) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_latents = self._sf_add_noise(generator_pred_x0, noise, timestep) + + faker_x0 = self._predict_x0_with_scheduler( + self.critic, + noisy_latents, + timestep, + batch, + conditional=True, + attn_kind="dense", + ) + real_cond_x0 = self._predict_x0_with_scheduler( + self.teacher, + noisy_latents, + timestep, + batch, + conditional=True, + attn_kind="dense", + ) + real_uncond_x0 = self._predict_x0_with_scheduler( + self.teacher, + noisy_latents, + timestep, + batch, + conditional=False, + attn_kind="dense", + ) + real_cfg_x0 = real_uncond_x0 + (real_cond_x0 - real_uncond_x0) * guidance_scale + + denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean() + grad = (faker_x0 - real_cfg_x0) / denom + grad = torch.nan_to_num(grad) + + loss = 0.5 * torch.mean((generator_pred_x0.float() - (generator_pred_x0.float() - grad.float()).detach())**2) + return loss diff --git a/fastvideo/train/methods/fine_tuning/__init__.py b/fastvideo/train/methods/fine_tuning/__init__.py new file mode 100644 index 000000000..6f862df4e --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.train.methods.fine_tuning.dfsft import DiffusionForcingSFTMethod + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + +__all__ = [ + "DiffusionForcingSFTMethod", + "FineTuneMethod", +] + + +def __getattr__(name: str) -> object: + # Lazy import to avoid circular imports during registry bring-up. + if name == "DiffusionForcingSFTMethod": + from fastvideo.train.methods.fine_tuning.dfsft import ( + DiffusionForcingSFTMethod, ) + + return DiffusionForcingSFTMethod + if name == "FineTuneMethod": + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + + return FineTuneMethod + raise AttributeError(name) diff --git a/fastvideo/train/methods/fine_tuning/dfsft.py b/fastvideo/train/methods/fine_tuning/dfsft.py new file mode 100644 index 000000000..4f91110a2 --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/dfsft.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Diffusion-forcing SFT method (DFSFT; algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) + + +class DiffusionForcingSFTMethod(TrainingMethod): + """Diffusion-forcing SFT (DFSFT): train only ``student`` + with inhomogeneous timesteps. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError("DFSFT requires role 'student'") + if not self.student._trainable: + raise ValueError( + "DFSFT requires student to be trainable" + ) + self._attn_kind: Literal["dense", "vsa"] = ( + self._parse_attn_kind( + self.method_config.get("attn_kind", None) + ) + ) + + self._chunk_size = self._parse_chunk_size( + self.method_config.get("chunk_size", None) + ) + self._timestep_index_range = ( + self._parse_timestep_index_range() + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict(self) -> dict[str, Any]: + return {"student": self._student_optimizer} + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return {"student": self._student_lr_scheduler} + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + del iteration + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source="data", + ) + + if training_batch.latents is None: + raise RuntimeError( + "prepare_batch() must set TrainingBatch.latents" + ) + + clean_latents = training_batch.latents + if not torch.is_tensor(clean_latents): + raise TypeError( + "TrainingBatch.latents must be a torch.Tensor" + ) + if clean_latents.ndim != 5: + raise ValueError( + "TrainingBatch.latents must be " + "[B, T, C, H, W], got " + f"shape={tuple(clean_latents.shape)}" + ) + + batch_size, num_latents = ( + int(clean_latents.shape[0]), + int(clean_latents.shape[1]), + ) + + expected_chunk = getattr( + self.student.transformer, + "num_frame_per_block", + None, + ) + if ( + expected_chunk is not None + and int(expected_chunk) != int(self._chunk_size) + ): + raise ValueError( + "DFSFT chunk_size must match " + "transformer.num_frame_per_block for " + f"causal training (got {self._chunk_size}, " + f"expected {expected_chunk})." + ) + + timestep_indices = self._sample_t_inhom_indices( + batch_size=batch_size, + num_latents=num_latents, + device=clean_latents.device, + ) + sp_size = int( + self.training_config.distributed.sp_size + ) + sp_group = getattr(self.student, "sp_group", None) + if ( + sp_size > 1 + and sp_group is not None + and hasattr(sp_group, "broadcast") + ): + sp_group.broadcast(timestep_indices, src=0) + + scheduler = self.student.noise_scheduler + if scheduler is None: + raise ValueError( + "DFSFT requires student.noise_scheduler" + ) + + schedule_timesteps = scheduler.timesteps.to( + device=clean_latents.device, dtype=torch.float32 + ) + schedule_sigmas = scheduler.sigmas.to( + device=clean_latents.device, + dtype=clean_latents.dtype, + ) + t_inhom = schedule_timesteps[timestep_indices] + + # Override the homogeneous timesteps from prepare_batch + # so that set_forward_context (in predict_noise and + # backward) receives the correct per-chunk timesteps. + training_batch.timesteps = t_inhom + + noise = getattr(training_batch, "noise", None) + if noise is None: + noise = torch.randn_like(clean_latents) + else: + if not torch.is_tensor(noise): + raise TypeError( + "TrainingBatch.noise must be a " + "torch.Tensor when set" + ) + noise = noise.permute(0, 2, 1, 3, 4).to( + dtype=clean_latents.dtype + ) + + noisy_latents = self.student.add_noise( + clean_latents, + noise, + t_inhom.flatten(), + ) + + pred = self.student.predict_noise( + noisy_latents, + t_inhom, + training_batch, + conditional=True, + attn_kind=self._attn_kind, + ) + + if bool( + self.training_config.model.precondition_outputs + ): + sigmas = schedule_sigmas[timestep_indices] + sigmas = sigmas.unsqueeze(-1).unsqueeze( + -1 + ).unsqueeze(-1) + pred_x0 = noisy_latents - pred * sigmas + loss = F.mse_loss( + pred_x0.float(), clean_latents.float() + ) + else: + target = noise - clean_latents + loss = F.mse_loss( + pred.float(), target.float() + ) + + if self._attn_kind == "vsa": + attn_metadata = training_batch.attn_metadata_vsa + else: + attn_metadata = training_batch.attn_metadata + + loss_map = {"total_loss": loss, "dfsft_loss": loss} + outputs: dict[str, Any] = { + "_fv_backward": ( + training_batch.timesteps, + attn_metadata, + ) + } + metrics: dict[str, LogScalar] = {} + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + ctx = outputs.get("_fv_backward") + if ctx is None: + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + self.student.backward( + loss_map["total_loss"], + ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + del iteration + return [self._student_optimizer] + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + del iteration + return [self._student_lr_scheduler] + + def _parse_chunk_size(self, raw: Any) -> int: + if raw in (None, ""): + return 3 + if isinstance(raw, bool): + raise ValueError( + "method_config.chunk_size must be an int, " + "got bool" + ) + if isinstance(raw, float) and not raw.is_integer(): + raise ValueError( + "method_config.chunk_size must be an int, " + "got float" + ) + if isinstance(raw, str) and not raw.strip(): + raise ValueError( + "method_config.chunk_size must be an int, " + "got empty string" + ) + try: + value = int(raw) + except (TypeError, ValueError) as e: + raise ValueError( + "method_config.chunk_size must be an int, " + f"got {type(raw).__name__}" + ) from e + if value <= 0: + raise ValueError( + "method_config.chunk_size must be > 0" + ) + return value + + def _parse_ratio( + self, + raw: Any, + *, + where: str, + default: float, + ) -> float: + if raw in (None, ""): + return float(default) + if isinstance(raw, bool): + raise ValueError( + f"{where} must be a number/string, got bool" + ) + if isinstance(raw, int | float): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError( + f"{where} must be a number/string, " + f"got {type(raw).__name__}" + ) + + def _parse_timestep_index_range( + self, + ) -> tuple[int, int]: + scheduler = self.student.noise_scheduler + if scheduler is None: + raise ValueError( + "DFSFT requires student.noise_scheduler" + ) + num_steps = int( + getattr( + scheduler, "config", scheduler + ).num_train_timesteps + ) + + min_ratio = self._parse_ratio( + self.method_config.get( + "min_timestep_ratio", None + ), + where="method.min_timestep_ratio", + default=0.0, + ) + max_ratio = self._parse_ratio( + self.method_config.get( + "max_timestep_ratio", None + ), + where="method.max_timestep_ratio", + default=1.0, + ) + + if not ( + 0.0 <= min_ratio <= 1.0 + and 0.0 <= max_ratio <= 1.0 + ): + raise ValueError( + "DFSFT timestep ratios must be in [0,1], " + f"got min={min_ratio}, max={max_ratio}" + ) + if max_ratio < min_ratio: + raise ValueError( + "method_config.max_timestep_ratio must be " + ">= min_timestep_ratio" + ) + + min_index = int(min_ratio * num_steps) + max_index = int(max_ratio * num_steps) + min_index = max(0, min(min_index, num_steps - 1)) + max_index = max(0, min(max_index, num_steps - 1)) + + if max_index <= min_index: + max_index = min(num_steps - 1, min_index + 1) + + return min_index, max_index + 1 + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + student_lr = float(tc.optimizer.learning_rate) + if student_lr <= 0.0: + raise ValueError( + "training.learning_rate must be > 0 " + "for dfsft" + ) + + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) + + def _sample_t_inhom_indices( + self, + *, + batch_size: int, + num_latents: int, + device: torch.device, + ) -> torch.Tensor: + chunk_size = self._chunk_size + num_chunks = ( + (num_latents + chunk_size - 1) // chunk_size + ) + low, high = self._timestep_index_range + chunk_indices = torch.randint( + low=low, + high=high, + size=(batch_size, num_chunks), + device=device, + dtype=torch.long, + ) + expanded = chunk_indices.repeat_interleave( + chunk_size, dim=1 + ) + return expanded[:, :num_latents] diff --git a/fastvideo/train/methods/fine_tuning/finetune.py b/fastvideo/train/methods/fine_tuning/finetune.py new file mode 100644 index 000000000..cf7dc3139 --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/finetune.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Supervised finetuning method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) + + +class FineTuneMethod(TrainingMethod): + """Supervised finetuning: only ``student`` participates.""" + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError( + "FineTuneMethod requires role 'student'" + ) + if not self.student._trainable: + raise ValueError( + "FineTuneMethod requires student to be " + "trainable" + ) + self._attn_kind: Literal["dense", "vsa"] = ( + self._parse_attn_kind( + self.method_config.get("attn_kind", None) + ) + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict(self) -> dict[str, Any]: + return {"student": self._student_optimizer} + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return {"student": self._student_lr_scheduler} + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + del iteration + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source="data", + ) + + if training_batch.latents is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.latents" + ) + if training_batch.noisy_model_input is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.noisy_model_input" + ) + if training_batch.noise is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.noise" + ) + if training_batch.sigmas is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.sigmas" + ) + if training_batch.timesteps is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.timesteps" + ) + + clean_latents = training_batch.latents + noisy_latents = ( + training_batch.noisy_model_input.permute( + 0, 2, 1, 3, 4 + ) + ) + noise = training_batch.noise.permute( + 0, 2, 1, 3, 4 + ) + sigmas = training_batch.sigmas + timesteps = training_batch.timesteps + + pred = self.student.predict_noise( + noisy_latents, + timesteps, + training_batch, + conditional=True, + attn_kind=self._attn_kind, + ) + + if bool( + self.training_config.model.precondition_outputs + ): + pred_x0 = noisy_latents - pred * sigmas + loss = F.mse_loss( + pred_x0.float(), clean_latents.float() + ) + else: + target = noise - clean_latents + loss = F.mse_loss( + pred.float(), target.float() + ) + + if self._attn_kind == "vsa": + attn_metadata = training_batch.attn_metadata_vsa + else: + attn_metadata = training_batch.attn_metadata + + loss_map = { + "total_loss": loss, + "finetune_loss": loss, + } + outputs: dict[str, Any] = { + "_fv_backward": ( + training_batch.timesteps, + attn_metadata, + ) + } + metrics: dict[str, LogScalar] = {} + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + ctx = outputs.get("_fv_backward") + if ctx is None: + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + self.student.backward( + loss_map["total_loss"], + ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + del iteration + return [self._student_optimizer] + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + del iteration + return [self._student_lr_scheduler] + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + + student_lr = float(tc.optimizer.learning_rate) + if student_lr <= 0.0: + raise ValueError( + "training.learning_rate must be > 0 " + "for finetune" + ) + + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) diff --git a/fastvideo/train/methods/knowledge_distillation/__init__.py b/fastvideo/train/methods/knowledge_distillation/__init__.py new file mode 100644 index 000000000..324710b84 --- /dev/null +++ b/fastvideo/train/methods/knowledge_distillation/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/fastvideo/train/models/__init__.py b/fastvideo/train/models/__init__.py new file mode 100644 index 000000000..56b47b1af --- /dev/null +++ b/fastvideo/train/models/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Model build plugins for Phase 2/2.9 distillation. + +These are "model plugins" selected by ``recipe.family`` / ``roles..family``. +""" diff --git a/fastvideo/train/models/base.py b/fastvideo/train/models/base.py new file mode 100644 index 000000000..d74406278 --- /dev/null +++ b/fastvideo/train/models/base.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Literal, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + from fastvideo.pipelines import TrainingBatch + + +class ModelBase(ABC): + """Per-role model instance. + + Every role (student, teacher, critic, …) gets its own ``ModelBase`` + instance. Each instance owns its own ``transformer`` and + ``noise_scheduler``. Heavyweight resources (VAE, dataloader, RNG + seeds) are loaded lazily via :meth:`init_preprocessors`, which the + method calls **only on the student**. + """ + + transformer: torch.nn.Module + noise_scheduler: Any + _trainable: bool + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_preprocessors(self, training_config: TrainingConfig) -> None: + """Load VAE, build dataloader, seed RNGs. + + Called only on the student by the method's ``__init__``. + Default is a no-op so teacher/critic instances skip this. + """ + + def on_train_start(self) -> None: + """Called once before the training loop begins.""" + + def get_rng_generators(self) -> dict[str, torch.Generator]: + """Return RNG generators for checkpoint resume.""" + return {} + + # ------------------------------------------------------------------ + # Timestep helpers + # ------------------------------------------------------------------ + + @property + def num_train_timesteps(self) -> int: + """Return the scheduler's training timestep horizon.""" + return int(self.noise_scheduler.num_train_timesteps) + + def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor: + """Apply model/pipeline timestep shifting and clamp.""" + return timestep + + # ------------------------------------------------------------------ + # Runtime primitives + # ------------------------------------------------------------------ + + @abstractmethod + def prepare_batch( + self, + raw_batch: dict[str, Any], + *, + current_vsa_sparsity: float = 0.0, + latents_source: Literal["data", "zeros"] = "data", + ) -> TrainingBatch: + """Convert a dataloader batch into forward primitives.""" + + @abstractmethod + def add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + """Apply forward-process noise at *timestep*.""" + + @abstractmethod + def predict_noise( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + """Predict noise/flow for the given noisy latents.""" + + @abstractmethod + def predict_x0( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + """Predict x0 for the given noisy latents.""" + + @abstractmethod + def backward( + self, + loss: torch.Tensor, + ctx: Any, + *, + grad_accum_rounds: int, + ) -> None: + """Backward that may restore forward-context.""" + + +class CausalModelBase(ModelBase): + """Extension for causal / streaming model plugins. + + Cache state is internal to the model instance and keyed by + *cache_tag* (no role handle needed). + """ + + @abstractmethod + def clear_caches(self, *, cache_tag: str = "pos") -> None: + """Clear internal caches before starting a new rollout.""" + + @abstractmethod + def predict_noise_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + """Streaming predict-noise that may update internal caches.""" + + @abstractmethod + def predict_x0_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + """Streaming predict-x0 that may update internal caches.""" diff --git a/fastvideo/train/models/wan/__init__.py b/fastvideo/train/models/wan/__init__.py new file mode 100644 index 000000000..d79381246 --- /dev/null +++ b/fastvideo/train/models/wan/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Wan model plugin package.""" + +from fastvideo.train.models.wan.wan import ( + WanModel as WanModel, ) diff --git a/fastvideo/train/models/wan/wan.py b/fastvideo/train/models/wan/wan.py new file mode 100644 index 000000000..3267e18fb --- /dev/null +++ b/fastvideo/train/models/wan/wan.py @@ -0,0 +1,716 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Wan model plugin (per-role instance).""" + +from __future__ import annotations + +import copy +import gc +from typing import Any, Literal, TYPE_CHECKING + +import torch + +import fastvideo.envs as envs +from fastvideo.configs.sample import SamplingParam +from fastvideo.distributed import ( + get_local_torch_device, + get_sp_group, + get_world_group, +) +from fastvideo.forward_context import set_forward_context +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, ) +from fastvideo.models.utils import pred_noise_to_pred_video +from fastvideo.pipelines import TrainingBatch +from fastvideo.pipelines.basic.wan.wan_pipeline import ( + WanPipeline, ) +from fastvideo.pipelines.pipeline_batch_info import ( + ForwardBatch, ) +from fastvideo.training.activation_checkpoint import ( + apply_activation_checkpointing, ) +from fastvideo.training.training_utils import ( + compute_density_for_timestep_sampling, + get_sigmas, + normalize_dit_input, + shift_timestep, +) +from fastvideo.utils import ( + is_vmoba_available, + is_vsa_available, + set_random_seed, +) + +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.module_state import ( + apply_trainable, ) +from fastvideo.train.utils.moduleloader import ( + load_module_from_path, ) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +try: + from fastvideo.attention.backends.video_sparse_attn import ( + VideoSparseAttentionMetadataBuilder, ) + from fastvideo.attention.backends.vmoba import ( + VideoMobaAttentionMetadataBuilder, ) +except Exception: + VideoSparseAttentionMetadataBuilder = None # type: ignore[assignment] + VideoMobaAttentionMetadataBuilder = None # type: ignore[assignment] + + +class WanModel(ModelBase): + """Wan per-role model: owns transformer + noise_scheduler.""" + + def __init__( + self, + *, + init_from: str, + training_config: TrainingConfig, + trainable: bool = True, + disable_custom_init_weights: bool = False, + flow_shift: float = 3.0, + enable_gradient_checkpointing_type: str + | None = None, + ) -> None: + self._init_from = str(init_from) + self._trainable = bool(trainable) + + transformer = load_module_from_path( + model_path=self._init_from, + module_type="transformer", + training_config=training_config, + disable_custom_init_weights=(disable_custom_init_weights), + override_transformer_cls_name=("WanTransformer3DModel"), + ) + transformer = apply_trainable(transformer, trainable=self._trainable) + # Fall back to training_config.model if not set on the + # model YAML section directly. + ckpt_type = ( + enable_gradient_checkpointing_type + or getattr( + getattr(training_config, "model", None), + "enable_gradient_checkpointing_type", + None, + ) + ) + if self._trainable and ckpt_type: + transformer = apply_activation_checkpointing( + transformer, + checkpointing_type=ckpt_type, + ) + self.transformer = transformer + + self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift))) + + # Filled by init_preprocessors (student only). + self.vae: Any = None + self.training_config: TrainingConfig = training_config + self.dataloader: Any = None + self.validator: Any = None + self.start_step: int = 0 + + self.world_group: Any = None + self.sp_group: Any = None + self.device: Any = get_local_torch_device() + + self.noise_random_generator: (torch.Generator | None) = None + self.noise_gen_cuda: torch.Generator | None = None + + self.negative_prompt_embeds: (torch.Tensor | None) = None + self.negative_prompt_attention_mask: (torch.Tensor | None) = None + + # Timestep mechanics. + self.timestep_shift: float = float(flow_shift) + self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps) + self.min_timestep: int = 0 + self.max_timestep: int = self.num_train_timestep + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_preprocessors(self, training_config: TrainingConfig) -> None: + self.vae = load_module_from_path( + model_path=str(training_config.model_path), + module_type="vae", + training_config=training_config, + ) + + self.world_group = get_world_group() + self.sp_group = get_sp_group() + + self._init_timestep_mechanics() + + from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_t2v, ) + from fastvideo.train.utils.dataloader import ( + build_parquet_t2v_train_dataloader, ) + + text_len = ( + training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr] + 0].arch_config.text_len) + self.dataloader = build_parquet_t2v_train_dataloader( + training_config.data, + text_len=int(text_len), + parquet_schema=pyarrow_schema_t2v, + ) + self.start_step = 0 + + @property + def num_train_timesteps(self) -> int: + return int(self.num_train_timestep) + + def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor: + timestep = shift_timestep( + timestep, + self.timestep_shift, + self.num_train_timestep, + ) + return timestep.clamp(self.min_timestep, self.max_timestep) + + def on_train_start(self) -> None: + assert self.training_config is not None + seed = self.training_config.data.seed + if seed is None: + raise ValueError("training.data.seed must be set " + "for training") + + global_rank = int(getattr(self.world_group, "rank", 0)) + sp_world_size = int(self.training_config.distributed.sp_size or 1) + if sp_world_size > 1: + sp_group_seed = int(seed) + (global_rank // sp_world_size) + set_random_seed(sp_group_seed) + else: + set_random_seed(int(seed) + global_rank) + + self.noise_random_generator = torch.Generator(device="cpu").manual_seed(int(seed)) + self.noise_gen_cuda = torch.Generator(device=self.device).manual_seed(int(seed)) + + self.ensure_negative_conditioning() + + def get_rng_generators(self, ) -> dict[str, torch.Generator]: + generators: dict[str, torch.Generator] = {} + if self.noise_random_generator is not None: + generators["noise_cpu"] = (self.noise_random_generator) + if self.noise_gen_cuda is not None: + generators["noise_cuda"] = self.noise_gen_cuda + return generators + + # ------------------------------------------------------------------ + # Runtime primitives + # ------------------------------------------------------------------ + + def prepare_batch( + self, + raw_batch: dict[str, Any], + *, + current_vsa_sparsity: float = 0.0, + latents_source: Literal["data", "zeros"] = "data", + ) -> TrainingBatch: + self.ensure_negative_conditioning() + assert self.training_config is not None + tc = self.training_config + + dtype = self._get_training_dtype() + device = self.device + + training_batch = TrainingBatch(current_vsa_sparsity=current_vsa_sparsity) + encoder_hidden_states = raw_batch["text_embedding"] + encoder_attention_mask = raw_batch["text_attention_mask"] + infos = raw_batch.get("info_list") + + if latents_source == "zeros": + batch_size = encoder_hidden_states.shape[0] + vae_config = ( + tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr] + ) + num_channels = vae_config.z_dim + spatial_compression_ratio = (vae_config.spatial_compression_ratio) + latent_height = (tc.data.num_height // spatial_compression_ratio) + latent_width = (tc.data.num_width // spatial_compression_ratio) + latents = torch.zeros( + batch_size, + num_channels, + tc.data.num_latent_t, + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + elif latents_source == "data": + if "vae_latent" not in raw_batch: + raise ValueError("vae_latent not found in batch " + "and latents_source='data'") + latents = raw_batch["vae_latent"] + latents = latents[:, :, :tc.data.num_latent_t] + latents = latents.to(device, dtype=dtype) + else: + raise ValueError(f"Unknown latents_source: " + f"{latents_source!r}") + + training_batch.latents = latents + training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype)) + training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype)) + training_batch.infos = infos + + training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae) + training_batch = self._prepare_dit_inputs(training_batch) + training_batch = self._build_attention_metadata(training_batch) + + training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata) + if training_batch.attn_metadata is not None: + training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined] + + return training_batch + + def add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + b, t = clean_latents.shape[:2] + noisy = self.noise_scheduler.add_noise( + clean_latents.flatten(0, 1), + noise.flatten(0, 1), + timestep, + ).unflatten(0, (b, t)) + return noisy + + def predict_x0( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + if conditional: + text_dict = batch.conditional_dict + if text_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + else: + text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond) + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict)) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + if conditional: + text_dict = batch.conditional_dict + if text_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + else: + text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond) + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict)) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + return pred_noise + + def backward( + self, + loss: torch.Tensor, + ctx: Any, + *, + grad_accum_rounds: int, + ) -> None: + timesteps, attn_metadata = ctx + with set_forward_context( + current_timestep=timesteps, + attn_metadata=attn_metadata, + ): + (loss / max(1, int(grad_accum_rounds))).backward() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_training_dtype(self) -> torch.dtype: + return torch.bfloat16 + + def _init_timestep_mechanics(self) -> None: + assert self.training_config is not None + tc = self.training_config + self.timestep_shift = float(tc.pipeline_config.flow_shift # type: ignore[union-attr] + ) + self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps) + # min/max timestep ratios now come from method_config; + # default to full range. + self.min_timestep = 0 + self.max_timestep = self.num_train_timestep + + def ensure_negative_conditioning(self) -> None: + if self.negative_prompt_embeds is not None: + return + + assert self.training_config is not None + tc = self.training_config + world_group = self.world_group + device = self.device + dtype = self._get_training_dtype() + + from fastvideo.train.utils.moduleloader import ( + make_inference_args, ) + + neg_embeds: torch.Tensor | None = None + neg_mask: torch.Tensor | None = None + + if world_group.rank_in_group == 0: + sampling_param = SamplingParam.from_pretrained(tc.model_path) + negative_prompt = sampling_param.negative_prompt + + inference_args = make_inference_args(tc, model_path=tc.model_path) + + prompt_pipeline = WanPipeline.from_pretrained( + tc.model_path, + args=inference_args, + inference_mode=True, + loaded_modules={"transformer": self.transformer}, + tp_size=tc.distributed.tp_size, + sp_size=tc.distributed.sp_size, + num_gpus=tc.distributed.num_gpus, + pin_cpu_memory=(tc.distributed.pin_cpu_memory), + dit_cpu_offload=True, + ) + + batch_negative = ForwardBatch( + data_type="video", + prompt=negative_prompt, + prompt_embeds=[], + prompt_attention_mask=[], + ) + result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined] + batch_negative, + inference_args, + ) + + neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype) + neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype)) + + del prompt_pipeline + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + meta = torch.zeros((2, ), device=device, dtype=torch.int64) + if world_group.rank_in_group == 0: + assert neg_embeds is not None + assert neg_mask is not None + meta[0] = neg_embeds.ndim + meta[1] = neg_mask.ndim + world_group.broadcast(meta, src=0) + embed_ndim, mask_ndim = ( + int(meta[0].item()), + int(meta[1].item()), + ) + + max_ndim = 8 + embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) + mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) + if world_group.rank_in_group == 0: + assert neg_embeds is not None + assert neg_mask is not None + embed_shape[:embed_ndim] = torch.tensor( + list(neg_embeds.shape), + device=device, + dtype=torch.int64, + ) + mask_shape[:mask_ndim] = torch.tensor( + list(neg_mask.shape), + device=device, + dtype=torch.int64, + ) + world_group.broadcast(embed_shape, src=0) + world_group.broadcast(mask_shape, src=0) + + embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist()) + mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist()) + + if world_group.rank_in_group != 0: + neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype) + neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype) + assert neg_embeds is not None + assert neg_mask is not None + + world_group.broadcast(neg_embeds, src=0) + world_group.broadcast(neg_mask, src=0) + + self.negative_prompt_embeds = neg_embeds + self.negative_prompt_attention_mask = neg_mask + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: + if self.noise_random_generator is None: + raise RuntimeError("on_train_start() must be called before " + "prepare_batch()") + assert self.training_config is not None + tc = self.training_config + + u = compute_density_for_timestep_sampling( + weighting_scheme=tc.model.weighting_scheme, + batch_size=batch_size, + generator=self.noise_random_generator, + logit_mean=tc.model.logit_mean, + logit_std=tc.model.logit_std, + mode_scale=tc.model.mode_scale, + ) + indices = (u * self.noise_scheduler.config.num_train_timesteps).long() + return self.noise_scheduler.timesteps[indices].to(device=device) + + def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents_shape = training_batch.raw_latent_shape + patch_size = ( + tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr] + ) + current_vsa_sparsity = (training_batch.current_vsa_sparsity) + assert latents_shape is not None + assert training_batch.timesteps is not None + + if (envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"): + if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VIDEO_SPARSE_ATTN, but " + "fastvideo_kernel is not correctly " + "installed or detected.") + training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc] + raw_latent_shape=latents_shape[2:5], + current_timestep=(training_batch.timesteps), + patch_size=patch_size, + VSA_sparsity=current_vsa_sparsity, + device=self.device, + ) + elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"): + if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VMOBA_ATTN, but fastvideo_kernel " + "(or flash_attn>=2.7.4) is not " + "correctly installed.") + moba_params = tc.model.moba_config.copy() + moba_params.update({ + "current_timestep": (training_batch.timesteps), + "raw_latent_shape": (training_batch.raw_latent_shape[2:5]), + "patch_size": patch_size, + "device": self.device, + }) + training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(** + moba_params) # type: ignore[misc] + else: + training_batch.attn_metadata = None + + return training_batch + + def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents = training_batch.latents + assert isinstance(latents, torch.Tensor) + batch_size = latents.shape[0] + + if self.noise_gen_cuda is None: + raise RuntimeError("on_train_start() must be called before " + "prepare_batch()") + + noise = torch.randn( + latents.shape, + generator=self.noise_gen_cuda, + device=latents.device, + dtype=latents.dtype, + ) + timesteps = self._sample_timesteps(batch_size, latents.device) + if int(tc.distributed.sp_size or 1) > 1: + self.sp_group.broadcast(timesteps, src=0) + + sigmas = get_sigmas( + self.noise_scheduler, + latents.device, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + ) + noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise) + + training_batch.noisy_model_input = (noisy_model_input) + training_batch.timesteps = timesteps + training_batch.sigmas = sigmas + training_batch.noise = noise + training_batch.raw_latent_shape = latents.shape + + training_batch.conditional_dict = { + "encoder_hidden_states": (training_batch.encoder_hidden_states), + "encoder_attention_mask": (training_batch.encoder_attention_mask), + } + + if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None): + neg_embeds = self.negative_prompt_embeds + neg_mask = (self.negative_prompt_attention_mask) + if (neg_embeds.shape[0] == 1 and batch_size > 1): + neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous() + if (neg_mask.shape[0] == 1 and batch_size > 1): + neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous() + training_batch.unconditional_dict = { + "encoder_hidden_states": neg_embeds, + "encoder_attention_mask": neg_mask, + } + + training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4)) + return training_batch + + def _build_distill_input_kwargs( + self, + noise_input: torch.Tensor, + timestep: torch.Tensor, + text_dict: dict[str, torch.Tensor] | None, + ) -> dict[str, Any]: + if text_dict is None: + raise ValueError("text_dict cannot be None for " + "Wan distillation") + return { + "hidden_states": noise_input.permute(0, 2, 1, 3, 4), + "encoder_hidden_states": text_dict["encoder_hidden_states"], + "encoder_attention_mask": text_dict["encoder_attention_mask"], + "timestep": timestep, + "return_dict": False, + } + + def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module: + return self.transformer + + def _get_uncond_text_dict( + self, + batch: TrainingBatch, + *, + cfg_uncond: dict[str, Any] | None, + ) -> dict[str, torch.Tensor]: + if cfg_uncond is None: + text_dict = getattr(batch, "unconditional_dict", None) + if text_dict is None: + raise RuntimeError("Missing unconditional_dict; " + "ensure_negative_conditioning() " + "may have failed") + return text_dict + + on_missing_raw = cfg_uncond.get("on_missing", "error") + if not isinstance(on_missing_raw, str): + raise ValueError("method_config.cfg_uncond.on_missing " + "must be a string, got " + f"{type(on_missing_raw).__name__}") + on_missing = on_missing_raw.strip().lower() + if on_missing not in {"error", "ignore"}: + raise ValueError("method_config.cfg_uncond.on_missing " + "must be one of {error, ignore}, got " + f"{on_missing_raw!r}") + + for channel, policy_raw in cfg_uncond.items(): + if channel in {"on_missing", "text"}: + continue + if policy_raw is None: + continue + if not isinstance(policy_raw, str): + raise ValueError("method_config.cfg_uncond values " + "must be strings, got " + f"{channel}=" + f"{type(policy_raw).__name__}") + policy = policy_raw.strip().lower() + if policy == "keep": + continue + if on_missing == "ignore": + continue + raise ValueError("WanModel does not support " + "cfg_uncond channel " + f"{channel!r} (policy={policy!r}). " + "Set cfg_uncond.on_missing=ignore or " + "remove the channel.") + + text_policy_raw = cfg_uncond.get("text", None) + if text_policy_raw is None: + text_policy = "negative_prompt" + elif not isinstance(text_policy_raw, str): + raise ValueError("method_config.cfg_uncond.text must be " + "a string, got " + f"{type(text_policy_raw).__name__}") + else: + text_policy = (text_policy_raw.strip().lower()) + + if text_policy in {"negative_prompt"}: + text_dict = getattr(batch, "unconditional_dict", None) + if text_dict is None: + raise RuntimeError("Missing unconditional_dict; " + "ensure_negative_conditioning() " + "may have failed") + return text_dict + if text_policy == "keep": + if batch.conditional_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + return batch.conditional_dict + if text_policy == "zero": + if batch.conditional_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + cond = batch.conditional_dict + enc = cond["encoder_hidden_states"] + mask = cond["encoder_attention_mask"] + if not torch.is_tensor(enc) or not torch.is_tensor(mask): + raise TypeError("conditional_dict must contain " + "tensor text inputs") + return { + "encoder_hidden_states": (torch.zeros_like(enc)), + "encoder_attention_mask": (torch.zeros_like(mask)), + } + if text_policy == "drop": + raise ValueError("cfg_uncond.text=drop is not supported " + "for Wan. Use " + "{negative_prompt, keep, zero}.") + raise ValueError("cfg_uncond.text must be one of " + "{negative_prompt, keep, zero, drop}, got " + f"{text_policy_raw!r}") diff --git a/fastvideo/train/trainer.py b/fastvideo/train/trainer.py new file mode 100644 index 000000000..b44fe57a5 --- /dev/null +++ b/fastvideo/train/trainer.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING + +import torch +from tqdm.auto import tqdm + +from fastvideo.distributed import get_sp_group, get_world_group +from fastvideo.train.callbacks.callback import CallbackDict +from fastvideo.train.methods.base import TrainingMethod +from fastvideo.train.utils.tracking import build_tracker + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + + +def _coerce_log_scalar(value: Any, *, where: str) -> float: + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise ValueError(f"Expected scalar tensor at {where}, " + f"got shape={tuple(value.shape)}") + return float(value.detach().item()) + if isinstance(value, float | int): + return float(value) + raise TypeError(f"Expected a scalar (float/int/Tensor) at " + f"{where}, got {type(value).__name__}") + + +@dataclass(slots=True) +class TrainLoopState: + step: int + accum_iter: int + current_vsa_sparsity: float + + +class Trainer: + + def __init__( + self, + training_config: TrainingConfig, + *, + config: dict[str, Any] | None = None, + callback_configs: dict[str, dict[str, Any]] + | None = None, + ) -> None: + self.training_config = training_config + self.world_group = get_world_group() + self.sp_group = get_sp_group() + self.global_rank = self.world_group.rank + self.local_rank = self.world_group.local_rank + self.tracker = build_tracker( + training_config.tracker, + training_config.checkpoint, + config=config, + ) + self.callbacks = CallbackDict( + callback_configs or {}, + training_config, + ) + + def _iter_dataloader(self, dataloader: Any) -> Iterator[dict[str, Any]]: + data_iter = iter(dataloader) + while True: + batch = next(data_iter, None) + if batch is None: + data_iter = iter(dataloader) + batch = next(data_iter) + yield batch + + def _get_current_vsa_sparsity(self, step: int) -> float: + tc = self.training_config + vsa_sparsity = tc.vsa.sparsity + vsa_decay_rate = tc.vsa.decay_rate + vsa_decay_interval_steps = (tc.vsa.decay_interval_steps) + if vsa_decay_interval_steps > 1: + current_decay_times = min( + step // vsa_decay_interval_steps, + int(vsa_sparsity // vsa_decay_rate), + ) + return current_decay_times * vsa_decay_rate + return vsa_sparsity + + def run( + self, + method: TrainingMethod, + *, + dataloader: Any, + max_steps: int, + start_step: int = 0, + checkpoint_manager: Any | None = None, + ) -> None: + tc = self.training_config + grad_accum = max( + 1, + int(tc.loop.gradient_accumulation_steps or 1), + ) + + method.set_tracker(self.tracker) + method.on_train_start() + + resume_from_checkpoint = (tc.checkpoint.resume_from_checkpoint or "") + if checkpoint_manager is not None: + resumed_step = (checkpoint_manager.maybe_resume(resume_from_checkpoint=(resume_from_checkpoint))) + if resumed_step is not None: + start_step = int(resumed_step) + + self.callbacks.on_train_start( + method, iteration=start_step, + ) + self.callbacks.on_validation_begin( + method, iteration=start_step, + ) + method.optimizers_zero_grad(start_step) + + data_stream = self._iter_dataloader(dataloader) + progress = tqdm( + range(start_step + 1, max_steps + 1), + initial=start_step, + desc="Steps", + disable=self.local_rank > 0, + ) + for step in progress: + t0 = time.perf_counter() + current_vsa_sparsity = (self._get_current_vsa_sparsity(step)) + + loss_sums: dict[str, float] = {} + metric_sums: dict[str, float] = {} + for accum_iter in range(grad_accum): + batch = next(data_stream) + loss_map, outputs, step_metrics = method.single_train_step( + batch, + step, + current_vsa_sparsity=(current_vsa_sparsity), + ) + + method.backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum, + ) + + for k, v in loss_map.items(): + if isinstance(v, torch.Tensor): + loss_sums[k] = loss_sums.get(k, 0.0) + float(v.detach().item()) + for k, v in step_metrics.items(): + if k in loss_sums: + raise ValueError(f"Metric key {k!r} collides " + "with loss key. Use a " + "different name (e.g. prefix " + "with 'train/').") + metric_sums[k] = metric_sums.get(k, 0.0) + _coerce_log_scalar( + v, + where=("method.single_train_step()" + f".metrics[{k!r}]"), + ) + + self.callbacks.on_before_optimizer_step( + method, iteration=step, + ) + method.optimizers_schedulers_step(step) + method.optimizers_zero_grad(step) + + metrics = {k: v / grad_accum for k, v in loss_sums.items()} + metrics.update({k: v / grad_accum for k, v in metric_sums.items()}) + metrics["step_time_sec"] = (time.perf_counter() - t0) + metrics["vsa_sparsity"] = float(current_vsa_sparsity) + if self.global_rank == 0 and metrics: + self.tracker.log(metrics, step) + + self.callbacks.on_training_step_end( + method, metrics, iteration=step, + ) + + if checkpoint_manager is not None: + checkpoint_manager.maybe_save(step) + + self.callbacks.on_validation_begin( + method, iteration=step, + ) + self.callbacks.on_validation_end( + method, iteration=step, + ) + + self.callbacks.on_train_end( + method, iteration=max_steps, + ) + + if checkpoint_manager is not None: + checkpoint_manager.save_final(max_steps) + + self.tracker.finish() diff --git a/fastvideo/train/utils/__init__.py b/fastvideo/train/utils/__init__.py new file mode 100644 index 000000000..d7eba4033 --- /dev/null +++ b/fastvideo/train/utils/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Distillation utilities shared across families/methods/entrypoints.""" diff --git a/fastvideo/train/utils/builder.py b/fastvideo/train/utils/builder.py new file mode 100644 index 000000000..d6d8d9976 --- /dev/null +++ b/fastvideo/train/utils/builder.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Assembly: build method + dataloader from a ``_target_``-based config.""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from fastvideo.train.utils.instantiate import ( + instantiate, + resolve_target, +) +from fastvideo.train.utils.config import RunConfig + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + from fastvideo.train.methods.base import TrainingMethod + + +def build_from_config(cfg: RunConfig, ) -> tuple[TrainingConfig, TrainingMethod, Any, int]: + """Build method + dataloader from a v3 run config. + + 1. Instantiate each model in ``cfg.models`` via ``_target_``. + 2. Resolve the method class from ``cfg.method["_target_"]`` + and construct it with ``(cfg=cfg, role_models=...)``. + 3. Return ``(training_args, method, dataloader, start_step)``. + """ + from fastvideo.train.models.base import ModelBase + + # --- 1. Build role model instances --- + role_models: dict[str, ModelBase] = {} + for role, model_cfg in cfg.models.items(): + model = instantiate( + model_cfg, training_config=cfg.training) + if not isinstance(model, ModelBase): + raise TypeError(f"models.{role}._target_ must resolve to a " + f"ModelBase subclass, got {type(model).__name__}") + role_models[role] = model + + # --- 2. Build method --- + method_cfg = dict(cfg.method) + method_target = str(method_cfg.pop("_target_")) + method_cls = resolve_target(method_target) + + # The student model provides the dataloader. + student = role_models.get("student") + + method = method_cls( + cfg=cfg, + role_models=role_models, + ) + + # --- 3. Gather dataloader and start_step --- + dataloader = (getattr(student, "dataloader", None) if student is not None else None) + start_step = int(getattr(student, "start_step", 0) if student is not None else 0) + + return cfg.training, method, dataloader, start_step diff --git a/fastvideo/train/utils/checkpoint.py b/fastvideo/train/utils/checkpoint.py new file mode 100644 index 000000000..62166dde7 --- /dev/null +++ b/fastvideo/train/utils/checkpoint.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import re +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint-(\d+)$") + + +def _is_stateful(obj: Any) -> bool: + return callable(getattr(obj, "state_dict", None)) and callable(getattr(obj, "load_state_dict", None)) + + +def _rank() -> int: + if dist.is_available() and dist.is_initialized(): + return int(dist.get_rank()) + return 0 + + +def _barrier() -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def _parse_step_from_dir(checkpoint_dir: Path) -> int: + match = _CHECKPOINT_DIR_RE.match(checkpoint_dir.name) + if not match: + raise ValueError(f"Invalid checkpoint directory name {checkpoint_dir.name!r}; " + "expected 'checkpoint-'") + return int(match.group(1)) + + +def _find_latest_checkpoint(output_dir: Path) -> Path | None: + if not output_dir.exists(): + return None + + candidates: list[tuple[int, Path]] = [] + for child in output_dir.iterdir(): + if not child.is_dir(): + continue + if not _CHECKPOINT_DIR_RE.match(child.name): + continue + if not (child / "dcp").is_dir(): + continue + try: + step = _parse_step_from_dir(child) + except Exception: + continue + candidates.append((step, child)) + + if not candidates: + return None + candidates.sort(key=lambda x: x[0]) + return candidates[-1][1] + + +def _resolve_resume_checkpoint(resume_from_checkpoint: str, *, output_dir: str) -> Path: + """Resolve a user-provided resume path to a concrete checkpoint dir. + + Accepted values: + - /path/to/output_dir/checkpoint- + - /path/to/output_dir/checkpoint-/dcp + - /path/to/output_dir (auto-pick latest checkpoint-*/dcp) + """ + + raw = os.path.expanduser(str(resume_from_checkpoint)) + path = Path(raw).resolve() + if not path.exists(): + raise FileNotFoundError(f"resume_from_checkpoint not found: {path}") + + if path.is_dir() and path.name == "dcp": + path = path.parent + + if path.is_dir() and _CHECKPOINT_DIR_RE.match(path.name): + if not (path / "dcp").is_dir(): + raise FileNotFoundError(f"Missing dcp dir under checkpoint: {path / 'dcp'}") + return path + + # Treat as output_dir -> pick latest. + latest = _find_latest_checkpoint(path) + if latest is not None: + return latest + + # Give a clearer error message. + out = Path(os.path.expanduser(str(output_dir))).resolve() + raise ValueError("Could not resolve resume checkpoint. Expected a checkpoint directory " + f"named 'checkpoint-' (with 'dcp/' inside), or an output_dir " + f"containing such checkpoints. Got: {path} (output_dir={out}).") + + +class _RoleModuleContainer(torch.nn.Module): + """Ephemeral container to expose multiple role modules as a single + ``nn.Module``. + + Used by ``OptimizerWrapper`` which expects a single root module + covering all parameters owned by the optimizer. + """ + + def __init__(self, modules: dict[str, torch.nn.Module]) -> None: + super().__init__() + for name, module in modules.items(): + self.add_module(name, module) + + +class _CallbackStateWrapper: + """Wraps a CallbackDict for DCP save/load.""" + + def __init__(self, callbacks: Any) -> None: + self._callbacks = callbacks + + def state_dict(self) -> dict[str, Any]: + return self._callbacks.state_dict() + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + self._callbacks.load_state_dict(state_dict) + + +@dataclass(slots=True) +class CheckpointConfig: + save_steps: int + keep_last: int + + +class CheckpointManager: + """Role-based checkpoint manager for training runtime. + + - Checkpoint policy lives in YAML (via TrainingArgs fields). + - Resume path is typically provided via CLI (``--resume-from-checkpoint``). + """ + + def __init__( + self, + *, + method: Any, + dataloader: Any, + output_dir: str, + config: CheckpointConfig, + callbacks: Any | None = None, + raw_config: dict[str, Any] | None = None, + ) -> None: + self.method = method + self.dataloader = dataloader + self.output_dir = str(output_dir) + self.config = config + self._callbacks = callbacks + self._raw_config = raw_config + self._last_saved_step: int | None = None + + def _build_states(self) -> dict[str, Any]: + states: dict[str, Any] = self.method.checkpoint_state() + + # Dataloader (optional but recommended for exact resume). + if _is_stateful(self.dataloader): + states["dataloader"] = self.dataloader + + # Callback state (e.g. EMA shadow weights, validation RNG). + if self._callbacks is not None and _is_stateful(self._callbacks): + states["callbacks"] = _CallbackStateWrapper( + self._callbacks, + ) + + return states + + def _checkpoint_dir(self, step: int) -> Path: + return Path(self.output_dir) / f"checkpoint-{step}" + + def _dcp_dir(self, step: int) -> Path: + return self._checkpoint_dir(step) / "dcp" + + def maybe_save(self, step: int) -> None: + save_steps = int(self.config.save_steps or 0) + if save_steps <= 0: + return + if step % save_steps != 0: + return + if self._last_saved_step == step: + return + self.save(step) + + def save_final(self, step: int) -> None: + save_steps = int(self.config.save_steps or 0) + if save_steps <= 0: + return + self.save(step) + + def save(self, step: int) -> None: + checkpoint_dir = self._checkpoint_dir(step) + dcp_dir = self._dcp_dir(step) + os.makedirs(dcp_dir, exist_ok=True) + + states = self._build_states() + if _rank() == 0: + logger.info( + "Saving checkpoint to %s", checkpoint_dir, + ) + self._write_metadata(checkpoint_dir, step) + dcp.save(states, checkpoint_id=str(dcp_dir)) + _barrier() + self._last_saved_step = step + + self._cleanup_old_checkpoints() + + def _write_metadata( + self, checkpoint_dir: Path, step: int, + ) -> None: + metadata: dict[str, Any] = {"step": step} + if self._raw_config is not None: + metadata["config"] = self._raw_config + meta_path = checkpoint_dir / "metadata.json" + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + @staticmethod + def load_metadata( + checkpoint_dir: str | Path, + ) -> dict[str, Any]: + """Read ``metadata.json`` from a checkpoint dir.""" + meta_path = Path(checkpoint_dir) / "metadata.json" + if not meta_path.is_file(): + raise FileNotFoundError( + f"No metadata.json in {checkpoint_dir}" + ) + with open(meta_path, encoding="utf-8") as f: + return json.load(f) # type: ignore[no-any-return] + + def maybe_resume(self, *, resume_from_checkpoint: str | None) -> int | None: + if not resume_from_checkpoint: + return None + + resolved = _resolve_resume_checkpoint( + resume_from_checkpoint, + output_dir=self.output_dir, + ) + step = _parse_step_from_dir(resolved) + + states = self._build_states() + logger.info("Loading Phase 2 checkpoint from %s", resolved) + dcp.load(states, checkpoint_id=str(resolved / "dcp")) + _barrier() + logger.info("Checkpoint loaded; resuming from step=%s", step) + return step + + def _cleanup_old_checkpoints(self) -> None: + keep_last = int(self.config.keep_last or 0) + if keep_last <= 0: + return + + if _rank() != 0: + _barrier() + return + + output_dir = Path(self.output_dir) + candidates: list[tuple[int, Path]] = [] + for child in output_dir.iterdir(): + if not child.is_dir(): + continue + if not _CHECKPOINT_DIR_RE.match(child.name): + continue + try: + step = _parse_step_from_dir(child) + except Exception: + continue + candidates.append((step, child)) + + candidates.sort(key=lambda x: x[0]) + to_delete = candidates[:-keep_last] if len(candidates) > keep_last else [] + for step, path in to_delete: + logger.info("Removing old checkpoint (keep_last=%s): %s", keep_last, path) + shutil.rmtree(path, ignore_errors=True) + + _barrier() diff --git a/fastvideo/train/utils/config.py b/fastvideo/train/utils/config.py new file mode 100644 index 000000000..704362b98 --- /dev/null +++ b/fastvideo/train/utils/config.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Training run config (``_target_`` based YAML).""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml + +from fastvideo.train.utils.training_config import ( + CheckpointConfig, + DataConfig, + TrainingConfig, + DistributedConfig, + ModelTrainingConfig, + OptimizerConfig, + TrackerConfig, + TrainingLoopConfig, + VSAConfig, +) + + +@dataclass(slots=True) +class RunConfig: + """Parsed run config loaded from YAML.""" + + models: dict[str, dict[str, Any]] + method: dict[str, Any] + training: TrainingConfig + callbacks: dict[str, dict[str, Any]] + raw: dict[str, Any] + + def resolved_config(self) -> dict[str, Any]: + """Return a fully-resolved config dict with defaults. + + Suitable for logging to W&B so that every parameter + (including defaults) is visible. + """ + import dataclasses + + def _safe_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return { + f.name: _safe_asdict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + if not callable(getattr(obj, f.name)) + } + if isinstance(obj, dict): + return {k: _safe_asdict(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(_safe_asdict(v) for v in obj) + return obj + + resolved: dict[str, Any] = {} + resolved["models"] = dict(self.models) + resolved["method"] = dict(self.method) + resolved["training"] = _safe_asdict(self.training) + resolved["callbacks"] = dict(self.callbacks) + return resolved + + +# ---- parsing helpers (kept for use by methods) ---- + + +def _resolve_existing_file(path: str) -> str: + if not path: + return path + expanded = os.path.expanduser(path) + resolved = Path(expanded).resolve() + if not resolved.exists(): + raise FileNotFoundError(f"Config file not found: {resolved}") + if not resolved.is_file(): + raise ValueError(f"Expected a file path, got: {resolved}") + return str(resolved) + + +def _require_mapping(raw: Any, *, where: str) -> dict[str, Any]: + if not isinstance(raw, dict): + raise ValueError(f"Expected mapping at {where}, " + f"got {type(raw).__name__}") + return raw + + +def _require_str(raw: Any, *, where: str) -> str: + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"Expected non-empty string at {where}") + return raw + + +def get_optional_int(mapping: dict[str, Any], key: str, *, where: str) -> int | None: + raw = mapping.get(key) + if raw is None: + return None + if isinstance(raw, bool): + raise ValueError(f"Expected int at {where}, got bool") + if isinstance(raw, int): + return int(raw) + if isinstance(raw, float) and raw.is_integer(): + return int(raw) + if isinstance(raw, str) and raw.strip(): + return int(raw) + raise ValueError(f"Expected int at {where}, " + f"got {type(raw).__name__}") + + +def get_optional_float(mapping: dict[str, Any], key: str, *, where: str) -> float | None: + raw = mapping.get(key) + if raw is None: + return None + if isinstance(raw, bool): + raise ValueError(f"Expected float at {where}, got bool") + if isinstance(raw, int | float): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError(f"Expected float at {where}, " + f"got {type(raw).__name__}") + + +def parse_betas(raw: Any, *, where: str) -> tuple[float, float]: + if raw is None: + raise ValueError(f"Missing betas for {where}") + if isinstance(raw, tuple | list) and len(raw) == 2: + return float(raw[0]), float(raw[1]) + if isinstance(raw, str): + parts = [p.strip() for p in raw.split(",") if p.strip()] + if len(parts) != 2: + raise ValueError(f"Expected betas as 'b1,b2' at {where}, " + f"got {raw!r}") + return float(parts[0]), float(parts[1]) + raise ValueError(f"Expected betas as 'b1,b2' at {where}, " + f"got {type(raw).__name__}") + + +# ---- config convenience helpers ---- + + +def require_positive_int( + mapping: dict[str, Any], + key: str, + *, + default: int | None = None, + where: str | None = None, +) -> int: + """Read an int that must be > 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_int(mapping, key, where=loc) + if val is None or val <= 0: + raise ValueError(f"{loc} must be a positive integer, got {raw!r}") + return val + + +def require_non_negative_int( + mapping: dict[str, Any], + key: str, + *, + default: int | None = None, + where: str | None = None, +) -> int: + """Read an int that must be >= 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_int(mapping, key, where=loc) + if val is None or val < 0: + raise ValueError(f"{loc} must be a non-negative integer, " + f"got {raw!r}") + return val + + +def require_non_negative_float( + mapping: dict[str, Any], + key: str, + *, + default: float | None = None, + where: str | None = None, +) -> float: + """Read a float that must be >= 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_float(mapping, key, where=loc) + if val is None or val < 0.0: + raise ValueError(f"{loc} must be a non-negative float, " + f"got {raw!r}") + return val + + +def require_choice( + mapping: dict[str, Any], + key: str, + choices: set[str] | frozenset[str], + *, + default: str | None = None, + where: str | None = None, +) -> str: + """Read a string that must be one of *choices*.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + if default not in choices: + raise ValueError(f"Default {default!r} not in {choices}") + return default + raise ValueError(f"Missing required key {loc!r}") + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"{loc} must be a non-empty string, " + f"got {type(raw).__name__}") + val = raw.strip().lower() + if val not in choices: + raise ValueError(f"{loc} must be one of {sorted(choices)}, " + f"got {raw!r}") + return val + + +def require_bool( + mapping: dict[str, Any], + key: str, + *, + default: bool | None = None, + where: str | None = None, +) -> bool: + """Read a bool value.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + if not isinstance(raw, bool): + raise ValueError(f"{loc} must be a bool, " + f"got {type(raw).__name__}") + return raw + + +def _parse_pipeline_config( + cfg: dict[str, Any], + *, + models: dict[str, dict[str, Any]], +) -> Any: + """Resolve PipelineConfig from the ``pipeline:`` YAML key.""" + from fastvideo.configs.pipelines.base import PipelineConfig + + pipeline_raw = cfg.get("pipeline") + if pipeline_raw is None: + return None + + # Derive model_path from models.student.init_from — + # needed by PipelineConfig.from_kwargs. + model_path: str | None = None + student_cfg = models.get("student") + if student_cfg is not None: + init_from = student_cfg.get("init_from") + if init_from is not None: + model_path = str(init_from) + + kwargs: dict[str, Any] = {"pipeline_config": pipeline_raw} + if model_path is not None: + kwargs["model_path"] = model_path + + if isinstance(pipeline_raw, str): + kwargs["pipeline_config"] = _resolve_existing_file( + pipeline_raw) + + return PipelineConfig.from_kwargs(kwargs) + + +def _build_training_config( + t: dict[str, Any], + *, + models: dict[str, dict[str, Any]], + pipeline_config: Any, +) -> TrainingConfig: + """Build TrainingConfig from nested training: YAML.""" + d = dict(t.get("distributed", {}) or {}) + da = dict(t.get("data", {}) or {}) + o = dict(t.get("optimizer", {}) or {}) + lo = dict(t.get("loop", {}) or {}) + ck = dict(t.get("checkpoint", {}) or {}) + tr = dict(t.get("tracker", {}) or {}) + vs = dict(t.get("vsa", {}) or {}) + m = dict(t.get("model", {}) or {}) + + num_gpus = int(d.get("num_gpus", 1) or 1) + + betas_raw = o.get("betas", "0.9,0.999") + betas = parse_betas(betas_raw, + where="training.optimizer.betas") + + model_path = str(t.get("model_path", "") or "") + if not model_path: + student_cfg = models.get("student") + if student_cfg is not None: + init_from = student_cfg.get("init_from") + if init_from is not None: + model_path = str(init_from) + + return TrainingConfig( + distributed=DistributedConfig( + num_gpus=num_gpus, + tp_size=int(d.get("tp_size", 1) or 1), + sp_size=int( + d.get("sp_size", num_gpus) or num_gpus), + hsdp_replicate_dim=int( + d.get("hsdp_replicate_dim", 1) or 1), + hsdp_shard_dim=int( + d.get("hsdp_shard_dim", num_gpus) + or num_gpus), + pin_cpu_memory=bool( + d.get("pin_cpu_memory", False)), + ), + data=DataConfig( + data_path=str(da.get("data_path", "") or ""), + train_batch_size=int( + da.get("train_batch_size", 1) or 1), + dataloader_num_workers=int( + da.get("dataloader_num_workers", 0) or 0), + training_cfg_rate=float( + da.get("training_cfg_rate", 0.0) or 0.0), + seed=int(da.get("seed", 0) or 0), + num_height=int( + da.get("num_height", 0) or 0), + num_width=int(da.get("num_width", 0) or 0), + num_latent_t=int( + da.get("num_latent_t", 0) or 0), + num_frames=int( + da.get("num_frames", 0) or 0), + ), + optimizer=OptimizerConfig( + learning_rate=float( + o.get("learning_rate", 0.0) or 0.0), + betas=betas, + weight_decay=float( + o.get("weight_decay", 0.0) or 0.0), + lr_scheduler=str( + o.get("lr_scheduler", "constant") + or "constant"), + lr_warmup_steps=int( + o.get("lr_warmup_steps", 0) or 0), + lr_num_cycles=int( + o.get("lr_num_cycles", 0) or 0), + lr_power=float( + o.get("lr_power", 0.0) or 0.0), + min_lr_ratio=float( + o.get("min_lr_ratio", 0.5) or 0.5), + ), + loop=TrainingLoopConfig( + max_train_steps=int( + lo.get("max_train_steps", 0) or 0), + gradient_accumulation_steps=int( + lo.get("gradient_accumulation_steps", 1) + or 1), + ), + checkpoint=CheckpointConfig( + output_dir=str( + ck.get("output_dir", "") or ""), + resume_from_checkpoint=str( + ck.get("resume_from_checkpoint", "") + or ""), + training_state_checkpointing_steps=int( + ck.get( + "training_state_checkpointing_steps", + 0) or 0), + checkpoints_total_limit=int( + ck.get("checkpoints_total_limit", 0) + or 0), + ), + tracker=TrackerConfig( + trackers=list( + tr.get("trackers", []) or []), + project_name=str( + tr.get("project_name", "fastvideo") + or "fastvideo"), + run_name=str(tr.get("run_name", "") or ""), + ), + vsa=VSAConfig( + sparsity=float( + vs.get("sparsity", 0.0) or 0.0), + decay_rate=float( + vs.get("decay_rate", 0.0) or 0.0), + decay_interval_steps=int( + vs.get("decay_interval_steps", 0) or 0), + ), + model=ModelTrainingConfig( + weighting_scheme=str( + m.get("weighting_scheme", "uniform") + or "uniform"), + logit_mean=float( + m.get("logit_mean", 0.0) or 0.0), + logit_std=float( + m.get("logit_std", 1.0) or 1.0), + mode_scale=float( + m.get("mode_scale", 1.0) or 1.0), + precondition_outputs=bool( + m.get("precondition_outputs", False)), + moba_config=dict( + m.get("moba_config", {}) or {}), + enable_gradient_checkpointing_type=( + m.get( + "enable_gradient_checkpointing_type" + )), + ), + pipeline_config=pipeline_config, + model_path=model_path, + dit_precision=str( + t.get("dit_precision", "fp32") or "fp32"), + ) + + +def load_run_config(path: str) -> RunConfig: + """Load a run config from YAML. + + Expected top-level keys: ``models``, ``method``, + ``training`` (nested), and optionally ``callbacks`` + and ``pipeline``. + """ + path = _resolve_existing_file(path) + with open(path, encoding="utf-8") as f: + raw = yaml.safe_load(f) + cfg = _require_mapping(raw, where=path) + + # --- models --- + models_raw = _require_mapping( + cfg.get("models"), where="models") + models: dict[str, dict[str, Any]] = {} + for role, model_cfg_raw in models_raw.items(): + role_str = _require_str( + role, where="models.") + model_cfg = _require_mapping( + model_cfg_raw, where=f"models.{role_str}") + if "_target_" not in model_cfg: + raise ValueError( + f"models.{role_str} must have a " + "'_target_' key") + models[role_str] = dict(model_cfg) + + # --- method --- + method_raw = _require_mapping( + cfg.get("method"), where="method") + if "_target_" not in method_raw: + raise ValueError( + "method must have a '_target_' key") + method = dict(method_raw) + + # --- callbacks --- + callbacks_raw = cfg.get("callbacks", None) + if callbacks_raw is None: + callbacks: dict[str, dict[str, Any]] = {} + else: + callbacks = _require_mapping( + callbacks_raw, where="callbacks") + + # --- pipeline config --- + pipeline_config = _parse_pipeline_config( + cfg, models=models) + + # --- training config --- + training_raw = _require_mapping( + cfg.get("training"), where="training") + t = dict(training_raw) + training = _build_training_config( + t, models=models, + pipeline_config=pipeline_config) + + return RunConfig( + models=models, + method=method, + training=training, + callbacks=callbacks, + raw=cfg, + ) diff --git a/fastvideo/train/utils/dataloader.py b/fastvideo/train/utils/dataloader.py new file mode 100644 index 000000000..a1b22d3ba --- /dev/null +++ b/fastvideo/train/utils/dataloader.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + DataConfig, ) + + +def build_parquet_t2v_train_dataloader( + data_config: DataConfig, + *, + text_len: int, + parquet_schema: Any, +) -> Any: + """Build a parquet dataloader for T2V-style datasets.""" + + from fastvideo.dataset import ( + build_parquet_map_style_dataloader, ) + + _dataset, dataloader = (build_parquet_map_style_dataloader( + data_config.data_path, + data_config.train_batch_size, + num_data_workers=(data_config.dataloader_num_workers), + parquet_schema=parquet_schema, + cfg_rate=data_config.training_cfg_rate, + drop_last=True, + text_padding_length=int(text_len), + seed=int(data_config.seed or 0), + )) + return dataloader diff --git a/fastvideo/train/utils/instantiate.py b/fastvideo/train/utils/instantiate.py new file mode 100644 index 000000000..1d1e3f306 --- /dev/null +++ b/fastvideo/train/utils/instantiate.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +"""``_target_``-based instantiation utilities. + +These helpers resolve a dotted Python path to a class and instantiate it, +filtering constructor kwargs through ``inspect.signature`` so that only +recognized parameters are forwarded. Unrecognized keys emit a warning +rather than raising — this keeps YAML configs forward-compatible when +a class drops a parameter in a later version. +""" + +from __future__ import annotations + +import importlib +import inspect +import warnings +from typing import Any + + +def resolve_target(target: str) -> type: + """Import and return the class (or callable) at *target*. + + *target* must be a fully-qualified dotted path, e.g. + ``"fastvideo.train.models.wan.WanModel"``. + """ + if not isinstance(target, str) or not target.strip(): + raise ValueError(f"_target_ must be a non-empty dotted path string, " + f"got {target!r}") + target = target.strip() + parts = target.rsplit(".", 1) + if len(parts) != 2: + raise ValueError(f"_target_ must contain at least one dot " + f"(module.ClassName), got {target!r}") + module_path, attr_name = parts + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as exc: + raise ImportError(f"Cannot import module {module_path!r} " + f"(from _target_={target!r})") from exc + try: + cls = getattr(module, attr_name) + except AttributeError as exc: + raise ImportError(f"Module {module_path!r} has no attribute " + f"{attr_name!r} (from _target_={target!r})") from exc + return cls + + +def instantiate(cfg: dict[str, Any], **extra: Any) -> Any: + """Instantiate the class specified by ``cfg["_target_"]``. + + All remaining keys in *cfg* (minus ``_target_``) plus any *extra* + keyword arguments are forwarded to the constructor. Keys that do + not match an ``__init__`` parameter are silently warned about and + dropped, so callers can safely pass a superset. + """ + if not isinstance(cfg, dict): + raise TypeError(f"instantiate() expects a dict with '_target_', " + f"got {type(cfg).__name__}") + target_str = cfg.get("_target_") + if target_str is None: + raise KeyError("Config dict is missing '_target_' key") + + cls = resolve_target(str(target_str)) + kwargs: dict[str, Any] = {k: v for k, v in cfg.items() if k != "_target_"} + kwargs.update(extra) + + sig = inspect.signature(cls.__init__) + params = sig.parameters + has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + if not has_var_keyword: + valid_names = { + name + for name, p in params.items() if p.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + valid_names.discard("self") + unrecognized = set(kwargs) - valid_names + if unrecognized: + warnings.warn( + f"instantiate({target_str}): dropping unrecognized " + f"kwargs {sorted(unrecognized)}", + stacklevel=2, + ) + for key in unrecognized: + del kwargs[key] + + return cls(**kwargs) diff --git a/fastvideo/train/utils/module_state.py b/fastvideo/train/utils/module_state.py new file mode 100644 index 000000000..6d28a005f --- /dev/null +++ b/fastvideo/train/utils/module_state.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch + + +def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module: + """Apply train/eval mode + requires_grad based on a role's trainable flag.""" + + module.requires_grad_(bool(trainable)) + if trainable: + module.train() + else: + module.eval() + return module diff --git a/fastvideo/train/utils/moduleloader.py b/fastvideo/train/utils/moduleloader.py new file mode 100644 index 000000000..7d18db197 --- /dev/null +++ b/fastvideo/train/utils/moduleloader.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from typing import Any, TYPE_CHECKING + +import torch + +from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.fastvideo_args import ExecutionMode, TrainingArgs +from fastvideo.models.loader.component_loader import ( + PipelineComponentLoader, ) +from fastvideo.utils import ( + maybe_download_model, + verify_model_config_and_directory, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +# ------------------------------------------------------------------ +# TrainingArgs builders (only place that creates FastVideoArgs) +# ------------------------------------------------------------------ + + +def _make_training_args( + tc: TrainingConfig, + *, + model_path: str, +) -> TrainingArgs: + """Build a TrainingArgs for PipelineComponentLoader.""" + pipeline_config = tc.pipeline_config or PipelineConfig() + # Propagate dit_precision from TrainingConfig to PipelineConfig + # so that TransformerLoader.load() picks up the correct + # default_dtype (e.g. fp32 master weights for training). + if tc.dit_precision and tc.dit_precision != pipeline_config.dit_precision: + pipeline_config.dit_precision = tc.dit_precision + return TrainingArgs( + model_path=model_path, + mode=ExecutionMode.DISTILLATION, + inference_mode=False, + pipeline_config=pipeline_config, + num_gpus=tc.distributed.num_gpus, + tp_size=tc.distributed.tp_size, + sp_size=tc.distributed.sp_size, + hsdp_replicate_dim=tc.distributed.hsdp_replicate_dim, + hsdp_shard_dim=tc.distributed.hsdp_shard_dim, + pin_cpu_memory=tc.distributed.pin_cpu_memory, + dit_cpu_offload=False, + dit_layerwise_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, + image_encoder_cpu_offload=False, + use_fsdp_inference=False, + enable_torch_compile=False, + ) + + +def make_inference_args( + tc: TrainingConfig, + *, + model_path: str, +) -> TrainingArgs: + """Build a TrainingArgs for inference (validation / pipelines).""" + args = _make_training_args(tc, model_path=model_path) + args.inference_mode = True + args.mode = ExecutionMode.INFERENCE + args.dit_cpu_offload = True + args.VSA_sparsity = tc.vsa.sparsity + return args + + +# ------------------------------------------------------------------ +# Module loading +# ------------------------------------------------------------------ + + +def load_module_from_path( + *, + model_path: str, + module_type: str, + training_config: TrainingConfig, + disable_custom_init_weights: bool = False, + override_transformer_cls_name: str | None = None, +) -> torch.nn.Module: + """Load a single pipeline component module. + + Accepts a ``TrainingConfig`` and internally builds the + ``TrainingArgs`` needed by ``PipelineComponentLoader``. + """ + fastvideo_args: Any = _make_training_args( + training_config, model_path=model_path) + + local_model_path = maybe_download_model(model_path) + config = verify_model_config_and_directory(local_model_path) + + if module_type not in config: + raise ValueError(f"Module {module_type!r} not found in " + f"config at {local_model_path}") + + module_info = config[module_type] + if module_info is None: + raise ValueError(f"Module {module_type!r} has null value in " + f"config at {local_model_path}") + + transformers_or_diffusers, _architecture = module_info + component_path = os.path.join(local_model_path, module_type) + + old_override: str | None = None + if override_transformer_cls_name is not None: + old_override = getattr( + fastvideo_args, + "override_transformer_cls_name", + None, + ) + fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name) + + if disable_custom_init_weights: + fastvideo_args._loading_teacher_critic_model = True + try: + module = PipelineComponentLoader.load_module( + module_name=module_type, + component_model_path=component_path, + transformers_or_diffusers=(transformers_or_diffusers), + fastvideo_args=fastvideo_args, + ) + finally: + if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"): + del fastvideo_args._loading_teacher_critic_model + if override_transformer_cls_name is not None: + if old_override is None: + if hasattr( + fastvideo_args, + "override_transformer_cls_name", + ): + fastvideo_args.override_transformer_cls_name = (None) + else: + fastvideo_args.override_transformer_cls_name = (old_override) + + if not isinstance(module, torch.nn.Module): + raise TypeError(f"Loaded {module_type!r} is not a " + f"torch.nn.Module: {type(module)}") + return module diff --git a/fastvideo/train/utils/optimizer.py b/fastvideo/train/utils/optimizer.py new file mode 100644 index 000000000..43a79d98d --- /dev/null +++ b/fastvideo/train/utils/optimizer.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from fastvideo.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, + get_scheduler, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + OptimizerConfig, + TrainingLoopConfig, + ) + + +def build_optimizer_and_scheduler( + *, + params: list[torch.nn.Parameter], + optimizer_config: OptimizerConfig, + loop_config: TrainingLoopConfig, + learning_rate: float, + betas: tuple[float, float], + scheduler_name: str, +) -> tuple[torch.optim.Optimizer, object]: + """Build an AdamW optimizer and LR scheduler. + + Returns ``(optimizer, lr_scheduler)`` so the caller can store them + as method-level attributes. + """ + if not params: + raise ValueError("No trainable parameters passed to " + "build_optimizer_and_scheduler") + + optimizer = torch.optim.AdamW( + params, + lr=float(learning_rate), + betas=betas, + weight_decay=float(optimizer_config.weight_decay), + eps=1e-8, + ) + + scheduler = get_scheduler( + str(scheduler_name), + optimizer=optimizer, + num_warmup_steps=int(optimizer_config.lr_warmup_steps), + num_training_steps=int(loop_config.max_train_steps), + num_cycles=int(optimizer_config.lr_num_cycles), + power=float(optimizer_config.lr_power), + min_lr_ratio=float(optimizer_config.min_lr_ratio), + last_epoch=-1, + ) + + return optimizer, scheduler + + +def clip_grad_norm_if_needed( + module: torch.nn.Module, + max_grad_norm: float, +) -> float: + if max_grad_norm <= 0.0: + return 0.0 + grad_norm = (clip_grad_norm_while_handling_failing_dtensor_cases( + [p for p in module.parameters()], + max_grad_norm, + foreach=None, + )) + return (float(grad_norm.item()) if grad_norm is not None else 0.0) diff --git a/fastvideo/train/utils/tracking.py b/fastvideo/train/utils/tracking.py new file mode 100644 index 000000000..7ad28a2e4 --- /dev/null +++ b/fastvideo/train/utils/tracking.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from typing import Any, TYPE_CHECKING + +from fastvideo.distributed import get_world_group +from fastvideo.training.trackers import ( + initialize_trackers, + Trackers, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + CheckpointConfig, + TrackerConfig, + ) + + +def build_tracker( + tracker_config: TrackerConfig, + checkpoint_config: CheckpointConfig, + *, + config: dict[str, Any] | None, +) -> Any: + """Build a tracker instance for a distillation run.""" + + world_group = get_world_group() + + trackers = list(tracker_config.trackers) + if not trackers and str(tracker_config.project_name): + trackers.append(Trackers.WANDB.value) + if world_group.rank != 0: + trackers = [] + + tracker_log_dir = (checkpoint_config.output_dir or os.getcwd()) + if trackers: + tracker_log_dir = os.path.join(tracker_log_dir, "tracker") + + tracker_config_dict = config if trackers else None + tracker_run_name = tracker_config.run_name or None + project = (tracker_config.project_name or "fastvideo") + + return initialize_trackers( + trackers, + experiment_name=project, + config=tracker_config_dict, + log_dir=tracker_log_dir, + run_name=tracker_run_name, + ) diff --git a/fastvideo/train/utils/training_config.py b/fastvideo/train/utils/training_config.py new file mode 100644 index 000000000..0167751db --- /dev/null +++ b/fastvideo/train/utils/training_config.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Typed training config — replaces TrainingArgs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.configs.pipelines.base import PipelineConfig + + +@dataclass(slots=True) +class DistributedConfig: + num_gpus: int = 1 + tp_size: int = 1 + sp_size: int = 1 + hsdp_replicate_dim: int = 1 + hsdp_shard_dim: int = -1 + pin_cpu_memory: bool = False + + +@dataclass(slots=True) +class DataConfig: + data_path: str = "" + train_batch_size: int = 1 + dataloader_num_workers: int = 0 + training_cfg_rate: float = 0.0 + seed: int = 0 + num_height: int = 0 + num_width: int = 0 + num_latent_t: int = 0 + num_frames: int = 0 + + +@dataclass(slots=True) +class OptimizerConfig: + learning_rate: float = 0.0 + betas: tuple[float, float] = (0.9, 0.999) + weight_decay: float = 0.0 + lr_scheduler: str = "constant" + lr_warmup_steps: int = 0 + lr_num_cycles: int = 0 + lr_power: float = 0.0 + min_lr_ratio: float = 0.5 + + +@dataclass(slots=True) +class TrainingLoopConfig: + max_train_steps: int = 0 + gradient_accumulation_steps: int = 1 + + +@dataclass(slots=True) +class CheckpointConfig: + output_dir: str = "" + resume_from_checkpoint: str = "" + training_state_checkpointing_steps: int = 0 + checkpoints_total_limit: int = 0 + + +@dataclass(slots=True) +class TrackerConfig: + trackers: list[str] = field(default_factory=list) + project_name: str = "fastvideo" + run_name: str = "" + + +@dataclass(slots=True) +class VSAConfig: + sparsity: float = 0.0 + decay_rate: float = 0.0 + decay_interval_steps: int = 0 + + +@dataclass(slots=True) +class ModelTrainingConfig: + weighting_scheme: str = "uniform" + logit_mean: float = 0.0 + logit_std: float = 1.0 + mode_scale: float = 1.0 + precondition_outputs: bool = False + moba_config: dict = field(default_factory=dict) + enable_gradient_checkpointing_type: str | None = None + + +@dataclass(slots=True) +class TrainingConfig: + distributed: DistributedConfig = field(default_factory=DistributedConfig) + data: DataConfig = field(default_factory=DataConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + loop: TrainingLoopConfig = field(default_factory=TrainingLoopConfig) + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + tracker: TrackerConfig = field(default_factory=TrackerConfig) + vsa: VSAConfig = field(default_factory=VSAConfig) + model: ModelTrainingConfig = field(default_factory=ModelTrainingConfig) + pipeline_config: PipelineConfig | None = None + model_path: str = "" + dit_precision: str = "fp32" diff --git a/fastvideo/train/utils/validation.py b/fastvideo/train/utils/validation.py new file mode 100644 index 000000000..5d7722d97 --- /dev/null +++ b/fastvideo/train/utils/validation.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Literal, cast + +from fastvideo.train.utils.config import get_optional_int + + +def is_validation_enabled(cfg: dict[str, Any]) -> bool: + if not cfg: + return False + enabled = cfg.get("enabled") + if enabled is None: + return True + if isinstance(enabled, bool): + return bool(enabled) + raise ValueError("training.validation.enabled must be a bool when set, got " + f"{type(enabled).__name__}") + + +def parse_validation_every_steps(cfg: dict[str, Any]) -> int: + raw = cfg.get("every_steps") + if raw is None: + raise ValueError("training.validation.every_steps must be set when validation is enabled") + if isinstance(raw, bool): + raise ValueError("training.validation.every_steps must be an int, got bool") + if isinstance(raw, int): + return int(raw) + if isinstance(raw, float) and raw.is_integer(): + return int(raw) + if isinstance(raw, str) and raw.strip(): + return int(raw) + raise ValueError("training.validation.every_steps must be an int, got " + f"{type(raw).__name__}") + + +def parse_validation_dataset_file(cfg: dict[str, Any]) -> str: + raw = cfg.get("dataset_file") + if not isinstance(raw, str) or not raw.strip(): + raise ValueError("training.validation.dataset_file must be set when validation is enabled") + return raw.strip() + + +def parse_validation_sampling_steps(cfg: dict[str, Any]) -> list[int]: + raw = cfg.get("sampling_steps") + steps: list[int] = [] + if raw is None or raw == "": + raise ValueError("training.validation.sampling_steps must be set for validation") + if isinstance(raw, bool): + raise ValueError("validation sampling_steps must be an int/list/str, got bool") + if isinstance(raw, int) or (isinstance(raw, float) and raw.is_integer()): + steps = [int(raw)] + elif isinstance(raw, str): + steps = [int(s) for s in raw.split(",") if str(s).strip()] + elif isinstance(raw, list): + steps = [int(s) for s in raw] + else: + raise ValueError("validation sampling_steps must be an int/list/str, got " + f"{type(raw).__name__}") + return [s for s in steps if int(s) > 0] + + +def parse_validation_guidance_scale(cfg: dict[str, Any]) -> float | None: + raw = cfg.get("guidance_scale") + if raw in (None, ""): + return None + if isinstance(raw, bool): + raise ValueError("validation guidance_scale must be a number/string, got bool") + if isinstance(raw, (int, float)): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError("validation guidance_scale must be a number/string, got " + f"{type(raw).__name__}") + + +def parse_validation_sampler_kind( + cfg: dict[str, Any], + *, + default: Literal["ode", "sde"], +) -> Literal["ode", "sde"]: + raw = cfg.get("sampler_kind", default) + if raw is None: + raw = default + if not isinstance(raw, str): + raise ValueError("training.validation.sampler_kind must be a string when set, got " + f"{type(raw).__name__}") + kind = raw.strip().lower() + if kind not in {"ode", "sde"}: + raise ValueError("training.validation.sampler_kind must be one of {ode, sde}, got " + f"{raw!r}") + return cast(Literal["ode", "sde"], kind) + + +def parse_validation_rollout_mode( + cfg: dict[str, Any], + *, + default: Literal["parallel", "streaming"] = "parallel", +) -> Literal["parallel", "streaming"]: + raw = cfg.get("rollout_mode", default) + if raw is None: + raw = default + if not isinstance(raw, str): + raise ValueError("training.validation.rollout_mode must be a string when set, got " + f"{type(raw).__name__}") + mode = raw.strip().lower() + if mode not in {"parallel", "streaming"}: + raise ValueError("training.validation.rollout_mode must be one of {parallel, streaming}, " + f"got {raw!r}") + return cast(Literal["parallel", "streaming"], mode) + + +def parse_validation_ode_solver( + cfg: dict[str, Any], + *, + sampler_kind: Literal["ode", "sde"], +) -> str | None: + raw = cfg.get("ode_solver") + if raw in (None, ""): + return None + if sampler_kind != "ode": + raise ValueError("training.validation.ode_solver is only valid when " + "training.validation.sampler_kind='ode'") + if not isinstance(raw, str): + raise ValueError("training.validation.ode_solver must be a string when set, got " + f"{type(raw).__name__}") + solver = raw.strip().lower() + if solver in {"unipc", "unipc_multistep", "multistep"}: + return "unipc" + if solver in {"euler", "flowmatch", "flowmatch_euler"}: + return "euler" + raise ValueError("training.validation.ode_solver must be one of {unipc, euler}, got " + f"{raw!r}") + + +def parse_validation_output_dir(cfg: dict[str, Any]) -> str | None: + raw = cfg.get("output_dir") + if raw is None: + return None + if not isinstance(raw, str): + raise ValueError("training.validation.output_dir must be a string when set, got " + f"{type(raw).__name__}") + return raw + + +def parse_validation_num_frames(cfg: dict[str, Any]) -> int | None: + num_frames = get_optional_int(cfg, "num_frames", where="training.validation.num_frames") + if num_frames is not None and num_frames <= 0: + raise ValueError("training.validation.num_frames must be > 0 when set") + return num_frames diff --git a/fastvideo/training/checkpointing_utils.py b/fastvideo/training/checkpointing_utils.py index bc6aeed55..a3d4e84e2 100644 --- a/fastvideo/training/checkpointing_utils.py +++ b/fastvideo/training/checkpointing_utils.py @@ -21,10 +21,25 @@ def state_dict(self) -> dict[str, Any]: state_dict = get_model_state_dict( self.model) # type: ignore[no-any-return] # filter out non-trainable parameters - param_requires_grad = set([ - k for k, v in dict(self.model.named_parameters()).items() - if v.requires_grad - ]) + param_requires_grad: set[str] = set() + for name, param in self.model.named_parameters(): + if not bool(param.requires_grad): + continue + param_requires_grad.add(name) + + # Activation checkpointing wraps modules with an internal attribute + # `_checkpoint_wrapped_module`, which changes the *parameter name* + # observed via `named_parameters()`: + # + # named_parameters: blocks.0._checkpoint_wrapped_module.weight + # state_dict: blocks.0.weight + # + # `get_model_state_dict()` returns the unwrapped key names, so we + # also add the unwrapped form for filtering. + if "._checkpoint_wrapped_module." in name: + param_requires_grad.add( + name.replace("._checkpoint_wrapped_module.", ".") + ) state_dict = { k: v for k, v in state_dict.items() if k in param_requires_grad diff --git a/fastvideo/training/distillation_pipeline.py b/fastvideo/training/distillation_pipeline.py index 8abcfa955..e505186bd 100644 --- a/fastvideo/training/distillation_pipeline.py +++ b/fastvideo/training/distillation_pipeline.py @@ -112,7 +112,7 @@ def load_modules(self, if training_args.real_score_model_path: logger.info("Loading real score transformer from: %s", training_args.real_score_model_path) - training_args.override_transformer_cls_name = "WanTransformer3DModel" + # training_args.override_transformer_cls_name = "WanTransformer3DModel" # TODO(will): can use deepcopy instead if the model is the same self.real_score_transformer = self.load_module_from_path( training_args.real_score_model_path, "transformer", @@ -138,7 +138,7 @@ def load_modules(self, if training_args.fake_score_model_path: logger.info("Loading fake score transformer from: %s", training_args.fake_score_model_path) - training_args.override_transformer_cls_name = "WanTransformer3DModel" + # training_args.override_transformer_cls_name = "WanTransformer3DModel" self.fake_score_transformer = self.load_module_from_path( training_args.fake_score_model_path, "transformer", training_args) @@ -1208,7 +1208,8 @@ def _log_validation(self, transformer, training_args, global_step) -> None: training_args.validation_dataset_file, local_main_process_only=False) validation_dataset = ValidationDataset( - training_args.validation_dataset_file) + training_args.validation_dataset_file, + num_samples=training_args.validation_num_samples) validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) @@ -1277,10 +1278,14 @@ def run_validation_with_ema( prompt_embeds=[], prompt_attention_mask=[], ) - result_batch = self.validation_pipeline.prompt_encoding_stage( # type: ignore - batch_negative, training_args) - self.negative_prompt_embeds, self.negative_prompt_attention_mask = result_batch.prompt_embeds[ - 0], result_batch.prompt_attention_mask[0] + if hasattr(self.validation_pipeline, "prompt_encoding_stage"): + result_batch = self.validation_pipeline.prompt_encoding_stage( # type: ignore + batch_negative, training_args) + self.negative_prompt_embeds, self.negative_prompt_attention_mask = result_batch.prompt_embeds[ + 0], result_batch.prompt_attention_mask[0] + else: + self.negative_prompt_embeds = None + self.negative_prompt_attention_mask = None logger.info( "rank: %s: rank_in_sp_group: %s, batch.prompt: %s", @@ -1308,6 +1313,7 @@ def run_validation_with_ema( x = torchvision.utils.make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8)) + frames = self._post_process_validation_frames(frames, batch) videos.append(frames) audios.append(output_batch.extra.get("audio")) audio_sample_rates.append( @@ -1441,6 +1447,8 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: fake_score_log_keys = ['generator_pred_video'] dmd_log_keys = ['faker_score_pred_video', 'real_score_pred_video'] + os.makedirs(training_args.output_dir, exist_ok=True) + for latent_key in fake_score_log_keys: latents = fake_score_latents_vis_dict[latent_key] latents = _prepare_vae_latents(latents) @@ -1460,8 +1468,20 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: video = video.cpu().float() video = video.permute(0, 2, 1, 3, 4) video = (video * 255).numpy().astype(np.uint8) + + video_filename = os.path.join(training_args.output_dir, + f"{latent_key}_step_{step}.mp4") + # [B, T, C, H, W] to [H, W, C] + video_frames = [ + np.transpose(video[0, t], (1, 2, 0)) + for t in range(video.shape[1]) + ] + video_frames = self._post_process_validation_frames( + video_frames, training_batch) + imageio.mimsave(video_filename, video_frames, fps=24) + video_artifact = self.tracker.video( - video, fps=24, format="mp4") # change to 16 for Wan2.1 + video, fps=24, format="mp4", caption=latent_key) # change to 16 for Wan2.1 if video_artifact is not None: tracker_loss_dict[latent_key] = video_artifact # Clean up references @@ -1489,8 +1509,20 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: video = video.cpu().float() video = video.permute(0, 2, 1, 3, 4) video = (video * 255).numpy().astype(np.uint8) + + video_filename = os.path.join(training_args.output_dir, + f"{latent_key}_step_{step}.mp4") + # [B, T, C, H, W] to [H, W, C] + video_frames = [ + np.transpose(video[0, t], (1, 2, 0)) + for t in range(video.shape[1]) + ] + video_frames = self._post_process_validation_frames( + video_frames, training_batch) + imageio.mimsave(video_filename, video_frames, fps=24) + video_artifact = self.tracker.video( - video, fps=24, format="mp4") # change to 16 for Wan2.1 + video, fps=24, format="mp4", caption=latent_key) # change to 16 for Wan2.1 if video_artifact is not None: tracker_loss_dict[latent_key] = video_artifact # Clean up references diff --git a/fastvideo/training/trackers.py b/fastvideo/training/trackers.py index 281d79325..02578c58a 100644 --- a/fastvideo/training/trackers.py +++ b/fastvideo/training/trackers.py @@ -11,6 +11,7 @@ import copy import os import pathlib +import shutil import time from dataclasses import dataclass from enum import Enum @@ -92,6 +93,21 @@ def log_artifacts(self, artifacts: dict[str, Any], step: int) -> None: def finish(self) -> None: # pragma: no cover - interface """Finalize the tracker session.""" + def log_file( + self, + path: str, + *, + name: str | None = None, + ) -> None: + """Log a local file to the tracker run (best-effort). + + Useful for attaching the exact YAML config used for a run. + + Trackers that do not support files should treat this as a no-op. + """ + + del path, name + def video( self, data: Any, @@ -134,12 +150,13 @@ def __init__( import wandb - pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) + self._log_dir = os.path.abspath(str(log_dir)) + pathlib.Path(self._log_dir).mkdir(parents=True, exist_ok=True) self._wandb = wandb self._run = wandb.init( project=experiment_name, - dir=log_dir, + dir=self._log_dir, config=config, name=run_name, ) @@ -154,6 +171,45 @@ def log(self, metrics: dict[str, Any], step: int) -> None: def finish(self) -> None: self._run.finish() + def log_file(self, path: str, *, name: str | None = None) -> None: + resolved = os.path.abspath(os.path.expanduser(str(path))) + if not os.path.isfile(resolved): + logger.warning("W&B log_file skipped; file not found: %s", resolved) + return + + target_name = str(name).strip() if name is not None and str(name).strip() else None + if target_name is None: + target_name = os.path.basename(resolved) + + # Prefer placing files directly under the W&B run directory to avoid + # symlink-based saves (which may not sync reliably on some clusters). + run_dir = getattr(self._run, "dir", None) + dest_root = self._log_dir if not isinstance(run_dir, str) else run_dir + dest_root = os.path.abspath(str(dest_root)) + + save_path = resolved + dest_path = os.path.join(dest_root, target_name) + try: + pathlib.Path(dest_root).mkdir(parents=True, exist_ok=True) + shutil.copyfile(resolved, dest_path) + except Exception: + logger.exception( + "Failed to copy file for W&B upload: %s -> %s", + resolved, + dest_path, + ) + else: + save_path = dest_path + + try: + self._run.save( + save_path, + base_path=os.path.dirname(save_path), + policy="now", + ) + except Exception: + logger.exception("Failed to upload file to W&B: %s", save_path) + def video( self, data: Any, @@ -201,6 +257,10 @@ def log_artifacts(self, artifacts: dict[str, Any], step: int) -> None: tracker.log_artifacts(artifacts, step) self._timed_metrics = {} + def log_file(self, path: str, *, name: str | None = None) -> None: + for tracker in self._trackers: + tracker.log_file(path, name=name) + def finish(self) -> None: for tracker in self._trackers: tracker.finish() diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index d9aa6e227..5a5da77b2 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist import torchvision +from diffusers import FlowMatchEulerDiscreteScheduler from einops import rearrange from torch.utils.data import DataLoader from torchdata.stateful_dataloader import StatefulDataLoader @@ -47,9 +48,9 @@ initialize_trackers, Trackers) from fastvideo.training.training_utils import ( clip_grad_norm_while_handling_failing_dtensor_cases, - compute_density_for_timestep_sampling, count_trainable, get_scheduler, - get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint, - shard_latents_across_sp) + compute_density_for_timestep_sampling, count_trainable, + count_trainable_total, get_scheduler, get_sigmas, load_checkpoint, + normalize_dit_input, save_checkpoint, shard_latents_across_sp) from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict) @@ -117,7 +118,7 @@ def initialize_training_pipeline(self, training_args: TrainingArgs): # Set random seeds for deterministic training assert self.seed is not None, "seed must be set" - set_random_seed(self.seed + self.global_rank) + set_random_seed(self.seed) self.transformer.train() if training_args.enable_gradient_checkpointing_type is not None: self.transformer = apply_activation_checkpointing( @@ -193,7 +194,8 @@ def initialize_training_pipeline(self, training_args: TrainingArgs): text_padding_length=training_args.pipeline_config. text_encoder_configs[0].arch_config. text_len, # type: ignore[attr-defined] - seed=self.seed) + seed=self.seed, + reshuffle_each_epoch=training_args.reshuffle_each_epoch) self.noise_scheduler = noise_scheduler if self.training_args.boundary_ratio is not None: @@ -257,6 +259,8 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: if batch is None: self.current_epoch += 1 logger.info("Starting epoch %s", self.current_epoch) + # Reshuffle dataset order each epoch + self.train_dataset.sampler.set_epoch(self.current_epoch) # Reset iterator for next epoch self.train_loader_iter = iter(self.train_dataloader) # Get first batch of new epoch @@ -586,26 +590,40 @@ def train(self) -> None: local_main_process_only=False) if not self.post_init_called: self.post_init() - num_trainable_params = count_trainable(self.transformer) - logger.info("Starting training with %s B trainable parameters", - round(num_trainable_params / 1e9, 3)) + local_trainable = count_trainable(self.transformer) + total_trainable = count_trainable_total( + self.transformer, + get_local_torch_device(), + ) + logger.info( + "Starting training with %s B trainable parameters (total); " + "this rank shard: %s B", + round(total_trainable / 1e9, 3), + round(local_trainable / 1e9, 3), + ) if getattr(self, "transformer_2", None) is not None: - num_trainable_params = count_trainable(self.transformer_2) + local_trainable_2 = count_trainable(self.transformer_2) + total_trainable_2 = count_trainable_total( + self.transformer_2, + get_local_torch_device(), + ) logger.info( - "Transformer 2: Starting training with %s B trainable parameters", - round(num_trainable_params / 1e9, 3)) + "Transformer 2: %s B trainable parameters (total); " + "this rank shard: %s B", + round(total_trainable_2 / 1e9, 3), + round(local_trainable_2 / 1e9, 3), + ) # Set random seeds for deterministic training - self.noise_random_generator = torch.Generator( - device="cpu").manual_seed(self.seed + self.global_rank) + self.noise_random_generator = torch.Generator(device="cpu").manual_seed( + self.seed) self.noise_gen_cuda = torch.Generator( - device=current_platform.device_name).manual_seed(self.seed + - self.global_rank) + device=current_platform.device_name).manual_seed(self.seed) self.validation_random_generator = torch.Generator( - device="cpu").manual_seed(self.seed + self.global_rank) - logger.info("Initialized random seeds with seed: %s", - self.seed + self.global_rank) + device="cpu").manual_seed(self.seed) + logger.info("Initialized random seeds with seed: %s", self.seed) + self.noise_scheduler = FlowMatchEulerDiscreteScheduler() if self.training_args.resume_from_checkpoint: @@ -617,6 +635,9 @@ def train(self) -> None: self._log_training_info() + self._best_mf_angle_err_mean = float('inf') + self._last_mf_angle_err_mean = float('inf') + self._log_validation(self.transformer, self.training_args, self.init_steps) @@ -726,6 +747,43 @@ def train(self) -> None: "GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage, trainable_params) + best_start = self.training_args.best_checkpoint_start_step + if (best_start > 0 + and step >= best_start + and self._last_mf_angle_err_mean + < self._best_mf_angle_err_mean): + self._best_mf_angle_err_mean = ( + self._last_mf_angle_err_mean) + logger.info( + "New best mf_angle_err_mean=%.6f at step %d, " + "saving best checkpoint", + self._best_mf_angle_err_mean, step) + save_checkpoint( + self.transformer, self.global_rank, + self.training_args.output_dir, "best", + self.optimizer, self.train_dataloader, + self.lr_scheduler, + self.noise_random_generator) + if self.global_rank == 0: + import json + meta_path = os.path.join( + self.training_args.output_dir, + "checkpoint-best", + "best_metric.json") + with open(meta_path, "w") as f: + json.dump({ + "step": step, + "mf_angle_err_mean": + self._best_mf_angle_err_mean, + }, f, indent=2) + self.tracker.log({ + "best/mf_angle_err_mean": + self._best_mf_angle_err_mean, + "best/step": step, + }, step) + self.transformer.train() + self.sp_group.barrier() + self.tracker.finish() save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, @@ -808,11 +866,42 @@ def _prepare_validation_batch(self, sampling_param: SamplingParam, return batch + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Post-process validation frames before saving. + + Override this method in subclasses to add custom processing, + e.g., overlay action indicators for action-conditioned models. + + Args: + frames: List of numpy arrays (H, W, C) representing video frames + batch: The ForwardBatch containing input data (may include action data) + + Returns: + Processed frames (same format as input) + """ + return frames + + def _evaluate_validation_video( + self, + video_path: str, + caption: str, + action_path: str | None, + global_step: int, + num_inference_steps: int, + ) -> dict[str, float] | None: + """Optionally evaluate a saved validation video and return scalars.""" + del video_path, caption, action_path, global_step + del num_inference_steps + return None + @torch.no_grad() def _log_validation(self, transformer, training_args, global_step) -> None: """ Generate a validation video and log it to the configured tracker to check the quality during training. """ + self._last_mf_angle_err_mean = float('inf') training_args.inference_mode = True training_args.dit_cpu_offload = False if not training_args.log_validation: @@ -831,7 +920,8 @@ def _log_validation(self, transformer, training_args, global_step) -> None: training_args.validation_dataset_file, local_main_process_only=False) validation_dataset = ValidationDataset( - training_args.validation_dataset_file) + training_args.validation_dataset_file, + num_samples=training_args.validation_num_samples) validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) @@ -855,15 +945,16 @@ def _log_validation(self, transformer, training_args, global_step) -> None: local_main_process_only=False) step_videos: list[np.ndarray] = [] step_captions: list[str] = [] - - step_audio: list[np.ndarray | None] = [] - step_sample_rates: list[int | None] = [] + step_action_paths: list[str | None] = [] for validation_batch in validation_dataloader: batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch, num_inference_steps) + action_path = validation_batch.get("action_path") + if not isinstance(action_path, str): + action_path = None logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s", self.global_rank, self.rank_in_sp_group, @@ -899,75 +990,126 @@ def _log_validation(self, transformer, training_args, global_step) -> None: x = torchvision.utils.make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8)) + + # Apply optional post-processing (e.g., overlay for action-conditioned models) + frames = self._post_process_validation_frames(frames, batch) step_videos.append(frames) + step_action_paths.append(action_path) # Only sp_group leaders (rank_in_sp_group == 0) need to send their # results to global rank 0 - if self.rank_in_sp_group == 0 and self.global_rank == 0: - # Global rank 0 collects results from all sp_group leaders - all_videos = step_videos # Start with own results - all_captions = step_captions - all_audios = step_audio - all_sample_rates = step_sample_rates - - # Receive from other sp_group leaders - for sp_group_idx in range(1, num_sp_groups): - src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders - recv_videos = world_group.recv_object(src=src_rank) - recv_captions = world_group.recv_object(src=src_rank) - recv_audios = world_group.recv_object(src=src_rank) - recv_sample_rates = world_group.recv_object(src=src_rank) - - all_videos.extend(recv_videos) - all_captions.extend(recv_captions) - all_audios.extend(recv_audios) - all_sample_rates.extend(recv_sample_rates) - - video_filenames = [] - for i, (video, caption, audio, sample_rate) in enumerate( - zip(all_videos, - all_captions, - all_audios, - all_sample_rates, + if self.rank_in_sp_group == 0: + local_video_filenames: list[str] = [] + local_validation_metrics: list[dict[str, float]] = [] + local_eval_error: str | None = None + + for i, (video, caption, action_path) in enumerate( + zip(step_videos, + step_captions, + step_action_paths, strict=True)): os.makedirs(training_args.output_dir, exist_ok=True) filename = os.path.join( training_args.output_dir, - f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4" + f"validation_step_{global_step}_inference_steps_{num_inference_steps}_rank_{self.global_rank}_video_{i}.mp4" ) imageio.mimsave(filename, video, fps=sampling_param.fps) - # Mux audio if available - if (audio is not None and sample_rate is not None - and not self._mux_audio( - filename, - audio, - sample_rate, - )): - logger.warning( - "Audio mux failed for validation video %s; saved video without audio.", - filename) - video_filenames.append(filename) - - artifacts = [] - for filename, caption in zip(video_filenames, - all_captions, - strict=True): - video_artifact = self.tracker.video(filename, - caption=caption) - if video_artifact is not None: - artifacts.append(video_artifact) - if artifacts: - logs = { - f"validation_videos_{num_inference_steps}_steps": - artifacts - } - self.tracker.log_artifacts(logs, global_step) - elif self.rank_in_sp_group == 0: - # Other sp_group leaders send their results to global rank 0 - world_group.send_object(step_videos, dst=0) - world_group.send_object(step_captions, dst=0) - world_group.send_object(step_audio, dst=0) - world_group.send_object(step_sample_rates, dst=0) + local_video_filenames.append(filename) + + try: + sample_metrics = self._evaluate_validation_video( + video_path=filename, + caption=caption, + action_path=action_path, + global_step=global_step, + num_inference_steps=num_inference_steps, + ) + if sample_metrics: + local_validation_metrics.append(sample_metrics) + except Exception as e: + local_eval_error = ( + f"rank {self.global_rank} validation eval failed " + f"for {filename}: {e}") + logger.exception(local_eval_error) + break + + if self.global_rank == 0: + all_video_filenames = local_video_filenames + all_captions = step_captions + validation_metrics = local_validation_metrics + eval_errors: list[str] = [] + if local_eval_error: + eval_errors.append(local_eval_error) + + # Receive from other sp_group leaders + for sp_group_idx in range(1, num_sp_groups): + src_rank = sp_group_idx * self.sp_world_size + recv_video_filenames = world_group.recv_object( + src=src_rank) + recv_captions = world_group.recv_object(src=src_rank) + recv_metrics = world_group.recv_object(src=src_rank) + recv_error = world_group.recv_object(src=src_rank) + + all_video_filenames.extend(recv_video_filenames) + all_captions.extend(recv_captions) + validation_metrics.extend(recv_metrics) + if recv_error: + eval_errors.append(str(recv_error)) + + if eval_errors: + raise RuntimeError( + "Validation flow evaluation failed:\n" + + "\n".join(eval_errors)) + + artifacts = [] + for filename, caption in zip(all_video_filenames, + all_captions, + strict=True): + video_artifact = self.tracker.video(filename, + caption=caption) + if video_artifact is not None: + artifacts.append(video_artifact) + if artifacts: + logs = { + f"validation_videos_{num_inference_steps}_steps": + artifacts + } + self.tracker.log_artifacts(logs, global_step) + + if validation_metrics: + metric_logs: dict[str, float] = {} + metric_keys = sorted( + {k for row in validation_metrics for k in row.keys()}) + for metric_key in metric_keys: + metric_vals = [ + row[metric_key] for row in validation_metrics + if metric_key in row + and np.isfinite(row[metric_key]) + ] + if not metric_vals: + continue + metric_logs[f"metrics/{metric_key}"] = float( + np.mean(metric_vals)) + self.tracker.log(metric_logs, global_step) + + mf_val = metric_logs.get( + "metrics/mf_angle_err_mean") + if mf_val is not None: + self._last_mf_angle_err_mean = mf_val + else: + # Other sp_group leaders send their local results to rank 0 + world_group.send_object(local_video_filenames, dst=0) + world_group.send_object(step_captions, dst=0) + world_group.send_object(local_validation_metrics, dst=0) + world_group.send_object(local_eval_error, dst=0) + if local_eval_error: + raise RuntimeError(local_eval_error) + + # Broadcast the latest mf_angle_err_mean from rank 0 to all ranks + _mf_tensor = torch.tensor( + [self._last_mf_angle_err_mean], device=self.device) + dist.broadcast(_mf_tensor, src=0) + self._last_mf_angle_err_mean = _mf_tensor.item() # Re-enable gradients for training training_args.inference_mode = False diff --git a/fastvideo/training/training_utils.py b/fastvideo/training/training_utils.py index 1291ffcfe..c97e9f835 100644 --- a/fastvideo/training/training_utils.py +++ b/fastvideo/training/training_utils.py @@ -1770,9 +1770,27 @@ def _local_numel(p: torch.Tensor) -> int: def count_trainable(model: torch.nn.Module) -> int: + """Return this rank's trainable parameter count (FSDP local shard).""" return sum(_local_numel(p) for p in model.parameters() if p.requires_grad) +def count_trainable_total( + model: torch.nn.Module, + device: torch.device | None = None, +) -> int: + """Return total trainable parameter count across all ranks (FSDP-safe). + + When device is provided and dist is initialized, torch.distributed.all_reduce(SUM) + with the default world group is used. Otherwise returns local count. + """ + local = count_trainable(model) + if device is not None and dist.is_initialized(): + t = torch.tensor([local], dtype=torch.long, device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t.item() + return local + + class EMA_FSDP: """ FSDP2-friendly EMA with two modes: diff --git a/fastvideo/utils.py b/fastvideo/utils.py index d3efd69c8..d722f4054 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -935,7 +935,7 @@ def save_decoded_latents_as_video(decoded_latents: list[torch.Tensor], for x in videos: x = make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) - frames.append((x * 255).numpy().astype(np.uint8)) + frames.append((x * 255).cpu().numpy().astype(np.uint8)) os.makedirs(os.path.dirname(output_path), exist_ok=True) imageio.mimsave(output_path, frames, fps=fps, format="mp4")