Skip to content

Commit 64655e7

Browse files
sullivanj91claude
andcommitted
fix: address PR review comments — percent_change prior_count, negative guard, usage docs
- _math.py: add prior_count param to percent_change(); formula becomes (x - y) / (y + prior_count), dampening explosion when ref_mean ≈ 0 - __init__.py: add ValueError guard for negative prior_count; thread prior_count through all three percent_change call sites (_pdex_ref, _pdex_all, _pdex_on_target); expand prior_count docstring with recommended usage (start with 0.5, combine with min_mean_expression for full suppression) - CLAUDE.md: update percent_change schema formula to show prior_count - tests/test_math.py: add TestPercentChangeWithPriorCount (4 unit tests) - tests/test_pdex.py: add test_negative_prior_count_raises to validation suite Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4445adb commit 64655e7

5 files changed

Lines changed: 63 additions & 10 deletions

File tree

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ The returned Polars DataFrame (or pandas DataFrame when `as_pandas=True`) has co
8080
| `target_membership` | int | Number of cells in the target group |
8181
| `ref_membership` | int | Number of cells in the reference |
8282
| `fold_change` | float | log2((target_mean + prior_count) / (ref_mean + prior_count)) — computed from pseudobulk means |
83-
| `percent_change` | float | (target_mean - ref_mean) / ref_mean — computed from pseudobulk means |
83+
| `percent_change` | float | (target_mean - ref_mean) / (ref_mean + prior_count) — computed from pseudobulk means |
8484
| `p_value` | float | Mann-Whitney U p-value (per-cell vectors) |
8585
| `statistic` | float | Mann-Whitney U statistic |
8686
| `fdr` | float | FDR-corrected p-value, applied per-group across genes. For `on_target` mode, applied across all groups. |

src/pdex/__init__.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,20 @@ def pdex(
204204
:class:`polars.DataFrame`. Requires ``pyarrow``.
205205
prior_count:
206206
Pseudocount added to both ``target_mean`` and ``ref_mean`` before computing
207-
``fold_change``. When ``prior_count > 0``, extreme fold changes from near-zero
208-
reference means (scRNA-seq sparsity artifact) are dampened toward zero.
209-
Has no effect on the Mann-Whitney U p-value or FDR.
207+
``fold_change`` and ``percent_change``. When ``prior_count > 0``, extreme
208+
values from near-zero reference means (scRNA-seq sparsity artifact) are
209+
dampened toward zero. Has no effect on the Mann-Whitney U p-value or FDR.
210210
Default ``0.0`` preserves existing behaviour.
211+
212+
**Recommended usage:** For scRNA-seq CRISPRi/CRISPRa screens where many
213+
genes are unexpressed in the reference group, start with ``prior_count=0.5``.
214+
This provides modest dampening without substantially compressing fold changes
215+
for well-expressed genes. For complete suppression of the sparsity artifact,
216+
combine with a ``min_mean_expression`` pre-filter on the reference group —
217+
``prior_count`` alone cannot eliminate low p-values arising from per-cell
218+
distributional shifts in near-zero genes.
219+
220+
Must be non-negative. Raises :class:`ValueError` if negative.
211221
**kwargs:
212222
Mode-specific keyword arguments:
213223
@@ -246,6 +256,9 @@ def pdex(
246256
adata.n_vars,
247257
)
248258

259+
if prior_count < 0:
260+
raise ValueError(f"prior_count must be non-negative, got {prior_count}")
261+
249262
# Set the global threadpool for numba
250263
set_numba_threadpool(threads)
251264

@@ -365,7 +378,7 @@ def _pdex_ref(
365378
)
366379

367380
fc = fold_change(group_bulk, ref_bulk, prior_count)
368-
pc = percent_change(group_bulk, ref_bulk)
381+
pc = percent_change(group_bulk, ref_bulk, prior_count)
369382
mwu_result = mwu(group_matrix, ref_data)
370383

371384
mwu_statistic = mwu_result.statistic
@@ -427,7 +440,7 @@ def _pdex_all(
427440
)
428441

429442
fc = fold_change(group_bulk, rest_bulk, prior_count)
430-
pc = percent_change(group_bulk, rest_bulk)
443+
pc = percent_change(group_bulk, rest_bulk, prior_count)
431444
mwu_result = mwu(group_matrix, rest_matrix)
432445

433446
mwu_statistic = mwu_result.statistic
@@ -517,7 +530,7 @@ def _pdex_on_target(
517530
fc = float(
518531
fold_change(np.array([target_mean]), np.array([ref_mean]), prior_count)[0]
519532
)
520-
pc = float(percent_change(np.array([target_mean]), np.array([ref_mean]))[0])
533+
pc = float(percent_change(np.array([target_mean]), np.array([ref_mean]), prior_count)[0])
521534

522535
mwu_result = mwu(group_col, ref_col)
523536
p_value = float(np.clip(np.asarray(mwu_result.pvalue).ravel()[0], 0, 1))

src/pdex/_math.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,14 @@ def fold_change(x: np.ndarray, y: np.ndarray, prior_count: float = 0.0) -> np.nd
117117

118118

119119
@nb.njit(parallel=True)
120-
def percent_change(x: np.ndarray, y: np.ndarray) -> np.ndarray:
121-
"""Calculates the change between two arrays."""
122-
return (x - y) / y
120+
def percent_change(x: np.ndarray, y: np.ndarray, prior_count: float = 0.0) -> np.ndarray:
121+
"""Calculates the percent change between two arrays.
122+
123+
When ``prior_count > 0``, adds a pseudocount to the denominator before
124+
computing the ratio, dampening extreme values when the reference mean is
125+
near zero (scRNA-seq sparsity artifact).
126+
"""
127+
return (x - y) / (y + prior_count)
123128

124129

125130
def mwu(

tests/test_math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,37 @@ def test_equal_means_still_zero(self):
8686
np.testing.assert_allclose(result, [0.0, 0.0])
8787

8888

89+
class TestPercentChangeWithPriorCount:
90+
def test_zero_prior_count_matches_baseline(self):
91+
"""prior_count=0.0 must be identical to calling without it."""
92+
x = np.array([4.0, 8.0, 0.1])
93+
y = np.array([2.0, 4.0, 0.001])
94+
np.testing.assert_array_equal(percent_change(x, y), percent_change(x, y, 0.0))
95+
96+
def test_dampens_extreme_pc_from_near_zero_denominator(self):
97+
"""prior_count=0.5 pulls extreme percent change toward zero."""
98+
x = np.array([0.1])
99+
y = np.array([0.001])
100+
pc_raw = percent_change(x, y)[0]
101+
pc_dampened = percent_change(x, y, 0.5)[0]
102+
assert abs(pc_dampened) < abs(pc_raw)
103+
np.testing.assert_allclose(pc_dampened, (0.1 - 0.001) / (0.001 + 0.5), rtol=1e-5)
104+
105+
def test_preserves_direction(self):
106+
"""prior_count should not flip the sign of percent change."""
107+
x = np.array([2.0, 0.5])
108+
y = np.array([1.0, 1.0])
109+
result = percent_change(x, y, 0.5)
110+
assert result[0] > 0
111+
assert result[1] < 0
112+
113+
def test_equal_means_still_zero(self):
114+
"""When target_mean == ref_mean, percent_change should be 0 regardless of prior_count."""
115+
x = np.array([0.5, 2.0])
116+
result = percent_change(x, x, 0.5)
117+
np.testing.assert_allclose(result, [0.0, 0.0])
118+
119+
89120
class TestBulkMatrixGeometric:
90121
"""Tests for bulk_matrix_geometric."""
91122

tests/test_pdex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ def test_unknown_gene_name_warns_and_skips(self, on_target_adata):
478478

479479

480480
class TestPdexValidation:
481+
def test_negative_prior_count_raises(self, small_adata):
482+
with pytest.raises(ValueError, match="prior_count must be non-negative"):
483+
pdex(small_adata, groupby="guide", is_log1p=False, prior_count=-0.1)
484+
481485
def test_invalid_mode(self, small_adata):
482486
with pytest.raises(ValueError, match="Invalid mode"):
483487
pdex(

0 commit comments

Comments
 (0)