Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/hirad/conf/eval_real.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ width: 1088
# If you don't want to evaluate all generated samples, you can specify a range of time steps to evaluate.
# This is useful for debugging or if you only want to evaluate a subset of the generated data.
# Make sure that generated samples are available for the specified time steps.
# times_range: ['20230601-0000','20230831-2300',1]
# Use times_ranges to combine multiple seasons/years into a single evaluation run.
times_ranges:
- ['20210601-0000', '20210831-2300', 1]
- ['20220601-0000', '20220831-2300', 1]
- ['20230601-0000', '20230831-2300', 1]

# List of channels to evaluate/plot - comment out if all channels are to be used
# plot_channels: ['2t', '10u', '10v', 'tp', 't_700', 'u_700', 'v_700', 'z_700', 'q_700', 'w_700']
14 changes: 4 additions & 10 deletions src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, concat_and_group_diurnal
from hirad.eval.eval_utils import resolve_times

def save_plot(hour, means, stds, labels, ylabel, title, out_path):
hrs = np.concatenate([hour.values, [24]])
Expand Down Expand Up @@ -58,16 +59,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting computations for diurnal cycle of precipitation amount and wet-hours")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
logger.info(f"Loaded {len(times)} timesteps to process")
Expand Down
59 changes: 28 additions & 31 deletions src/hirad/eval/diurnal_cycle_precip_p99.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import get_channel_indices, load_land_sea_mask


Expand Down Expand Up @@ -69,16 +70,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Loaded {len(times)} timesteps to process")

Expand Down Expand Up @@ -111,26 +105,26 @@ def main(cfg: dict):
data_list = []
try:
for ts in times:
data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor") * land_mask
data = torch.load(out_root/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor")
data_list.append(data)
except:
logger.error(f"Error loading data for mode {mode}. Skipping.")
continue

da = xr.DataArray(
np.stack(data_list, axis=0),
dims=['time', 'lat', 'lon'],
coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]}
coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times],
'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}
)

# Group by hour and compute 99th percentile
hourly_p99 = da.groupby('time.hour').quantile(0.99, dim='time')

# Apply scaling factor for baseline
# if mode == 'baseline':
# hourly_p99 = hourly_p99 / 6.0

pct99_mean[mode] = hourly_p99.mean(dim=['lat', 'lon'])

# Select only land pixels to avoid all-NaN slices in quantile
land_bool = land_mask.notnull().stack(space=('lat', 'lon'))
da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values)

# Group by hour and compute 99th percentile over time, then spatial mean
hourly_p99 = da_land.groupby('time.hour').quantile(0.99, dim='time')
pct99_mean[mode] = hourly_p99.mean(dim='space')

# -- Predictions: compute per hour per member, then mean+std across members --
logger.info("Processing predictions")
Expand All @@ -139,21 +133,24 @@ def main(cfg: dict):
pred_data_list = []
for ts in times:
preds = torch.load(out_root/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon]
# Extract precipitation channel and convert to xarray for proper broadcasting
tp_data = preds[:, tp_out] # [n_members, lat, lon]
tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'])
pred_data_list.append(tp_da * land_mask) # apply mask

pred_da = xr.concat(pred_data_list, dim='time') # [n_members, time, lat, lon]
tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'],
coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']})
pred_data_list.append(tp_da)

pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon]
pred_da = pred_da.assign_coords({
'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
})
# Transpose to get the expected dimension order: [member, time, lat, lon]
pred_da = pred_da.transpose('member', 'time', 'lat', 'lon')


# Select only land pixels to avoid all-NaN slices in quantile
land_bool = land_mask.notnull().stack(space=('lat', 'lon'))
pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values)

logger.info('Calculating 99th percentile for predictions')
# Group by hour, compute 99th percentile across time, then spatial mean
hourly_p99_by_member = pred_da.groupby('time.hour').quantile(0.99, dim='time').mean(dim=['lat', 'lon'])
# Group by hour, compute 99th percentile across time, then spatial mean over land
hourly_p99_by_member = pred_da_land.groupby('time.hour').quantile(0.99, dim='time').mean(dim='space')

# Store ensemble statistics as xarray DataArrays
pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member')
Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/diurnal_cycle_temp_wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import get_channel_indices, load_land_sea_mask, concat_and_group_diurnal

def main(cfg: dict):
Expand All @@ -38,16 +39,9 @@ def main(cfg: dict):

# Load times
logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
logger.info(f"Loaded {len(times)} timesteps to process")
Expand Down
29 changes: 29 additions & 0 deletions src/hirad/eval/eval_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,33 @@
import numpy as np
from typing import Optional

from hirad.utils.function_utils import get_time_from_range


def resolve_times(cfg: dict, gen_cfg: dict, time_format: str = "%Y%m%d-%H%M") -> Optional[list]:
"""Resolve the list of timestep strings from eval or generation config.

Priority order (both in eval cfg and generation cfg fallback):
1. ``times_ranges`` – list of [start, end, step] ranges, concatenated.
2. ``times_range`` – single [start, end, step] range.
3. ``times`` – explicit list of strings.

Returns ``None`` when no time specification is found in either config.
"""
def _from_cfg(source: dict) -> Optional[list]:
if source.get("times_ranges"):
times = []
for tr in source["times_ranges"]:
times.extend(get_time_from_range(tr, time_format=time_format))
return times
if source.get("times_range"):
return get_time_from_range(source["times_range"], time_format=time_format)
if source.get("times"):
return source["times"]
return None

return _from_cfg(cfg) or _from_cfg(gen_cfg.get("generation", {}))


def percentiles_from_histogram(hist_counts, bin_edges, percentiles_dict):
"""
Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import get_channel_indices, load_land_sea_mask
from hirad.eval.eval_utils import percentiles_from_histogram

Expand Down Expand Up @@ -124,16 +125,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting computation for domain-mean precipitation distribution over land")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Loaded {len(times)} timesteps to process")

Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/map_precip_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import (
plot_map_precipitation, plot_map, get_channel_indices, GridConfig
)
Expand Down Expand Up @@ -166,16 +167,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting precipitation statistics generation")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Processing {len(times)} timesteps")

Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/map_wind_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import plot_map, get_channel_indices, GridConfig


Expand Down Expand Up @@ -230,16 +231,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting wind statistics generation")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Processing {len(times)} timesteps")

Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/probability_of_exceedance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import get_channel_indices, load_land_sea_mask
from hirad.eval.eval_utils import percentiles_from_histogram

Expand Down Expand Up @@ -121,16 +122,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting computation for probability of exceedance over land")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Loaded {len(times)} timesteps to process")

Expand Down
14 changes: 4 additions & 10 deletions src/hirad/eval/probability_of_exceedance_wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
from hirad.utils.function_utils import get_time_from_range
from hirad.eval.eval_utils import resolve_times
from hirad.eval.plotting import get_channel_indices
from hirad.eval.eval_utils import percentiles_from_histogram

Expand Down Expand Up @@ -144,16 +145,9 @@ def main(cfg: dict):
gen_cfg = yaml.safe_load(f)

logger.info("Starting computation for probability of exceedance for wind speed")
if cfg.get("times_range", None):
times = get_time_from_range(cfg.get("times_range"), time_format="%Y%m%d-%H%M")
elif cfg.get("times", None):
times = cfg.get("times")
elif gen_cfg.get("generation").get("times_range", None):
times = get_time_from_range(gen_cfg.get("generation").get("times_range"), time_format="%Y%m%d-%H%M")
elif gen_cfg.get("generation").get("times", None):
times = gen_cfg.get("generation").get("times")
else:
logger.error("No times or times_range specified in config or generation config.")
times = resolve_times(cfg, gen_cfg)
if times is None:
logger.error("No times, times_range, or times_ranges specified in config or generation config.")
return
logger.info(f"Loaded {len(times)} timesteps to process")

Expand Down