Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pdex"
version = "0.1.24"
version = "0.1.25"
description = "Parallel differential expression for single-cell perturbation sequencing"
readme = "README.md"
authors = [{ name = "noam teyssier", email = "noam.teyssier@arcinstitute.org" }]
Expand Down
9 changes: 7 additions & 2 deletions src/pdex/_single_cell.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging
import math
import multiprocessing as mp
import os
from collections.abc import Iterator
from functools import partial
from multiprocessing.shared_memory import SharedMemory

import anndata as ad
import os
import numpy as np
import pandas as pd
import polars as pl
from numba import njit, prange, get_num_threads, get_thread_id
from numba import get_num_threads, get_thread_id, njit, prange
from scipy.sparse import csc_matrix, csr_matrix
from scipy.stats import anderson_ksamp, false_discovery_control, mannwhitneyu, ttest_ind
from tqdm import tqdm
Expand Down Expand Up @@ -345,6 +345,11 @@ def parallel_differential_expression(

if not is_log1p:
is_log1p = guess_is_log(adata)
if is_log1p:
logger.info("Auto-Detected log1p for dataset.")
else:
logger.info("Auto-Detected non-log1p for dataset.")
logger.info("Log1p status: %s", is_log1p)

# Precompute the number of combinations and batches
n_combinations = len(unique_targets) * len(unique_features)
Expand Down
39 changes: 15 additions & 24 deletions src/pdex/_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
import anndata as ad
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix

# A heuristic to determine if the data is log-transformed
# Checks if the mean cell umi count is greater than a certain threshold
# If the the mean cell umi count is < UPPER_LIMIT_LOG, it is assumed that the data is log-transformed
#
# This limit is set to 15 (log-data with >15 average UMI counts would mean an
# average UMI count of ($ e^{15} - 1 = 3.26M $ ) which is unlikely at this point)
UPPER_LIMIT_LOG = 15
EPSILON = 1e-3


def guess_is_log(adata: ad.AnnData, num_cells: int | float = 5e2) -> bool:
"""Make an *educated* guess whether the provided anndata is log-transformed.

Selects a random subset of cells and sums their counts.
Returns false if all decimal components are zero (unlikely for log transformed data)
def guess_is_log(adata: ad.AnnData) -> bool:
"""
# Select either the provided `num_cells` or the maximum number of cells in the `adata`
num_cells = int(min(num_cells, adata.shape[0]))

# Draw a random mask of cells
mask = np.random.choice(adata.shape[0], size=num_cells, replace=False)
Make an *educated* guess whether the provided anndata is log-transformed.

# Sum the matrix across the selected cell subset
sums = adata[mask].X.sum(axis=1) # type: ignore
Checks whether the any fractional value of the matrix is greater than an epsilon.

# Determine the mean cell umi count
mean_umi_count = np.mean(sums)

# Return True if the mean cell umi count is less than the upper limit
return bool(mean_umi_count < UPPER_LIMIT_LOG)
This *cannot* tell the difference between log and normalized data.
"""
if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix):
frac, _ = np.modf(adata.X.data)
elif adata.X is None:
raise ValueError("adata.X is None")
else:
frac, _ = np.modf(adata.X) # type: ignore

return bool(np.any(frac > EPSILON))
36 changes: 27 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import anndata as ad
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix

from pdex._utils import guess_is_log

Expand All @@ -8,18 +9,35 @@
MAX_COUNT = 1e6


def build_anndata(log=False) -> ad.AnnData:
def build_anndata(log=False, sparse: str | None = None) -> ad.AnnData:
dim = (N_CELLS, N_GENES)
matrix = np.random.random(size=dim)
if sparse == "csr":
matrix = csr_matrix(matrix)
elif sparse == "csc":
matrix = csc_matrix(matrix)
return ad.AnnData(
X=np.random.random(size=dim)
if log
else np.random.randint(0, int(MAX_COUNT), size=dim)
X=matrix if log else np.random.randint(0, int(MAX_COUNT), size=dim)
)


def test_log_guess():
log_anndata = build_anndata(log=True)
assert guess_is_log(log_anndata)
def test_log_guess_logtrue():
adata = build_anndata(log=True)
assert guess_is_log(adata)

count_anndata = build_anndata(log=False)
assert not guess_is_log(count_anndata)
adata = build_anndata(log=True, sparse="csc")
assert guess_is_log(adata)

adata = build_anndata(log=True, sparse="csr")
assert guess_is_log(adata)


def test_log_guess_logfalse():
adata = build_anndata(log=False)
assert not guess_is_log(adata)

adata = build_anndata(log=False, sparse="csc")
assert not guess_is_log(adata)

adata = build_anndata(log=False, sparse="csr")
assert not guess_is_log(adata)