Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b30096f
Add mokey-patch to get the bandpass option to SED.thin from GalSim PR…
rmjarvis Mar 27, 2026
b237700
Add input.bandpass option and store in PSF class with wcs and pointing.
rmjarvis Mar 30, 2026
82267c8
Add PSF.set_context as a cleaner pattern for setting wcs, pointing, b…
rmjarvis Mar 30, 2026
c387637
Store seds as unit-flux normalized, which is normally what we'll want.
rmjarvis Mar 30, 2026
7fcf9eb
Add sed_tol option
rmjarvis Mar 30, 2026
7bbdcca
Have Roman PSF (normally) get filter_name from input bandpass.
rmjarvis Mar 31, 2026
fdb9f12
Construct effective sed = sed*bandpass to thin and store
rmjarvis Mar 31, 2026
a1ec916
Add a real test of the accuracy of chromatic RomanPSF fitting
rmjarvis Mar 31, 2026
933e6ce
Need to use Add, not Sum when components can be chromatic
rmjarvis Mar 31, 2026
b60a3d6
Make reflux work correctly for chromatic objects
rmjarvis Apr 1, 2026
990b44a
Use a picklable flat_bandpass
rmjarvis Apr 1, 2026
022545e
Update piff.yaml for dev run with separate files for chrom/achrom
rmjarvis Apr 1, 2026
94f2299
Remove galsim_patch.py, since no longer necessary after 3be35999
rmjarvis Apr 1, 2026
56e2ddf
Both OpticalModel and RomanOpticalModel should have _centered=True
rmjarvis Apr 1, 2026
30815fb
Add test of chromatic psf using reflux
rmjarvis Apr 1, 2026
766f158
Add sed_max_samples option
rmjarvis Apr 2, 2026
312cb3e
this is the best chromatic run so far
rmjarvis Apr 3, 2026
921ba6f
typos
rmjarvis Apr 3, 2026
fb6e1d2
coverage
rmjarvis Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions piff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
# Also let piff.version show the version.
version = __version__

from . import galsim_patch

# We don't have any C functions, but once we do, I recommend using cffi to
# wrap them. This is the entire code we need to get C functions into Python.
if False:
Expand Down
6 changes: 3 additions & 3 deletions piff/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def process(config, logger=None):
config['select'][key] = config['input'].pop(key)

# read in the input images
objects, wcs, pointing = Input.process(config['input'], logger=logger)
objects, wcs, pointing, bandpass = Input.process(config['input'], logger=logger)
stars = Select.process(config.get('select',{}), objects, logger=logger)

psf = PSF.process(config['psf'], logger=logger)
psf.fit(stars, wcs, pointing, logger=logger)
psf = PSF.process(config['psf'], wcs, pointing, bandpass, logger=logger)
psf.fit(stars, logger=logger)

# Attach these for reference
psf.initial_objects = objects
Expand Down
92 changes: 92 additions & 0 deletions piff/galsim_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
import galsim
from astropy import units
from galsim import utilities
from galsim.table import LookupTable, _LookupTable
from galsim.sed import SED
from galsim.errors import GalSimError

# Temporary compatibility patch for GalSim < 2.9.
# Remove this module once Piff can depend on GalSim 2.9+, which includes thin(bandpass=...).

def _piff_sed_thin(self, rel_err=1.e-4, trim_zeros=True, preserve_range=True, fast_search=True,
bandpass=None): # pragma: no cover
def _bandpass_native_waves(waves):
# The thinning algorithm runs on the SED's native tabulation grid, but when a
# bandpass target is provided we need to evaluate the throughput at the same physical
# observed wavelengths. This helper converts SED-native wavelengths into the native
# wavelength units used by bandpass._tp.
if self.wave_factor:
observed_nm = waves / self.wave_factor
else:
observed_nm = (waves * self.wave_type).to(units.nm, units.spectral()).value
observed_nm *= (1.0 + self.redshift)
if bandpass.wave_factor:
return observed_nm * bandpass.wave_factor
else:
return (observed_nm * units.nm).to(bandpass.wave_type, units.spectral()).value

if bandpass is not None:
if self.blue_limit > bandpass.red_limit or self.red_limit < bandpass.blue_limit:
raise GalSimError("Bandpass does not overlap the SED wavelength range.")
wave_list, _, _ = utilities.combine_wave_list(self, bandpass)
if preserve_range and not trim_zeros:
# If we want to preserve the range, add back the limits to each end.
front = [self.blue_limit] if self.blue_limit < wave_list[0] else []
back = [self.red_limit] if self.red_limit > wave_list[-1] else []
if front or back:
wave_list = np.concatenate((front, wave_list, back))
else:
wave_list = self.wave_list

if len(wave_list) > 0:
rest_wave_native = self._get_rest_native_waves(wave_list)
spec_native = self._spec(rest_wave_native)

if bandpass is not None:
# Identify the overlapping region in the SED's native wavelength units to
# determine which portion of the wave_list is within the bandpass limits.
band_native_limits = self._get_rest_native_waves(
np.array([bandpass.blue_limit, bandpass.red_limit])
)
native_blue_limit = np.min(band_native_limits)
native_red_limit = np.max(band_native_limits)
in_band = np.logical_and(
rest_wave_native >= native_blue_limit,
rest_wave_native <= native_red_limit,
)

# Compute the product SED * bandpass, since that is the quantity whose integral we
# want to preserve for this observation.
tp_native = np.zeros_like(rest_wave_native, dtype=float)
bp_wave_native = _bandpass_native_waves(rest_wave_native[in_band])
tp_native[in_band] = bandpass._tp(bp_wave_native) / bandpass.wave_factor
spec_native *= tp_native

# Note that this is thinning in native units, not nm and photons/nm.
interpolant = (self.interpolant if not isinstance(self._spec, LookupTable)
else self._spec.interpolant)
newx, newf = utilities.thin_tabulated_values(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be renamed as thinned_lambda thinned_flux but if this module is going to be removed anyway please ignore this comment

rest_wave_native, spec_native, rel_err=rel_err,
trim_zeros=trim_zeros, preserve_range=preserve_range,
fast_search=fast_search, interpolant=interpolant)

if bandpass is not None:
# Convert the thinned product back into an SED by dividing out the bandpass
# wherever the throughput is non-zero.
in_band = np.logical_and(newx >= native_blue_limit, newx <= native_red_limit)
tp_native = np.zeros_like(newx, dtype=float)
bp_wave_native = _bandpass_native_waves(newx[in_band])
tp_native[in_band] = bandpass._tp(bp_wave_native) / bandpass.wave_factor
nz = tp_native != 0. # Don't divide by 0.
assert np.all(newf[~nz] == 0.)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an temporary compatibility patch, but for clarity an error message can be added something like "thinned sed has non-zero values where throughput is zero"

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments on this file should probably be migrated to GalSim-developers/GalSim#1355.
(I no longer have this patch in the PR.) But also, I don't think it's possible for this assert to trigger. That's why it's an assert, not a normal error.

newf[nz] /= tp_native[nz]

newspec = _LookupTable(newx, newf, interpolant=interpolant)
return SED(newspec, self.wave_type, self.flux_type, redshift=self.redshift,
fast=self.fast)
else:
return self


galsim.SED.thin = _piff_sed_thin
42 changes: 37 additions & 5 deletions piff/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def process(cls, config_input, logger=None):
:param config_input: The configuration dict.
:param logger: A logger object for logging debug info. [default: None]

:returns: stars, wcs, pointing
:returns: stars, wcs, pointing, bandpass

stars is a list of Star instances with the initial data.
wcs is a dict of WCS solutions indexed by chipnum.
pointing is either a galsim.CelestialCoord or None.
bandpass is either a galsim.Bandpass or None.
"""
# Get the class to use for handling the input data
# Default type is 'Files'
Expand All @@ -78,7 +79,10 @@ def process(cls, config_input, logger=None):
# Get the pointing (the coordinate center of the field of view)
pointing = input_handler.getPointing(logger)

return stars, wcs, pointing
# Get the observing bandpass, if any
bandpass = input_handler.getBandpass(logger)

return stars, wcs, pointing, bandpass

@classmethod
def __init_subclass__(cls):
Expand Down Expand Up @@ -148,6 +152,15 @@ def getPointing(self, logger=None):
"""
return self.pointing

def getBandpass(self, logger=None):
"""Get the observing bandpass for the input images.

:param logger: A logger object for logging debug info. [default: None]

:returns: a galsim.Bandpass or None.
"""
return None


class InputFiles(Input):
"""An Input handler that just takes a list of image files and catalog files.
Expand Down Expand Up @@ -258,6 +271,10 @@ class InputFiles(Input):
SEDs from FITS tables. [default: ``WAVE``]
:sed_flux_key: FITS-table column name to use for flux when loading SEDs
from FITS tables. [default: ``FLUX``]
:bandpass: GalSim bandpass config dict describing the observing bandpass for the
input image(s). Required when using ``sed_col`` or ``sed_file_name``.
SED thinning will target this bandpass via
``sed.thin(..., bandpass=...)``. [default: None]

.. note::

Expand Down Expand Up @@ -435,7 +452,7 @@ def __init__(self, config, logger=None):
'noise' : str,
'nstars' : int,
}
ignore = [ 'nproc', 'nimages', 'ra', 'dec', 'wcs' ]
ignore = [ 'nproc', 'nimages', 'ra', 'dec', 'wcs', 'bandpass' ]

# We're going to change the config dict a bit. Make a copy so we don't mess up the
# user's original dict (in case they care).
Expand Down Expand Up @@ -675,6 +692,12 @@ def __init__(self, config, logger=None):
sed_flux_type = params.get('sed_flux_type', None)
sed_wave_key = params.get('sed_wave_key', 'WAVE')
sed_flux_key = params.get('sed_flux_key', 'FLUX')
if (sed_col is not None or sed_file_name is not None) and 'bandpass' not in config:
raise ValueError("bandpass is required when using sed_col or sed_file_name")
if 'bandpass' in config:
bandpass = galsim.config.BuildBandpass(config, 'bandpass', base, logger)[0]
else:
bandpass = None
sky_col = params.get('sky_col', None)
gain_col = params.get('gain_col', None)
gain = params.get('gain', None)
Expand Down Expand Up @@ -709,6 +732,7 @@ def __init__(self, config, logger=None):
'sed_flux_type': sed_flux_type,
'sed_wave_key': sed_wave_key,
'sed_flux_key': sed_flux_key,
'bandpass': bandpass,
'sky_col' : sky_col,
'gain_col' : gain_col,
'sky' : sky,
Expand All @@ -719,6 +743,7 @@ def __init__(self, config, logger=None):
'stamp_size' : self.stamp_size})

self.use_partial = config.get('use_partial', False)
self.bandpass = bandpass

# Read all the wcs's, since we'll need this for the pointing, which in turn we'll
# need for when we make the stars.
Expand All @@ -730,6 +755,9 @@ def __init__(self, config, logger=None):
self.setPointing(ra, dec, logger)
self.config = galsim.config.CleanConfig(config)

def getBandpass(self, logger=None):
return self.bandpass

def load_images(self, stars, logger=None):
"""Load the image data into a list of Stars.

Expand Down Expand Up @@ -1126,7 +1154,7 @@ def _parse_flux_type(flux_type):

@staticmethod
def _read_sed_file(sed_file_name, sed_wave_type, sed_flux_type,
sed_wave_key, sed_flux_key):
sed_wave_key, sed_flux_key, bandpass):
cache_key = (sed_file_name, sed_wave_type, sed_flux_type, sed_wave_key, sed_flux_key)
if cache_key in InputFiles._sed_cache:
return InputFiles._sed_cache[cache_key]
Expand Down Expand Up @@ -1187,6 +1215,7 @@ def readStarCatalog(cat_file_name, cat_hdu, x_col, y_col,
ra_col, dec_col, ra_units, dec_units, image,
flag_col, skip_flag, use_flag, property_cols, sed_col, sed_wave_type,
sed_file_name, sed_flux_type, sed_wave_key, sed_flux_key,
bandpass,
properties, image_num, sky_col, gain_col, sky, gain, satur,
trust_pos, nstars, stamp_size, config, logger):
"""Read in the star catalogs and return lists of positions for each star in each image.
Expand Down Expand Up @@ -1214,6 +1243,7 @@ def readStarCatalog(cat_file_name, cat_hdu, x_col, y_col,
:param sed_flux_type: Flux type for SED files.
:param sed_wave_key: FITS-table wavelength column name for SED files.
:param sed_flux_key: FITS-table flux column name for SED files.
:param bandpass: galsim.Bandpass for SED thinning and unit-flux normalization.
:param sky_col: A column with sky (background) levels.
:param gain_col: A column with gain values.
:param sky: Either a float value for the sky to use for all objects or a str
Expand Down Expand Up @@ -1369,6 +1399,7 @@ def safe_to_image(wcs, ra, dec):
sed_flux_type=sed_flux_type,
sed_wave_key=sed_wave_key,
sed_flux_key=sed_flux_key,
bandpass=bandpass,
))
extra_props['sed'] = np.array(sed_values, dtype=object)
extra_props['sed_file_name'] = np.array(sed_file_name_values, dtype=object)
Expand All @@ -1381,7 +1412,8 @@ def safe_to_image(wcs, ra, dec):
sed = InputFiles._read_sed_file(
sed_file_name, sed_wave_type=sed_wave_type,
sed_flux_type=sed_flux_type, sed_wave_key=sed_wave_key,
sed_flux_key=sed_flux_key
sed_flux_key=sed_flux_key,
bandpass=bandpass
)
extra_props['sed'] = np.array([sed] * len(cat), dtype=object)
extra_props['sed_file_name'] = np.array([sed_file_name] * len(cat), dtype=object)
Expand Down
44 changes: 34 additions & 10 deletions piff/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ class PSF(object):
# Normally overridden by subclasses. But this gives a reasonable default for
# tests that bypass places where it would be potentially set in normal operations.
degenerate_points = False
wcs = None
pointing = None
bandpass = None

@classmethod
def process(cls, config_psf, logger=None):
def process(cls, config_psf, wcs=None, pointing=None, bandpass=None, logger=None):
"""Process the config dict and return a PSF instance.

As the PSF class is an abstract base class, the returned type will in fact be some
Expand All @@ -56,14 +59,24 @@ def process(cls, config_psf, logger=None):
This function merely creates a "blank" PSF object. It does not actually do any
part of the solution yet. Typically this will be followed by fit:

>>> psf = piff.PSF.process(config['psf'])
>>> stars, wcs, pointing = piff.Input.process(config['input'])
>>> psf.fit(stars, wcs, pointing)
>>> stars, wcs, pointing, bandpass = piff.Input.process(config['input'])
>>> psf = piff.PSF.process(config['psf'], wcs, pointing, bandpass)
>>> psf.fit(stars)

at which point, the ``psf`` instance would have a solution to the PSF model.

.. note::

The preferred pattern now is to provide wcs and pointing here, but these used to
be set when calling fit. The old pattern is still supported, but deprecated.

:param config_psf: A dict specifying what type of PSF to build along with the
appropriate kwargs for building it.
:param wcs: A dict of WCS solutions indexed by chipnum.
:param pointing: A galsim.CelestialCoord object giving the telescope pointing.
[Note: pointing should be None if the WCS is not a CelestialWCS]
:param bandpass: Optional galsim.Bandpass shared by the input data.
[default: None]
:param logger: A logger object for logging debug info. [default: None]

:returns: a PSF instance of the appropriate type.
Expand Down Expand Up @@ -92,6 +105,9 @@ def process(cls, config_psf, logger=None):
# At top level, the num is always None.
# Composite PSF types will turn this into a series of integer values for each component.
psf.set_num(None)
psf.wcs = wcs
psf.pointing = pointing
psf.bandpass = bandpass

return psf

Expand Down Expand Up @@ -373,7 +389,8 @@ def remove_outliers(self, stars, iteration, logger):
nremoved = 0
return stars, nremoved

def fit(self, stars, wcs, pointing, logger=None, convert_funcs=None, draw_method=None):
def fit(self, stars, wcs=None, pointing=None, logger=None,
convert_funcs=None, draw_method=None):
"""Fit interpolated PSF model to star data using standard sequence of operations.

:param stars: A list of Star instances.
Expand All @@ -391,8 +408,10 @@ def fit(self, stars, wcs, pointing, logger=None, convert_funcs=None, draw_method
from .config import LoggerWrapper
logger = LoggerWrapper(logger)

self.wcs = wcs
self.pointing = pointing
if self.wcs is None:
logger.error("WARNING: wcs and pointing should now be given in process, not fit.")
Copy link
Copy Markdown
Collaborator

@HyeongHan HyeongHan Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's warning, shouldn't it be logger.warning?
If it's an error, the message should be ERROR.
self.pointing is None condition also be added?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All such messages in Piff are done through the logger mechanism, not python's warnings module. Given that, logger.error vs logger.warning dictates what verbosity level it shows up. logger.error's are always written, which I think makes sense for deprecation warnings. The user should fix it. Warnings that are merely about something going a little weird in the processing, but might be ok are logger.warning.

self.wcs = wcs
self.pointing = pointing

# Initialize stars as needed by the PSF modeling class.
stars = self.initialize_flux_center(stars, logger=logger)
Expand Down Expand Up @@ -766,9 +785,12 @@ def _write(self, writer, name, logger):
if hasattr(self, 'stars'):
Star.write(self.stars, w, 'stars')
logger.verbose("Wrote the PSF stars to name %s", w.get_full_name('stars'))
if hasattr(self, 'wcs'):
if self.wcs is not None:
w.write_wcs_map('wcs', self.wcs, self.pointing)
logger.verbose("Wrote the PSF WCS to name %s", w.get_full_name('wcs'))
if self.bandpass is not None:
w.write_bandpass('bandpass', self.bandpass)
logger.verbose("Wrote the PSF bandpass to name %s", w.get_full_name('bandpass'))
self._finish_write(w, logger=logger)

@classmethod
Expand Down Expand Up @@ -822,15 +844,17 @@ def _read(cls, reader, name, logger):

with reader.nested(name) as r:
# Read the stars, wcs, pointing values
wcs, pointing = r.read_wcs_map('wcs', logger=logger)
bandpass = r.read_bandpass('bandpass')
stars = Star.read(r, 'stars')
if stars is not None:
logger.debug("stars = %s", stars)
psf.stars = stars
wcs, pointing = r.read_wcs_map('wcs', logger=logger)
if wcs is not None:
logger.debug("wcs = %s, pointing = %s",wcs,pointing)
logger.debug("wcs = %s, pointing = %s, bandpass = %s", wcs, pointing, bandpass)
psf.wcs = wcs
psf.pointing = pointing
psf.bandpass = bandpass

# Just in case the class needs to do something else at the end.
psf._finish_read(r, logger)
Expand Down
Loading