Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/input_pipeline/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def parse_and_keep_features(dataset, config, data_columns, tokenize):
"""Parse arrayrecord features or keep specified columns for other formats."""
if config.grain_file_type in ("arrayrecord", "tfrecord"):
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
else:
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
return dataset


Expand Down
84 changes: 78 additions & 6 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,34 @@ def _process_string(string_tensor):
return features


def validate_tfds_data_types(dataset, data_keys, tokenize):
"""Validate that TFDS dataset types match the tokenization configuration."""
import tensorflow as tf # pylint: disable=import-outside-toplevel

spec = dataset.element_spec
for k in data_keys:
if k not in spec:
continue

# For TFDS, strings/bytes are usually tf.string
is_string = spec[k].dtype == tf.string

if tokenize and not is_string:
raise ValueError(
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is enabled. "
"This often happens if the dataset is already tokenized (contains integers) "
"but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. "
"If your data is already tokenized, please set it to False."
)
if not tokenize and is_string:
raise ValueError(
f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is disabled. "
"This often happens if the dataset is NOT tokenized (contains strings) "
"but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. "
"If your data is not tokenized, please set it to True."
)


########## Functions used by HF pipeline


Expand Down Expand Up @@ -544,7 +572,7 @@ def make_tfrecord_iter_dataset(path: str):

@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
"""Parse serialized example"""
"""Parse serialized example proto into a dictionary of arrays."""

def __init__(self, data_columns, tokenize):
self.data_columns = data_columns
Expand Down Expand Up @@ -575,17 +603,61 @@ def map(self, element):

@dataclasses.dataclass
class NormalizeFeatures(grain.MapTransform):
"""Normalize text feature keys."""
"""Universal feature normalizer and validator. Acts as selector and validator."""

def __init__(self, column_names, tokenize):
self.column_names = column_names
self.tokenize = tokenize

def map(self, element):
if self.tokenize:
return {col: element[col][0].decode() for col in self.column_names}
else:
return {col: element[col] for col in self.column_names}
"""Selects configured columns, normalizes types, and validates tokenize/data consistency."""
res = {}
for col in self.column_names:
Comment thread
aireenmei marked this conversation as resolved.
val = element[col]
if self.tokenize:
# ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object).
# An empty array means the proto stored data in int64_list (already tokenized) — user config error.
if isinstance(val, (list, np.ndarray)):
if len(val) != 1:
raise ValueError(
f"Expected single-element string/bytes list for column '{col}' because tokenization is enabled, "
"but got an empty list. This often happens if the dataset is already tokenized (contains integers) "
"but tokenization is enabled in the configuration. "
"If your data is already tokenized, please set 'tokenize_train_data=False' "
"(or 'tokenize_eval_data=False')."
)
val = val[0] # unwrap the single-element array from ParseFeatures

# ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above.
if isinstance(val, bytes):
val = val.decode("utf-8")

# Parquet: string columns arrive as scalar str (no unwrapping needed).
# Any other type indicates a misconfiguration (e.g. already-tokenized integers).
if not isinstance(val, str):
raise ValueError(
f"Expected string or bytes for column '{col}' but got type {type(val)}. "
"If your data is already tokenized, please set 'tokenize_train_data=False' "
"(or 'tokenize_eval_data=False') in the configuration."
)
res[col] = val
else:
# Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize.
if isinstance(val, (str, bytes)):
raise ValueError(
f"Expected tokenized integers for column '{col}' because tokenization is disabled, "
f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' "
"(or 'tokenize_eval_data=True') in the configuration."
)
# ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data
# was stored in bytes_list (not tokenized) — user forgot to enable tokenization.
if isinstance(val, (list, np.ndarray)) and len(val) == 0:
raise ValueError(
f"Column '{col}' is empty. This often happens if the dataset contains strings "
"but tokenization is disabled (looking for integers)."
)
res[col] = val
return res


@dataclasses.dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/input_pipeline/tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def preprocessing_pipeline(
"Please set train_data_columns or eval_data_columns accordingly."
)

input_pipeline_utils.validate_tfds_data_types(dataset, data_column_names, tokenize)

if not use_dpo:
assert len(data_column_names) == 1
dataset = dataset.map(
Expand Down
Loading