Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/wmel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,25 @@
)
from wmel.metrics import (
BradleyTerryRanking,
CorrelationResult,
EpisodeResult,
McNemarResult,
Scorecard,
ac_ci_half_width,
action_success_rate,
average_planning_latency_ms,
average_steps_to_success,
bootstrap_correlation_ci,
compute_scorecard,
detectable_gap_at_n,
holm_correction,
kendall_tau,
mcnemar_exact,
newcombe_paired_diff_ci,
paired_bradley_terry_ranking,
perturbation_recovery_rate,
required_n_for_half_width,
spearman_rho,
)

__version__ = "0.18.0"
Expand All @@ -53,6 +57,7 @@
"BenchmarkRunner",
"BradleyTerryRanking",
"CompositePerturbation",
"CorrelationResult",
"DropNextActions",
"EnvPerturbation",
"EpisodeResult",
Expand All @@ -66,15 +71,18 @@
"action_success_rate",
"average_planning_latency_ms",
"average_steps_to_success",
"bootstrap_correlation_ci",
"compute_scorecard",
"detectable_gap_at_n",
"holm_correction",
"horizon_sweep",
"kendall_tau",
"mcnemar_exact",
"newcombe_paired_diff_ci",
"paired_bradley_terry_ranking",
"perturbation_recovery_rate",
"required_n_for_half_width",
"spearman_rho",
"print_horizon_sweep",
"print_scorecard",
"to_json_report",
Expand Down
152 changes: 152 additions & 0 deletions src/wmel/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,155 @@ def paired_bradley_terry_ranking(
n_boot=n_boot,
prior=prior,
)


# --- Rank correlation (for offline-metric vs downstream-performance studies) --

@dataclass(frozen=True)
class CorrelationResult:
"""A rank correlation with a bootstrap confidence interval.

Built for the question "does a cheap offline metric predict downstream
decision quality?": correlate one value per (model, env, planner) cell
against its CPG / success and report the strength with an honest interval.
Rank-based because at the handful-of-cells sample sizes this is used for, a
monotone (not linear) relationship on incomparable scales is what matters.
``n_boot`` is the number of *valid* (non-degenerate) resamples actually used,
which can be fewer than the count requested.
"""

rho: float
ci_low: float
ci_high: float
n_pairs: int
method: str
n_boot: int


def _rankdata(xs: Sequence[float]) -> list[float]:
"""Average (fractional) ranks with tie handling; ranks are 1-based."""
n = len(xs)
order = sorted(range(n), key=lambda i: xs[i])
ranks = [0.0] * n
i = 0
while i < n:
j = i
while j + 1 < n and xs[order[j + 1]] == xs[order[i]]:
j += 1
avg = (i + j) / 2.0 + 1.0 # mean of the 1-based positions i..j
for k in range(i, j + 1):
ranks[order[k]] = avg
i = j + 1
return ranks


def _pearson(a: Sequence[float], b: Sequence[float]) -> float:
"""Pearson correlation; raises ValueError if either side has zero variance."""
mean_a, mean_b = fmean(a), fmean(b)
num = sum((ai - mean_a) * (bi - mean_b) for ai, bi in zip(a, b))
den = math.sqrt(sum((ai - mean_a) ** 2 for ai in a)) * math.sqrt(
sum((bi - mean_b) ** 2 for bi in b)
)
if den == 0.0:
raise ValueError("correlation undefined: zero variance in an input")
return num / den


def spearman_rho(xs: Sequence[float], ys: Sequence[float]) -> float:
"""Spearman rank correlation in [-1, 1] (Pearson on average ranks, tie-safe).

Raises ValueError on mismatched lengths, fewer than two points, or a
constant input (correlation undefined).
"""
if len(xs) != len(ys):
raise ValueError(f"length mismatch: {len(xs)} vs {len(ys)}")
if len(xs) < 2:
raise ValueError("need at least two pairs")
return _pearson(_rankdata(xs), _rankdata(ys))


def kendall_tau(xs: Sequence[float], ys: Sequence[float]) -> float:
"""Kendall tau-b in [-1, 1] (tie-corrected). O(n^2); robust at very small n.

Raises ValueError on mismatched lengths, fewer than two points, or a
degenerate denominator (a constant input).
"""
if len(xs) != len(ys):
raise ValueError(f"length mismatch: {len(xs)} vs {len(ys)}")
n = len(xs)
if n < 2:
raise ValueError("need at least two pairs")
n0 = n * (n - 1) // 2
nc = nd = n1 = n2 = 0
for i in range(n):
for j in range(i + 1, n):
dx = xs[i] - xs[j]
dy = ys[i] - ys[j]
if dx == 0 and dy == 0:
n1 += 1
n2 += 1
elif dx == 0:
n1 += 1
elif dy == 0:
n2 += 1
elif (dx > 0) == (dy > 0):
nc += 1
else:
nd += 1
den = math.sqrt((n0 - n1) * (n0 - n2))
if den == 0.0:
raise ValueError("Kendall tau-b undefined: a constant input")
return (nc - nd) / den


def bootstrap_correlation_ci(
xs: Sequence[float],
ys: Sequence[float],
*,
method: str = "spearman",
n_boot: int = 10_000,
alpha: float = 0.05,
seed: int = 0,
) -> CorrelationResult:
"""Rank correlation with a paired percentile bootstrap CI.

Resamples cell indices with replacement (the pair ``(x_i, y_i)`` kept
together, like :func:`paired_bootstrap_gap_ci`) and recomputes the
correlation each draw. Degenerate resamples (a constant input, which makes
the correlation undefined) are skipped; the reported ``n_boot`` is the number
of valid resamples. The interval is therefore conditional on non-degenerate
resamples -- for a near-constant arm at very small n this can make it
optimistically narrow. Deterministic given ``seed``.
"""
if len(xs) != len(ys):
raise ValueError(f"length mismatch: {len(xs)} vs {len(ys)}")
n = len(xs)
if n < 2:
raise ValueError("need at least two pairs")
if n_boot < 1:
raise ValueError("n_boot must be >= 1")
if not (0.0 < alpha < 1.0):
raise ValueError("alpha must be in (0, 1)")
fns = {"spearman": spearman_rho, "kendall": kendall_tau}
if method not in fns:
raise ValueError(f"method must be one of {sorted(fns)} (got {method!r})")
fn = fns[method]

point = fn(xs, ys)
rng = random.Random(seed)
rhos: list[float] = []
for _ in range(n_boot):
idx = [rng.randrange(n) for _ in range(n)]
try:
rhos.append(fn([xs[i] for i in idx], [ys[i] for i in idx]))
except ValueError:
continue # degenerate resample (a constant arm): no information
if len(rhos) < 2:
raise ValueError("correlation bootstrap degenerate: too few valid resamples")
rhos.sort()
m = len(rhos)
lo = rhos[max(0, min(int((alpha / 2.0) * m), m - 1))]
hi = rhos[max(0, min(int((1.0 - alpha / 2.0) * m) - 1, m - 1))]
return CorrelationResult(
rho=point, ci_low=lo, ci_high=hi, n_pairs=n, method=method, n_boot=m
)
115 changes: 115 additions & 0 deletions tests/test_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for the rank-correlation primitives (Spearman, Kendall, bootstrap CI).

These back the offline-metric vs downstream-performance study. The strongest
checks are closed-form: perfect monotone data gives +/-1, and small hand-worked
examples pin the tie handling.
"""

from __future__ import annotations

import pytest

from wmel.metrics import (
CorrelationResult,
bootstrap_correlation_ci,
kendall_tau,
spearman_rho,
)
from wmel.metrics import _rankdata # noqa: PLC2701 (private, tie-handling check)


# --- rank helper ------------------------------------------------------------

def test_rankdata_average_ranks_with_ties():
assert _rankdata([10, 20, 20, 30]) == [1.0, 2.5, 2.5, 4.0]
assert _rankdata([5, 5, 5]) == [2.0, 2.0, 2.0]


# --- Spearman ---------------------------------------------------------------

def test_spearman_perfect_monotone():
assert spearman_rho([1, 2, 3, 4, 5], [10, 20, 30, 40, 50]) == pytest.approx(1.0)
assert spearman_rho([1, 2, 3, 4, 5], [9, 7, 5, 3, 1]) == pytest.approx(-1.0)


def test_spearman_known_value():
# No ties -> Spearman = Pearson on the values themselves.
# [1,2,3,4] vs [1,3,2,4]: cov 4.0, var 5.0 each -> rho 0.8.
assert spearman_rho([1, 2, 3, 4], [1, 3, 2, 4]) == pytest.approx(0.8)


def test_spearman_rejects_degenerate_and_mismatched():
with pytest.raises(ValueError):
spearman_rho([1, 2, 3], [5, 5, 5]) # constant -> undefined
with pytest.raises(ValueError):
spearman_rho([1, 2], [1, 2, 3])
with pytest.raises(ValueError):
spearman_rho([1.0], [2.0])


# --- Kendall tau-b ----------------------------------------------------------

def test_kendall_perfect_monotone():
assert kendall_tau([1, 2, 3, 4], [2, 4, 6, 8]) == pytest.approx(1.0)
assert kendall_tau([1, 2, 3, 4], [8, 6, 4, 2]) == pytest.approx(-1.0)


def test_kendall_known_value():
# [1,2,3,4] vs [1,3,2,4]: 5 concordant, 1 discordant, no ties -> 4/6.
assert kendall_tau([1, 2, 3, 4], [1, 3, 2, 4]) == pytest.approx(4 / 6)


def test_kendall_tie_corrected():
# [1,1,2] vs [1,2,2]: one concordant pair, one x-tie, one y-tie ->
# tau_b = (1-0)/sqrt((3-1)(3-1)) = 0.5.
assert kendall_tau([1, 1, 2], [1, 2, 2]) == pytest.approx(0.5)


def test_kendall_rejects_degenerate():
with pytest.raises(ValueError):
kendall_tau([1, 1, 1], [1, 2, 3])


# --- bootstrap CI -----------------------------------------------------------

def test_bootstrap_perfect_correlation_is_tight():
r = bootstrap_correlation_ci(list(range(10)), list(range(10)), n_boot=500, seed=0)
assert isinstance(r, CorrelationResult)
assert r.rho == pytest.approx(1.0)
assert r.ci_low == pytest.approx(1.0) and r.ci_high == pytest.approx(1.0)
assert r.n_pairs == 10
assert r.method == "spearman"


def test_bootstrap_is_deterministic_given_seed():
xs = [1, 2, 3, 4, 5, 6, 7, 8]
ys = [2, 1, 4, 3, 6, 5, 8, 7]
a = bootstrap_correlation_ci(xs, ys, n_boot=400, seed=3)
b = bootstrap_correlation_ci(xs, ys, n_boot=400, seed=3)
assert (a.rho, a.ci_low, a.ci_high, a.n_boot) == (b.rho, b.ci_low, b.ci_high, b.n_boot)


def test_bootstrap_ci_brackets_point():
xs = list(range(12))
ys = [0, 1, 3, 2, 4, 6, 5, 7, 9, 8, 10, 11] # strong but imperfect monotone
r = bootstrap_correlation_ci(xs, ys, n_boot=1000, seed=0)
assert r.ci_low <= r.rho <= r.ci_high
assert 0.0 < r.rho < 1.0 # genuinely partial


def test_bootstrap_kendall_method_and_bad_method():
r = bootstrap_correlation_ci(list(range(8)), list(range(8)), method="kendall", n_boot=200, seed=1)
assert r.method == "kendall" and r.rho == pytest.approx(1.0)
with pytest.raises(ValueError):
bootstrap_correlation_ci([1, 2, 3], [1, 2, 3], method="pearson")


def test_bootstrap_input_validation():
with pytest.raises(ValueError):
bootstrap_correlation_ci([1, 2], [1, 2, 3])
with pytest.raises(ValueError):
bootstrap_correlation_ci([1.0], [2.0])
with pytest.raises(ValueError):
bootstrap_correlation_ci([1, 2, 3], [1, 2, 3], alpha=0.0)
with pytest.raises(ValueError):
bootstrap_correlation_ci([1, 2, 3], [1, 2, 3], n_boot=0)