Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/configs/sft_avlm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
defaults:
- sft_vlm_3B.yaml

sft:
val_batches: 2
val_global_batch_size: 8

policy:
max_total_sequence_length: 32768
train_global_batch_size: 8
dtensor_cfg:
tensor_parallel_size: 1
dynamic_batching:
enabled: true
tokenizer:
video:
num_frames: 16

data:
# dataset
train:
dataset_name: daily-omni
split: train
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
validation: null
# default settings for all datasets
default:
prompt_file: null
12 changes: 12 additions & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
print("\n▶ Setting up data...")
# setup train dataset
task_data_processors = {}
task_data_preprocessors = {}
data_list = []

if isinstance(data_config["train"], dict):
Expand All @@ -85,19 +86,23 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
add_generation_prompt=data_config["add_generation_prompt"],
)
task_data_processors[data.task_name] = (data.task_spec, data_processor)
if hasattr(data, "preprocessor") and data.preprocessor is not None:
task_data_preprocessors[data.task_name] = data.preprocessor

merged_data = concatenate_datasets([data.dataset for data in data_list])
dataset = AllTaskProcessedDataset(
merged_data,
tokenizer,
None,
task_data_processors,
task_data_preprocessors=task_data_preprocessors,
max_seq_length=data_config["max_input_seq_length"],
)
print(f" ✓ Training dataset loaded with {len(dataset)} samples.")

# setup validation dataset
val_task_data_processors = {}
val_task_data_preprocessors = {}
val_data_list = []

# validation dataset from train dataset (when train dataset's split_validation_size > 0)
Expand All @@ -107,6 +112,10 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
# bind task_name to task_data_processors
task_name = data.task_name
val_task_data_processors[task_name] = task_data_processors[task_name]
if task_name in task_data_preprocessors:
val_task_data_preprocessors[task_name] = task_data_preprocessors[
task_name
]

# validation dataset from config
if "validation" in data_config and data_config["validation"] is not None:
Expand All @@ -130,6 +139,8 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
val_data.task_spec,
val_data_processor,
)
if hasattr(val_data, "preprocessor") and val_data.preprocessor is not None:
val_task_data_preprocessors[val_data.task_name] = val_data.preprocessor

val_dataset = None
if len(val_data_list) > 0:
Expand All @@ -139,6 +150,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
tokenizer,
None,
val_task_data_processors,
task_data_preprocessors=val_task_data_preprocessors,
max_seq_length=data_config["max_input_seq_length"],
)
print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.")
Expand Down
33 changes: 33 additions & 0 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,39 @@ def get_tokenizer(
processor.bos_token_id = tokenizer.bos_token_id
# copy name_or_path from tokenizer to processor for logging
processor.name_or_path = tokenizer.name_or_path
if hasattr(processor, "feature_extractor") and "audio" in tokenizer_config:
if (
"sampling_rate" in tokenizer_config["audio"]
and tokenizer_config["audio"]["sampling_rate"]
!= processor.feature_extractor.sampling_rate
):
new_sampling_rate = tokenizer_config["audio"]["sampling_rate"]
warnings.warn(
f"Overriding audio sampling rate from {processor.feature_extractor.sampling_rate} to {new_sampling_rate}"
)
processor.feature_extractor.sampling_rate = new_sampling_rate
if hasattr(processor, "video_processor") and "video" in tokenizer_config:
if (
"fps" in tokenizer_config["video"]
and tokenizer_config["video"]["fps"] != processor.video_processor.fps
):
# override the video loading fps
new_fps = tokenizer_config["video"]["fps"]
warnings.warn(
f"Overriding video fps from {processor.video_processor.fps} to {new_fps}"
)
processor.video_processor.fps = new_fps
# fps and num_frames cannot co-exist, but let it crash later
if (
"num_frames" in tokenizer_config["video"]
and tokenizer_config["video"]["num_frames"]
!= processor.video_processor.num_frames
):
new_num_frames = tokenizer_config["video"]["num_frames"]
warnings.warn(
f"Overriding video num_frames from {processor.video_processor.num_frames} to {new_num_frames}"
)
processor.video_processor.num_frames = new_num_frames

return tokenizer if processor is None else processor

Expand Down
19 changes: 19 additions & 0 deletions nemo_rl/data/datasets/processed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo_rl.data.datasets.utils import assert_no_double_bos
from nemo_rl.data.interfaces import (
DatumSpec,
TaskDataPreProcessFnCallable,
TaskDataProcessFnCallable,
TaskDataSpec,
)
Expand Down Expand Up @@ -52,13 +53,17 @@ def __init__(
dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]]
| TaskDataProcessFnCallable
),
task_data_preprocessors: Optional[
Union[dict[str, TaskDataPreProcessFnCallable], TaskDataPreProcessFnCallable]
] = None,
max_seq_length: Optional[int] = None,
):
self.dataset = dataset
self.tokenizer = tokenizer
# TODO @yukih: will be removed once eval datasets are adapted
self.default_task_data_spec = default_task_data_spec
self.task_data_processors = task_data_processors
self.task_data_preprocessors = task_data_preprocessors
self.max_seq_length = max_seq_length
self._bos_checked = False

Expand Down Expand Up @@ -95,6 +100,20 @@ def __getitem__(self, idx: int) -> DatumSpec:
"""Return a single prompt."""
entry = self.dataset[idx]

# preprocessing
task_data_preprocessor = None
if self.task_data_preprocessors:
if isinstance(self.task_data_preprocessors, dict):
task_name = entry["task_name"]
if task_name in self.task_data_preprocessors:
task_data_preprocessor = self.task_data_preprocessors[task_name]
else:
task_data_preprocessor = self.task_data_preprocessors

if task_data_preprocessor is not None:
entry = task_data_preprocessor(entry)

# processing
if isinstance(self.task_data_processors, dict):
task_name = entry["task_name"]

Expand Down
7 changes: 6 additions & 1 deletion nemo_rl/data/datasets/raw_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from datasets import Dataset

from nemo_rl.data import PreferenceDatasetConfig, ResponseDatasetConfig
from nemo_rl.data.interfaces import TaskDataProcessFnCallable, TaskDataSpec
from nemo_rl.data.interfaces import (
TaskDataPreProcessFnCallable,
TaskDataProcessFnCallable,
TaskDataSpec,
)
from nemo_rl.data.processors import PROCESSOR_REGISTRY


Expand All @@ -27,6 +31,7 @@ class RawDataset:
val_dataset: Dataset | None
processor: TaskDataProcessFnCallable
task_spec: TaskDataSpec
preprocessor: TaskDataPreProcessFnCallable | None = None

def split_train_validation(self, test_size: float, seed: int):
if test_size > 0:
Expand Down
8 changes: 8 additions & 0 deletions nemo_rl/data/datasets/response_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from nemo_rl.data import ResponseDatasetConfig
from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset
from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset
from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset
from nemo_rl.data.datasets.response_datasets.dapo_math import (
DAPOMath17KDataset,
DAPOMathAIME2024Dataset,
)
from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset
from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import (
GeneralConversationsJsonlDataset,
)
from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset
from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset
from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset
Expand All @@ -39,6 +43,8 @@
# built-in datasets
"AIME2024": AIME2024Dataset,
"clevr-cogent": CLEVRCoGenTDataset,
"daily-omni": DailyOmniDataset,
"general-conversation-jsonl": GeneralConversationsJsonlDataset,
"DAPOMath17K": DAPOMath17KDataset,
"DAPOMathAIME2024": DAPOMathAIME2024Dataset,
"DeepScaler": DeepScalerDataset,
Expand Down Expand Up @@ -84,6 +90,8 @@ def load_response_dataset(data_config: ResponseDatasetConfig):
__all__ = [
"AIME2024Dataset",
"CLEVRCoGenTDataset",
"DailyOmniDataset",
"GeneralConversationsJsonlDataset",
"DAPOMath17KDataset",
"DAPOMathAIME2024Dataset",
"DeepScalerDataset",
Expand Down
140 changes: 140 additions & 0 deletions nemo_rl/data/datasets/response_datasets/daily_omni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from huggingface_hub import snapshot_download

from nemo_rl.data.datasets.raw_dataset import RawDataset
from nemo_rl.data.datasets.utils import (
get_huggingface_cache_path,
load_dataset_from_path,
)


class DailyOmniDataset(RawDataset):
"""Simple wrapper around the Daily-Omni dataset.

Args:
split: Split name for the dataset, default is "train"
"""

task_name = "daily-omni"

def __init__(
self,
split: str = "train",
split_validation_size: float = 0,
seed: int = 42,
**kwargs,
):
# train, valA, and valB are supported splits.
SPLIT_TO_HF_NAME = {
"train": "liarliar/Daily-Omni",
}
if split not in SPLIT_TO_HF_NAME:
raise ValueError(f"Invalid split: {split}. Please use 'train'.")

self.hf_cache_dir = get_huggingface_cache_path(SPLIT_TO_HF_NAME[split])
if not self.hf_cache_dir:
# download the dataset
self.hf_cache_dir = snapshot_download(
repo_id=SPLIT_TO_HF_NAME[split], repo_type="dataset"
)
if not self.hf_cache_dir:
raise ValueError("Cannot download DailyOmniDataset.")

json_file = os.path.join(self.hf_cache_dir, "qa.json")

if not os.path.isfile(json_file):
raise ValueError(f"{json_file} cannot be found.")

files_folder = os.path.join(self.hf_cache_dir, "Videos")
if not os.path.isdir(files_folder):
# prepare the dataset
# TODO: move untar, unzip func to utils?
import tarfile

archive_filename = os.path.join(self.hf_cache_dir, "Videos.tar")
if not os.path.isfile(archive_filename):
raise ValueError(f"{archive_filename} cannot be found.")
try:
with tarfile.open(archive_filename, "r:*") as tar:
# Extract all contents to the specified path
tar.extractall(path=self.hf_cache_dir)
if os.path.isdir(files_folder):
print(
f"Successfully extracted '{archive_filename}' to '{files_folder}'"
)
else:
raise ValueError(
f"Cannot find the extracted folder {files_folder}. Extraction failed."
)
except tarfile.ReadError:
raise tarfile.ReadError(
"Error: Could not read the tar file. It might be corrupted or not a tar file."
)
except Exception as e:
raise Exception(f"An unexpected error occurred: {e}")

self.dataset = load_dataset_from_path(json_file)

# format - disable features to avoid schema conflicts
self.dataset = self.dataset.add_column(
"task_name", [self.task_name] * len(self.dataset)
)

self.preprocessor = self.format_data

# `self.val_dataset` is used (not None) only when current dataset is used for both training and validation
self.val_dataset = None
self.split_train_validation(split_validation_size, seed)

@classmethod
def get_prompt(cls, data: dict[str, Any]) -> str:
# WARNING: model could have preference of a different prompt
prompt = data["Question"] + "\n" + "\n".join(data["Choice"])
candidate_answers = [chr(ord("A") + idx) for idx in range(len(data["Choice"]))]
candidate_answers_all_but_last = ",".join(candidate_answers[:-1])
prompt += (
"\n"
+ "Your replies must contain only a single letter "
+ f"(either {candidate_answers_all_but_last} or {candidate_answers[-1]})."
)
return prompt

def format_data(self, data: dict[str, Any]) -> dict[str, Any]:
user_content = [
{
"type": "video",
"video": os.path.join(
self.hf_cache_dir,
"Videos",
data["video_id"],
data["video_id"] + "_video.mp4",
),
},
{
"type": "text",
"text": self.get_prompt(data),
},
]
return {
"messages": [
{"role": "user", "content": user_content},
{"role": "assistant", "content": data["Answer"]},
],
"task_name": self.task_name,
}
Loading
Loading