Skip to content

Add JAXBench: TPU kernel benchmark suite#34

Open
aryatschand wants to merge 1 commit intomainfrom
add-jaxbench
Open

Add JAXBench: TPU kernel benchmark suite#34
aryatschand wants to merge 1 commit intomainfrom
add-jaxbench

Conversation

@aryatschand
Copy link
Copy Markdown
Collaborator

Summary

  • Adds JAXBench, a suite of 50 curated JAX/TPU kernel workloads with a production-ready evaluation harness for benchmarking AI-generated kernel optimizations.
  • benchmark/: 50 workloads (17 priority production operators + 33 KernelBench fused ops) with consistent interface. 8 have hand-optimized Pallas TPU kernel variants.
  • harness/: Evaluation pipeline with device-side profiling via jax.profiler.trace(), correctness checking (atol=1e-2, rtol=1e-2), and three-way comparison (baseline XLA vs candidate vs Pallas reference).
  • CLI: python -m JAXBench {evaluate,run,list} for both agent and user workflows.

Test plan

  • python -m JAXBench list shows all 50 workloads (works on CPU)
  • python -m JAXBench list --json returns valid JSON
  • python -m JAXBench run --workload 8p_GEMM --tpu v6e produces timing results (requires TPU)
  • python -m JAXBench evaluate --workload 8p_GEMM --kernel JAXBench/benchmark/8p_GEMM/optimized.py --json returns structured eval output (requires TPU)
  • python -m JAXBench run --all --tpu v6e produces results.json + results.csv (requires TPU)

🤖 Generated with Claude Code

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 1, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Adds JAXBench, a suite of 50 curated JAX/TPU kernel workloads with a
production-ready evaluation harness for benchmarking AI-generated kernel
optimizations.

- benchmark/: 50 workloads (17 priority + 33 KernelBench) with consistent
  interface (CONFIG, create_inputs, workload). 8 have hand-optimized Pallas
  variants.
- harness/: Evaluation pipeline with device-side profiling via
  jax.profiler.trace(), correctness checking (atol=1e-2, rtol=1e-2), and
  three-way comparison (baseline XLA vs candidate vs Pallas reference).
- CLI: python -m JAXBench {evaluate,run,list} for agent and user workflows.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants