Releases: p-gueguen/rctd-py
v0.3.6
Changed
- Pre-stage
spatial_countson the device once in_doublet.py(steps 3/4/6) and_multi.py(forward-selection iterations). Previously each batch didtorch.tensor(spatial_counts[pix_idx], device=device)— a pixel that appears in many triples/tasks was gathered and copied repeatedly. Now bothspatial_countsandspatial_numimove to the device once at the start of each mode and per-batch access is an on-device gather. Numerical output is unchanged: a cross-worktree replay against unmodifiedmainreproduced the v0.3.5 output bit-for-bit on a synthetic doublet workload. GPU bench on fgcz-r-023 (L40S, sm_89, N=20000, K=30, G=500): 76.2 s → 75.0 s (1.5% — modest because the H2D was already memcpy-bound).
Documented (no code change)
-
fp32 concordance on GPU is now empirically verified. The fp32 path has been exposed via
RCTDConfig(dtype="float32")/ CLI--dtype float32since the initial release, and the perf test suite already times it for all three modes — but no test ever asserted numerical agreement with fp64, andCLAUDE.mdflagged the spline-indexfloor(sqrt(lam/delta))in_calc_q_all_implas a precision-sensitive site. Newtests/test_fp32_concordance.pyassertsspot_classagreement ≥ 99% (observed: 100%),first_typeagreement ≥ 99% (observed: 100%), andweights_doubletmax diff < 1e-2 (observed: <1e-6 on the synthetic fixture). On L40S sm_89 at N=20000, K=30 the fp32 doublet run produced an identicalspot_classhash to fp64 (63dfa94a10f7aa93) while cutting wall time from 76 s → 39 s (~2×). README now points consumer-GPU users atfloat32with the empirical numbers.Caveat: only synthetic data was stress-tested. The
floor(sqrt(...))spline-index sensitivity could shift on real-worldlamdistributions near integer boundaries; the assertion is set at ≥99% to admit a small drift before failing.
Internal
- New
tests/test_doublet_prestage.pyruns doublet mode through the pre-staged path and asserts shape, simplex sums, valid class/type ranges, and that ≥90% of pixels reach a non-degenerate split. Tolerance-based rather than byte-hash — early CI on this branch caught that hash equality across PyTorch/numpy minor versions is too brittle.
v0.3.5
Fixed
-
CPU-eigh thread oversubscription (issue #22, reported by @meisproject).
_psd_batchnow caps PyTorch's intra-op thread count to 1 around the CPUtorch.linalg.eighcall and restores the caller's previous count on exit. On hosts with many CPU cores, default OpenBLAS thread count oversubscribed under batchedsyevd— V100 + 64 cores at K=38 stalled at Step 1 = 3086 s in v0.3.4. The auto-cap now produces the bounded-threads behavior out of the box, no env vars required.Empirically confirmed on Tesla V100 (smei, #22 comment): with the auto-cap on
mainand noOMP_NUM_THREADSenv var and no--eigh-thresholdflag, Step 1 = 27.4 s (down from 3086 s in v0.3.4 with the same bare command — ~113× speedup). Total doublet-mode wall time on smei's 6113-pixel × K=38 workload: 57.7 s, vs the original ~52 min.Numerical output is bit-identical to v0.3.4 (existing K=78 atol=1e-9 equivalence test passes). Users on Hopper/Blackwell are unaffected — they stay on GPU eigh and never enter the CPU branch. Users who already set
OMP/MKL/OPENBLAS_NUM_THREADS=1see no behavior change.Note: smei's earlier A/B also disconfirmed bumping the per-arch K threshold default for
sm_<9— forcing GPU eigh at K=38 on Volta was 4× slower (126.2 s) than CPU eigh with bounded threads (33.8 s). The--eigh-thresholdflag from v0.3.4 remains as a diagnostic / power-user knob; the default behavior is now correct on every architecture we have empirical data for (Volta + V100, Ada + L40S, Hopper, Blackwell).
Internal
- New
.pre-commit-config.yamlmirroring the CI lint + format checks (ruff-format+ruff-check --fix, pinned to v0.15.6 matching the dev extra). Install once after cloning:uv pip install pre-commit && pre-commit install. CONTRIBUTING.md updated with the workflow.
v0.3.4
Bundles two _psd_batch improvements. v0.3.3 was prepared and merged to main (CPU eigh crash fix for #20) but never tagged to PyPI; both changes ship together here.
Fixed
_LinAlgErrorcrash in doublet mode at K≈49 on the CPU eigh path (reported by @EduardGhemes-ICR, #20)._psd_batchpreviously calledtorch.linalg.eighraw on the CPU branch; a single non-finite or near-degenerate batch element would crash LAPACKsyevdwith "error code: 99" and kill multi-hour Xenium runs. The CPU branch now mirrors the GPU branch's NaN guard (extended to ±Inf) and adds a small-diagonal-jitter retry ladder (1e-6 → 1e-4 → ε·I last resort). Happy-path output is bit-identical to v0.3.2 — only previously-crashing inputs are affected. Triggered most often on older arches (Volta / Turing / Ampere / Ada / L40S) where K > 16 falls through to CPU eigh, but the guard is unconditional and applies to CPU-only deployments as well.
Added
-
--eigh-thresholdCLI flag andRCTDConfig.eigh_threshold(reported by @meisproject, #22). Manually override the K cutoff for staying on GPU eigh inside_psd_batch. The arch-based default (K≤16on sm_<9,K≤128on sm_≥9) was derived from L40S benchmarks at K=45 where CPU OpenBLAS won — but only withOMP_NUM_THREADScapped. Users on Volta (V100, sm_70), Turing, Ampere, or Ada (L20/L40S, sm_89) who hit Step 1 perf cliffs at K∈[17, 64] (e.g. K=38 reported at 3086 s for 6113 pixels) can now force GPU eigh via--eigh-threshold 64without waiting on a per-arch benchmark / release. Setting--eigh-threshold 0forces CPU eigh on every arch (diagnostic counter-case). DefaultNonepreserves v0.3.2 arch-gated behavior bit-for-bit.Caveat: this ships the override mechanism, not a confirmed perf win on V100/L20. The maintainer has no V100 or L20 hardware to bench against; whether GPU eigh actually beats CPU offload at K=38 on those arches is unverified. Recommended diagnostic sequence: try
OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1alone first (often the real fix is the BLAS thread cap, not the dispatch), then layer--eigh-threshold 64if Step 1 is still slow.
v0.3.2 — class_df hierarchical fallback + Blackwell perf fix
Added
- Hierarchical cell type fallback (
class_df) for doublet mode (#14). When a granular reference makes type-level resolution ambiguous, RCTD now reports the best two subtypes alongsidefirst_class/second_classboolean flags indicating that the assignment is only trustworthy at the parent-class level — mirroring R spacexr exactly.- Python API:
RCTDConfig(class_df={"T_CD4": "T_cell", ...}) - CLI:
--class-df path.tsv(TSV with columnscell_typeandclass) - New result fields on
DoubletResult:first_class_name,second_class_name(string arrays, populated only whenclass_dfis provided) - When
class_dfis omitted (default), behavior is bit-identical to v0.3.1 — verified by an explicit identity-mapping regression test.
- Python API:
- Arch-gated GPU eigh threshold (
_psd_batch). On Hopper (sm_90+) and Blackwell (sm_100+), the K-cutoff for staying on GPU eigh is bumped from 16 to 128 viatorch.cuda.get_device_capability. Older architectures (Volta, Turing, Ampere, Ada / L40S) keep the K≤16 cutoff that earlier benchmarks showed wins via CPU OpenBLAS. - TorchScript-fused box-QP as the
compile=Falsepath. The eager Python Gauss-Seidel loop is replaced by_solve_box_qp_batch_adaptive_jit(@torch.jit.script, separate from torch.compile / Inductor). Fuses 50 sweeps × K coords into a single TorchScript graph and adds batch-level early exit, eliminating the kernel-launch storm at K>16. - 23 new perf-regression tests (
tests/test_blackwell_perf.py) covering arch-detection across 8 GPU architectures, eager/JIT numerical equivalence at K=3,8,16,32,78,100, ill-conditioned matrices, active lower-bound constraints, CPU path preservation, CPU perf neutrality, and end-to-end full + doublet integration regression.
Fixed
-
Blackwell perf cliff at K>16, doublet mode (reported by @litj). Before this release, doublet mode at K≈78 with ~100k pixels would stall for 8+ hours on Blackwell + CUDA 13 with
--no-compile, despite GPU memory being allocated and the process running. Root cause: the K>16 path in_psd_batchunconditionally CPU-offloaded eigendecomposition, which then oversubscribed all CPU cores via OpenBLAS while the GPU sat at 0% utilization. Reproduced on FGCZ Blackwell node and verified end-to-end:Configuration Step 1 (full-mode fit) on K=78, 102k pixels v0.3.0 + --no-compile, no env caps>8h, killed without completing v0.3.2 patched (K≤128 GPU eigh + JIT box-QP + OMP_NUM_THREADS=1)2800 s (~47 min) Convergence rate 1.000, GPU util sustained at 85%. The
OMP_NUM_THREADS=1env var is still recommended on Blackwell to prevent OpenBLAS from spawning threads for incidental CPU LAPACK calls. -
L40S / Ampere unaffected by the dispatch change. Regression-tested on fgcz-r-023 (L40S, sm_89): the JIT box-QP path is 1.6–1.85× faster than the previous eager Python loop at K=45 and K=78, with max numerical diff ~1e-6. Arch gating preserves the CPU eigh offload that earlier L40S benchmarks validated.
Notes
--no-compileandRCTDConfig(compile=False)semantics are unchanged. Users who previously selected this path now get the JIT-script box-QP automatically; no API change.- For users on Blackwell hitting the K>16 perf cliff on v0.3.0 / v0.3.1: upgrading to v0.3.2 is sufficient; the recommended env vars (
OMP_NUM_THREADS=1,MKL_NUM_THREADS=1,OPENBLAS_NUM_THREADS=1) still apply as belt-and-suspenders.
v0.3.0
What's Changed
Bug Fixes
-
counts_MIN pixel filter now enforced (fixes #11): R spacexr calls
restrict_counts()twice — the second call withgene_list_bulkwas missing from rctd-py. Pixels with fewer thancounts_MIN=10counts in the DE gene set are now correctly removed. Validated: exact pixel count match with R spacexr on Xenium Region 1 (n_filtered=13,936). -
torch.compile fallback for environments without CUDA headers (fixes #10):
torch.compilefails at runtime on GPU nodes without CUDA development headers (cuda.h) because Triton attempts to compile CUDA code. Added lazy auto-detection with graceful fallback to eager mode, plusRCTDConfig(compile=False)and--no-compileCLI flag for explicit control. -
cuSOLVER batch-size crash fix:
torch.linalg.eighhas an undocumented batch-size limit in CUDA 12.8 (~27k-31k depending on K). Added_eigh_safe()that sub-batches at 25k, fixing crashes at--batch-size 50000.
New Features
pixel_maskin result types (fixes #8, fixes #9):FullResult,DoubletResult, andMultiResultnow include apixel_maskfield (boolean array matching the input AnnData shape). Maps results back to original barcodes:result = run_rctd(spatial, reference) weights_df = pd.DataFrame( result.weights, index=spatial.obs_names[result.pixel_mask], columns=result.cell_type_names, )
Improvements
-
Memory: sparse-aware reference profiles: Large references (370k+ cells) no longer require
.todense()during profile computation. Sparse mat-vec products keep memory usage proportional to non-zero entries. -
Numerical precision:
_longdouble_sum()uses numpy longdouble (80-bit) for bulk reductions, matching R's extended precision on x86-64. -
Tutorial notebook fixed: Marimo figures now render in static HTML export.
Breaking Changes
counts_MIN=10is now enforced — result pixel counts will differ from v0.2.x (fewer pixels, matching R spacexr).FullResult,DoubletResult,MultiResultgain apixel_maskfield (defaultNone, backward-compatible for directrun_*_mode()callers).RCTDConfiggains acompilefield (defaultTrue).
Validation
- 100/100 tests pass (Python 3.10-3.12)
- Xenium Region 1:
n_filtered=13,936exact match with R,dominant_type_agreement=0.9973,pixel_corr_median=1.0 - No runtime regression on tutorial or Xenium benchmarks
Full Changelog: v0.2.2...v0.3.0
v0.2.2: Fix GPU multi mode crash
Bug fix
- Fix cuSOLVER crash in multi mode on GPU: NVIDIA's batched eigendecomposition (
cusolverDnXsyevBatched) fails on 1×1 matrices, which occur during multi mode's iterative type selection (K_sub=1). Added analytical K=1 path and NaN guard for degenerate Hessians.
Upgrade
uv pip install --upgrade rctd-py==0.2.2
Full mode and doublet mode were unaffected.
v0.2.1: device control, performance optimizations, CLI
New features
deviceparameter inRCTDConfig: force CPU/GPU withdevice="cpu"/"cuda"/"auto"rctd runCLI command for full/doublet/multi modes- Auto batch sizing based on available VRAM
- Analytical K=2 solvers for faster doublet mode
Performance
- Shared-profile IRWLS solver (28% faster, 17% less VRAM)
- Batched log-likelihood computation
torch.compileintegration
Bug fixes
- Correct 0-indexed spot_class labels in tutorial (#4)
- Handle corrupt Q-matrices download with automatic retry
- Fix flaky test tolerances
v0.2.0 — PyTorch backend, device control, consolidated CI
What's Changed
Breaking: JAX → PyTorch migration
- Replaced JAX/jaxlib with PyTorch as the sole compute backend
- All dependencies updated:
torch>=2.0replacesjax>=0.4.20, jaxlib>=0.4.20
New features
deviceparameter — force CPU or GPU viaRCTDConfig(device="cpu")or"cuda"(default"auto"preserves existing behavior)sigma_override— bypass sigma auto-calibration with a known value (e.g. from R) for exact concordance
Testing & validation
- Added R concordance tests with pre-computed spacexr v2.2.1 fixtures (no R required to run)
- 99.7% dominant type agreement on 14k-cell Xenium, 100% with
sigma_override
CI & packaging
- Consolidated lint + test into single CI workflow
- Removed codecov (no token configured)
- Updated README:
uv pip install, benchmark tables, device docs
Full Changelog: v0.1.1...v0.2.0
v0.1.1 — Sigma estimation 23× speedup
What's new in v0.1.1
Performance
- 23× faster sigma estimation via three targeted optimizations:
- Cache the 437×437 tridiagonal matrix inverse (eliminated ~144 redundant O(n³) inversions)
- Precompute all 126 spline coefficient matrices once at startup
- Vectorize sigma candidate evaluation with
jax.vmap/jax.jit(85 sequential → 1 fused kernel)
- Total end-to-end time on Blackwell B200: ~3.5 min vs ~51 min for R spacexr (15× speedup)
New
- Validation report with spatial cell-type maps: https://p-gueguen.github.io/rctd-py/
- q_matrices.npz is now auto-downloaded on first use (not bundled in wheel)
Fixes
- Lint: remove unused variable in
_likelihood.py
Installation
uv pip install rctd-pyrctd-py v0.1.0 — Initial release
rctd-py v0.1.0
GPU-accelerated Robust Cell Type Decomposition (RCTD) for spatial transcriptomics.
Highlights
- JAX reimplementation of spacexr RCTD with 63x GPU speedup (L40S) over R
- 99.7% agreement with R spacexr on 58k Xenium pixels
- Three deconvolution modes: full, doublet, multi
- Pure Python — no R dependency