From efaad1b357c8a064a035812fe1ed78d3dab4d49b Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Mon, 22 Jun 2026 18:57:44 +0000 Subject: [PATCH] Static padding and insert video --- src/maxtext/common/common_types.py | 1 + .../sft-vision-llava-video-178k.yml | 38 +++ src/maxtext/configs/types.py | 5 + .../input_pipeline/hf_data_processing.py | 31 +++ .../input_pipeline/input_pipeline_utils.py | 125 ++++++++- src/maxtext/layers/decoders.py | 2 +- src/maxtext/layers/encoders.py | 13 + src/maxtext/models/models.py | 4 + src/maxtext/models/qwen3.py | 2 +- src/maxtext/multimodal/processor.py | 26 +- .../multimodal/processor_qwen3_omni.py | 257 ++++++++++++++++-- src/maxtext/multimodal/utils.py | 113 +++++--- src/maxtext/trainers/pre_train/train.py | 14 +- src/maxtext/utils/maxtext_utils.py | 22 ++ .../download_hf_multimodal_dataset.py | 208 ++++++++++++++ 15 files changed, 788 insertions(+), 73 deletions(-) create mode 100644 src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml create mode 100644 tools/data_generation/download_hf_multimodal_dataset.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..45d07e72fa 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -84,6 +84,7 @@ class MultimodalInput: video_masks: Array | None = None audio_embeddings: Array | None = None audio_masks: Array | None = None + audio_token_masks: Array | None = None bidirectional_mask: Array | None = None bidirectional_mask_video: Array | None = None diff --git a/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml new file mode 100644 index 0000000000..f0c5b15aff --- /dev/null +++ b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml @@ -0,0 +1,38 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_config: "base.yml" + +use_sft: true +use_tunix_gradient_accumulation: true +use_multimodal: true +sft_train_on_completion_only: true +packing: false # packing is not supported yet +freeze_vision_encoder_params: true +learning_rate: 2.e-5 + +# -------------- Model -------------- +model_name: "qwen3-omni-30b-a3b" +tokenizer_path: "Qwen/Qwen3-Omni-30B-A3B-Instruct" + +# -------------- HF pipeline -------------- +dataset_type: "hf" +hf_path: "parquet" +hf_train_files: "gs://YOUR_BUCKET/path/to/parquet/*.parquet" +train_split: "train" +train_data_columns: ["query", "label"] +train_image_column: "video" + +# Local SSD path for videos on the TPU VM +video_directory: "/path/to/video_directory" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..2747944c03 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1872,6 +1872,7 @@ class MultimodalGeneral(BaseModel): description="Maximum number of images per example for training with image lists. -1 means no limit.", ) video_path: PathStr = Field("", description="Path to a video for decoding.") + video_directory: PathStr = Field("", description="Local directory path containing video files for SFT.") audio_path: PathStr = Field("", description="Path to an audio file for decoding.") video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.") audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.") @@ -1879,6 +1880,10 @@ class MultimodalGeneral(BaseModel): use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.") mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.") position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).") + filter_sft_sequences_by_length: bool = Field( + False, + description="Filter out multimodal SFT sequences that exceed max_prefill_predict_length or max_target_length.", + ) class VisionTower(BaseModel): diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index 370f1895bd..347dcb321d 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -55,6 +55,25 @@ def vision_sft_preprocessing_pipeline( """pipeline for multimodal SFT with HF dataset""" assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}" + + # Format conversations if columns are missing + features_keys = list(dataset.features.keys()) if dataset.features else [] + if "conversations" in features_keys and not all(col in features_keys for col in text_columns): + def format_llava_video_dataset(example): + conversations = example["conversations"] + query = "" + label = "" + for turn in conversations: + if turn["from"] == "human" and not query: + query = turn["value"] + elif turn["from"] == "gpt" and not label: + label = turn["value"] + example[text_columns[0]] = query + example[text_columns[1]] = label + return example + + dataset = dataset.map(format_llava_video_dataset) + # Tunix GA requires per-micro-batch slicing at the data level, # whereas Native GA processes the full batch and splits it internally. if config.elastic_enabled: @@ -137,6 +156,18 @@ def vision_sft_preprocessing_pipeline( fn_kwargs={"column_name": text_columns[0], "config": config}, ) + # Filter out sequences exceeding max_prefill_predict_length or max_target_length + if getattr(config, "filter_sft_sequences_by_length", False): + max_prefill = getattr(config, "max_prefill_predict_length", 8192) + max_target = getattr(config, "max_target_length", 8192 + 512) + + def filter_by_length(example): + prefill_len = len(example[text_columns[0]]) + response_len = len(example[text_columns[1]]) + return (prefill_len <= max_prefill) and (prefill_len + response_len <= max_target) + + dataset = dataset.filter(filter_by_length) + dataset = input_pipeline_utils.HFDataSource( dataset=dataset, dataloading_host_index=dataloading_host_index, diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 621b79bb47..c944a96186 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -92,17 +92,25 @@ def _process_string(string_tensor): def reformat_prompt(example, column, image_placeholder, model_name): """reformat prompt for multimodal SFT""" - if isinstance(example["images"], list): - num_images = len(example["images"]) + if isinstance(example["images"], str): + example[column] = mm_processor.reformat_prompt( + example[column], image_placeholder, model_name, num_images=0, video_placeholder=image_placeholder, num_videos=1 + ) else: - num_images = 1 - example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images) + if isinstance(example["images"], list): + num_images = len(example["images"]) + else: + num_images = 1 + example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images) return example def reformat_response(example, column, model_name): """reformat response for multimodal SFT""" - example[column] = mm_processor.reformat_response(example[column][0], model_name) + val = example[column] + if isinstance(val, (list, tuple)) and len(val) > 0: + val = val[0] + example[column] = mm_processor.reformat_response(val, model_name) return example @@ -120,9 +128,17 @@ def merge_image_columns(example, image_columns, max_num_images_per_example): def pre_process_image_sft(example, image_column, config): - """pre-process image for multimodal SFT""" + """pre-process image or video for multimodal SFT""" def _process_image_fn(image): + if isinstance(image, str): + import os + + video_directory = getattr(config, "video_directory", "") + if video_directory: + image = os.path.join(video_directory, image) + return mm_processor.preprocess_image_for_training(image, config) + if isinstance(image, list): image = [np.array(mm_utils.convert_to_RGB(img)) for img in image] else: @@ -131,7 +147,7 @@ def _process_image_fn(image): image = mm_processor.preprocess_image_for_training(image, config) return image - example[image_column] = _process_image_fn(example[image_column]) + example[image_column] = _process_image_fn(example[image_column]) if example.get(image_column) is not None else None return example @@ -702,12 +718,80 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) - if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput): raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}") - if preprocessed_image.pixel_values is None: - raise ValueError("Input preprocessed_image must have pixel_values to pad images.") - if self.config.model_name and self.config.model_name.startswith("qwen3-omni"): + # Pad video_values and audio_values to fixed shapes so grain.Batch can stack them. + video_values = getattr(preprocessed_image, "video_values", None) + video_grid_thw = getattr(preprocessed_image, "video_grid_thw", None) + if video_values is not None: + target_shape = mm_processor.get_dummy_video_shape_for_init( + self.config.model_name, batch_size=1, config=self.config + ) + # target_shape = (1, C, max_T_px, max_H_px, max_W_px) + padded = np.zeros(target_shape[1:], dtype=video_values.dtype) # (C, max_T, max_H, max_W) + _, c, t, h, w = video_values.shape + max_t_px, max_h_px, max_w_px = target_shape[2], target_shape[3], target_shape[4] + t_clip = min(t, max_t_px) + h_clip = min(h, max_h_px) + w_clip = min(w, max_w_px) + padded[:, :t_clip, :h_clip, :w_clip] = video_values[0, :, :t_clip, :h_clip, :w_clip] + preprocessed_image.video_values = padded + + if video_grid_thw is not None: + from maxtext.multimodal.processor_qwen3_omni import VIDEO_MAX_GRID_T, VIDEO_MAX_GRID_H, VIDEO_MAX_GRID_W + merge_size = getattr(self.config, "spatial_merge_size_for_vit", 2) + max_t = VIDEO_MAX_GRID_T + max_h_merged = VIDEO_MAX_GRID_H // merge_size + max_w_merged = VIDEO_MAX_GRID_W // merge_size + + actual_t, actual_h, actual_w = video_grid_thw[0] + actual_t = min(actual_t, VIDEO_MAX_GRID_T) + actual_h = min(actual_h, VIDEO_MAX_GRID_H) + actual_w = min(actual_w, VIDEO_MAX_GRID_W) + + actual_h_merged = actual_h // merge_size + actual_w_merged = actual_w // merge_size + + mask_3d = np.zeros((max_t, max_h_merged, max_w_merged), dtype=np.int32) + mask_3d[:actual_t, :actual_h_merged, :actual_w_merged] = 1 + preprocessed_image.video_mask = mask_3d.flatten() + + print( + f"[SFT_DEBUG] Padding Video: Original Pixel Shape: {video_values.shape}, Padded Pixel Shape: {padded.shape}. " + f"Original Grid (THW): {video_grid_thw[0]}, Clipped Grid: [{actual_t}, {actual_h}, {actual_w}]. " + f"Video Mask Shape (flattened): {preprocessed_image.video_mask.shape}, Valid tokens count: {np.sum(preprocessed_image.video_mask)}" + ) + + audio_values = getattr(preprocessed_image, "audio_values", None) + audio_lengths = getattr(preprocessed_image, "audio_lengths", None) + if audio_values is not None: + target_audio = mm_processor.get_dummy_audio_shape_for_sft( + self.config.model_name, batch_size=1, config=self.config + ) + # target_audio = (1, num_mel_bins, AUDIO_MAX_TIME) + _, mel, t_audio = audio_values.shape + padded_audio = np.zeros(target_audio[1:], dtype=audio_values.dtype) # (mel, max_time) + padded_audio[:, :t_audio] = audio_values[0] + preprocessed_image.audio_values = padded_audio + + if audio_lengths is not None: + from maxtext.multimodal.processor_qwen3_omni import AUDIO_MAX_TIME, _get_feat_extract_output_lengths + max_audio_tokens = _get_feat_extract_output_lengths(AUDIO_MAX_TIME) + actual_audio_tokens = audio_lengths[0] + + audio_token_mask = np.zeros(max_audio_tokens, dtype=np.int32) + audio_token_mask[:actual_audio_tokens] = 1 + preprocessed_image.audio_token_mask = audio_token_mask + + print( + f"[SFT_DEBUG] Padding Audio: Original Mel Shape: {audio_values.shape}, Padded Mel Shape: {padded_audio.shape}. " + f"Audio Mask Shape: {preprocessed_image.audio_token_mask.shape}, Valid audio tokens count: {np.sum(preprocessed_image.audio_token_mask)}" + ) + return preprocessed_image + if preprocessed_image.pixel_values is None: + raise ValueError("Input preprocessed_image must have pixel_values to pad images.") + # Determine the maximum number of images/masks allowed. image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image) single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0] @@ -812,6 +896,27 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]: if preprocessed_image.pixel_mask is not None: output["image_masks"] = preprocessed_image.pixel_mask + # Extract video and audio tensors from Qwen3OmniPreprocessorOutput. + video_values = getattr(preprocessed_image, "video_values", None) + if video_values is not None: + output["videos"] = video_values + video_grid_thw = getattr(preprocessed_image, "video_grid_thw", None) + if video_grid_thw is not None: + output["video_grid_thw"] = video_grid_thw + video_mask = getattr(preprocessed_image, "video_mask", None) + if video_mask is not None: + output["video_masks"] = video_mask + + audio_values = getattr(preprocessed_image, "audio_values", None) + if audio_values is not None: + output["audios"] = audio_values + audio_lengths = getattr(preprocessed_image, "audio_lengths", None) + if audio_lengths is not None: + output["audio_lengths"] = audio_lengths + audio_token_mask = getattr(preprocessed_image, "audio_token_mask", None) + if audio_token_mask is not None: + output["audio_token_masks"] = audio_token_mask + return output diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index b28b6dcb7a..8e6bd9ed1e 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -752,7 +752,7 @@ def _apply_embedding( text_embeddings=y, multimodal_embeddings=audio_embeddings, mask=audio_masks, - token_masks=None, + token_masks=getattr(multimodal_input, "audio_token_masks", None), ) else: raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") diff --git a/src/maxtext/layers/encoders.py b/src/maxtext/layers/encoders.py index 0db94ad9d6..f7b831e784 100644 --- a/src/maxtext/layers/encoders.py +++ b/src/maxtext/layers/encoders.py @@ -94,6 +94,13 @@ def __call__(self, input_images, deterministic=False): else: embeddings = encoder_output + if self.config.model_name in ["qwen3-omni-30b-a3b"]: + jax.debug.print( + "[SFT_DEBUG] VisionEncoder: Input Shape: {x}, Encoder Output Shape: {y}", + x=input_images.shape, + y=embeddings.shape + ) + if self.config.freeze_vision_encoder_params: embeddings = jax.lax.stop_gradient(embeddings) if deep_feats is not None: @@ -103,6 +110,12 @@ def __call__(self, input_images, deterministic=False): projector = getattr(self, self.projector_name) embeddings = projector(embeddings) + if self.config.model_name in ["qwen3-omni-30b-a3b"]: + jax.debug.print( + "[SFT_DEBUG] VisionEncoder: Projector Output Shape: {x}", + x=embeddings.shape + ) + return embeddings, deep_feats diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index ac908c0f96..00b2842586 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -130,6 +130,7 @@ def __call__( encoder_videos: None | jnp.ndarray = None, encoder_video_masks: None | jnp.ndarray = None, encoder_audios: None | jnp.ndarray = None, + encoder_audio_token_masks: None | jnp.ndarray = None, enable_dropout=True, model_mode=MODEL_MODE_TRAIN, previous_chunk=None, @@ -195,6 +196,7 @@ def __call__( video_masks=encoder_video_masks, audio_embeddings=audio_embeddings, audio_masks=audio_masks, + audio_token_masks=encoder_audio_token_masks, bidirectional_mask=bidirectional_mask_image, bidirectional_mask_video=bidirectional_mask_video, ) @@ -443,6 +445,7 @@ def __call__( encoder_videos: jax.Array | None = None, encoder_video_masks: jax.Array | None = None, encoder_audios: jax.Array | None = None, + encoder_audio_token_masks: jax.Array | None = None, enable_dropout=True, model_mode=MODEL_MODE_TRAIN, previous_chunk=None, @@ -522,6 +525,7 @@ def __call__( video_masks=encoder_video_masks, audio_embeddings=audio_embeddings, audio_masks=audio_masks, + audio_token_masks=encoder_audio_token_masks, bidirectional_mask=bidirectional_mask_image, bidirectional_mask_video=bidirectional_mask_video, ) diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index 3ddcb7ed12..7e9ef3ec49 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -1906,7 +1906,7 @@ def __init__(self, config: Config, *, mesh=None, rngs: nnx.Rngs = None): num_kv_heads=self.config.num_attention_heads_for_vit, head_dim=head_dim, max_target_length=self.config.num_position_embeddings_for_vit, - attention_kernel="dot_product", + attention_kernel="autoselected", inputs_q_shape=(1, 1, self.config.hidden_size_for_vit), inputs_kv_shape=(1, 1, self.config.hidden_size_for_vit), float32_qk_product=self.config.float32_qk_product, diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 7c99800f2a..211d47f002 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -69,9 +69,13 @@ def preprocess_image_for_training(image, config): return preprocess_mm_data_llama4(image) elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: - from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel + from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training, preprocess_mm_data_qwen3_omni_for_training_video # pylint: disable=import-outside-toplevel - return preprocess_mm_data_qwen3_omni_for_training(image, config) + if isinstance(image, str): + use_audio_in_video = getattr(config, "use_audio_in_video", False) + return preprocess_mm_data_qwen3_omni_for_training_video(image, config) + else: + return preprocess_mm_data_qwen3_omni_for_training(image, config) else: raise ValueError(f"Model {config.model_name} not supported for image preprocessing.") @@ -188,6 +192,24 @@ def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_seque return image_shape +def get_dummy_video_shape_for_init(model_name, batch_size=1, config=None): + """Return the fixed padded shape for video batch tensors used in SFT training.""" + if model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + from maxtext.multimodal.processor_qwen3_omni import get_dummy_video_shape_for_init_qwen3_omni # pylint: disable=import-outside-toplevel + + return get_dummy_video_shape_for_init_qwen3_omni(batch_size, config) + return () + + +def get_dummy_audio_shape_for_sft(model_name, batch_size=1, config=None): + """Return the fixed padded shape for audio batch tensors used in SFT training.""" + if model_name in ["qwen3-omni-30b-a3b"]: + from maxtext.multimodal.processor_qwen3_omni import get_dummy_audio_shape_for_init_qwen3_omni_sft # pylint: disable=import-outside-toplevel + + return get_dummy_audio_shape_for_init_qwen3_omni_sft(batch_size, config) + return () + + def get_dummy_audio_shape_for_init(config): """Return the shape of the dummy audio for specific model's initialization. diff --git a/src/maxtext/multimodal/processor_qwen3_omni.py b/src/maxtext/multimodal/processor_qwen3_omni.py index b29b8acc84..6d4c54ae8f 100644 --- a/src/maxtext/multimodal/processor_qwen3_omni.py +++ b/src/maxtext/multimodal/processor_qwen3_omni.py @@ -19,6 +19,7 @@ import math import os +import logging from dataclasses import dataclass import numpy as np @@ -58,7 +59,18 @@ DITHER = 0.0 # Amount of dithering to apply to audio signal. QWEN3_TEMPORAL_PATCH_SIZE = 2 -QWEN3_OMNI_IMAGE_SIZE = 768 +QWEN3_OMNI_IMAGE_SIZE = 512 + +# Max grid sizes for fixed-shape video batching (used for padding before grain.Batch). +# These bound the T, H, W patch dimensions. Derived from VIDEO_TOTAL_PIXELS / typical FPS: +# max_grid_t = 32 (64 frames / temporal_patch_size=2) +# max_grid_h = 32 (392 pixels; h*w ≤ VIDEO_MAX_PIXELS/patch_size² = 768 patches) +# max_grid_w = 32 (392 pixels) +VIDEO_MAX_GRID_T = 32 +VIDEO_MAX_GRID_H = 32 +VIDEO_MAX_GRID_W = 32 +# Max audio time steps for fixed-shape audio batching (30 s @ 16 kHz / hop 160 ≈ 3000). +AUDIO_MAX_TIME = 3000 QWEN_SPECIAL_TOKEN_CONFIGS = { @@ -129,11 +141,13 @@ class Qwen3OmniPreprocessorOutput(mm_utils.PreprocessorOutput): video_values: None | np.ndarray = None video_grid_thw: None | np.ndarray = None video_second_per_grid: None | np.ndarray = None + video_mask: None | np.ndarray = None # Audio attributes. num_audios: int = 0 audio_values: None | np.ndarray = None audio_mask: None | np.ndarray = None audio_lengths: None | np.ndarray = None + audio_token_mask: None | np.ndarray = None def smart_resize( @@ -344,6 +358,36 @@ def floor_by_factor(number: int, factor: int) -> int: return nframes +def _read_video_opencv(video_path, idx) -> np.ndarray: + """Robust fallback video reader using OpenCV.""" + import cv2 + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise RuntimeError(f"OpenCV failed to open video file: {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # OpenCV reads in BGR format, convert to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + + if len(frames) == 0: + raise RuntimeError(f"OpenCV decoded zero frames from video: {video_path}") + + selected_frames = [] + for i in idx: + clamped_i = min(i, len(frames) - 1) + selected_frames.append(frames[clamped_i]) + + video = np.stack(selected_frames, axis=0) + video = np.transpose(video, (0, 3, 1, 2)) + return video + + def _read_video_decord(video_path, video_start=0.0, video_end=None) -> tuple[np.ndarray, float]: """Read video using decord.VideoReader (torch-free version) @@ -370,24 +414,45 @@ def _read_video_decord(video_path, video_start=0.0, video_end=None) -> tuple[np. } try: vr = decord.VideoReader(video_path) - except Exception as e: - raise RuntimeError(f"Failed to read video from {video_path}: {e}") from e - total_frames, video_fps = len(vr), vr.get_avg_fps() - start_frame, end_frame, total_frames = calculate_video_frame_range( - video_config, - total_frames, - video_fps, - ) - nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) - - # Use numpy linspace instead of torch.linspace - idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() - - video = vr.get_batch(idx).asnumpy() - # Convert from THWC to TCHW format using numpy - video = np.transpose(video, (0, 3, 1, 2)) + total_frames, video_fps = len(vr), vr.get_avg_fps() + start_frame, end_frame, total_frames = calculate_video_frame_range( + video_config, + total_frames, + video_fps, + ) + nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) + idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() + video = vr.get_batch(idx).asnumpy() + video = np.transpose(video, (0, 3, 1, 2)) + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + except Exception as decord_error: + logging.warning( + f"Decord failed to load/decode video {video_path} due to: {decord_error}. " + "Falling back to OpenCV video reader." + ) + try: + import cv2 + cap = cv2.VideoCapture(video_path) + video_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 100 + cap.release() + + start_frame, end_frame, total_frames = calculate_video_frame_range( + video_config, + total_frames, + video_fps, + ) + nframes = smart_nframes(video_config, total_frames=total_frames, video_fps=video_fps) + idx = np.linspace(start_frame, end_frame, nframes).round().astype(int).tolist() + + video = _read_video_opencv(video_path, idx) + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + except Exception as cv2_error: + raise RuntimeError( + f"Both Decord and OpenCV failed to decode video. " + f"Decord error: {decord_error}. OpenCV error: {cv2_error}" + ) from decord_error - sample_fps = nframes / max(total_frames, 1e-6) * video_fps return video, sample_fps @@ -554,6 +619,107 @@ def preprocess_mm_data_qwen3_omni_for_training(images, config): ) +def preprocess_mm_data_qwen3_omni_for_training_video(video_path, config): + """Preprocesses video (and audio) for Qwen3-Omni SFT training.""" + + import os + + if not os.path.exists(video_path): + raise FileNotFoundError( + f"Video file not found at local path: '{video_path}'. " + "Please make sure you have fully downloaded the dataset using the " + "download utility before running SFT training." + ) + + try: + video_array, _ = _read_video_decord(video_path) + video_processed, video_grid_thw = preprocess_video(video_array, config) + orig_t, orig_h, orig_w = video_grid_thw[0] + video_values = np.reshape( + video_processed, + ( + 1, + config.num_channels_for_vit, + config.temporal_patch_size_for_vit * orig_t, + config.patch_size_for_vit * orig_h, + config.patch_size_for_vit * orig_w, + ), + ) + + # Clip grid to maximum limits to prevent TPU compilation OOMs and ensure static shapes + clip_t = min(orig_t, VIDEO_MAX_GRID_T) + clip_h = min(orig_h, VIDEO_MAX_GRID_H) + clip_w = min(orig_w, VIDEO_MAX_GRID_W) + + # Crop video_values to the clipped dimensions + t_px = config.temporal_patch_size_for_vit * clip_t + h_px = config.patch_size_for_vit * clip_h + w_px = config.patch_size_for_vit * clip_w + video_values = video_values[:, :, :t_px, :h_px, :w_px] + + # Update video_grid_thw + video_grid_thw = np.array([[clip_t, clip_h, clip_w]], dtype=np.int32) + print( + f"[SFT_DEBUG] Preprocessing Video: Raw path: {video_path}, Raw size: {video_array.shape}. " + f"Original Grid (THW): [{orig_t}, {orig_h}, {orig_w}], Clipped Grid (THW): {video_grid_thw[0]}, " + f"Cropped Pixel Shape: {video_values.shape}" + ) + except Exception as e: + logging.warning( + "\n" + "="*80 + "\n" + f"[DATASET CORRUPTION WARNING] BOTH DECORD AND OPENCV FAILED TO DECODE VIDEO:\n" + f"Path: {video_path}\n" + f"Error: {e}\n" + "Substituting dummy zero-video to prevent SFT training crash!\n" + "Please check if this video file is completely corrupted or empty.\n" + + "="*80 + "\n" + ) + grid_t = 1 + grid_h = 16 + grid_w = 16 + video_grid_thw = np.array([[grid_t, grid_h, grid_w]], dtype=np.int32) + fallback_t = config.temporal_patch_size_for_vit * grid_t + fallback_h = config.patch_size_for_vit * grid_h + fallback_w = config.patch_size_for_vit * grid_w + video_values = np.zeros( + (1, config.num_channels_for_vit, fallback_t, fallback_h, fallback_w), + dtype=np.float32 + ) + + processor_outputs = Qwen3OmniPreprocessorOutput( + num_videos=1, + video_values=video_values, + video_grid_thw=video_grid_thw, + video_second_per_grid=np.asarray([config.temporal_patch_size_for_vit], dtype=np.float32), + ) + + use_audio_in_video = getattr(config, "use_audio_in_video", False) + if use_audio_in_video: + try: + mt_audio = mm_utils.load_audio(video_path, sample_rate=SAMPLE_RATE) + mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio) + processor_outputs.audio_values = mt_audio + processor_outputs.audio_mask = mt_audio_mask + audio_mask_sum = np.sum(mt_audio_mask, axis=-1) + audio_lengths = _get_feat_extract_output_lengths(audio_mask_sum) + processor_outputs.audio_lengths = np.array(audio_lengths, dtype=np.int32) + print( + f"[SFT_DEBUG] Preprocessing Audio: Audio shape: {mt_audio.shape}, Audio lengths (tokens): {processor_outputs.audio_lengths}" + ) + except Exception as e: + logging.warning(f"Audio extraction failed for {video_path}: {e}. Using dummy audio.") + dummy_audio = np.zeros((1, 128, 3000), dtype=np.float32) + dummy_mask = np.zeros((1, 3000), dtype=np.int32) + processor_outputs.audio_values = dummy_audio + processor_outputs.audio_mask = dummy_mask + processor_outputs.audio_lengths = np.array([0], dtype=np.int32) + print( + f"[SFT_DEBUG] Preprocessing Audio (fallback): Audio shape: {dummy_audio.shape}, Audio lengths (tokens): {processor_outputs.audio_lengths}" + ) + + return processor_outputs + + def preprocess_mm_data_qwen3_omni(config): """Placeholder for multimodal data preprocessing.""" processor_outputs = Qwen3OmniPreprocessorOutput() @@ -579,18 +745,40 @@ def preprocess_mm_data_qwen3_omni(config): if config.video_path: video_array, _ = _read_video_decord(config.video_path) video_processed, video_grid_thw = preprocess_video(video_array, config) + orig_t, orig_h, orig_w = video_grid_thw[0] video_values = np.reshape( video_processed, ( 1, config.num_channels_for_vit, - config.temporal_patch_size_for_vit * video_grid_thw[0, 0], - config.patch_size_for_vit * video_grid_thw[0, 1], - config.patch_size_for_vit * video_grid_thw[0, 2], + config.temporal_patch_size_for_vit * orig_t, + config.patch_size_for_vit * orig_h, + config.patch_size_for_vit * orig_w, ), ) + + # Clip grid to maximum limits to prevent TPU compilation OOMs and ensure static shapes + clip_t = min(orig_t, VIDEO_MAX_GRID_T) + clip_h = min(orig_h, VIDEO_MAX_GRID_H) + clip_w = min(orig_w, VIDEO_MAX_GRID_W) + + # Crop video_values to the clipped dimensions + t_px = config.temporal_patch_size_for_vit * clip_t + h_px = config.patch_size_for_vit * clip_h + w_px = config.patch_size_for_vit * clip_w + video_values = video_values[:, :, :t_px, :h_px, :w_px] + + # Update video_grid_thw + video_grid_thw = np.array([[clip_t, clip_h, clip_w]], dtype=np.int32) + processor_outputs.video_values = video_values processor_outputs.video_grid_thw = video_grid_thw + print( + f"[SFT_DEBUG] Preprocessing Video (inference): Raw path: {config.video_path}, Raw size: {video_array.shape}. " + f"Original Grid (THW): [{orig_t}, {orig_h}, {orig_w}], Clipped Grid (THW): {video_grid_thw[0]}, " + f"Cropped Pixel Shape: {video_values.shape}" + ) + processor_outputs.video_grid_thw = video_grid_thw processor_outputs.video_second_per_grid = np.asarray([config.temporal_patch_size_for_vit], dtype=np.float32) processor_outputs.num_videos = 1 # Only one video for now. @@ -757,6 +945,33 @@ def get_dummy_image_shape_for_init_qwen3_omni(batch_size): return image_shape +def get_dummy_video_shape_for_init_qwen3_omni(batch_size, config): + """Return the fixed padded shape for a video batch element for Qwen3-Omni. + + All video_values tensors are zero-padded to this shape in PadOrTrimToMaxLength + so that grain.Batch can stack them into a uniform batch. + Shape: (batch, C, max_T_pixels, max_H_pixels, max_W_pixels) + """ + tps = config.temporal_patch_size_for_vit + ps = config.patch_size_for_vit + return ( + batch_size, + mm_utils.NUM_IMAGE_CHANNELS, + VIDEO_MAX_GRID_T * tps, + VIDEO_MAX_GRID_H * ps, + VIDEO_MAX_GRID_W * ps, + ) + + +def get_dummy_audio_shape_for_init_qwen3_omni_sft(batch_size, config): + """Return the fixed padded shape for an audio batch element for Qwen3-Omni SFT. + + All audio_values tensors are zero-padded to this shape so that grain.Batch + can stack them. Shape: (batch, num_mel_bins, AUDIO_MAX_TIME) + """ + return (batch_size, config.num_mel_bins_for_audio, AUDIO_MAX_TIME) + + def get_dummy_audio_shape_for_init_qwen3_omni(config): """Return the shape of the dummy audio for Qwen3-Omni model's initialization.""" # Audio shape: (batch, num_mel_bins, audio_length) diff --git a/src/maxtext/multimodal/utils.py b/src/maxtext/multimodal/utils.py index 65b5670fc1..fa692c41ae 100644 --- a/src/maxtext/multimodal/utils.py +++ b/src/maxtext/multimodal/utils.py @@ -148,22 +148,36 @@ def merge_mm_embeddings( # Process Optional Token Masks flat_token_masks_processed = None if token_masks is not None: - # Handle the tiled case where token_masks batch dimension is (B * N) - if token_masks.shape[0] != batch_size: - if token_masks.shape[0] % batch_size != 0: - raise ValueError( - "Batch dimension of token_masks must be a multiple of the text" - f" batch size. Got {token_masks.shape[0]} and {batch_size}." - ) - # Reshape from (B * N, T) to (B, N * T) - flat_tile_masks = token_masks.reshape(batch_size, -1) + if multimodal_embeddings.ndim == 3: + # Already flat, no need to repeat (e.g. Qwen3-Omni video/audio) + flat_token_masks_processed = token_masks.reshape(batch_size, -1) else: - # This handles cases where token_masks is already (B, ...) - flat_tile_masks = token_masks.reshape(batch_size, -1) - - # Expand the tile-level mask to a token-level mask to match the embeddings. - # A mask of shape (B, N*T) becomes (B, N*T*K) by repeating each element K times. - flat_token_masks_processed = jnp.repeat(flat_tile_masks, repeats=num_toks_per_token, axis=1) + # Handle the tiled case where token_masks batch dimension is (B * N) + if token_masks.shape[0] != batch_size: + if token_masks.shape[0] % batch_size != 0: + raise ValueError( + "Batch dimension of token_masks must be a multiple of the text" + f" batch size. Got {token_masks.shape[0]} and {batch_size}." + ) + # Reshape from (B * N, T) to (B, N * T) + flat_tile_masks = token_masks.reshape(batch_size, -1) + else: + # This handles cases where token_masks is already (B, ...) + flat_tile_masks = token_masks.reshape(batch_size, -1) + + # Expand the tile-level mask to a token-level mask to match the embeddings. + # A mask of shape (B, N*T) becomes (B, N*T*K) by repeating each element K times. + flat_token_masks_processed = jnp.repeat(flat_tile_masks, repeats=num_toks_per_token, axis=1) + + if flat_token_masks_processed is not None: + valid_counts = jnp.sum(flat_token_masks_processed, axis=1) + jax.debug.print( + "[SFT_DEBUG] merge_mm_embeddings: Input shape: {x}, Token Mask shape: {m}, " + "Valid tokens per batch element: {counts}", + x=flat_multimodal_embeddings.shape, + m=flat_token_masks_processed.shape, + counts=valid_counts + ) # Vmap the inner merge function over the batch dimension return jax.vmap( @@ -767,25 +781,52 @@ def window_function( def load_audio(data_path: str, sample_rate: int = 16000) -> np.ndarray: - """Load audio from a file path. - - Args: - data_path (str): The path to the audio file or video file. - sample_rate (int): The target sample rate in Hz. Default is 16000. - - Returns: - np.ndarray: The loaded audio waveform. - - Raises: - FileNotFoundError: If the audio file does not exist. - RuntimeError: If the audio file cannot be loaded. - """ + """Load audio from a file path (supporting both audio and video files).""" if not os.path.isfile(data_path): - raise FileNotFoundError(f"Audio file not found at path {data_path}. Please specify a valid audio file path") - if librosa is None: - raise ImportError("librosa is required for audio processing but not installed.") - try: - audio = librosa.load(data_path, sr=sample_rate)[0] - return audio - except Exception as e: - raise RuntimeError(f"Failed to load audio from {data_path}: {e}") from e + raise FileNotFoundError(f"Audio file not found at path {data_path}.") + + import soundfile as sf + import subprocess + import tempfile + + is_video = data_path.lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".flv", ".webm")) + + if is_video: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: + temp_wav_path = temp_wav.name + + try: + cmd = [ + "ffmpeg", + "-y", + "-i", + data_path, + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + str(sample_rate), + "-ac", + "1", + temp_wav_path, + ] + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + audio, sr = sf.read(temp_wav_path) + assert sr == sample_rate, f"Sample rate mismatch: expected {sample_rate}, got {sr}" + return audio + except Exception as e: + raise RuntimeError(f"Failed to extract and load audio from video {data_path}: {e}") + finally: + if os.path.exists(temp_wav_path): + os.remove(temp_wav_path) + else: + try: + audio, sr = sf.read(data_path) + if sr != sample_rate: + if librosa is not None: + audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate) + else: + raise RuntimeError(f"Audio sample rate {sr} does not match target {sample_rate} and librosa is not installed.") + return audio + except Exception as e: + raise RuntimeError(f"Failed to load audio from {data_path}: {e}") diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..16e5a6cbc5 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -104,10 +104,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_train_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_train_on, :] else: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_eval_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_eval_on, :] mutable_collections = ["intermediates"] if config.mtp_num_layers > 0 and is_train: # The single model.apply call now triggers the entire chain if MTP is enabled: @@ -142,6 +144,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr decoder_segment_ids=data["inputs_segmentation"], encoder_images=data["images"] if config.use_multimodal else None, encoder_image_masks=data["image_masks"] if config.use_multimodal and "image_masks" in data else None, + encoder_videos=data["videos"] if config.use_multimodal and "videos" in data else None, + encoder_video_masks=data["video_masks"] if config.use_multimodal and "video_masks" in data else None, + encoder_audios=data["audios"] if "audios" in data else None, + encoder_audio_token_masks=data["audio_token_masks"] if "audio_token_masks" in data else None, enable_dropout=config.enable_dropout if is_train else False, rngs={"dropout": rng1, "params": aqt_rng}, mutable=mutable_collections, @@ -191,6 +197,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr decoder_segment_ids=data["inputs_segmentation"], encoder_images=data["images"] if config.use_multimodal else None, encoder_image_masks=data["image_masks"] if config.use_multimodal and "image_masks" in data else None, + encoder_videos=data["videos"] if config.use_multimodal and "videos" in data else None, + encoder_video_masks=data["video_masks"] if config.use_multimodal and "video_masks" in data else None, + encoder_audios=data["audios"] if "audios" in data else None, + encoder_audio_token_masks=data["audio_token_masks"] if "audio_token_masks" in data else None, enable_dropout=config.enable_dropout if is_train else False, decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 238758da92..0f6f6cd801 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -172,6 +172,28 @@ def get_shaped_batch(config, batch_sharding=None): ) shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32, sharding=batch_sharding) shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32, sharding=batch_sharding) + video_shape = mm_processor.get_dummy_video_shape_for_init( + config.model_name, batch_size=config.micro_batch_size_to_train_on, config=config + ) + if video_shape: + shaped_batch["videos"] = jax.ShapeDtypeStruct(video_shape, jnp.float32, sharding=batch_sharding) + + from maxtext.multimodal.processor_qwen3_omni import VIDEO_MAX_GRID_T, VIDEO_MAX_GRID_H, VIDEO_MAX_GRID_W + merge_size = getattr(config, "spatial_merge_size_for_vit", 2) + video_mask_len = VIDEO_MAX_GRID_T * (VIDEO_MAX_GRID_H // merge_size) * (VIDEO_MAX_GRID_W // merge_size) + video_mask_shape = (config.micro_batch_size_to_train_on, video_mask_len) + shaped_batch["video_masks"] = jax.ShapeDtypeStruct(video_mask_shape, jnp.int32, sharding=batch_sharding) + + audio_sft_shape = mm_processor.get_dummy_audio_shape_for_sft( + config.model_name, batch_size=config.micro_batch_size_to_train_on, config=config + ) + if audio_sft_shape: + shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_sft_shape, jnp.float32, sharding=batch_sharding) + + from maxtext.multimodal.processor_qwen3_omni import AUDIO_MAX_TIME, _get_feat_extract_output_lengths + max_audio_tokens = _get_feat_extract_output_lengths(AUDIO_MAX_TIME) + audio_mask_shape = (config.micro_batch_size_to_train_on, max_audio_tokens) + shaped_batch["audio_token_masks"] = jax.ShapeDtypeStruct(audio_mask_shape, jnp.int32, sharding=batch_sharding) if config.use_audio: audio_shape = mm_processor.get_dummy_audio_shape_for_init(config) shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32, sharding=batch_sharding) diff --git a/tools/data_generation/download_hf_multimodal_dataset.py b/tools/data_generation/download_hf_multimodal_dataset.py new file mode 100644 index 0000000000..cdc426e37c --- /dev/null +++ b/tools/data_generation/download_hf_multimodal_dataset.py @@ -0,0 +1,208 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Download a HuggingFace multimodal dataset, unzip video archives, +and generate local Parquet metadata files. +""" + +import argparse +import os +import re +import shutil +import tarfile +import zipfile +from huggingface_hub import hf_hub_download, list_repo_files +from datasets import load_dataset +import pyarrow.parquet as pq + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Download and prepare local multimodal dataset from HuggingFace Hub." + ) + parser.add_argument( + "--repo_id", + required=True, + help="HuggingFace dataset repository ID (e.g. lmms-lab/LLaVA-Video-178K)", + ) + parser.add_argument( + "--subset", + required=True, + help="Subset directory inside repository (e.g. 0_30_s_academic_v0_1)", + ) + parser.add_argument( + "--dataset_dir", + required=True, + help="Target local directory to write both videos and parquets.", + ) + parser.add_argument( + "--split", + default="all", + choices=["all", "caption", "open_ended", "multi_choice"], + help="Specific split to prepare (default: all).", + ) + parser.add_argument( + "--token", + default=None, + help="HuggingFace access token for gated datasets.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Ensure the target local directory exists + os.makedirs(args.dataset_dir, exist_ok=True) + + print(f"Connecting to HuggingFace Hub to scan '{args.repo_id}' under subset '{args.subset}'...") + try: + all_files = list_repo_files(repo_id=args.repo_id, repo_type="dataset", token=args.token) + except Exception as e: + print(f"Error accessing HuggingFace repository: {e}") + return + + subset_prefix = args.subset.strip("/") + "/" + target_files = [f for f in all_files if f.startswith(subset_prefix)] + + if not target_files: + print(f"Error: No files found in HuggingFace repo matching subset prefix '{subset_prefix}'") + return + + # 1. Separate JSON metadata files based on split choice + split_patterns = { + "caption": r".*cap.*\.json", + "open_ended": r".*oe.*\.json", + "multi_choice": r".*mc.*\.json", + } + + json_files = [] + if args.split == "all": + json_files = [f for f in target_files if f.endswith(".json")] + else: + pattern = split_patterns[args.split] + json_files = [f for f in target_files if f.endswith(".json") and re.match(pattern, os.path.basename(f))] + + if not json_files: + print(f"Error: No metadata JSON files found matching split choice '{args.split}'") + return + + # 2. Identify video archive files + tar_files = [f for f in target_files if f.endswith(".tar.gz") or f.endswith(".tar") or f.endswith(".zip")] + + print("\n" + "="*80) + print(f"DATASET PREPARATION PLAN") + print(f"HuggingFace Repo: {args.repo_id}") + print(f"Subset / Split: {args.subset} / {args.split}") + print(f"Local Directory: {args.dataset_dir}") + print(f"Metadata JSONs: {len(json_files)}") + print(f"Video Archives: {len(tar_files)}") + print("="*80 + "\n") + + # 3. Download and extract video archives + staging_dir = os.path.join(args.dataset_dir, ".staging") + os.makedirs(staging_dir, exist_ok=True) + + downloaded_archives = [] + for i, f in enumerate(tar_files): + filename = os.path.basename(f) + print(f"[{i+1}/{len(tar_files)}] Downloading video archive: {filename} ...") + try: + local_path = hf_hub_download( + repo_id=args.repo_id, + filename=f, + repo_type="dataset", + local_dir=staging_dir, + token=args.token + ) + downloaded_archives.append(local_path) + except Exception as e: + print(f"Failed to download archive {filename}: {e}") + return + + for i, archive_path in enumerate(downloaded_archives): + print(f"[{i+1}/{len(downloaded_archives)}] Extracting video archive: {os.path.basename(archive_path)} ...") + try: + if archive_path.endswith(".tar.gz") or archive_path.endswith(".tar"): + with tarfile.open(archive_path, "r:gz" if archive_path.endswith(".tar.gz") else "r:") as tar: + tar.extractall(path=args.dataset_dir) + elif archive_path.endswith(".zip"): + with zipfile.ZipFile(archive_path, "r") as zip_ref: + zip_ref.extractall(path=args.dataset_dir) + except Exception as e: + print(f"Failed to extract archive {archive_path}: {e}") + return + + # Clean up staging directory + shutil.rmtree(staging_dir) + print("Video archives extracted successfully. Staging directory cleaned.") + + # 4. Download and convert JSON files to local Parquet files + local_json_paths = [] + for i, f in enumerate(json_files): + filename = os.path.basename(f) + print(f"[{i+1}/{len(json_files)}] Downloading metadata JSON: {filename} ...") + try: + local_path = hf_hub_download( + repo_id=args.repo_id, + filename=f, + repo_type="dataset", + local_dir=args.dataset_dir, + token=args.token + ) + local_json_paths.append(local_path) + except Exception as e: + print(f"Failed to download metadata JSON {filename}: {e}") + return + + print("\nConverting JSON files to local Parquet format...") + try: + ds = load_dataset("json", data_files=local_json_paths, split="train") + table = ds.data.table + + # Target filename indicates subset/split configurations + parquet_filename = f"llava-video-178k-{args.split}-00000-of-00001.parquet" + output_parquet_path = os.path.join(args.dataset_dir, parquet_filename) + pq.write_table(table, output_parquet_path, compression="zstd") + + print(f"Success! Local parquet file generated at: {output_parquet_path}") + except Exception as e: + print(f"Error during JSON-to-Parquet conversion: {e}") + return + finally: + # Always clean up temporary JSON files + for p in local_json_paths: + if os.path.exists(p): + os.remove(p) + + # Clean up empty directories created for JSON files + for p in local_json_paths: + dirname = os.path.dirname(p) + target = os.path.abspath(args.dataset_dir) + while dirname and os.path.abspath(dirname) != target: + try: + if not os.listdir(dirname): + os.rmdir(dirname) + else: + break + except OSError: + break + dirname = os.path.dirname(dirname) + + print(f"\nAll operations completed successfully! Dataset is ready locally at: {args.dataset_dir}\n") + + +if __name__ == "__main__": + main()