Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d7a4b56
fix: make stepk match galsim
beckermr Apr 23, 2026
8e1c000
test: remove regression test
beckermr Apr 23, 2026
11586e7
fix: fudge maxk just a bit bigger for now
beckermr Apr 23, 2026
6fd5ac6
fix: remove commented code
beckermr Apr 23, 2026
723a6fe
fix: use smallest dimension
beckermr Apr 23, 2026
22855a3
fix: lower fudge for maxk
beckermr Apr 23, 2026
dff7164
doc: get comment right
beckermr Apr 23, 2026
1d8aee0
fix: pad out maxk by some number of pixels
beckermr Apr 24, 2026
1832461
test: some test changes sicne there are more bugs here
beckermr Apr 24, 2026
ae2a60d
fix: adjust tests
beckermr Apr 24, 2026
02740b5
test: ensure jax-galsim is more conservative
beckermr Apr 24, 2026
15eac76
Update tests/jax/test_interpolatedimage_utils.py
beckermr Apr 24, 2026
26d1209
fix: address code review
beckermr Apr 24, 2026
62cebd6
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
5a4050c
style: pre-commit
beckermr Apr 24, 2026
d556039
fix: stupid copy paste bug
beckermr Apr 24, 2026
513a250
fix: clean out unused code
beckermr Apr 24, 2026
930e189
fix: clean out unused code
beckermr Apr 24, 2026
5fa99dc
test: add explicit test for behavior of force stepk and maxk
beckermr Apr 24, 2026
323c12a
Update jax_galsim/interpolatedimage.py
beckermr Apr 24, 2026
b8a319b
doc: notes on purpose of separate ii class
beckermr Apr 24, 2026
75cf5a3
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
2f1bdb4
style: pre-the-commit
beckermr Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def _stepk(self):
if self._jax_aux_data["_force_stepk"] > 0:
return self._jax_aux_data["_force_stepk"]
else:
return super()._stepk
# galsim uses a different way to handle the WCS effects on stepk
# for interpolated images. IDK why. - MRB
# super()._stepk
return self._original.stepk / self._original._wcs._minScale()

@property
@implements(_galsim.interpolatedimage.InterpolatedImage.x_interpolant)
Expand Down Expand Up @@ -403,7 +406,6 @@ def __init__(
# this class does a ton of munging of the inputs that I don't want to reconstruct when
# flattening and unflattening the class.
# thus I am going to make some refs here so we have it when we need it
self._workspace = {}
self._jax_children = (
image,
dict(
Expand Down Expand Up @@ -530,13 +532,10 @@ def tree_unflatten(cls, aux_data, children):
return ret

def __getstate__(self):
d = self.__dict__.copy()
d.pop("_workspace")
return d
return self.__dict__.copy()

def __setstate__(self, d):
self.__dict__ = d
self._workspace = {}

@property
def x_interpolant(self):
Expand Down Expand Up @@ -687,13 +686,15 @@ def _maxk(self):
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
return self._jax_aux_data["_force_maxk"] * minor
else:
return self._getMaxK(self._jax_aux_data["calculate_maxk"])
# the factor of 1.1 here is a fudge to make jax_galsim a bit more
# conservative when computing maxk
return self._getMaxK(self._jax_aux_data["calculate_maxk"]) * 1.1

@property
def _stepk(self):
if self._jax_aux_data["_force_stepk"]:
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
return self._jax_aux_data["_force_stepk"] * minor
major, _ = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
return self._jax_aux_data["_force_stepk"] * major
Comment thread
beckermr marked this conversation as resolved.
Outdated
else:
return self._getStepK(self._jax_aux_data["calculate_stepk"])

Expand Down Expand Up @@ -1175,7 +1176,7 @@ def _flux_frac(a, x, y, cenx, ceny):
dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1))
d = jnp.arange(a.shape[0])
d = jnp.reshape(d, (1, 1, -1))
msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d)
msk = (jnp.abs(dx) <= d) & (jnp.abs(dy) <= d)
Comment thread
beckermr marked this conversation as resolved.
res = jnp.sum(
jnp.where(
msk,
Expand All @@ -1184,7 +1185,6 @@ def _flux_frac(a, x, y, cenx, ceny):
),
axis=(0, 1),
)
res = jnp.where(res > 0, res, -jnp.inf)
return res


Expand All @@ -1193,28 +1193,20 @@ def _calculate_size_containing_flux(image, thresh):
cenx, ceny = image.center.x, image.center.y
x, y = image.get_pixel_centers()
fluxes = _flux_frac(image.array, x, y, cenx, ceny)
msk = fluxes >= -jnp.inf
fluxes = jnp.where(msk, fluxes, jnp.max(fluxes))
# msk = fluxes >= -jnp.inf
# fluxes = jnp.where(msk, fluxes, jnp.max(fluxes))
d = jnp.arange(image.array.shape[0]) + 1.0
Comment thread
beckermr marked this conversation as resolved.
Outdated
# below we use a linear interpolation table to find the maximum size
# in pixels that contains a given flux (called thresh here)
# expfac controls how much we oversample the interpolation table
# in order to return a more accurate result
# we have it hard coded at 4 to compromise between speed and accuracy
expfac = 4.0
dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0
fluxes = jnp.interp(dint, d, fluxes)
msk = fluxes <= thresh
p = jnp.sign(thresh)
msk = (p * fluxes) >= (p * thresh)
return (
jnp.argmax(
jnp.argmin(
jnp.where(
msk,
dint,
-jnp.inf,
d,
jnp.inf,
)
)
/ expfac
+ 1.0
+ 0.5
)


Expand Down
229 changes: 158 additions & 71 deletions tests/jax/test_interpolatedimage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from jax_galsim.interpolatedimage import (
_draw_with_interpolant_kval,
_draw_with_interpolant_xval,
_flux_frac,
)

FRAC_TEST_TO_KEEP = 0.5


@pytest.mark.parametrize(
"interp",
Expand Down Expand Up @@ -208,7 +209,7 @@ def test_interpolatedimage_utils_comp_to_galsim(
)

rng = np.random.RandomState(seed=seed)
if rng.uniform() < 0.75:
if rng.uniform() < FRAC_TEST_TO_KEEP:
pytest.skip(
"Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time."
)
Expand Down Expand Up @@ -359,74 +360,160 @@ def test_interpolatedimage_interpolant_sample(interp):
np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}")


def test_interpolatedimage_flux_frac():
obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2)
img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel")
true_val = [
0.02186161,
0.06551123,
0.10894079,
0.15200604,
0.19456641,
0.23648629,
0.27763629,
0.31789470,
0.35714823,
0.39529300,
0.43223542,
0.46789303,
0.50219434,
0.53507960,
0.56650090,
0.59642231,
0.62481892,
0.65167749,
0.67699528,
0.70077991,
0.72304893,
0.74382806,
0.76315117,
0.78105938,
0.79759991,
0.81282544,
0.82679272,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
0.83956224,
]
@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"])
@pytest.mark.parametrize("normalization", ["sb", "flux"])
@pytest.mark.parametrize("use_true_center", [True, False])
@pytest.mark.parametrize(
"wcs",
[
_galsim.PixelScale(0.2),
_galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23),
_galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)),
],
)
@pytest.mark.parametrize(
"offset_x",
[
-4.35,
-0.45,
0.0,
0.67,
3.78,
],
)
@pytest.mark.parametrize(
"offset_y",
[
-2.12,
-0.33,
0.0,
0.12,
1.45,
],
)
@pytest.mark.parametrize(
"ref_array",
[
_galsim.Gaussian(fwhm=0.9)
.shear(g1=0.3, g2=-0.2)
.drawImage(nx=33, ny=33, scale=0.2)
.array,
_galsim.Gaussian(fwhm=0.9)
.shear(g1=-0.03, g2=0.1)
.drawImage(nx=32, ny=32, scale=0.2)
.array,
],
)
def test_interpolatedimage_utils_comp_stepk_maxk_to_galsim(
ref_array,
offset_x,
offset_y,
wcs,
use_true_center,
normalization,
x_interp,
):
seed = max(
abs(
int(
hashlib.sha1(
f"{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode(
"utf-8"
)
).hexdigest(),
16,
)
)
% (10**7),
1,
)

rng = np.random.RandomState(seed=seed)
if rng.uniform() < FRAC_TEST_TO_KEEP:
pytest.skip(
"Skipping `test_interpolatedimage_utils_comp_stepk_maxk_to_galsim` case at random to save time."
)

nse = rng.uniform(size=ref_array.shape) * ref_array.max() * 0.05

gimage_in = _galsim.Image(ref_array + nse, scale=0.2)
jgimage_in = jax_galsim.Image(ref_array + nse, scale=0.2)

np.testing.assert_allclose(gimage_in.center.x, jgimage_in.center.x)
np.testing.assert_allclose(gimage_in.center.y, jgimage_in.center.y)

x, y = img.get_pixel_centers()
cenx = img.center.x
ceny = img.center.y
val = _flux_frac(img.array, x, y, cenx, ceny)
np.testing.assert_allclose(
val,
true_val,
rtol=0,
atol=1e-6,
gii = _galsim.InterpolatedImage(
gimage_in,
wcs=wcs,
offset=_galsim.PositionD(offset_x, offset_y),
use_true_center=use_true_center,
normalization=normalization,
x_interpolant=x_interp,
flux=20,
)
jgii = jax_galsim.InterpolatedImage(
jgimage_in,
wcs=jax_galsim.BaseWCS.from_galsim(wcs),
offset=jax_galsim.PositionD(offset_x, offset_y),
use_true_center=use_true_center,
normalization=normalization,
x_interpolant=x_interp,
flux=20,
)

gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux
gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh)

from jax_galsim.interpolatedimage import _calculate_size_containing_flux

jgthresh = (
1.0 - jgii._original.gsparams.folding_threshold
) * jgii._original._image_flux
jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh)

lgR = _galsim_stepk_loop(gii._image, gthresh)
ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh)

np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x)
np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y)
np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0))
np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum())
np.testing.assert_allclose(gthresh, jgthresh, rtol=0, atol=1e-6)
np.testing.assert_allclose(gR, jgR, rtol=0, atol=1e-6)
np.testing.assert_allclose(gR, ljgR, rtol=0, atol=1e-6)
np.testing.assert_allclose(lgR, jgR, rtol=0, atol=1e-6)

np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0, atol=1e-6)
# FIXME: make maxk match
np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0)


# this is a copy of the galsim C++ algorithm in a pure python
# loop to help with debugging and testing
def _galsim_stepk_loop(im, target_flux):
if target_flux > 0:
p = 1.0
else:
p = -1.0

b = im.bounds
dmax = int(min((b.getXMax() - b.getXMin()) / 2, (b.getYMax() - b.getYMin()) / 2))

flux = im(0, 0)
d = 1
while d <= dmax:
# Add the left, right, top and bottom sides of box
for x in range(-d, d):
# Note: All 4 corners are added exactly once by including x=-d but omitting
# x=d from the loop.
flux += im(x, -d) # bottom
flux += im(d, x) # right
flux += im(-x, d) # top
flux += im(-d, -x) # left

if p * flux >= p * target_flux:
break

d += 1

return d + 0.5
Loading