Skip to content
10 changes: 10 additions & 0 deletions src/cell_eval/_cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def parse_args_run(parser: ap.ArgumentParser):
type=str,
help="Metrics to skip (comma-separated for multiple) (see docs for more details)",
)
parser.add_argument(
"-k",
"--topk",
type=int,
default=10,
help="k for top_k_accuracy (number of nearest neighbors) [default: %(default)s]",
)
parser.add_argument(
"--version",
action="version",
Expand Down Expand Up @@ -142,6 +149,9 @@ def run_evaluation(args: ap.Namespace):
else {}
)

# Always pass top-k for top_k_accuracy
metric_kwargs.setdefault("top_k_accuracy", {})["k"] = args.topk

skip_metrics = args.skip_metrics.split(",") if args.skip_metrics else None

if args.celltype_col is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/cell_eval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
mse,
mse_delta,
pearson_delta,
top_k_accuracy,
)
from ._de import (
DEDirectionMatch,
Expand All @@ -31,6 +32,7 @@
"mse_delta",
"mae_delta",
"discrimination_score",
"top_k_accuracy",
# DE metrics
"DEDirectionMatch",
"DESpearmanSignificant",
Expand Down
58 changes: 58 additions & 0 deletions src/cell_eval/metrics/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,64 @@ def discrimination_score(

return norm_ranks

def top_k_accuracy(
data,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The data parameter is missing a type hint. For consistency with other metric functions in this file and for better code clarity, please add the type hint PerturbationAnndataPair.

Suggested change
data,
data: PerturbationAnndataPair,

k: int = 10,
metric: str = "l2",
embed_key: str | None = None,
) -> dict[str, float]:
"""
Top-k accuracy over pseudo-bulked perturbation profiles.
For each perturbation, we compute one vector for real and one for predicted
(pseudobulk/mean per perturbation). We then compare each predicted
perturbation vector against all real perturbation vectors and mark a hit if
the correct real perturbation is within the top-k closest.
Args:
data: PerturbationAnndataPair
k: number of nearest neighbors to consider per perturbation
metric: one of {"l2", "euclidean", "cosine"}
embed_key: optional key for .obsm
"""

if k <= 0:
raise ValueError("Parameter `k` must be positive.")

metric = metric.lower()
if metric in {"l2", "euclidean"}:
dist_metric = "euclidean"
elif metric == "cosine":
dist_metric = "cosine"
else:
raise ValueError(f"Unsupported metric: {metric}")

# Build one vector per perturbation (exclude control) in a consistent order
real_vectors: list[np.ndarray] = []
pred_vectors: list[np.ndarray] = []
perts_order: list[str] = []
for bulk in data.iter_bulk_arrays(embed_key=embed_key):
perts_order.append(bulk.key)
real_vectors.append(bulk.pert_real)
pred_vectors.append(bulk.pert_pred)

if not real_vectors:
return {}

real_mat = np.vstack(real_vectors)
pred_mat = np.vstack(pred_vectors)

# Compute distance matrix between predicted and real pseudo-bulks
D = skm.pairwise_distances(pred_mat, real_mat, metric=dist_metric)

n_real = D.shape[1]
k_eff = int(min(max(1, k), n_real))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The max(1, k) check is redundant because there's already a validation on line 219 that ensures k is a positive integer. You can simplify this line to improve clarity.

Suggested change
k_eff = int(min(max(1, k), n_real))
k_eff = int(min(k, n_real))


scores: dict[str, float] = {}
for i, pert in enumerate(perts_order):
# indices of k smallest distances
idx = np.argpartition(D[i], k_eff - 1)[:k_eff]
scores[str(pert)] = 1.0 if i in idx else 0.0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pert variable is already a string, as it comes from perts_order which is a list[str]. The call to str(pert) is redundant and can be removed.

Suggested change
scores[str(pert)] = 1.0 if i in idx else 0.0
scores[pert] = 1.0 if i in idx else 0.0


return scores

def _generic_evaluation(
data: PerturbationAnndataPair,
Expand Down
9 changes: 9 additions & 0 deletions src/cell_eval/metrics/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
mse,
mse_delta,
pearson_delta,
top_k_accuracy,
)
from ._de import (
DEDirectionMatch,
Expand Down Expand Up @@ -72,6 +73,14 @@
kwargs={"metric": distance_metric},
)

metrics_registry.register(
name="top_k_accuracy",
metric_type=MetricType.ANNDATA_PAIR,
description="Top-k retrieval accuracy of predicted perturbation profiles",
best_value=MetricBestValue.ONE,
func=top_k_accuracy,
)

metrics_registry.register(
name="pearson_edistance",
metric_type=MetricType.ANNDATA_PAIR,
Expand Down
Loading