Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [push, pull_request]
jobs:
all_jobs:
runs-on: ubuntu-latest
needs: [formatting, pytest]
needs: [type-checking, formatting, pytest]
steps:
- name: Complete
run: echo "Complete"
Expand All @@ -27,6 +27,29 @@ jobs:
run: |
uv sync --all-extras --dev

type-checking:
runs-on: ubuntu-latest

needs: [install-job]

steps:
- uses: actions/checkout@v4

- name: install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
cache-dependency-glob: "pyproject.toml"
python-version: "3.12"

- name: install dependencies
run: |
uv sync --all-extras --dev

- name: run formatting
run: |
uv run pyrefly check

formatting:
runs-on: ubuntu-latest

Expand Down
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pdex"
version = "0.1.15"
version = "0.1.16"
description = "Parallel differential expression for single-cell perturbation sequencing"
readme = "README.md"
authors = [{ name = "noam teyssier", email = "noam.teyssier@arcinstitute.org" }]
Expand All @@ -23,8 +23,4 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["pytest>=8.3.5", "ruff>=0.11.8"]

[tool.pyright]
venvPath = "."
venv = ".venv"
dev = ["pyrefly>=0.19.0", "pytest>=8.3.5", "ruff>=0.11.8"]
2 changes: 2 additions & 0 deletions pyrefly.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
project-includes = ["**/*"]
project-excludes = ["**/*venv/**/*"]
21 changes: 14 additions & 7 deletions src/pdex/_single_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from collections.abc import Iterator
from functools import partial
from multiprocessing.shared_memory import SharedMemory
from typing import Any

import anndata as ad
import numpy as np
import pandas as pd
import polars as pl
from adjustpy import adjust
from adjustpy import adjust # type: ignore (exists but not annotated)
from numpy.typing import NDArray
from scipy.sparse import csr_matrix
from scipy.stats import anderson_ksamp, mannwhitneyu, ttest_ind
from tqdm import tqdm
Expand All @@ -25,15 +27,20 @@

def _build_shared_matrix(
data: np.ndarray | np.matrix | csr_matrix,
) -> tuple[str, tuple[int, int], np.dtype]:
) -> tuple[str, tuple[int, ...], np.dtype[Any]]:
"""Create a shared memory matrix from a numpy array."""
if isinstance(data, np.matrix):
data = np.asarray(data)
elif isinstance(data, csr_matrix):
data = data.toarray()

# After conversion enforce ndarray type
assert isinstance(data, np.ndarray)

shared_matrix = SharedMemory(create=True, size=data.nbytes)
matrix = np.ndarray(data.shape, dtype=data.dtype, buffer=shared_matrix.buf)
matrix[:] = data

return shared_matrix.name, data.shape, data.dtype


Expand All @@ -45,7 +52,7 @@ def _conclude_shared_memory(name: str):


def _combinations_generator(
target_masks: dict[str, np.ndarray],
target_masks: dict[str, NDArray[bool]],
var_indices: dict[str, int],
reference: str,
target_list: list[str],
Expand Down Expand Up @@ -168,9 +175,9 @@ def _get_obs_mask(
adata: ad.AnnData,
target_name: str,
variable_name: str = "target_gene",
) -> np.ndarray:
) -> NDArray[bool]:
"""Return a boolean mask for a specific target name in the obs variable."""
return adata.obs[variable_name] == target_name
return np.array(adata.obs[variable_name] == target_name, dtype=bool)


def _get_var_index(
Expand All @@ -191,7 +198,7 @@ def _get_var_index(


def _sample_mean(
x: np.ndarray,
x: NDArray[float],
is_log1p: bool,
exp_post_agg: bool,
) -> float:
Expand All @@ -203,7 +210,7 @@ def _sample_mean(
"""
if is_log1p:
if exp_post_agg:
return np.expm1(np.mean(x))
return np.expm1(x.mean())
else:
return np.expm1(x).mean()
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pbdex.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def build_random_anndata(
]

return ad.AnnData(
X=np.random.randint(0, MAX_UMI, size=(n_cells, n_genes)),
X=np.random.randint(low=0, high=int(MAX_UMI), size=(n_cells, n_genes)),
obs=obs,
var=pd.DataFrame(index=[f"gene.{j}" for j in np.arange(N_GENES)]),
var=pd.DataFrame(index=np.array([f"gene.{j}" for j in np.arange(N_GENES)])),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_pdex.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def build_random_anndata(
if random_state is not None:
np.random.seed(random_state)
return ad.AnnData(
X=np.random.randint(0, MAX_UMI, size=(n_cells, n_genes)),
X=np.random.randint(0, int(MAX_UMI), size=(n_cells, n_genes)),
obs=pd.DataFrame(
{
pert_col: np.random.choice(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def build_anndata(log=False) -> ad.AnnData:
return ad.AnnData(
X=np.random.random(size=dim)
if log
else np.random.randint(0, MAX_COUNT, size=dim)
else np.random.randint(0, int(MAX_COUNT), size=dim)
)


Expand Down
Loading