Skip to content
Draft
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
99c75ad
Add Clenshaw-Curtis quadrature rule implementation
matt-graham Apr 28, 2026
ae28cc1
Change to inverse RFFT based implementation
matt-graham Apr 30, 2026
f18db35
Correct CC weights expression for even L
matt-graham Apr 30, 2026
2e47ee4
Add Fejer-2 weights and refactor CC implementation to reuse
matt-graham May 19, 2026
efa2fb4
Add dimensions for CC and F2 sampling grids
matt-graham May 19, 2026
8f2c2d9
Add CC and F2 sampling methods to numpy quadrature wrapper
matt-graham May 19, 2026
2311a4f
Fix F2 theta positions
matt-graham May 19, 2026
8e7259a
Separate out CC and F2 phi grid sizes
matt-graham May 19, 2026
3462e95
Add tests for quadrature rules on known 1D integrals
matt-graham May 20, 2026
00baa27
Link quadrature weight and ntheta defs for CC and F2 rules
matt-graham May 21, 2026
21ed461
Simplify index to phi computation using nphi def
matt-graham May 21, 2026
9e1f84c
Increase number of theta samples for F2 rule by one
matt-graham May 21, 2026
f5e7278
Factor out m offset 1 and pole singularity sets
matt-graham May 21, 2026
63790b8
Factor out Gl quadrature theta only to match other rules
matt-graham May 21, 2026
a0e0d72
Match signature of DH theta only quadrature weight function to other …
matt-graham May 21, 2026
37bac50
Use nphi_equiang function uniformly in quadrature weight implementations
matt-graham May 21, 2026
a0c2891
Refactor index to theta to use ntheta
matt-graham May 21, 2026
530babe
Double CC and F2 scheme number of theta samples
matt-graham May 21, 2026
1ac5fc3
Include CC and F2 in m offset 1 schemes
matt-graham May 21, 2026
aa09f7c
Correct bug in determining scheme for dealing with pole singularity
matt-graham May 21, 2026
9366a6c
Correct bug in using MW rather than MWSS to compute nphi
matt-graham May 21, 2026
417b374
Generalize additional implicit special casing for handling pole singu…
matt-graham May 21, 2026
0daaeaa
Add CC and F2 to sample scheme lists in docstrings
matt-graham May 21, 2026
7f64fd7
Include CC and F2 schemes in spherical transform tests
matt-graham May 21, 2026
f801d6c
Move array-api-extra to package rather than build dependencies
matt-graham May 21, 2026
ed9dcc5
Remove CC + F2 from tests using SSHT
matt-graham May 21, 2026
176b82f
Further removing of CC + F2 from tests using SSHT
matt-graham May 21, 2026
832476b
Add equiangular sampling scheme set
matt-graham May 22, 2026
304bee5
Adjust tolerances for spherical precompute tests
matt-graham May 22, 2026
3756e22
Handle CC and F2 schemes in precompute kernel construction
matt-graham May 22, 2026
0be15f1
Generalize handling of poles in JAX precompute function
matt-graham May 22, 2026
bff50d3
Account for L=1 case for CC quadrature weights
matt-graham May 22, 2026
e0c447d
Add additional quadrature weights tests
matt-graham May 22, 2026
4fbdded
Remove unneeded special case for CC rule with n_theta = 2
matt-graham May 22, 2026
f6b4145
Give quadrature test using transform more descriptive name
matt-graham May 22, 2026
71de50f
Add further tests for quadrature rule exceptions
matt-graham May 22, 2026
e2aea88
Add tests for HEALPix quadrature weights
matt-graham May 22, 2026
b3efcf7
Add tests for Wigner kernel exceptions and n sample function
matt-graham May 22, 2026
9049377
Relax test condition on number of samples to non-negativity
matt-graham May 22, 2026
53573ed
Add CC and F2 sampling scheme details to docs
matt-graham May 22, 2026
3447e48
Minor fixes to sampling scheme doc page formatting
matt-graham May 22, 2026
87fb759
Merge branch 'main' into mmg/clenshaw-curtis-quadrature
matt-graham Jun 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api/utility/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ Utility Functions
- Compute MW quadrature weights for :math:`\theta` and :math:`\phi` integration.
* - :func:`~s2fft.utils.quadrature.quad_weights_mwss`
- Compute MWSS quadrature weights for :math:`\theta` and :math:`\phi` integration.
* - :func:`~s2fft.utils.quadrature.quad_weight_dh_theta_only`
- Compute DH quadrature weight for :math:`\theta` integration (only), for given :math:`\theta`.
* - :func:`~s2fft.utils.quadrature.quad_weights_dh_theta_only`
- Compute DH quadrature weight for :math:`\theta` integration (only).
* - :func:`~s2fft.utils.quadrature.quad_weights_mw_theta_only`
- Compute MW quadrature weights for :math:`\theta` integration (only).
* - :func:`~s2fft.utils.quadrature.quad_weights_mwss_theta_only`
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ description = "Differentiable and accelerated spherical transforms with JAX"
dependencies = [
"numpy>=1.20",
"jax>=0.5.0",
"array-api-extra >= 0.10.1",
]
dynamic = [
"version",
Expand Down
32 changes: 16 additions & 16 deletions s2fft/base_transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def inverse(
spin (int, optional): Harmonic spin. Defaults to 0.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int, optional): HEALPix Nside resolution parameter. Only required
if sampling="healpix". Defaults to None.
Expand Down Expand Up @@ -80,7 +80,7 @@ def _inverse(
spin (int, optional): Harmonic spin. Defaults to 0.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

method (str, optional): Harmonic transform algorithm. Supported algorithms include
{"direct", "sov", "sov_fft", "sov_fft_vectorized"}. Defaults to
Expand Down Expand Up @@ -154,7 +154,7 @@ def forward(
spin (int, optional): Harmonic spin. Defaults to 0.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int, optional): HEALPix Nside resolution parameter. Only required
if sampling="healpix". Defaults to None.
Expand Down Expand Up @@ -217,7 +217,7 @@ def _forward(
spin (int, optional): Harmonic spin. Defaults to 0.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

method (str, optional): Harmonic transform algorithm. Supported algorithms include
{"direct", "sov", "sov_fft", "sov_fft_vectorized"}. Defaults to
Expand Down Expand Up @@ -304,7 +304,7 @@ def _compute_inverse_direct(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand Down Expand Up @@ -391,7 +391,7 @@ def _compute_inverse_sov(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand Down Expand Up @@ -465,7 +465,7 @@ def _compute_inverse_sov_fft(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand All @@ -486,7 +486,7 @@ def _compute_inverse_sov_fft(
assert L >= 2 * nside

ftm = np.zeros(samples.ftm_shape(L, sampling, nside), dtype=np.complex128)
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
m_offset = 1 if sampling in samples.M_OFFSET_1_SCHEMES else 0

for t, theta in enumerate(thetas):
phi_ring_offset = (
Expand Down Expand Up @@ -558,7 +558,7 @@ def _compute_inverse_sov_fft_vectorized(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand All @@ -576,7 +576,7 @@ def _compute_inverse_sov_fft_vectorized(

"""
ftm = np.zeros(samples.ftm_shape(L, sampling, nside), dtype=np.complex128)
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
m_offset = 1 if sampling in samples.M_OFFSET_1_SCHEMES else 0

for t, theta in enumerate(thetas):
phase_shift = (
Expand Down Expand Up @@ -634,7 +634,7 @@ def _compute_forward_direct(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand Down Expand Up @@ -726,7 +726,7 @@ def _compute_forward_sov(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand Down Expand Up @@ -822,7 +822,7 @@ def _compute_forward_sov_fft(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand All @@ -844,7 +844,7 @@ def _compute_forward_sov_fft(
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
ftm = np.zeros_like(f).astype(np.complex128)

m_offset = 1 if sampling in ["mwss", "healpix"] else 0
m_offset = 1 if sampling in samples.M_OFFSET_1_SCHEMES else 0

if sampling.lower() == "healpix":
ftm = hp.healpix_fft(f, L, nside, "numpy", reality)
Expand Down Expand Up @@ -938,7 +938,7 @@ def _compute_forward_sov_fft_vectorized(
spin (int): Harmonic spin.

sampling (str): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "healpix"}.
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}.

thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere.

Expand All @@ -960,7 +960,7 @@ def _compute_forward_sov_fft_vectorized(
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
ftm = np.zeros_like(f).astype(np.complex128)

m_offset = 1 if sampling in ["mwss", "healpix"] else 0
m_offset = 1 if sampling in samples.M_OFFSET_1_SCHEMES else 0
if reality:
m_conj = (-1) ** (np.arange(1, L) % 2)

Expand Down
88 changes: 45 additions & 43 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@
PM_MAX_STABLE_SPIN = 6


def _n_sample_wigner_fourier_inverse_fft(sampling: str, n_theta: int) -> int:
"""
Number of samples for inverse FFT over Wigner Fourier coefficients.

Args:
sampling: String specifier of sampling scheme.
n_theta: Number of (co)latitude samples in sampling scheme.

Returns:
Number of samples.

"""
if sampling == "mw":
return 2 * n_theta - 1
elif sampling == "mwss":
return 2 * n_theta - 2
elif sampling == "dh":
return 2 * n_theta
elif sampling == "cc":
return 2 * n_theta - 2
elif sampling == "f2":
return 2 * n_theta + 2
Comment on lines +34 to +37

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jasonmcewen I arrived at these by trial and error as I could not see the pattern in how the values for the other schemes was derived - if there some underlying relationship here it would be good to document. It might also be worth moving this to one of modules under s2fft.sampling

else:
raise ValueError(f"Equiangular sampling scheme {sampling} not recognised")


def spin_spherical_kernel(
L: int,
spin: int = 0,
Expand All @@ -39,7 +65,7 @@ def spin_spherical_kernel(
Defaults to False.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".
Expand Down Expand Up @@ -110,18 +136,12 @@ def spin_spherical_kernel(
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
dl[:, el] = delta[:, m_start_ind:, L - 1 - spin]

# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
# MW, MWSS, DH, CC & F2 sampling ARE uniform in theta therefore CAN be calculated
# using the Fourier decomposition of Wigner d-functions.
# - The complexity of this approach is O(L^3LogL).
# - This approach is stable for arbitrary abs(spins) <= L.
if sampling.lower() in ["mw", "mwss", "dh"]:
# Number of samples for inverse FFT over Wigner Fourier coefficients.
if sampling.lower() == "mw":
nsamps = 2 * len(thetas) - 1
elif sampling.lower() == "mwss":
nsamps = 2 * len(thetas) - 2
elif sampling.lower() == "dh":
nsamps = 2 * len(thetas)
if sampling.lower() in samples.EQUIANGULAR_SCHEMES:
nsamps = _n_sample_wigner_fourier_inverse_fft(sampling.lower(), len(thetas))
delta = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)

# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
Expand Down Expand Up @@ -187,7 +207,7 @@ def spin_spherical_kernel_jax(
Defaults to False.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".
Expand Down Expand Up @@ -243,12 +263,12 @@ def spin_spherical_kernel_jax(
dl = jnp.swapaxes(dl, 0, 1)

# North pole singularity
if sampling.lower() == "mwss":
if sampling.lower() in samples.INCLUDES_NORTH_POLE_SCHEMES:
dl = dl.at[0].set(0)
dl = dl.at[0, :, L - 1 - spin].set(1)

# South pole singularity
if sampling.lower() in ["mw", "mwss"]:
if sampling.lower() in samples.INCLUDES_SOUTH_POLE_SCHEMES:
dl = dl.at[-1].set(0)
dl = dl.at[-1, :, L - 1 + spin].set((-1) ** (jnp.arange(L) - spin))
dl = dl.at[:, : jnp.abs(spin)].multiply(0)
Expand Down Expand Up @@ -277,14 +297,8 @@ def spin_spherical_kernel_jax(
# using the Fourier decomposition of Wigner d-functions.
# - The complexity of this approach is O(L^3LogL).
# - This approach is stable for arbitrary abs(spins) <= L.
elif sampling.lower() in ["mw", "mwss", "dh"]:
# Number of samples for inverse FFT over Wigner Fourier coefficients.
if sampling.lower() == "mw":
nsamps = 2 * len(thetas) - 1
elif sampling.lower() == "mwss":
nsamps = 2 * len(thetas) - 2
elif sampling.lower() == "dh":
nsamps = 2 * len(thetas)
elif sampling.lower() in samples.EQUIANGULAR_SCHEMES:
nsamps = _n_sample_wigner_fourier_inverse_fft(sampling.lower(), len(thetas))
delta = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)

# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
Expand Down Expand Up @@ -355,7 +369,7 @@ def wigner_kernel(
Defaults to False.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".
Expand All @@ -371,15 +385,15 @@ def wigner_kernel(
np.ndarray: Transform kernel for Wigner transform.

"""
if mode.lower() == "fft" and sampling.lower() not in ["mw", "mwss", "dh"]:
if mode.lower() == "fft" and sampling.lower() not in samples.EQUIANGULAR_SCHEMES:
raise ValueError(
f"Fourier based recursion is not valid for {sampling} sampling."
)
# Determine operational mode automatically.
# - Can only use the FFT approach when uniformly sampling in theta.
# - FFT approach is only more efficient when N <= L/Log(L) roughly.
if mode.lower() == "auto":
if sampling.lower() in ["mw", "mwss", "dh"]:
if sampling.lower() in samples.EQUIANGULAR_SCHEMES:
mode = "fft" if N <= int(L / np.log(L)) else "direct"
else:
mode = "direct"
Expand Down Expand Up @@ -410,20 +424,14 @@ def wigner_kernel(
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n]

# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
# MW, MWSS, DH, CC & F2 sampling ARE uniform in theta therefore CAN be calculated
# using the Fourier decomposition of Wigner d-functions.
# - The complexity of this approach is O(NL^3LogL).
# - This approach is stable for arbitrary abs(spins) <= L.
# Therefore when NL^3LogL <= L^4 i.e. when N <= L/LogL, the Fourier based approach
# is more efficient. This can be a large difference for large L >> N.
elif mode.lower() == "fft":
# Number of samples for inverse FFT over Wigner Fourier coefficients.
if sampling.lower() == "mw":
nsamps = 2 * len(thetas) - 1
elif sampling.lower() == "mwss":
nsamps = 2 * len(thetas) - 2
elif sampling.lower() == "dh":
nsamps = 2 * len(thetas)
nsamps = _n_sample_wigner_fourier_inverse_fft(sampling.lower(), len(thetas))
delta = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)

# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
Expand Down Expand Up @@ -494,7 +502,7 @@ def wigner_kernel_jax(
Defaults to False.

sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
{"mw", "mwss", "dh", "gl", "healpix", "cc", "f2"}. Defaults to "mw".

nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".
Expand All @@ -510,15 +518,15 @@ def wigner_kernel_jax(
jnp.ndarray: Transform kernel for Wigner transform.

"""
if mode.lower() == "fft" and sampling.lower() not in ["mw", "mwss", "dh"]:
if mode.lower() == "fft" and sampling.lower() not in samples.EQUIANGULAR_SCHEMES:
raise ValueError(
f"Fourier based recursion is not valid for {sampling} sampling."
)
# Determine operational mode automatically.
# - Can only use the FFT approach when uniformly sampling in theta.
# - FFT approach is only more efficient when N <= L/Log(L) roughly.
if mode.lower() == "auto":
if sampling.lower() in ["mw", "mwss", "dh"]:
if sampling.lower() in samples.EQUIANGULAR_SCHEMES:
mode = "fft" if N <= int(L / np.log(L)) else "direct"
else:
mode = "direct"
Expand Down Expand Up @@ -550,20 +558,14 @@ def wigner_kernel_jax(
delta = vfunc(delta, thetas, L, el)
dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n])

# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
# MW, MWSS, DH, CC & F2 sampling ARE uniform in theta therefore CAN be calculated
# using the Fourier decomposition of Wigner d-functions.
# - The complexity of this approach is O(NL^3LogL).
# - This approach is stable for arbitrary abs(spins) <= L.
# Therefore when NL^3LogL <= L^4 i.e. when N <= L/LogL, the Fourier based approach
# is more efficient. This can be a large difference for large L >> N.
elif mode.lower() == "fft":
# Number of samples for inverse FFT over Wigner Fourier coefficients.
if sampling.lower() == "mw":
nsamps = 2 * len(thetas) - 1
elif sampling.lower() == "mwss":
nsamps = 2 * len(thetas) - 2
elif sampling.lower() == "dh":
nsamps = 2 * len(thetas)
nsamps = _n_sample_wigner_fourier_inverse_fft(sampling.lower(), len(thetas))
delta = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)

# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
Expand Down
Loading
Loading