diff --git a/src/maxtext/input_pipeline/data_processing_utils.py b/src/maxtext/input_pipeline/data_processing_utils.py index e336e814cb..95ccbae4e3 100644 --- a/src/maxtext/input_pipeline/data_processing_utils.py +++ b/src/maxtext/input_pipeline/data_processing_utils.py @@ -29,9 +29,9 @@ def parse_and_keep_features(dataset, config, data_columns, tokenize): """Parse arrayrecord features or keep specified columns for other formats.""" if config.grain_file_type in ("arrayrecord", "tfrecord"): dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) else: dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) return dataset diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 1b9910f433..ee989bb24e 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -89,6 +89,34 @@ def _process_string(string_tensor): return features +def validate_tfds_data_types(dataset, data_keys, tokenize): + """Validate that TFDS dataset types match the tokenization configuration.""" + import tensorflow as tf # pylint: disable=import-outside-toplevel + + spec = dataset.element_spec + for k in data_keys: + if k not in spec: + continue + + # For TFDS, strings/bytes are usually tf.string + is_string = spec[k].dtype == tf.string + + if tokenize and not is_string: + raise ValueError( + f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is enabled. " + "This often happens if the dataset is already tokenized (contains integers) " + "but 'tokenize_train_data=True' (or 'tokenize_eval_data=True') is set in the configuration. " + "If your data is already tokenized, please set it to False." + ) + if not tokenize and is_string: + raise ValueError( + f"TFDS dataset column '{k}' has type {spec[k].dtype}, but tokenization is disabled. " + "This often happens if the dataset is NOT tokenized (contains strings) " + "but 'tokenize_train_data=False' (or 'tokenize_eval_data=False') is set in the configuration. " + "If your data is not tokenized, please set it to True." + ) + + ########## Functions used by HF pipeline @@ -544,7 +572,7 @@ def make_tfrecord_iter_dataset(path: str): @dataclasses.dataclass class ParseFeatures(grain.MapTransform): - """Parse serialized example""" + """Parse serialized example proto into a dictionary of arrays.""" def __init__(self, data_columns, tokenize): self.data_columns = data_columns @@ -575,17 +603,61 @@ def map(self, element): @dataclasses.dataclass class NormalizeFeatures(grain.MapTransform): - """Normalize text feature keys.""" + """Universal feature normalizer and validator. Acts as selector and validator.""" def __init__(self, column_names, tokenize): self.column_names = column_names self.tokenize = tokenize def map(self, element): - if self.tokenize: - return {col: element[col][0].decode() for col in self.column_names} - else: - return {col: element[col] for col in self.column_names} + """Selects configured columns, normalizes types, and validates tokenize/data consistency.""" + res = {} + for col in self.column_names: + val = element[col] + if self.tokenize: + # ArrayRecord/TFRecord: ParseFeatures wraps bytes_list.value in np.ndarray(dtype=object). + # An empty array means the proto stored data in int64_list (already tokenized) — user config error. + if isinstance(val, (list, np.ndarray)): + if len(val) != 1: + raise ValueError( + f"Expected single-element string/bytes list for column '{col}' because tokenization is enabled, " + "but got an empty list. This often happens if the dataset is already tokenized (contains integers) " + "but tokenization is enabled in the configuration. " + "If your data is already tokenized, please set 'tokenize_train_data=False' " + "(or 'tokenize_eval_data=False')." + ) + val = val[0] # unwrap the single-element array from ParseFeatures + + # ArrayRecord/TFRecord: proto bytes_list values are Python bytes after unwrapping above. + if isinstance(val, bytes): + val = val.decode("utf-8") + + # Parquet: string columns arrive as scalar str (no unwrapping needed). + # Any other type indicates a misconfiguration (e.g. already-tokenized integers). + if not isinstance(val, str): + raise ValueError( + f"Expected string or bytes for column '{col}' but got type {type(val)}. " + "If your data is already tokenized, please set 'tokenize_train_data=False' " + "(or 'tokenize_eval_data=False') in the configuration." + ) + res[col] = val + else: + # Parquet: text column arrives as scalar str/bytes — user forgot to pre-tokenize. + if isinstance(val, (str, bytes)): + raise ValueError( + f"Expected tokenized integers for column '{col}' because tokenization is disabled, " + f"but got strings. If your data is NOT tokenized, please set 'tokenize_train_data=True' " + "(or 'tokenize_eval_data=True') in the configuration." + ) + # ArrayRecord/TFRecord: ParseFeatures reads from int64_list; an empty array means the data + # was stored in bytes_list (not tokenized) — user forgot to enable tokenization. + if isinstance(val, (list, np.ndarray)) and len(val) == 0: + raise ValueError( + f"Column '{col}' is empty. This often happens if the dataset contains strings " + "but tokenization is disabled (looking for integers)." + ) + res[col] = val + return res @dataclasses.dataclass diff --git a/src/maxtext/input_pipeline/tfds_data_processing.py b/src/maxtext/input_pipeline/tfds_data_processing.py index 7795e621f0..73c89c3a06 100644 --- a/src/maxtext/input_pipeline/tfds_data_processing.py +++ b/src/maxtext/input_pipeline/tfds_data_processing.py @@ -102,6 +102,8 @@ def preprocessing_pipeline( "Please set train_data_columns or eval_data_columns accordingly." ) + input_pipeline_utils.validate_tfds_data_types(dataset, data_column_names, tokenize) + if not use_dpo: assert len(data_column_names) == 1 dataset = dataset.map(