Skip to content
1 change: 1 addition & 0 deletions devel/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
failed*.fits
2 changes: 2 additions & 0 deletions devel/roman/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ stars_13906_test_v251105.parquet
Roman_SRR_WFC_Pupil_Mask_Shortwave_2048_reformatted.fits.gz
*.png
*.piff
configs
output
1 change: 1 addition & 0 deletions piff/convolvepsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _finish_write(self, writer, logger):
'last_delta_chisq' : self.last_delta_chisq,
'dof' : self.dof,
'nremoved' : self.nremoved,
'niter' : self.niter,
}
writer.write_struct('chisq', chisq_dict)
logger.debug("Wrote the chisq info to extension %s", writer.get_full_name('chisq'))
Expand Down
2 changes: 2 additions & 0 deletions piff/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,13 @@ class InputFiles(Input):
:sky: The sky level to subtract from the image values. [default: None]
Note: It is an error to specify both sky and sky_col. If both are None,
no sky level will be subtracted off.

.. note::
The special value sky = 'median' means to compute the median of
the image and use that as the sky level. Any other string value
(rather than a float) indicates to get the value from the FITS
header.

:gain: The gain to use for adding Poisson noise to the weight map. [default:
None] It is an error for both gain and gain_col to be specified.
If both are None, then no additional noise will be added to account
Expand Down
4 changes: 3 additions & 1 deletion piff/pixelgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def fit(self, star, logger=None, convert_func=None, draw_method=None):
# covariance of dp is C = (AT A)^-1
# params_var = diag(C)
try:
params_var = np.diagonal(scipy.linalg.inv(star1.fit.A.T.dot(star1.fit.A)))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
params_var = np.diagonal(scipy.linalg.inv(star1.fit.A.T.dot(star1.fit.A)))
except np.linalg.LinAlgError as e:
# If we get an error, set the variance to "infinity".
logger.verbose("Caught error %s making params_var. Setting all to 1.e100",e)
Expand Down
82 changes: 65 additions & 17 deletions piff/roman/roman_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import galsim.roman
import numpy as np
import scipy.linalg
import copy

from ..interp import Interp
from ..model import Model
from ..outliers import Outliers
from ..psf import PSF
from ..star import Star
from ..config import LoggerWrapper
from ..util import run_multi

# Global control of GalSim Roman pupil-plane resolution in getPSF calls.
# Kept module-level so tests can override to faster values (e.g. 8 or 16).
Expand Down Expand Up @@ -146,13 +148,15 @@ def __init__(
max_zernike=22,
aberration_interp='constant',
aberration_prior_sigma=0.05,
nproc=1,
logger=None,
):
self.logger = logger
self.filter = filter
self.chromatic = chromatic
self.max_zernike = int(max_zernike)
self.aberration_interp = str(aberration_interp)
self.nproc = int(nproc)
self.set_num(None)

if self.max_zernike < 4 or self.max_zernike > 22:
Expand All @@ -175,10 +179,17 @@ def __init__(
'max_zernike': self.max_zernike,
'aberration_interp': self.aberration_interp,
'aberration_prior_sigma': self.prior_sigma,
'nproc': self.nproc,
}
self.sca_size = float(galsim.roman.n_pix)
self.clear_cache()

def __getstate__(self):
# Do not pickle logger instances for multiprocessing jobs.
state = dict(self.__dict__)
state['logger'] = None
return state

@property
def param_len(self):
return self.max_zernike - 3
Expand All @@ -189,7 +200,6 @@ def initialize_iteration(self):

def clear_cache(self):
self._corner_cache = {}
self._sca_wcs = {}

def _make_extra_aberrations(self, params):
# GalSim expects extra_aberrations indexed by Zernike number. Our parameter vector
Expand Down Expand Up @@ -262,7 +272,24 @@ def fit(self, star, logger=None, convert_func=None, draw_method=None):
draw_method=draw_method,
)[0]

@staticmethod
def _fit_sca_group_worker(model, sca, stars, convert_funcs, draw_method, logger):
fit_group, mean_params, corner_profiles = model._fit_sca_group(
sca,
stars,
convert_funcs=convert_funcs,
logger=logger,
draw_method=draw_method,
)
return sca, fit_group, mean_params, corner_profiles

def fit_many(self, stars, logger=None, convert_funcs=None, draw_method=None):
logger = LoggerWrapper(logger)

if len(stars) == 0:
self._last_sca_mean = {}
return []

if convert_funcs is None:
convert_funcs = [None] * len(stars)
elif len(convert_funcs) != len(stars):
Expand All @@ -277,20 +304,37 @@ def fit_many(self, stars, logger=None, convert_funcs=None, draw_method=None):
grouped[sca]['stars'].append(star)
grouped[sca]['convert_funcs'].append(convert_func)

out = [None] * len(stars)
sca_mean = {}
args = []
for sca, group in grouped.items():
fit_group, mean_params = self._fit_sca_group(
sca,
group['stars'],
logger=logger,
draw_method=draw_method,
convert_funcs=group['convert_funcs'],
# Keep multiprocessing payload small by only sending the relevant cache entry
# for this SCA to each worker.
worker_model = copy.copy(self)
worker_model._corner_cache = (
{sca: self._corner_cache[sca]} if sca in self._corner_cache else {}
)
args.append(
(worker_model, sca, group['stars'], group['convert_funcs'], draw_method)
)

fit_results = run_multi(
self._fit_sca_group_worker,
self.nproc,
raise_except=True,
args=args,
logger=logger,
)

out = [None] * len(stars)
sca_mean = {}
for fit_result in fit_results:
sca, fit_group, mean_params, corner_profiles = fit_result
indices = grouped[sca]['indices']
sca_mean[sca] = mean_params
wcs = stars[indices[0]].image.wcs
self._corner_cache[sca] = (mean_params, wcs, corner_profiles)
# Output the stars in the same order as input, so they stay matched with
# convert_funcs array if there is one.
for i, star in zip(group['indices'], fit_group):
for i, star in zip(indices, fit_group):
out[i] = star
self._last_sca_mean = sca_mean
return out
Expand Down Expand Up @@ -367,7 +411,7 @@ def _fit_sca_group(self, sca, stars, convert_funcs, logger=None, draw_method=Non
),
)
)
return out, sca_params
return out, sca_params, corner_sca_profiles

def draw(self, star, copy_image=True):
params = star.fit.get_params(self._num)
Expand Down Expand Up @@ -402,14 +446,18 @@ def _draw_model_image_from_corners(self, star, corner_profiles, convert_func=Non
def _get_corner_profiles(self, star, params, cache=True, sca=None):
if sca is None:
sca = _get_sca(star)
wcs = star.image.wcs
if sca in self._corner_cache:
cached_params, cached_profiles = self._corner_cache[sca]
if np.array_equal(cached_params, params):
cached_params, cached_wcs, cached_profiles = self._corner_cache[sca]
same_wcs = cached_wcs is wcs
if not same_wcs:
try:
same_wcs = (cached_wcs == wcs)
except Exception:
same_wcs = False
if same_wcs and np.array_equal(cached_params, params):
return cached_profiles

if sca not in self._sca_wcs:
self._sca_wcs[sca] = star.data.local_wcs
wcs = self._sca_wcs[sca]
wavelength = None if self.chromatic else self.bandpass.effective_wavelength
corners = (
galsim.PositionD(0.0, 0.0),
Expand All @@ -431,7 +479,7 @@ def _get_corner_profiles(self, star, params, cache=True, sca=None):
for corner, p in zip(corners, corner_params)
)
if cache:
self._corner_cache[sca] = (params, profiles)
self._corner_cache[sca] = (params, wcs, profiles)
return profiles

def _interpolate_corners(self, star, corner_profiles):
Expand Down
8 changes: 8 additions & 0 deletions piff/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ def rejectStars(self, stars, logger=None):
logger.debug("star properties = %s",star.data.properties)
continue

# If doing any selections that rely on hsm, reject any that fail hsm measurement.
if self.hsm_size_reject != 0 or self.max_pixel_cut is not None:
flag = star.hsm[6]
if flag != 0:
logger.verbose("Skipping star at position %f,%f because hsm failed",
star.image_pos.x, star.image_pos.y)
continue

# Add Poisson noise now. It's not a rejection step, but it's something we want
# to do to all the stars at the start, so they have the right noise level.
# We didn't do it earlier for efficiency reasons, in case the full set of objects
Expand Down
6 changes: 5 additions & 1 deletion piff/star_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def compute(self, psf, stars, logger=None):
if not s.is_reserve and (self.include_flagged or not s.is_flagged)]
possible_indices = sorted(possible_indices)

if self.nplot == 0 or self.nplot >= len(stars):
if self.nplot == 0 or self.nplot >= len(possible_indices):
# select all viable stars
self.indices = possible_indices
else:
Expand Down Expand Up @@ -167,12 +167,16 @@ def plot(self, logger=None, **kwargs):
index = self.indices[i]
u = star.data.properties['u']
v = star.data.properties['v']
x = star.data.properties['x']
y = star.data.properties['y']
chipnum = star.data.properties['chipnum']

title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
title += f'\n({chipnum}, {x:.0f}, {y:.0f})'
Comment thread
HyeongHan marked this conversation as resolved.
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
axs[ii][jj+2].set_title('Star - PSF')
Expand Down
1 change: 1 addition & 0 deletions piff/sumpsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def _finish_write(self, writer, logger):
'last_delta_chisq' : self.last_delta_chisq,
'dof' : self.dof,
'nremoved' : self.nremoved,
'niter' : self.niter,
}
writer.write_struct('chisq', chisq_dict)
logger.debug("Wrote the chisq info to %s", writer.get_full_name('chisq'))
Expand Down
3 changes: 2 additions & 1 deletion piff/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def run_multi(func, nproc, raise_except, args, logger, kwargs=None):

def log_output(result): # pragma: no cover (It is covered, but in an async process.)
i, out, log = result
logger.verbose(log)
if log.strip():
logger.verbose(log)
if isinstance(out, Exception):
logger.warning("Caught exception in multiprocessing job: %r",out)
err_list[i] = out
Expand Down
1 change: 1 addition & 0 deletions tests/input/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
test_input*.fits
test_input*.fits.fz
test_input_cat*
1 change: 1 addition & 0 deletions tests/test_convolvepsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_trivial_convolve1():
assert psf.chisq_thresh == 0.2
assert psf.max_iter == 10
assert psf.min_iter == 2
assert psf.niter > 0

for i, star in enumerate(psf.stars):
target = targets[i]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ def test_stars():
select = piff.FlagSelect(config['select'], logger=logger)
stars = input.makeStars(logger=logger)
stars = select.rejectStars(stars, logger=logger)
assert len(stars) == 90
assert len(stars) == 89
del config['select']['max_snr_weight']
del config['select']['max_pixel_cut']

Expand All @@ -1717,14 +1717,14 @@ def test_stars():
select = piff.FlagSelect(config['select'], logger=logger)
stars = input.makeStars(logger=logger)
stars = select.rejectStars(stars, logger=logger)
assert len(stars) == 88
assert len(stars) == 87

# hsm_size_reject can also be a float. (True is equivalent to 10.)
config['select']['hsm_size_reject'] = 100.
select = piff.FlagSelect(config['select'], logger=logger)
stars = input.makeStars(logger=logger)
stars = select.rejectStars(stars, logger=logger)
assert len(stars) == 90
assert len(stars) == 89
config['select']['hsm_size_reject'] = 3.
select = piff.FlagSelect(config['select'], logger=logger)
stars = input.makeStars(logger=logger)
Expand All @@ -1734,7 +1734,7 @@ def test_stars():
select = piff.FlagSelect(config['select'], logger=logger)
stars = input.makeStars(logger=logger)
stars = select.rejectStars(stars, logger=logger)
assert len(stars) == 88
assert len(stars) == 87
del config['select']['hsm_size_reject']

# alt_x and alt_y also include some object completely off the image, which are always skipped.
Expand Down
64 changes: 63 additions & 1 deletion tests/test_roman.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_roman_corner_cache():
assert profiles1 is profiles2
assert len(psf.model._corner_cache) == 1
assert 5 in psf.model._corner_cache
assert len(psf.model._corner_cache[5][1]) == 4
assert len(psf.model._corner_cache[5][2]) == 4


@timer
Expand Down Expand Up @@ -459,6 +459,68 @@ def test_roman_fit_many():
np.testing.assert_allclose(p0, truth_params, atol=0.0, rtol=1.e-3)


@timer
def test_roman_fit_many_nproc():
"""Check `fit_many` multiprocessing path and accuracy with `nproc > 1`.
"""
with fast_pupil_bin():
model = piff.roman.RomanOpticalModel(
filter='H158',
chromatic=False,
max_zernike=6,
aberration_prior_sigma=1.0e6,
nproc=2,
)
stars = [
piff.Star.makeTarget(
x=64.2,
y=64.1,
stamp_size=25,
scale=0.11,
properties={'sca': 5},
).withFlux(1.0, (0.0, 0.0)),
piff.Star.makeTarget(
x=171.8,
y=162.7,
stamp_size=25,
scale=0.11,
properties={'sca': 5},
).withFlux(1.0, (0.0, 0.0)),
]
stars = [model.initialize(s) for s in stars]
truth_params = np.array([0.004, -0.003, 0.005])
truth = [
model.draw(
piff.Star(
s.data,
s.fit.newParams(
truth_params,
params_var=np.zeros_like(truth_params),
),
)
)
for s in stars
]
fit_stars = [
piff.Star(
s.data,
stars[i].fit.newParams(
np.zeros_like(truth_params),
params_var=np.zeros_like(truth_params),
),
)
for i, s in enumerate(truth)
]

for _ in range(3):
fit_stars = model.fit_many(fit_stars)

assert [int(s['sca']) for s in fit_stars] == [5, 5]
for s in fit_stars:
np.testing.assert_allclose(s.fit.params, truth_params, atol=0.0, rtol=2.e-3)
assert list(model._corner_cache.keys()) == [5]


@timer
def test_roman_optics_convert_funcs():
"""Check aberration recovery when fitting with a nontrivial convert_func (profile shear).
Expand Down
Loading
Loading