Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"pylibraft",
"dask",
"cuvs",
"spatialdata",
]
default_role = "literal"
napoleon_google_docstring = False
Expand Down Expand Up @@ -126,6 +127,7 @@
"statsmodels": ("https://www.statsmodels.org/stable/", None),
"omnipath": ("https://omnipath.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
"spatialdata": ("https://spatialdata.scverse.org/en/stable/", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ dev = [
"pre-commit",
]

[project.entry-points."squidpy.backends"]
rapids_singlecell = "rapids_singlecell.squidpy_backend:RscSquidpyBackend"

[project.urls]
Documentation = "https://rapids-singlecell.readthedocs.io"
Source = "https://github.com/scverse/rapids_singlecell"
Expand Down
5 changes: 5 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from scipy.sparse import csc_matrix as csc_matrix_cpu
from scipy.sparse import csr_matrix as csr_matrix_cpu

try:
from spatialdata import SpatialData
except ImportError:
SpatialData = None


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)
Expand Down
30 changes: 30 additions & 0 deletions src/rapids_singlecell/squidpy_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Squidpy backend adapter for rapids_singlecell.

The dispatch decorator introspects the real RSC function signatures
(lazily imported on first access), so no need to duplicate them here.
"""

from __future__ import annotations

import importlib


class RscSquidpyBackend:
"""Backend adapter exposing rapids_singlecell GPU implementations to squidpy."""

name = "rapids_singlecell"
aliases = ["rapids-singlecell", "rsc", "cuda", "gpu"]

# squidpy function name -> module that implements it
_functions = {
"spatial_autocorr": "rapids_singlecell.squidpy_gpu",
"co_occurrence": "rapids_singlecell.squidpy_gpu",
"ligrec": "rapids_singlecell.squidpy_gpu",
}

def __getattr__(self, name: str):
if name in self._functions:
func = getattr(importlib.import_module(self._functions[name]), name)
setattr(self, name, func) # cache on instance
return func
raise AttributeError(f"{type(self).__name__!r} has no attribute {name!r}")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the separate class? couldn't it just be a module that lazy re-exports the relevant functions?

5 changes: 4 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy import sparse
from statsmodels.stats.multitest import multipletests

from rapids_singlecell._compat import SpatialData
from rapids_singlecell.preprocessing._utils import _sparse_to_dense

from ._gearysc import _gearys_C_cupy
Expand Down Expand Up @@ -49,7 +50,7 @@ def _to_cupy(vals, *, use_sparse: bool, dtype):


def spatial_autocorr(
adata: AnnData,
adata: AnnData | SpatialData,
*,
connectivity_key: str = "spatial_connectivities",
genes: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -118,6 +119,8 @@ def spatial_autocorr(
DataFrame containing the autocorrelation scores, p-values, and corrected p-values for each gene. \
If `copy` is False, the results are stored in `adata.uns` and None is returned.
"""
if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
if genes is None:
if "highly_variable" in adata.var:
genes = adata[:, adata.var["highly_variable"]].var_names.values
Expand Down
5 changes: 4 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_co_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from cuml.metrics import pairwise_distances

from rapids_singlecell._compat import SpatialData
from rapids_singlecell._cuda import _cooc_cuda as _co
from rapids_singlecell._utils import (
_calculate_blocks_per_pair,
Expand All @@ -21,7 +22,7 @@


def co_occurrence(
adata: AnnData,
adata: AnnData | SpatialData,
cluster_key: str,
*,
spatial_key: str = "spatial",
Expand Down Expand Up @@ -65,6 +66,8 @@ def co_occurrence(
computed at ``interval``.
"""

if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
_assert_categorical_obs(adata, key=cluster_key)
_assert_spatial_basis(adata, key=spatial_key)
spatial = cp.array(adata.obsm[spatial_key]).astype(np.float32)
Expand Down
6 changes: 5 additions & 1 deletion src/rapids_singlecell/squidpy_gpu/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from cupyx.scipy.sparse import issparse as cpissparse
from scipy.sparse import csc_matrix, issparse

from rapids_singlecell._compat import SpatialData

from ._utils import _assert_categorical_obs, _create_sparse_df

SOURCE = "source"
Expand Down Expand Up @@ -118,7 +120,7 @@ def _check_tuple_needles(needles, haystack, *, msg: str, reraise: bool = True):


def ligrec(
adata: AnnData,
adata: AnnData | SpatialData,
cluster_key: str,
*,
clusters: list | None = None,
Expand Down Expand Up @@ -233,6 +235,8 @@ def ligrec(
interacting components was 0 or it didn't pass the threshold percentage of \
cells being expressed within a given cluster.
"""
if SpatialData is not None and isinstance(adata, SpatialData):
adata = adata.table
# Get and Check interactions
if interactions is None:
interactions = _get_interactions(
Expand Down
11 changes: 11 additions & 0 deletions tests/test_backend_conformance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Run squidpy's backend conformance suite against the RSC backend."""

from __future__ import annotations

from squidpy.testing.backend_conformance import validate_backend


def test_conformance():
results = validate_backend("rapids_singlecell")
for name, status in results.items():
assert status == "PASSED", f"{name}: {status}"
Loading