Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d7a4b56
fix: make stepk match galsim
beckermr Apr 23, 2026
8e1c000
test: remove regression test
beckermr Apr 23, 2026
11586e7
fix: fudge maxk just a bit bigger for now
beckermr Apr 23, 2026
6fd5ac6
fix: remove commented code
beckermr Apr 23, 2026
723a6fe
fix: use smallest dimension
beckermr Apr 23, 2026
22855a3
fix: lower fudge for maxk
beckermr Apr 23, 2026
dff7164
doc: get comment right
beckermr Apr 23, 2026
1d8aee0
fix: pad out maxk by some number of pixels
beckermr Apr 24, 2026
1832461
test: some test changes sicne there are more bugs here
beckermr Apr 24, 2026
ae2a60d
fix: adjust tests
beckermr Apr 24, 2026
02740b5
test: ensure jax-galsim is more conservative
beckermr Apr 24, 2026
15eac76
Update tests/jax/test_interpolatedimage_utils.py
beckermr Apr 24, 2026
26d1209
fix: address code review
beckermr Apr 24, 2026
62cebd6
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
5a4050c
style: pre-commit
beckermr Apr 24, 2026
d556039
fix: stupid copy paste bug
beckermr Apr 24, 2026
513a250
fix: clean out unused code
beckermr Apr 24, 2026
930e189
fix: clean out unused code
beckermr Apr 24, 2026
5fa99dc
test: add explicit test for behavior of force stepk and maxk
beckermr Apr 24, 2026
323c12a
Update jax_galsim/interpolatedimage.py
beckermr Apr 24, 2026
b8a319b
doc: notes on purpose of separate ii class
beckermr Apr 24, 2026
75cf5a3
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
2f1bdb4
style: pre-the-commit
beckermr Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _stepk(self):
if self._jax_aux_data["_force_stepk"] > 0:
return self._jax_aux_data["_force_stepk"]
else:
return super()._stepk
return self._original.stepk

@property
@implements(_galsim.interpolatedimage.InterpolatedImage.x_interpolant)
Expand Down Expand Up @@ -728,7 +728,7 @@ def _getSimpleStepK(self, R):
# Add xInterp range in quadrature just like convolution:
R2 = self._x_interpolant.xrange
R = jnp.hypot(R, R2)
stepk = jnp.pi / R
stepk = jnp.pi / (R * self._wcs._minScale())
return stepk

def _getMaxK(self, calculate_maxk):
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def _flux_frac(a, x, y, cenx, ceny):
dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1))
d = jnp.arange(a.shape[0])
d = jnp.reshape(d, (1, 1, -1))
msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d)
msk = (jnp.abs(dx) <= d) & (jnp.abs(dy) <= d)
Comment thread
beckermr marked this conversation as resolved.
res = jnp.sum(
jnp.where(
msk,
Expand All @@ -1184,7 +1184,6 @@ def _flux_frac(a, x, y, cenx, ceny):
),
axis=(0, 1),
)
res = jnp.where(res > 0, res, -jnp.inf)
return res


Expand All @@ -1193,28 +1192,20 @@ def _calculate_size_containing_flux(image, thresh):
cenx, ceny = image.center.x, image.center.y
x, y = image.get_pixel_centers()
fluxes = _flux_frac(image.array, x, y, cenx, ceny)
msk = fluxes >= -jnp.inf
fluxes = jnp.where(msk, fluxes, jnp.max(fluxes))
# msk = fluxes >= -jnp.inf
# fluxes = jnp.where(msk, fluxes, jnp.max(fluxes))
d = jnp.arange(image.array.shape[0]) + 1.0
Comment thread
beckermr marked this conversation as resolved.
Outdated
# below we use a linear interpolation table to find the maximum size
# in pixels that contains a given flux (called thresh here)
# expfac controls how much we oversample the interpolation table
# in order to return a more accurate result
# we have it hard coded at 4 to compromise between speed and accuracy
expfac = 4.0
dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0
fluxes = jnp.interp(dint, d, fluxes)
msk = fluxes <= thresh
p = jnp.sign(thresh)
msk = (p * fluxes) >= (p * thresh)
return (
jnp.argmax(
jnp.argmin(
jnp.where(
msk,
dint,
-jnp.inf,
d,
jnp.inf,
)
)
/ expfac
+ 1.0
+ 0.5
)


Expand Down
151 changes: 151 additions & 0 deletions tests/jax/test_interpolatedimage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,154 @@ def test_interpolatedimage_flux_frac():
rtol=0,
atol=1e-6,
)


@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"])
@pytest.mark.parametrize("normalization", ["sb", "flux"])
@pytest.mark.parametrize("use_true_center", [True, False])
@pytest.mark.parametrize(
"wcs",
[
_galsim.PixelScale(0.2),
_galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23),
_galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)),
],
)
@pytest.mark.parametrize(
"offset_x",
[
-4.35,
-0.45,
0.0,
0.67,
3.78,
],
)
@pytest.mark.parametrize(
"offset_y",
[
-2.12,
-0.33,
0.0,
0.12,
1.45,
],
)
@pytest.mark.parametrize(
"ref_array",
[
_galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array,
_galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array,
],
)
def test_interpolatedimage_utils_comp_stepk_to_galsim(
ref_array,
offset_x,
offset_y,
wcs,
use_true_center,
normalization,
x_interp,
):
seed = max(
abs(
int(
hashlib.sha1(
f"{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode(
"utf-8"
)
).hexdigest(),
16,
)
)
% (10**7),
1,
)

rng = np.random.RandomState(seed=seed)
if rng.uniform() < 0.75:
pytest.skip(
"Skipping `test_interpolatedimage_utils_comp_stepk_to_galsim` case at random to save time."
)

nse = rng.uniform(size=ref_array.shape) * ref_array.max() * 0.05

gimage_in = _galsim.Image(ref_array + nse, scale=0.2)
jgimage_in = jax_galsim.Image(ref_array + nse, scale=0.2)

np.testing.assert_allclose(gimage_in.center.x, jgimage_in.center.x)
np.testing.assert_allclose(gimage_in.center.y, jgimage_in.center.y)

gii = _galsim.InterpolatedImage(
gimage_in,
wcs=wcs,
offset=_galsim.PositionD(offset_x, offset_y),
use_true_center=use_true_center,
normalization=normalization,
x_interpolant=x_interp,
flux=20,
)
jgii = jax_galsim.InterpolatedImage(
jgimage_in,
wcs=jax_galsim.BaseWCS.from_galsim(wcs),
offset=jax_galsim.PositionD(offset_x, offset_y),
use_true_center=use_true_center,
normalization=normalization,
x_interpolant=x_interp,
flux=20,
)

gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux
gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh)

from jax_galsim.interpolatedimage import _calculate_size_containing_flux

jgthresh = (
1.0 - jgii._original.gsparams.folding_threshold
) * jgii._original._image_flux
jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh)

lgR = _galsim_stepk_loop(gii._image, gthresh)
ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh)

np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x)
np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y)
np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0))
np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum())
np.testing.assert_allclose(gthresh, jgthresh, rtol=0, atol=1e-6)
np.testing.assert_allclose(gR, jgR, rtol=0, atol=1e-6)
np.testing.assert_allclose(gR, ljgR, rtol=0, atol=1e-6)
np.testing.assert_allclose(lgR, jgR, rtol=0, atol=1e-6)

np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0, atol=1e-6)


# this is a copy of the galsim C++ algorithm in a pure python
# loop to help with debugging and testing
def _galsim_stepk_loop(im, target_flux):
if target_flux > 0:
p = 1.0
else:
p = -1.0

b = im.bounds
dmax = int(min((b.getXMax() - b.getXMin()) / 2, (b.getYMax() - b.getYMin()) / 2))

flux = im(0, 0)
d = 1
while d <= dmax:
# Add the left, right, top and bottom sides of box
for x in range(-d, d):
# Note: All 4 corners are added exactly once by including x=-d but omitting
# x=d from the loop.
flux += im(x, -d) # bottom
flux += im(d, x) # right
flux += im(-x, d) # top
flux += im(-d, -x) # left

if p * flux >= p * target_flux:
break

d += 1

return d + 0.5
Loading