Skip to content

feat: add Vlasov 2D distributed initialization#285

Merged
joglekara merged 3 commits into
mainfrom
feat/vlasov2d-distributed-init
Jun 10, 2026
Merged

feat: add Vlasov 2D distributed initialization#285
joglekara merged 3 commits into
mainfrom
feat/vlasov2d-distributed-init

Conversation

@joglekara

Copy link
Copy Markdown
Member

Summary

  • Add grid.distribution-sharding config for Vlasov-2D distribution initialization
  • Add JAX NamedSharding helpers and make_array_from_callback machinery to initialize global f(x, y, vx, vy) arrays shard-by-shard
  • Cache the initialized distribution between solver-quantity setup and state initialization to avoid double allocation
  • Add reshard helper for future global-axis FFT/all-to-all pusher work
  • Document the new sharding block in the Vlasov-2D config reference

Notes

  • Existing behavior is unchanged unless grid.distribution-sharding.enabled is true
  • Sharded initialization currently requires deterministic density profiles (noise_type: none or noise_val: 0.0) because the existing noise profile is shape-seeded and would otherwise repeat per shard
  • This PR lays down initialization and sharding metadata machinery; the exponential pusher/all-to-all FFT conversion can build on the included reshard_for_global_axis_fft helper

Verification

  • uvx ruff format adept/_vlasov2d tests/test_vlasov2d/test_distributed_init.py
  • uvx ruff check adept/_vlasov2d tests/test_vlasov2d/test_distributed_init.py
  • uv run pytest tests/test_vlasov2d/test_distributed_init.py -q → 4 passed
  • Attempted uv run pytest tests/test_vlasov2d -q; existing long solver case was still CPU-bound after ~20 minutes, so I stopped it after the targeted tests had passed

joglekara and others added 3 commits June 3, 2026 10:59
…rd_to_partitioned

- helpers.py: compute each component's density profile once and reuse for
  both n_prof_species and the unsharded f_one path (previously evaluated
  twice on the host path).
- distributed.py: docstring note on reshard_for_global_axis_fft warning that
  declaring shardings on jax.jit lets XLA's SPMD partitioner do this work
  automatically. Add symmetric reshard_to_partitioned helper that restores
  the canonical partitioned layout (typically used after an inverse FFT).
- tests: round-trip test for reshard_for_global_axis_fft + reshard_to_partitioned.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@joglekara joglekara merged commit 6e703a7 into main Jun 10, 2026
1 check passed
@joglekara joglekara deleted the feat/vlasov2d-distributed-init branch June 10, 2026 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant