diff --git a/src/maxtext/input_pipeline/tfds_data_processing.py b/src/maxtext/input_pipeline/tfds_data_processing.py index 7795e621f0..e62afd2e95 100644 --- a/src/maxtext/input_pipeline/tfds_data_processing.py +++ b/src/maxtext/input_pipeline/tfds_data_processing.py @@ -305,6 +305,7 @@ def make_tfds_eval_iterator( use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) + global_shape = (config.global_batch_size_to_load_eval, config.max_target_length) return multihost_dataloading.RemoteIteratorWrapper( - get_ds_fn, preprocessing_fn, config, global_mesh, checkpoint_path=config.checkpoint_dir + get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path=config.checkpoint_dir )