Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ enable_profiler: False
enable_ml_diagnostics: True
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
enable_ondemand_xprof: True
skip_first_n_steps_for_profiler: 0
profiler_steps: 5

replicate_vae: False

Expand Down
112 changes: 84 additions & 28 deletions src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log("Could not retrieve Git commit hash.")

checkpoint_loader = LTX2Checkpointer(config=config)
load_time = 0.0
if pipeline is None:
t0_load = time.perf_counter()
# Use the config flag to determine if the upsampler should be loaded
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
Expand Down Expand Up @@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
scan_layers=config.scan_layers,
dtype=config.weights_dtype,
)
load_time = time.perf_counter() - t0_load

pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
Expand All @@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
)

out = call_pipeline(config, pipeline, prompt, negative_prompt)

# out should have .frames and .audio
videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None

max_logging.log("===================== Model details =======================")
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
Expand All @@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
max_logging.log("============================================================")

original_enable_profiler = config.get_keys().get("enable_profiler", False)
original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False)
original_num_steps = config.get_keys().get("num_inference_steps", 40)

# ---------------------------------------------------------
# Run 1: Warmup Compilation (Original steps, NO profiling)
# ---------------------------------------------------------
config.get_keys()["enable_profiler"] = False
config.get_keys()["enable_ml_diagnostics"] = False

max_logging.log(f"🚀 Starting warmup compilation pass ({original_num_steps} steps)...")
_ = call_pipeline(config, pipeline, prompt, negative_prompt)

compile_time = time.perf_counter() - s0
max_logging.log(f"compile_time: {compile_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/compile_time", compile_time, global_step=0)

# ---------------------------------------------------------
# Run 2: Actual Generation (Original steps, NO profiling)
# ---------------------------------------------------------

s0 = time.perf_counter()
max_logging.log("🚀 Starting actual full-length generation pass...")
out = call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time = time.perf_counter() - s0
max_logging.log(f"generation_time: {generation_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
num_devices = jax.device_count()
num_videos = num_devices * config.per_device_batch_size
if num_videos > 0:
generation_time_per_video = generation_time / num_videos
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")

# out should have .frames and .audio
videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None

saved_video_path = []
audio_sample_rate = (
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000
Expand All @@ -210,29 +244,51 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)

s0 = time.perf_counter()
call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time = time.perf_counter() - s0
max_logging.log(f"generation_time: {generation_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
num_devices = jax.device_count()
num_videos = num_devices * config.per_device_batch_size
if num_videos > 0:
generation_time_per_video = generation_time / num_videos
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
max_logging.log(
f"\n{'=' * 50}\n"
f" TIMING SUMMARY\n"
f"{'=' * 50}\n"
f" Load (checkpoint): {load_time:>7.1f}s\n"
f" Compile: {compile_time:>7.1f}s\n"
f" {'─' * 40}\n"
f" Inference: {generation_time:>7.1f}s\n"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to print a component wise split here for quick analysis now that we are timing all the components?

Elisa has done something like this for the WAN pipelines here

f"{'=' * 50}"
)

s0 = time.perf_counter()
if max_utils.profiler_enabled(config):
with max_utils.Profiler(config):
call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
# Free memory before profiling
del out
del videos
del audios

# ---------------------------------------------------------
# Run 3: Profiling Run (Only if profiling was originally enabled)
# ---------------------------------------------------------
if original_enable_profiler or original_enable_mld:
skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0)
if skip_first_n_steps_for_profiler != 0:
max_logging.log(
"\n⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n"
)

profiling_steps = config.get_keys().get("profiler_steps", 5)

config.get_keys()["enable_profiler"] = False
config.get_keys()["enable_ml_diagnostics"] = False
config.get_keys()["num_inference_steps"] = profiling_steps

max_logging.log(f"🚀 Warmup for profiling pass ({profiling_steps} steps)...")
_ = call_pipeline(config, pipeline, prompt, negative_prompt)

config.get_keys()["enable_profiler"] = original_enable_profiler
config.get_keys()["enable_ml_diagnostics"] = original_enable_mld

max_logging.log(f"🚀 Starting Profiling run ({profiling_steps} steps)...")
profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps")
profiler.start()

_ = call_pipeline(config, pipeline, prompt, negative_prompt)

profiler.stop()

return saved_video_path

Expand Down
Loading
Loading