diff --git a/examples/configs/sft_avlm.yaml b/examples/configs/sft_avlm.yaml new file mode 100644 index 0000000000..85e7d25414 --- /dev/null +++ b/examples/configs/sft_avlm.yaml @@ -0,0 +1,29 @@ +defaults: + - sft_vlm_3B.yaml + +sft: + val_batches: 2 + val_global_batch_size: 8 + +policy: + max_total_sequence_length: 32768 + train_global_batch_size: 8 + dtensor_cfg: + tensor_parallel_size: 1 + dynamic_batching: + enabled: true + tokenizer: + video: + num_frames: 16 + +data: + # dataset + train: + dataset_name: daily-omni + split: train + split_validation_size: 0.05 # use 5% of the training data as validation data + seed: 42 # seed for train/validation split when split_validation_size > 0 + validation: null + # default settings for all datasets + default: + prompt_file: null diff --git a/examples/run_sft.py b/examples/run_sft.py index 45fb43036f..4e80414f8d 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -66,6 +66,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): print("\n▶ Setting up data...") # setup train dataset task_data_processors = {} + task_data_preprocessors = {} data_list = [] if isinstance(data_config["train"], dict): @@ -85,6 +86,8 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): add_generation_prompt=data_config["add_generation_prompt"], ) task_data_processors[data.task_name] = (data.task_spec, data_processor) + if hasattr(data, "preprocessor") and data.preprocessor is not None: + task_data_preprocessors[data.task_name] = data.preprocessor merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( @@ -92,12 +95,14 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): tokenizer, None, task_data_processors, + task_data_preprocessors=task_data_preprocessors, max_seq_length=data_config["max_input_seq_length"], ) print(f" ✓ Training dataset loaded with {len(dataset)} samples.") # setup validation dataset val_task_data_processors = {} + val_task_data_preprocessors = {} val_data_list = [] # validation dataset from train dataset (when train dataset's split_validation_size > 0) @@ -107,6 +112,10 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): # bind task_name to task_data_processors task_name = data.task_name val_task_data_processors[task_name] = task_data_processors[task_name] + if task_name in task_data_preprocessors: + val_task_data_preprocessors[task_name] = task_data_preprocessors[ + task_name + ] # validation dataset from config if "validation" in data_config and data_config["validation"] is not None: @@ -130,6 +139,8 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): val_data.task_spec, val_data_processor, ) + if hasattr(val_data, "preprocessor") and val_data.preprocessor is not None: + val_task_data_preprocessors[val_data.task_name] = val_data.preprocessor val_dataset = None if len(val_data_list) > 0: @@ -139,6 +150,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): tokenizer, None, val_task_data_processors, + task_data_preprocessors=val_task_data_preprocessors, max_seq_length=data_config["max_input_seq_length"], ) print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index cc99033aba..8e632ca5ee 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -320,6 +320,39 @@ def get_tokenizer( processor.bos_token_id = tokenizer.bos_token_id # copy name_or_path from tokenizer to processor for logging processor.name_or_path = tokenizer.name_or_path + if hasattr(processor, "feature_extractor") and "audio" in tokenizer_config: + if ( + "sampling_rate" in tokenizer_config["audio"] + and tokenizer_config["audio"]["sampling_rate"] + != processor.feature_extractor.sampling_rate + ): + new_sampling_rate = tokenizer_config["audio"]["sampling_rate"] + warnings.warn( + f"Overriding audio sampling rate from {processor.feature_extractor.sampling_rate} to {new_sampling_rate}" + ) + processor.feature_extractor.sampling_rate = new_sampling_rate + if hasattr(processor, "video_processor") and "video" in tokenizer_config: + if ( + "fps" in tokenizer_config["video"] + and tokenizer_config["video"]["fps"] != processor.video_processor.fps + ): + # override the video loading fps + new_fps = tokenizer_config["video"]["fps"] + warnings.warn( + f"Overriding video fps from {processor.video_processor.fps} to {new_fps}" + ) + processor.video_processor.fps = new_fps + # fps and num_frames cannot co-exist, but let it crash later + if ( + "num_frames" in tokenizer_config["video"] + and tokenizer_config["video"]["num_frames"] + != processor.video_processor.num_frames + ): + new_num_frames = tokenizer_config["video"]["num_frames"] + warnings.warn( + f"Overriding video num_frames from {processor.video_processor.num_frames} to {new_num_frames}" + ) + processor.video_processor.num_frames = new_num_frames return tokenizer if processor is None else processor diff --git a/nemo_rl/data/datasets/processed_dataset.py b/nemo_rl/data/datasets/processed_dataset.py index add422e199..1971e7a12f 100644 --- a/nemo_rl/data/datasets/processed_dataset.py +++ b/nemo_rl/data/datasets/processed_dataset.py @@ -21,6 +21,7 @@ from nemo_rl.data.datasets.utils import assert_no_double_bos from nemo_rl.data.interfaces import ( DatumSpec, + TaskDataPreProcessFnCallable, TaskDataProcessFnCallable, TaskDataSpec, ) @@ -52,6 +53,9 @@ def __init__( dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] | TaskDataProcessFnCallable ), + task_data_preprocessors: Optional[ + Union[dict[str, TaskDataPreProcessFnCallable], TaskDataPreProcessFnCallable] + ] = None, max_seq_length: Optional[int] = None, ): self.dataset = dataset @@ -59,6 +63,7 @@ def __init__( # TODO @yukih: will be removed once eval datasets are adapted self.default_task_data_spec = default_task_data_spec self.task_data_processors = task_data_processors + self.task_data_preprocessors = task_data_preprocessors self.max_seq_length = max_seq_length self._bos_checked = False @@ -95,6 +100,20 @@ def __getitem__(self, idx: int) -> DatumSpec: """Return a single prompt.""" entry = self.dataset[idx] + # preprocessing + task_data_preprocessor = None + if self.task_data_preprocessors: + if isinstance(self.task_data_preprocessors, dict): + task_name = entry["task_name"] + if task_name in self.task_data_preprocessors: + task_data_preprocessor = self.task_data_preprocessors[task_name] + else: + task_data_preprocessor = self.task_data_preprocessors + + if task_data_preprocessor is not None: + entry = task_data_preprocessor(entry) + + # processing if isinstance(self.task_data_processors, dict): task_name = entry["task_name"] diff --git a/nemo_rl/data/datasets/raw_dataset.py b/nemo_rl/data/datasets/raw_dataset.py index decd722736..f425cef8d3 100644 --- a/nemo_rl/data/datasets/raw_dataset.py +++ b/nemo_rl/data/datasets/raw_dataset.py @@ -15,7 +15,11 @@ from datasets import Dataset from nemo_rl.data import PreferenceDatasetConfig, ResponseDatasetConfig -from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec +from nemo_rl.data.interfaces import ( + TaskDataPreProcessFnCallable, + TaskDataProcessFnCallable, + TaskDataSpec, +) from nemo_rl.data.processors import PROCESSOR_REGISTRY @@ -27,6 +31,7 @@ class RawDataset: val_dataset: Dataset | None processor: TaskDataProcessFnCallable task_spec: TaskDataSpec + preprocessor: TaskDataPreProcessFnCallable | None = None def split_train_validation(self, test_size: float, seed: int): if test_size > 0: diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 961b7b9ba8..eb48bb5204 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -15,11 +15,15 @@ from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset +from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset from nemo_rl.data.datasets.response_datasets.dapo_math import ( DAPOMath17KDataset, DAPOMathAIME2024Dataset, ) from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset +from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import ( + GeneralConversationsJsonlDataset, +) from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset @@ -39,6 +43,8 @@ # built-in datasets "AIME2024": AIME2024Dataset, "clevr-cogent": CLEVRCoGenTDataset, + "daily-omni": DailyOmniDataset, + "general-conversation-jsonl": GeneralConversationsJsonlDataset, "DAPOMath17K": DAPOMath17KDataset, "DAPOMathAIME2024": DAPOMathAIME2024Dataset, "DeepScaler": DeepScalerDataset, @@ -84,6 +90,8 @@ def load_response_dataset(data_config: ResponseDatasetConfig): __all__ = [ "AIME2024Dataset", "CLEVRCoGenTDataset", + "DailyOmniDataset", + "GeneralConversationsJsonlDataset", "DAPOMath17KDataset", "DAPOMathAIME2024Dataset", "DeepScalerDataset", diff --git a/nemo_rl/data/datasets/response_datasets/daily_omni.py b/nemo_rl/data/datasets/response_datasets/daily_omni.py new file mode 100644 index 0000000000..b2307e337f --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/daily_omni.py @@ -0,0 +1,140 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any + +from huggingface_hub import snapshot_download + +from nemo_rl.data.datasets.raw_dataset import RawDataset +from nemo_rl.data.datasets.utils import ( + get_huggingface_cache_path, + load_dataset_from_path, +) + + +class DailyOmniDataset(RawDataset): + """Simple wrapper around the Daily-Omni dataset. + + Args: + split: Split name for the dataset, default is "train" + """ + + task_name = "daily-omni" + + def __init__( + self, + split: str = "train", + split_validation_size: float = 0, + seed: int = 42, + **kwargs, + ): + # train, valA, and valB are supported splits. + SPLIT_TO_HF_NAME = { + "train": "liarliar/Daily-Omni", + } + if split not in SPLIT_TO_HF_NAME: + raise ValueError(f"Invalid split: {split}. Please use 'train'.") + + self.hf_cache_dir = get_huggingface_cache_path(SPLIT_TO_HF_NAME[split]) + if not self.hf_cache_dir: + # download the dataset + self.hf_cache_dir = snapshot_download( + repo_id=SPLIT_TO_HF_NAME[split], repo_type="dataset" + ) + if not self.hf_cache_dir: + raise ValueError("Cannot download DailyOmniDataset.") + + json_file = os.path.join(self.hf_cache_dir, "qa.json") + + if not os.path.isfile(json_file): + raise ValueError(f"{json_file} cannot be found.") + + files_folder = os.path.join(self.hf_cache_dir, "Videos") + if not os.path.isdir(files_folder): + # prepare the dataset + # TODO: move untar, unzip func to utils? + import tarfile + + archive_filename = os.path.join(self.hf_cache_dir, "Videos.tar") + if not os.path.isfile(archive_filename): + raise ValueError(f"{archive_filename} cannot be found.") + try: + with tarfile.open(archive_filename, "r:*") as tar: + # Extract all contents to the specified path + tar.extractall(path=self.hf_cache_dir) + if os.path.isdir(files_folder): + print( + f"Successfully extracted '{archive_filename}' to '{files_folder}'" + ) + else: + raise ValueError( + f"Cannot find the extracted folder {files_folder}. Extraction failed." + ) + except tarfile.ReadError: + raise tarfile.ReadError( + "Error: Could not read the tar file. It might be corrupted or not a tar file." + ) + except Exception as e: + raise Exception(f"An unexpected error occurred: {e}") + + self.dataset = load_dataset_from_path(json_file) + + # format - disable features to avoid schema conflicts + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) + ) + + self.preprocessor = self.format_data + + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) + + @classmethod + def get_prompt(cls, data: dict[str, Any]) -> str: + # WARNING: model could have preference of a different prompt + prompt = data["Question"] + "\n" + "\n".join(data["Choice"]) + candidate_answers = [chr(ord("A") + idx) for idx in range(len(data["Choice"]))] + candidate_answers_all_but_last = ",".join(candidate_answers[:-1]) + prompt += ( + "\n" + + "Your replies must contain only a single letter " + + f"(either {candidate_answers_all_but_last} or {candidate_answers[-1]})." + ) + return prompt + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + user_content = [ + { + "type": "video", + "video": os.path.join( + self.hf_cache_dir, + "Videos", + data["video_id"], + data["video_id"] + "_video.mp4", + ), + }, + { + "type": "text", + "text": self.get_prompt(data), + }, + ] + return { + "messages": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": data["Answer"]}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/response_datasets/general_conversations_dataset.py b/nemo_rl/data/datasets/response_datasets/general_conversations_dataset.py new file mode 100644 index 0000000000..10651c8490 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/general_conversations_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import warnings +from collections import defaultdict +from functools import partial +from typing import Any, Callable, Dict, Optional + +from nemo_rl.data import multimodal_utils +from nemo_rl.data.datasets.raw_dataset import RawDataset +from nemo_rl.data.datasets.utils import load_dataset_from_path + +# map the senders from the sample to the allowed ones +conversation_sender_mapping_sample_to_allowed = { + "human": "user", + "gpt": "assistant", + "agent": "assistant", +} + + +# convert +def convert_metadata(metadata: Dict[str, Any]): + data = metadata.copy() + + for tag in multimodal_utils.MEDIA_TAGS_TO_ALLOWED: + if tag in data: + tag_mapped = multimodal_utils.MEDIA_TAGS_TO_ALLOWED[tag] + if tag_mapped not in data: + data[tag_mapped] = data[tag] + del data[tag] + else: + warnings.warn( + f"Trying to map {tag} to {tag_mapped}, but {tag_mapped} already exists in the raw data. Mapping is not carried out." + ) + + for idx, message in enumerate(data["conversations"]): + msg_str = message["value"] + for tag in multimodal_utils.MEDIA_TAGS_TO_ALLOWED: + tag_str = "<" + tag + ">" + if tag_str in msg_str: + tag_str_mapped = multimodal_utils.MEDIA_TAGS[ + multimodal_utils.MEDIA_TAGS_TO_ALLOWED[tag] + ] + msg_str = msg_str.replace(tag_str, tag_str_mapped) + message["value"] = msg_str + data["conversations"][idx] = message + + return data + + +def conversation_process_message( + metadata: Dict[str, Any], + message: Dict[str, str], + media_index: dict, + raw: Optional[Dict[str, Any]] = None, + allow_empty_text: bool = False, + check_if_media_file_exist: bool = True, + tried_default_extensions: Optional[set] = None, + process_message_fragment: Callable = lambda tag, fragment: [{tag: fragment}], +) -> list[Dict[str, Any]]: + """Convert one conversation message from a string to a list of dictionaries representing media or text. + + Args: + raw: dictionary with all webdataset compliant keys of a sample. + Emtpy for jsonl dataset, non-empty otherwise. + metadata: + """ + if raw is None: + raw = {} + if tried_default_extensions is None: + tried_default_extensions = set() + fragments = [] + parts = re.split(multimodal_utils.MEDIA_TAG_PATTERN, message["value"]) + + # Convert the parts to message fragments + empty_text = True + for i, part in enumerate(parts): + if part in multimodal_utils.MEDIA_TAGS.values(): + # process multimodal tags + tag = multimodal_utils.MEDIA_TAGS_REVERSED[part] + if tag not in metadata: + raise ValueError( + f"{part} is found in the message, but no corresponding {tag} key can be found in {metadata}" + ) + if not isinstance(metadata[tag], list): + metadata[tag] = [metadata[tag]] + # try to extract the media object from the shard + basename = os.path.basename(metadata[tag][media_index[tag]]) + ext = basename.split(".", 1)[1] if "." in basename else "" + if ( + raw + and ext not in raw + and ext not in tried_default_extensions + and tag in multimodal_utils.DEFAULT_MEDIA_EXTENSIONS + ): + # try the default extension + for ext in multimodal_utils.DEFAULT_MEDIA_EXTENSIONS[tag]: + if ext in raw: + tried_default_extensions.add(ext) + break + media_file = None + if ext in raw: + media_file = ext + elif isinstance(metadata[tag][media_index[tag]], str) and os.path.isfile( + metadata[tag][media_index[tag]] + ): + # if cannot get it from the shard files, try to find the local file + media_file = metadata[tag][media_index[tag]] + elif check_if_media_file_exist: + sample_to_print = raw if raw else metadata + raise ValueError( + f"Cannot find the media file {metadata[tag][media_index[tag]]} from {sample_to_print} or locally." + ) + else: + media_file = metadata[tag][media_index[tag]] + media_index[tag] += 1 + fragments += process_message_fragment(tag, media_file) + else: + # process text + if part.strip(): + fragments += process_message_fragment("text", part) + empty_text = False + + if not allow_empty_text and empty_text: + fragments += process_message_fragment("text", " ") + + return fragments + + +class GeneralConversationsJsonlDataset(RawDataset): + """Loads general conversation datasets that have the json (manifest) files and media files in separate files (jsonl datasets). + + Each sample can be single/multi-turn conversations with multiple modalities. + Each modality can have one or more number of media objects. + There is no requirement of where the media tag (e.g. '') should appear in the conversations. + + The structure of the jsonl files could be like this. + + Example media filenames:: + + sample_000001.2345ew.flac + sample_000001.35tags.mp4 + sample_000001.as23ds.jpg + sample_000001.gd1dtg.wav + sample_000001.gds233.jpg + sample_000002.asf234.wav + ... + + Example JSON structure:: + + { + "sound": ["sample_000001.2345ew.flac", "sample_000001.gd1dtg.wav"], + "video": "sample_000001.35tags.mp4", + "image": ["sample_000001.as23ds.jpg", "sample_000001.gds233.jpg"], + "conversations": [ + { + "from": "user", + "value": "" + }, + { + "from": "assistant", + "value": "Automatic speech recognition is a technology that allows computers to recognize and transcribe spoken language. In the NeMo Framework, ASR is used for tasks such as speech-to-text and voice recognition." + }, + { + "from": "user", + "value": "Describe what is NeMo based on the tutorial video: