diff --git a/ChartQA_results.csv b/ChartQA_results.csv new file mode 100644 index 0000000000..be80a514f5 --- /dev/null +++ b/ChartQA_results.csv @@ -0,0 +1,6 @@ +question ID,question,label,output,is_correct +1,How many food item is shown in the bar graph?,14,12,False +2,What is the difference in value between Lamb and Corn?,0.57,9.07,False +3,How many bars are shown in the chart?,3,3,True +4,Is the sum value of Madagascar more then Fiji?,No,No,True +5,What's the value of the lowest bar?,23,23,True diff --git a/benchmarks/multimodal/__init__.py b/benchmarks/multimodal/__init__.py new file mode 100644 index 0000000000..4a62083b81 --- /dev/null +++ b/benchmarks/multimodal/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/benchmarks/multimodal/multimodal_eval.py b/benchmarks/multimodal/multimodal_eval.py new file mode 100644 index 0000000000..d28ee1922a --- /dev/null +++ b/benchmarks/multimodal/multimodal_eval.py @@ -0,0 +1,309 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This is a simple script for multimodal benchmark for a trained checkpoint. +HuggingFaceM4/ChartQA: https://huggingface.co/datasets/HuggingFaceM4/ChartQA + +Usage: +# Gemma3-4b on a single TPU v4-8 VM +python3 -m benchmarks.multimodal.multimodal_eval MaxText/configs/base.yml \ + model_name=gemma3-4b tokenizer_path=assets/tokenizer.gemma3 \ + load_parameters_path=gs://maxtext-model-checkpoints/gemma3-4b/multimodal/2025-05-21-23-23-59/checkpoints/0/items \ + base_output_directory=$YOUR_GCS_PATH \ + per_device_batch_size=1 run_name=mmeval_test steps=1 async_checkpointing=false \ + scan_layers=false use_multimodal=true attention=\'dot_product\' \ + max_prefill_predict_length=550 max_target_length=570 per_device_batch_size=1 \ + hf_data_dir=HuggingFaceM4/ChartQA hf_eval_split=test + +# Llama4-17b-16e on a TPU v5p-128 cluster (images resized to 336x336 for simplicity) +python -m benchmarks.multimodal.multimodal_eval \ + MaxText/configs/base.yml model_name=llama4-17b-16e image_resize=336 \ + tokenizer_path=meta-llama/Llama-4-Scout-17B-16E \ + load_parameters_path=gs://maxtext-model-checkpoints/llama4-17b-16e/hybrid/2025-07-22-11-03-20/0/items \ + base_output_directory=$YOUR_GCS_PATH \ + per_device_batch_size=1 run_name=mmeval_test steps=1 async_checkpointing=false \ + scan_layers=true use_multimodal=true attention=\'dot_product\' \ + max_prefill_predict_length=350 max_target_length=370 per_device_batch_size=1 \ + hf_data_dir=HuggingFaceM4/ChartQA hf_eval_split=test hf_access_token=\'$YOUR_HF_ACCESS_TOKEN\' \ + ici_fsdp_parallelism=1 ici_expert_parallelism=16 ici_tensor_parallelism=4 +""" + +import re +import sys +from dataclasses import dataclass +from datetime import datetime +from typing import List, Optional + +import absl +from absl import flags +import datasets +import jax +import numpy as np +import pandas as pd +from PIL import Image +from tqdm import tqdm + +from maxtext.configs import pyconfig +from maxtext.inference.maxengine import maxengine +from maxtext.multimodal import processor as mm_processor +from maxtext.multimodal import utils as mm_utils +import maxtext.multimodal.processor_gemma3 as processor_gemma3 +from maxtext.utils import gcs_utils +from maxtext.utils import max_logging +from maxtext.utils import max_utils + + +absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log + + +ASCII_UPPERCASE_A = ord("A") # ASCII value for uppercase 'A' +SUPPORTED_DATASETS = ["HuggingFaceM4/ChartQA"] + +DEFAULT_PROMPT_TEMPLATE = """You are an expert at answering questions based +on provided charts. Your task is to extract the exact answer from the +given context or determine that it's not present. +{image_placeholder} Question: {question} +For numerical answers, provide only the number. +For text answers, provide only the exact text. +For judgement questions, respond with "Yes" or "No". +If not found, output "N/A". +Your output must be only the exact answer within , with no extra contents. + +Example: +Question: What is the capital of France? +Your answer: Paris +""" + +SFT_PROMPT_TEMPLATE = "{image_placeholder}{question}" + + +@dataclass +class ParsedDatasetExample: + """Parsed example from the HuggingFace dataset.""" + question: Optional[str] = None + image_np: Optional[np.ndarray] = None + choices: Optional[List[str]] = None + answer: Optional[str] = None + + +def parse_dataset_example(example, hf_dataset_name): + """Parse a single example from the HuggingFace dataset.""" + parsed_example = ParsedDatasetExample() + if hf_dataset_name == "HuggingFaceM4/ChartQA": + parsed_example.question = example["query"] + parsed_example.image_np = np.asarray(example["image"].convert("RGB")) # Convert PIL object to np array + parsed_example.answer = example["label"][0] + else: + raise ValueError(f"Unsupported dataset: {hf_dataset_name}") + + # Resize the image if specified. This helps simplify the llama4's tiling, so we have a fixed input size + if cfg.image_resize != -1: + pil_img = Image.fromarray(parsed_example.image_np) + pil_img = pil_img.resize((cfg.image_resize, cfg.image_resize)) + parsed_example.image_np = np.asarray(pil_img.convert("RGB")) + + return parsed_example + + +def construct_prompt(parsed_dataset_example: ParsedDatasetExample, config, system_message: Optional[str] = None): + """Construct prompt from a parsed dataset example.""" + # image_placeholder = multimodal_utils.get_image_placeholder(config.model_name) if config.use_multimodal else "" + image_placeholder = "" + choices_text = "\n".join(f"{chr(ASCII_UPPERCASE_A + idx)}. {choice}" for idx, choice in enumerate(parsed_dataset_example.choices)) if parsed_dataset_example.choices else "" + # # Prompt for raw pretrained checkpoints + prompt = DEFAULT_PROMPT_TEMPLATE.format( + image_placeholder=image_placeholder, + question=parsed_dataset_example.question, + choices=choices_text if choices_text else "N/A" + ) + # # Prompt for SFT checkpoints, same as the original SFT prompt + # prompt = SFT_PROMPT_TEMPLATE.format( + # image_placeholder=image_placeholder, + # question=parsed_dataset_example.question + # ) + # Add extra model-specific formatting such as user/model/assistant tags + prompt = mm_processor.reformat_prompt(prompt, image_placeholder, config.model_name, num_images=1 if config.use_multimodal else 0) + prompt = system_message + "\n\n" + prompt if system_message else prompt + return prompt + + +def parse_answer(output_string): + # Try to match the ? template (e.g., Paris from any pretrained models) + match_xml = re.search(r"(.*?)", output_string, re.DOTALL) + if match_xml: + return match_xml.group(1).strip() + + # If not found, try to match the ['?'] template (e.g., ['Paris'] from HuggingFaceM4/ChartQA SFT) + match_list = re.search(r"\['(.*?)'\]", output_string) + if match_list: + return match_list.group(1).strip() + + # If neither template is found, return None + return None + + +def main(config): + engine = maxengine.MaxEngine(config) + params = engine.load_params() + + metadata = engine.get_tokenizer() + tokenizer = engine.build_tokenizer(metadata) + + max_prefill_predict_length = getattr(config, "max_prefill_predict_length", 1024) + max_target_length = getattr(config, "max_target_length", 2048) + + # Initialize counters for overall accuracy + correct_count = 0 + total_count = 0 + + # Get the HuggingFace dataset path and name from the config + hf_data_dir = config.hf_data_dir + hf_data_name = hf_data_dir.split("/")[-1] if "/" in hf_data_dir else hf_data_dir + hf_eval_split = config.hf_eval_split + + # Config for saving csv results + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + results_file_name = f"{hf_data_name}_results.csv" # Choose an appropriate name + result_gcs_path = f"{cfg.base_output_directory}/{timestamp}.csv" if cfg.base_output_directory else None + max_logging.log(f"Results will be saved to {results_file_name} and uploaded to GCS: {result_gcs_path}") + results_data = [] + + test_ds = datasets.load_dataset(hf_data_dir, "default", split=hf_eval_split) + for idx, example in enumerate(tqdm(test_ds, desc=f"Evaluating {hf_data_dir} dataset")): + prefill_length = config.max_prefill_predict_length + parsed_dataset_example = parse_dataset_example(example, hf_data_dir) + prompt = construct_prompt(parsed_dataset_example, config) + processor_output = processor_gemma3.preprocess_mm_data_gemma3(parsed_dataset_example.image_np) + prefill_length -= mm_processor.get_image_offsets(config=config, processor_output=processor_output) + print("\n" + "*"*50) + + # Tokenize the input + tokens, true_length = tokenizer.encode(prompt, is_bos=True, prefill_lengths=[prefill_length]) + if config.use_multimodal: + tokens = mm_processor.prepare_text_for_image_fusion( + tokens=tokens, config=config, processor_output=processor_output + ) + image_offsets = mm_processor.get_image_offsets(config=config, processor_output=processor_output) + true_length += image_offsets + if true_length > max_prefill_predict_length: + max_logging.log( + f"Warning: Prompt length {true_length} exceeds max prefill length" f" {max_prefill_predict_length}. Truncating." + ) + tokens = tokens[:max_prefill_predict_length] + true_length = max_prefill_predict_length + assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" + assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet" + + # Perform prefill + prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, images=processor_output.pixel_values, true_length=true_length) + slot = 0 + + # Initialize decode state + decode_state = engine.init_decode_state() + decode_state = engine.insert(prefill_result, decode_state, slot=slot) + + steps = range(max_prefill_predict_length, max_target_length) + sampled_tokens = [first_token.get_result_at_slot(slot).tokens.item()] + + predicted_answer = "" + + for _ in steps: + # Decode generated tokens so far + output = tokenizer.decode(sampled_tokens) + predicted_answer = parse_answer(output) + if predicted_answer: + break + + # Generate next token + decode_state, sampled_token = engine.generate(params, decode_state) + sampled_tokens.append(sampled_token.get_result_at_slot(slot).tokens.item()) + if sampled_tokens[-1] == tokenizer.eos_id: + break + + correct_answer = parsed_dataset_example.answer + # For open-ended answers, we do substring matching + # TODO(hengtaoguo): More robust correctness checking (e.g. numerical tolerance for numbers, Gemini APIs) + # is_correct = correct_answer in predicted_answer if predicted_answer is not None else False + is_correct = correct_answer in output + + # Log answer + max_logging.log( + f"{total_count + 1} | {parsed_dataset_example.question}\n" + f"[Model output] {output}\n" + f"[Label answer] {correct_answer}\n" + f"Matching: {is_correct}" + ) + + # Save results for CSV + results_data.append({ + "question ID": total_count + 1, + "question": parsed_dataset_example.question, + "label": parsed_dataset_example.answer, + "output": output, + "is_correct": is_correct + }) + + # Update accuracy for overall + if is_correct: + correct_count += 1 + total_count += 1 + max_logging.log(f"Running accuracy: {correct_count / (total_count):.4f} | Processed: {total_count}/{len(test_ds)}") + + if idx >= 4: # For debugging, limit to first 5 examples + break + + # Every 100 rows, save intermediate results to CSV and upload to GCS + if idx % 100 == 0 and result_gcs_path is not None and jax.process_index() == 0: + results_df = pd.DataFrame(results_data) + results_df.to_csv(results_file_name, index=False) + gcs_utils.upload_blob(result_gcs_path, results_file_name) + max_logging.log(f"Uploaded the results file to GCS bucket: {result_gcs_path}") + + # Final accuracy + if total_count > 0: + accuracy = correct_count / total_count + max_logging.log(f"\nFinal accuracy on {hf_data_dir} dataset: {accuracy:.4f}") + else: + max_logging.log("No valid predictions were made.") + + # Save predictions to CSV and upload to GCS + if result_gcs_path is not None and jax.process_index() == 0: + results_df = pd.DataFrame(results_data) + results_df.to_csv(results_file_name, index=False) + max_logging.log(f"Saved predictions to {results_file_name}") + gcs_utils.upload_blob(result_gcs_path, results_file_name) + max_logging.log(f"Uploaded the results file to GCS bucket: {result_gcs_path}") + + +def validate_config(config): + assert not config.load_full_state_path, ( + "Decode doesn't operate on full states! Convert to parameter checkpoint" + " first. Using generate_param_only_checkpoint." + ) + assert config.hf_data_dir, ( + "For benchmark evaluation, please specify the HuggingFace dataset name using the hf_data_dir config field." + ) + assert config.hf_data_dir in SUPPORTED_DATASETS, ( + f"Unsupported dataset {config.hf_data_dir}. Supported datasets are: {SUPPORTED_DATASETS}." + " Please add support for your desired dataset in the code of multimodal_eval.py." + ) + + +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + flags.FLAGS(sys.argv) + cfg = pyconfig.initialize(sys.argv) + validate_config(cfg) + max_utils.print_system_information() + main(cfg) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 708772e374..50353b03ba 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1106,6 +1106,7 @@ posemb_type_for_vit: "learn" # Set it to avoid unnecessary padding if you know the maximum number of images per example. max_num_images_per_example: -1 vision_output_length: -1 # The output length (number of soft tokens) from vision encoder, used in Gemma4. +image_resize: -1 # Resize image to this size, -1 means no resize. Helps simplify llama4 multimodal decoding w.r.t no tiling needed ### llama4 multi modal configs hidden_size_for_vit: 1408 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 12d02272ea..93b2dc2de9 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1714,7 +1714,7 @@ class MultimodalGeneral(BaseModel): use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.") mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.") position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).") - + image_resize: int = Field(-1, description="Resize images for simpler multimodal decoding; -1 disables resizing.") class VisionTower(BaseModel): """Configuration for the Vision Tower (Encoder) in a multimodal model."""