Skip to content

Few updates and fixes JPM edition #59

Description

@ASKabalan

JaxPM — Overview of Open PRs (May 2026)

This document tracks the seven open PRs against DifferentiableUniverseInitiative/JaxPM from @ASKabalan, their dependencies, and a recommended merge order. It mirrors the format of jax_cosmo#149.

Summary of PRs

### #54extend-normal_field

Reworks get_local_shape to derive local shape by mapping axis names to mesh device sizes (instead of assuming the array rank matches the sharding spec), and introduces get_sharding_for_shape, which trims a PartitionSpec to the rank of the target array. normal_field now calls get_sharding_for_shape before generating. Tested by a new parametrised test_normal_field. Every later PR depends on this, and most have already merged it in.

The goal is allowing 1D, 2D and 3D distributed normal sampling, used in jax-fli when sampling distribution fields of any type (3d density, 2d flat sky, or 1d spherical)

### #55update-growth-to-match-jax-cosmo

Tracks the upstream refactor in jax_cosmo#143 (cache rewrite). Replaces JaxPM's local _growth_factor_ODE cache machinery with jax_cosmo._compute_growth_tables, drops the now-removed cosmo._workspace dict (the tests no longer set cosmo._workspace = {}). The CI workflow pins jax_cosmo to ASKabalan:better-cache until upstream (jax_cosmo#143#146) ships.

Robust caching and growth is important, since as I mentionned in jax-cosmo PR, in a end-to-end differentiable pipeline, you can easily write in the dictionnary in parallel which causes a leak (has been reproduced by other people Tilman Troester, Natalia ...)

### #56jaxdecomp-fixes

Catches JaxPM up with the new jaxDecomp API: Halo is now passed as a tuple instead of a positional integer; and an at[] indexing bug in kernels.py is fixed. Pure compatibility PR — no behaviour change beyond what jaxDecomp now requires.

### #57fastpm-odes

Adds symplectic FastPM ODE integrators in the new jaxpm/ode.py (collecting kick/drift factors)

#58discodj-forces

Brings pm_forces and the LPT2 path in line with DISCO-DJ's force convention, adding dealiasing for LPT2. Required by jax-fli's DISCO-DJ comparison test. Mostly self-contained — touches jaxpm/pm.py only.

#53better-sharded-spherical-painting

Rewrites 2D painting so the flat-sky projection is sharded all the way through and only communicates at the final scatter; rewrites spherical painting so healpy.ang2pix / vec2pix run on every device with no all-gather (the only collective is the final aggregation of the HEALPix map). Measured: 2D RBF goes 17.4 MB → 3.3 MB, BILINEAR 2.5 MB → 832 KB, all "all-gather detected" warnings disappear. Depends on: #54 (sharded shape helpers), #55 (new growth API used in tests), #56 (new kernels API), and externally on jax-healpy#2 (multi-dim batch shapes in get_interp_weights / get_all_neighbours). The CI workflow temporarily installs jax-healpy from main until 0.6.1+ releases.

#4241-spherical-lensing

Rewrites jaxpm/lensing.py for sharded painting: replaces the old density_plane with a closure-based density_plane_fn suitable for diffeqsolve, adds spherical_density_fn driving the new sharded paint_particles_spherical, and updates CIC paint calls for the shape-preserving cic_paint_2d. Adds notebooks/08-convergence-vs-glass.ipynb validating against GLASS. Built on top of #53 — the branch was previously stuck on an old tip of #53 and was failing CI (test_spherical_painting_methods.py broadcast error). It has just been re-merged with the current tip of #53, which brings in the missing jaxdecomp + jax-healpy fixes.

Dependency graph

                                main
                                  │
        ┌─────────────────────────┼
        │                         │                        
   #54 extend-normal_field   #58 discodj-forces        
        │                         │
        ├────► #56 jaxdecomp-fixes
        │            │
        ├────► #55 update-growth-to-match-jax-cosmo
        │            │
        │            └────► #57 fastpm-odes
        │
        ▼
   #53 better-sharded-spherical-painting   ◄── needs #54 + #55 + #56
        │                                       + jax-healpy#2 (external)
        ▼
   #42 41-spherical-lensing                ◄── needs #53

External:
  jax_cosmo#143 / #146    →   pinned via ASKabalan:better-cache
  jax-healpy#2 / 0.6.1    →   pinned via CMBSciPol/jax-healpy@main

Recommended merge order

  1. Allow 1d and 2d normal fields in distributed mode #54extend-normal_field (merge as-is)
    Foundation. No prerequisites. Merging it first removes the most-shared rebase target (get_local_shape / get_sharding_for_shape) from every other branch's diff. After this, every other open PR should be rebased onto main and re-pushed.

  2. Jaxdecomp fixes #56jaxdecomp-fixes (merge after Allow 1d and 2d normal fields in distributed mode #54)
    Tiny, mechanical migration to the new jaxdecomp API. Doing it before Update growth to match jax cosmo #55 / Fastpm odes #57 means later branches don't have to carry the kernels.py rewrite. After merge, rebase Update growth to match jax cosmo #55, Fastpm odes #57, Discodj forces #58, update spherical and 2d painting to work better for sharding #53 — they'll pick up the new Halo-tuple signature for free.

  3. Update growth to match jax cosmo #55update-growth-to-match-jax-cosmo (merge after Allow 1d and 2d normal fields in distributed mode #54 + Jaxdecomp fixes #56)
    Brings in the jax_cosmo._compute_growth_tables refactor. Block until jax_cosmo#143 (or its successor #146) is actually merged on main, then drop the ASKabalan:better-cache pin in tests.yml and merge. Rebase Fastpm odes #57 and update spherical and 2d painting to work better for sharding #53 after.

  4. Fastpm odes #57fastpm-odes (merge after Update growth to match jax cosmo #55)
    Requires Update growth to match jax cosmo #55's new _compute_growth_tables signature. Self-contained otherwise (jaxpm/ode.py is new). Notebooks 01–05 will rebase cleanly.

  5. Discodj forces #58discodj-forces (merge any time after Allow 1d and 2d normal fields in distributed mode #54)
    No overlap with the lensing/painting/growth stack. Can be slotted in wherever convenient.

  6. update spherical and 2d painting to work better for sharding #53better-sharded-spherical-painting (merge after Allow 1d and 2d normal fields in distributed mode #54 + Update growth to match jax cosmo #55 + Jaxdecomp fixes #56)
    Blocked on jax-healpy#2 being released as 0.6.1+. Once jax-healpy ships, drop the git+https://github.com/CMBSciPol/jax-healpy pin in tests.yml and merge.

  7. 41 spherical lensing #4241-spherical-lensing (merge last)
    Built on update spherical and 2d painting to work better for sharding #53. With update spherical and 2d painting to work better for sharding #53 just re-merged, CI should now pass — pip install jax-healpy@main covers the broadcast bug and the new kernels.py / Halo tuple satisfies jaxdecomp. After update spherical and 2d painting to work better for sharding #53 merges to main, this PR's diff collapses to the lensing-specific changes (lensing.py + notebooks/08-convergence-vs-glass.ipynb).

Open external blockers

Upstream PR Blocks Workaround in CI
jax_cosmo#143 (cache refactor) #55, #57, #53, #42 pip install git+https://github.com/ASKabalan/jax_cosmo@better-cache
jax-healpy#2 (multi-dim batch shapes) #53, #42 pip install --force-reinstall --no-deps git+https://github.com/CMBSciPol/jax-healpy

Once both upstream PRs are merged and a new jax_cosmo / jax-healpy release is out, all of the git+https:// lines in .github/workflows/tests.yml can be removed.

update: jax_healpy was merged, waiting for new release

TL;DR

#54  →  #56  →  #55  →  #57
                  │
                  ↓
                 #53  →  #42

#58  → can land any time after #54

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions