-
Notifications
You must be signed in to change notification settings - Fork 9
AKR-K-Band Scores #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AKR-K-Band Scores #119
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ def es_ensemble( | |
| ens_w: "Array" = None, | ||
| estimator: str = "nrg", | ||
| backend: "Backend" = None, | ||
| **kwargs, | ||
| ) -> "Array": | ||
| r"""Compute the Energy Score for a finite multivariate ensemble. | ||
|
|
||
|
|
@@ -71,18 +72,39 @@ def es_ensemble( | |
| Some theoretical background on scoring rules for multivariate forecasts. | ||
| """ | ||
| obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend) | ||
|
|
||
| if estimator == "akr_kband": | ||
| k = kwargs.get("k", 1) | ||
|
|
||
| if ens_w is None: | ||
| if backend == "numba": | ||
| estimator_check(estimator, energy.estimator_gufuncs) | ||
| return energy.estimator_gufuncs[estimator](obs, fct) | ||
| if estimator == "akr_kband": | ||
| return energy.estimator_gufuncs[estimator](obs, fct, k) | ||
| else: | ||
| return energy.estimator_gufuncs[estimator](obs, fct) | ||
|
Comment on lines
+76
to
+85
|
||
| else: | ||
| if estimator == "akr_kband": | ||
| return energy.es(obs, fct, estimator=estimator, backend=backend, k=k) | ||
| return energy.es(obs, fct, estimator=estimator, backend=backend) | ||
| else: | ||
| ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend) | ||
| if backend == "numba": | ||
| estimator_check(estimator, energy.estimator_gufuncs_w) | ||
| return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w) | ||
| if estimator == "akr_kband": | ||
| return energy.estimator_gufuncs_w[estimator](obs, fct, k, ens_w) | ||
| else: | ||
| return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w) | ||
| else: | ||
| if estimator == "akr_kband": | ||
| return energy.es_w( | ||
| obs, | ||
| fct, | ||
| ens_w, | ||
| estimator=estimator, | ||
| backend=backend, | ||
| k=k, | ||
| ) | ||
| return energy.es_w(obs, fct, ens_w, estimator=estimator, backend=backend) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -83,6 +83,22 @@ def _energy_score_akr_circperm_gufunc( | |||||||||||||
| out[0] = e_1 / M - 0.5 * 1 / M * e_2 | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @guvectorize("(d),(m,d),()->()") | ||||||||||||||
| def _energy_score_akr_kband_gufunc( | ||||||||||||||
| obs: np.ndarray, fct: np.ndarray, k: int, out: np.ndarray | ||||||||||||||
| ): | ||||||||||||||
| """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" | ||||||||||||||
| M = fct.shape[0] | ||||||||||||||
| e_1 = 0.0 | ||||||||||||||
| e_2 = 0.0 | ||||||||||||||
| for i in range(M): | ||||||||||||||
| e_1 += float(np.linalg.norm(fct[i] - obs)) | ||||||||||||||
| for j in range(1, k + 1): | ||||||||||||||
| e_2 += 2 * float(np.linalg.norm(fct[i] - fct[(i + j) % M])) | ||||||||||||||
|
|
||||||||||||||
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 | ||||||||||||||
|
||||||||||||||
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 | |
| if k <= 0: | |
| # Avoid division by zero / invalid k; propagate a sentinel value. | |
| out[0] = np.nan | |
| else: | |
| out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -79,6 +79,28 @@ def _energy_score_akr_circperm_gufunc_w( | |||
| out[0] = e_1 - 0.5 * e_2 | ||||
|
|
||||
|
|
||||
| @guvectorize("(d),(m,d),(),(m)->()") | ||||
| def _energy_score_akr_kband_gufunc_w( | ||||
| obs: np.ndarray, fct: np.ndarray, k: int, ens_w: np.ndarray, out: np.ndarray | ||||
| ): | ||||
| """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" | ||||
| M = fct.shape[0] | ||||
|
|
||||
| e_1 = 0.0 | ||||
| e_2 = 0.0 | ||||
| for i in range(M): | ||||
| e_1 += float(np.linalg.norm(fct[i] - obs)) * ens_w[i] | ||||
| for j in range(1, k + 1): | ||||
| e_2 += ( | ||||
| 2 | ||||
| * float(np.linalg.norm(fct[i] - fct[(i + j) % M])) | ||||
| * ens_w[i] | ||||
| * ens_w[(i + j) % M] | ||||
|
||||
| * ens_w[(i + j) % M] |
Copilot
AI
Mar 31, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new numba gufunc divides by k (1/k), but there is no guard for k <= 0. Since the public wrapper currently accepts k via **kwargs and doesn't validate it on the numba path, k=0 will trigger division-by-zero / invalid output. Validate k in the Python wrapper before calling this gufunc (or add a safe guard here if feasible).
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,7 +7,11 @@ | |||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def es_ensemble( | ||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", fct: "Array", estimator: str = "nrg", backend=None | ||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", | ||||||||||||||||||||||||||||||||||||||||||||
| fct: "Array", | ||||||||||||||||||||||||||||||||||||||||||||
| estimator: str = "nrg", | ||||||||||||||||||||||||||||||||||||||||||||
| backend=None, | ||||||||||||||||||||||||||||||||||||||||||||
| k: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||
| ) -> "Array": | ||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| Compute the energy score based on a finite ensemble. | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -22,9 +26,11 @@ def es_ensemble( | |||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr(obs, fct, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||
| elif estimator == "akr_circperm": | ||||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr_circperm(obs, fct, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||
| elif estimator == "akr_kband": | ||||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr_kband(obs, fct, k=k, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||
| f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', and 'akr_circperm'." | ||||||||||||||||||||||||||||||||||||||||||||
| f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', 'akr_circperm', and 'akr_kband'." | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -89,6 +95,28 @@ def _es_ensemble_akr_circperm( | |||||||||||||||||||||||||||||||||||||||||||
| return E_1 - 0.5 * E_2 | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def _es_ensemble_akr_kband( | ||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", fct: "Array", k: int = 1, backend: "Backend" = None | ||||||||||||||||||||||||||||||||||||||||||||
| ) -> "Array": | ||||||||||||||||||||||||||||||||||||||||||||
| """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" | ||||||||||||||||||||||||||||||||||||||||||||
| B = backends.active if backend is None else backends[backend] | ||||||||||||||||||||||||||||||||||||||||||||
| M: int = fct.shape[-2] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if k < 1: | ||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | ||||||||||||||||||||||||||||||||||||||||||||
| E_1 = B.sum(err_norm, -1) / M | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| E_2 = 0.0 | ||||||||||||||||||||||||||||||||||||||||||||
| for j in range(1, k + 1): | ||||||||||||||||||||||||||||||||||||||||||||
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | ||||||||||||||||||||||||||||||||||||||||||||
| E_2 += 2 * B.sum(spread_norm, -1) | ||||||||||||||||||||||||||||||||||||||||||||
| E_2 = E_2 / (M * k) | ||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+108
to
+115
|
||||||||||||||||||||||||||||||||||||||||||||
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm, -1) / M | |
| E_2 = 0.0 | |
| for j in range(1, k + 1): | |
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | |
| E_2 += 2 * B.sum(spread_norm, -1) | |
| E_2 = E_2 / (M * k) | |
| # Clamp k to the maximum meaningful band width (number of unique nontrivial offsets). | |
| # For M == 1, this keeps k_eff at least 1 to avoid division by zero; spread terms are zero anyway. | |
| max_bandwidth = max(1, M - 1) | |
| k_eff = min(k, max_bandwidth) | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm, -1) / M | |
| E_2 = 0.0 | |
| for j in range(1, k_eff + 1): | |
| spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) | |
| E_2 += 2 * B.sum(spread_norm, -1) | |
| E_2 = E_2 / (M * k_eff) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,7 +7,12 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def es_ensemble_w( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", fct: "Array", ens_w: "Array", estimator: str = "nrg", backend=None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fct: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ens_w: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| estimator: str = "nrg", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| backend=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> "Array": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Compute the energy score based on a finite ensemble. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -22,9 +27,11 @@ def es_ensemble_w( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr_w(obs, fct, ens_w, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif estimator == "akr_circperm": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr_circperm_w(obs, fct, ens_w, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif estimator == "akr_kband": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = _es_ensemble_akr_kband_w(obs, fct, ens_w, k=k, backend=backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', and 'akr_circperm'." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', 'akr_circperm', and 'akr_kband'." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -102,6 +109,32 @@ def _es_ensemble_akr_circperm_w( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return E_1 - 0.5 * E_2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _es_ensemble_akr_kband_w( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| obs: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fct: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ens_w: "Array", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| backend: "Backend" = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> "Array": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Compute the weighted Energy Score using the AKR with k-band approximation.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B = backends.active if backend is None else backends[backend] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if k < 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| E_1 = B.sum(err_norm * ens_w, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| E_2 = 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for j in range(1, k + 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fct_shift = B.roll(fct, shift=-j, axis=-2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| spread_norm = B.norm(fct - fct_shift, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+131
to
+133
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| # Weight the spread term with ens_w only, consistent with other AKR estimators. | |
| E_2 += 2 * B.sum(spread_norm * ens_w, -1) |
Copilot
AI
Mar 31, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the unweighted implementation, this only checks k >= 1. For k >= M (ensemble size), the rolled offsets repeat and (when j is a multiple of M) compare members with themselves, while still dividing by k. Consider validating/clamping k to 1 <= k <= M-1 (or the maximum unique band width you intend) to avoid duplicated work and hard-to-interpret scaling.
| if k < 1: | |
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm * ens_w, -1) | |
| E_2 = 0.0 | |
| for j in range(1, k + 1): | |
| fct_shift = B.roll(fct, shift=-j, axis=-2) | |
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| return E_1 - 0.5 * E_2 / k | |
| M: int = fct.shape[-2] | |
| if M < 2: | |
| raise ValueError( | |
| "For estimator='akr_kband', ensemble size M must be >= 2." | |
| ) | |
| if k < 1: | |
| raise ValueError("For estimator='akr_kband', k must be >= 1.") | |
| # Clamp k to the maximum unique band width (M - 1) to avoid | |
| # repeated cyclic permutations and self-comparisons when k >= M. | |
| k_eff = min(k, M - 1) | |
| err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) | |
| E_1 = B.sum(err_norm * ens_w, -1) | |
| E_2 = 0.0 | |
| for j in range(1, k_eff + 1): | |
| fct_shift = B.roll(fct, shift=-j, axis=-2) | |
| ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) | |
| spread_norm = B.norm(fct - fct_shift, -1) | |
| E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) | |
| return E_1 - 0.5 * E_2 / k_eff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
es_ensemblenow accepts arbitrary**kwargs, but the docstring doesn't document any additional keyword arguments and (for non-akr_kbandestimators) extra kwargs are silently ignored. This makes it easy for user typos to go unnoticed. Prefer adding an explicit keyword-onlyk: int = 1parameter (documented in the Parameters section) and rejecting unexpected kwargs (or, if you keep**kwargs, validate thatkwargsis empty when the estimator doesn't consume them).