You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This document tracks the seven open PRs against DifferentiableUniverseInitiative/JaxPM from @ASKabalan, their dependencies, and a recommended merge order. It mirrors the format of jax_cosmo#149.
Reworks get_local_shape to derive local shape by mapping axis names to mesh device sizes (instead of assuming the array rank matches the sharding spec), and introduces get_sharding_for_shape, which trims a PartitionSpec to the rank of the target array. normal_field now calls get_sharding_for_shape before generating. Tested by a new parametrised test_normal_field. Every later PR depends on this, and most have already merged it in.
The goal is allowing 1D, 2D and 3D distributed normal sampling, used in jax-fli when sampling distribution fields of any type (3d density, 2d flat sky, or 1d spherical)
Tracks the upstream refactor in jax_cosmo#143 (cache rewrite). Replaces JaxPM's local _growth_factor_ODE cache machinery with jax_cosmo._compute_growth_tables, drops the now-removed cosmo._workspace dict (the tests no longer set cosmo._workspace = {}). The CI workflow pins jax_cosmo to ASKabalan:better-cache until upstream (jax_cosmo#143 → #146) ships.
Robust caching and growth is important, since as I mentionned in jax-cosmo PR, in a end-to-end differentiable pipeline, you can easily write in the dictionnary in parallel which causes a leak (has been reproduced by other people Tilman Troester, Natalia ...)
Catches JaxPM up with the new jaxDecomp API: Halo is now passed as a tuple instead of a positional integer; and an at[] indexing bug in kernels.py is fixed. Pure compatibility PR — no behaviour change beyond what jaxDecomp now requires.
Brings pm_forces and the LPT2 path in line with DISCO-DJ's force convention, adding dealiasing for LPT2. Required by jax-fli's DISCO-DJ comparison test. Mostly self-contained — touches jaxpm/pm.py only.
Rewrites 2D painting so the flat-sky projection is sharded all the way through and only communicates at the final scatter; rewrites spherical painting so healpy.ang2pix / vec2pix run on every device with no all-gather (the only collective is the final aggregation of the HEALPix map). Measured: 2D RBF goes 17.4 MB → 3.3 MB, BILINEAR 2.5 MB → 832 KB, all "all-gather detected" warnings disappear. Depends on: #54 (sharded shape helpers), #55 (new growth API used in tests), #56 (new kernels API), and externally on jax-healpy#2 (multi-dim batch shapes in get_interp_weights / get_all_neighbours). The CI workflow temporarily installs jax-healpy from main until 0.6.1+ releases.
Rewrites jaxpm/lensing.py for sharded painting: replaces the old density_plane with a closure-based density_plane_fn suitable for diffeqsolve, adds spherical_density_fn driving the new sharded paint_particles_spherical, and updates CIC paint calls for the shape-preserving cic_paint_2d. Adds notebooks/08-convergence-vs-glass.ipynb validating against GLASS. Built on top of #53 — the branch was previously stuck on an old tip of #53 and was failing CI (test_spherical_painting_methods.py broadcast error). It has just been re-merged with the current tip of #53, which brings in the missing jaxdecomp + jax-healpy fixes.
Allow 1d and 2d normal fields in distributed mode #54 — extend-normal_field(merge as-is)
Foundation. No prerequisites. Merging it first removes the most-shared rebase target (get_local_shape / get_sharding_for_shape) from every other branch's diff. After this, every other open PR should be rebased onto main and re-pushed.
Once both upstream PRs are merged and a new jax_cosmo / jax-healpy release is out, all of the git+https:// lines in .github/workflows/tests.yml can be removed.
update: jax_healpy was merged, waiting for new release
TL;DR
#54 → #56 → #55 → #57
│
↓
#53 → #42
#58 → can land any time after #54
JaxPM — Overview of Open PRs (May 2026)
This document tracks the seven open PRs against
DifferentiableUniverseInitiative/JaxPMfrom@ASKabalan, their dependencies, and a recommended merge order. It mirrors the format ofjax_cosmo#149.Summary of PRs
### #54 —extend-normal_fieldReworksget_local_shapeto derive local shape by mapping axis names to mesh device sizes (instead of assuming the array rank matches the sharding spec), and introducesget_sharding_for_shape, which trims aPartitionSpecto the rank of the target array.normal_fieldnow callsget_sharding_for_shapebefore generating. Tested by a new parametrisedtest_normal_field. Every later PR depends on this, and most have already merged it in.The goal is allowing 1D, 2D and 3D distributed normal sampling, used in jax-fli when sampling distribution fields of any type (3d density, 2d flat sky, or 1d spherical)### #55 —update-growth-to-match-jax-cosmoTracks the upstream refactor injax_cosmo#143(cache rewrite). Replaces JaxPM's local_growth_factor_ODEcache machinery withjax_cosmo._compute_growth_tables, drops the now-removedcosmo._workspacedict (the tests no longer setcosmo._workspace = {}). The CI workflow pinsjax_cosmotoASKabalan:better-cacheuntil upstream (jax_cosmo#143→#146) ships.Robust caching and growth is important, since as I mentionned in jax-cosmo PR, in a end-to-end differentiable pipeline, you can easily write in the dictionnary in parallel which causes a leak (has been reproduced by other people Tilman Troester, Natalia ...)### #56 —jaxdecomp-fixesCatches JaxPM up with the newjaxDecompAPI:Halois now passed as a tuple instead of a positional integer; and anat[]indexing bug inkernels.pyis fixed. Pure compatibility PR — no behaviour change beyond what jaxDecomp now requires.### #57 —fastpm-odesAdds symplectic FastPM ODE integrators in the newjaxpm/ode.py(collecting kick/drift factors)#58 —
discodj-forcesBrings
pm_forcesand the LPT2 path in line with DISCO-DJ's force convention, adding dealiasing for LPT2. Required byjax-fli's DISCO-DJ comparison test. Mostly self-contained — touchesjaxpm/pm.pyonly.#53 —
better-sharded-spherical-paintingRewrites 2D painting so the flat-sky projection is sharded all the way through and only communicates at the final scatter; rewrites spherical painting so
healpy.ang2pix/vec2pixrun on every device with no all-gather (the only collective is the final aggregation of the HEALPix map). Measured: 2D RBF goes 17.4 MB → 3.3 MB, BILINEAR 2.5 MB → 832 KB, all "all-gather detected" warnings disappear. Depends on: #54 (sharded shape helpers), #55 (new growth API used in tests), #56 (new kernels API), and externally onjax-healpy#2(multi-dim batch shapes inget_interp_weights/get_all_neighbours). The CI workflow temporarily installsjax-healpyfrommainuntil0.6.1+releases.#42 —
41-spherical-lensingRewrites
jaxpm/lensing.pyfor sharded painting: replaces the olddensity_planewith a closure-baseddensity_plane_fnsuitable fordiffeqsolve, addsspherical_density_fndriving the new shardedpaint_particles_spherical, and updates CIC paint calls for the shape-preservingcic_paint_2d. Addsnotebooks/08-convergence-vs-glass.ipynbvalidating against GLASS. Built on top of #53 — the branch was previously stuck on an old tip of #53 and was failing CI (test_spherical_painting_methods.pybroadcast error). It has just been re-merged with the current tip of #53, which brings in the missingjaxdecomp+jax-healpyfixes.Dependency graph
Recommended merge order
Allow 1d and 2d normal fields in distributed mode #54 —
extend-normal_field(merge as-is)Foundation. No prerequisites. Merging it first removes the most-shared rebase target (
get_local_shape/get_sharding_for_shape) from every other branch's diff. After this, every other open PR should be rebased ontomainand re-pushed.Jaxdecomp fixes #56 —
jaxdecomp-fixes(merge after Allow 1d and 2d normal fields in distributed mode #54)Tiny, mechanical migration to the new jaxdecomp API. Doing it before Update growth to match jax cosmo #55 / Fastpm odes #57 means later branches don't have to carry the kernels.py rewrite. After merge, rebase Update growth to match jax cosmo #55, Fastpm odes #57, Discodj forces #58, update spherical and 2d painting to work better for sharding #53 — they'll pick up the new
Halo-tuple signature for free.Update growth to match jax cosmo #55 —
update-growth-to-match-jax-cosmo(merge after Allow 1d and 2d normal fields in distributed mode #54 + Jaxdecomp fixes #56)Brings in the
jax_cosmo._compute_growth_tablesrefactor. Block untiljax_cosmo#143(or its successor #146) is actually merged onmain, then drop theASKabalan:better-cachepin intests.ymland merge. Rebase Fastpm odes #57 and update spherical and 2d painting to work better for sharding #53 after.Fastpm odes #57 —
fastpm-odes(merge after Update growth to match jax cosmo #55)Requires Update growth to match jax cosmo #55's new
_compute_growth_tablessignature. Self-contained otherwise (jaxpm/ode.pyis new). Notebooks 01–05 will rebase cleanly.Discodj forces #58 —
discodj-forces(merge any time after Allow 1d and 2d normal fields in distributed mode #54)No overlap with the lensing/painting/growth stack. Can be slotted in wherever convenient.
update spherical and 2d painting to work better for sharding #53 —
better-sharded-spherical-painting(merge after Allow 1d and 2d normal fields in distributed mode #54 + Update growth to match jax cosmo #55 + Jaxdecomp fixes #56)Blocked on
jax-healpy#2being released as0.6.1+. Once jax-healpy ships, drop thegit+https://github.com/CMBSciPol/jax-healpypin intests.ymland merge.41 spherical lensing #42 —
41-spherical-lensing(merge last)Built on update spherical and 2d painting to work better for sharding #53. With update spherical and 2d painting to work better for sharding #53 just re-merged, CI should now pass —
pip install jax-healpy@maincovers the broadcast bug and the newkernels.py/Halotuple satisfies jaxdecomp. After update spherical and 2d painting to work better for sharding #53 merges tomain, this PR's diff collapses to the lensing-specific changes (lensing.py+notebooks/08-convergence-vs-glass.ipynb).Open external blockers
jax_cosmo#143(cache refactor)pip install git+https://github.com/ASKabalan/jax_cosmo@better-cachejax-healpy#2(multi-dim batch shapes)pip install --force-reinstall --no-deps git+https://github.com/CMBSciPol/jax-healpyOnce both upstream PRs are merged and a new
jax_cosmo/jax-healpyrelease is out, all of thegit+https://lines in.github/workflows/tests.ymlcan be removed.update:
jax_healpywas merged, waiting for new releaseTL;DR