Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion dexpv2/_tests/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging

from skimage.data import cells3d
import pytest

from dexpv2.segmentation import detect_foreground
from dexpv2.segmentation import detect_foreground, reconstruction_by_dilation
from dexpv2.utils import to_cpu

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,3 +33,31 @@ def test_foreground_detection(interactive_test: bool) -> None:
viewer.add_labels(to_cpu(foreground))

napari.run()


def test_foreground_detection_with_float16() -> None:
# Test with float16 dat
# a
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
nuclei = xp.asarray(cells3d()[:, 1])
nuclei = nuclei / nuclei.max()
nuclei = nuclei.astype(xp.float16)
mask = xp.copy(nuclei)

# Ensure we are using cupy backend
import numpy as np

if isinstance(nuclei, np.ndarray):
pytest.skip("Skipping test as cupy is not available.")

foreground_cp = reconstruction_by_dilation(nuclei, mask, iterations=10)
foreground_cp = to_cpu(foreground_cp)

# Convert to numpy for comparison
# Obs. skimage operations won't work with np.float16 so we need to convert
# to float32 and hope that the conversion doesn't change the result too much
nuclei_np = nuclei.astype(xp.float32)
mask_np = mask.astype(xp.float32)
foreground_np = reconstruction_by_dilation(nuclei_np, mask_np, iterations=10)

# Check that the output is a binary mask
assert np.allclose(foreground_cp, foreground_np)
149 changes: 147 additions & 2 deletions dexpv2/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Tuple, List

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -10,6 +11,135 @@
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)

try:
import cupy as xp

LOG.info("cupy found.")
except (ModuleNotFoundError, ImportError):
import numpy as xp

LOG.info("cupy not found using numpy.")


def discretize_multiple_float16_to_uint16(
float16_arrays: List["xp.ndarray"], # type: ignore # 'xp' will be defined at runtime
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
) -> Tuple[List["xp.ndarray"], "xp.ndarray"]: # type: ignore
"""
Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to uint16,
preserving order using a global mapping across all arrays.

Args:
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
float16_arrays (List[xp.ndarray]): A list of arrays (e.g., CuPy or NumPy),
each with dtype float16.
'xp' should be the array module (e.g., numpy or cupy).

Returns:
tuple: A tuple containing:
- y_uint16_list (List[xp.ndarray]): A list of discretized arrays,
each with dtype uint16, corresponding
to the input arrays.
- uint16_to_float16_lookup (xp.ndarray): A single lookup table
(array of float16) for all arrays, where the index
is the uint16 value and the value is the corresponding
original float16 value.

Raises:
TypeError: If any input array's dtype is not float16.
ValueError: If the list of arrays is empty, or if the total number
of unique values across all arrays exceeds the capacity
of uint16 (65536).
"""
# Ensure xp is defined (this is more of a runtime check if not using static analysis)
if "xp" not in globals() and "xp" not in locals():
raise NameError(
"Array library 'xp' is not defined. Please import numpy as xp or cupy as xp."
)
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated

if not float16_arrays:
raise ValueError("Input list of arrays cannot be empty.")

# Validate input types and collect original shapes and sizes
original_shapes = []
original_sizes = []
for i, arr in enumerate(float16_arrays):
# Assuming 'xp' is defined, xp.ndarray would be the type to check against
# For simplicity and following user's snippet, primarily checking dtype.
if not hasattr(arr, "dtype") or arr.dtype != xp.float16:
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
raise TypeError(
f"Array at index {i} must be an 'xp.ndarray' with dtype xp.float16. "
f"Got type {type(arr)} with dtype {getattr(arr, 'dtype', 'N/A')}."
)
original_shapes.append(arr.shape)
original_sizes.append(arr.size)

# Handle case where all arrays combined are empty
if sum(original_sizes) == 0:
# Create empty uint16 arrays with original shapes
empty_uint16_list = [
xp.array([], dtype=xp.uint16).reshape(shape) for shape in original_shapes
]
return empty_uint16_list, xp.array([], dtype=xp.float16)

# Concatenate all arrays into a single flat array for global unique value finding.
# We need to ensure that we only concatenate non-empty arrays if ravel() on empty
# arrays with certain shapes causes issues, or handle shapes carefully.
# xp.concatenate([arr.ravel() for arr in float16_arrays]) should generally work.
# Ravel ensures that even multi-dimensional arrays become 1D before concatenation.
try:
combined_float16_array = xp.concatenate([arr.ravel() for arr in float16_arrays])
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
except Exception as e:
raise ValueError(
f"Error during concatenation of arrays: {e}. Ensure 'xp' is correctly defined (NumPy/CuPy)."
)

# Find unique values and their inverse indices from the combined array.
# unique_values will be sorted, which is crucial for order preservation.
# inverse_indices will correspond to the flattened combined_float16_array.
unique_values: "xp.ndarray" # type: ignore
inverse_indices: "xp.ndarray" # type: ignore
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
unique_values, inverse_indices = xp.unique(
Comment thread
ilan-theodoro marked this conversation as resolved.
Outdated
combined_float16_array, return_inverse=True
)

# The unique_values array serves as the global uint16 to float16 lookup table.
uint16_to_float16_lookup: "xp.ndarray" = unique_values # type: ignore

# Check if the number of unique values fits into uint16
# xp.iinfo(xp.uint16).max gives the max value (e.g., 65535).
# Number of unique values can be up to (max_value + 1).
if len(unique_values) > xp.iinfo(xp.uint16).max + 1:
raise ValueError(
f"Number of unique values ({len(unique_values)}) across all arrays "
f"exceeds the maximum capacity of uint16 ({xp.iinfo(xp.uint16).max + 1})."
)

# The inverse_indices array contains the uint16 representations for the combined flat array.
# Cast it to uint16.
y_uint16_combined_flat: "xp.ndarray" = inverse_indices.astype(xp.uint16) # type: ignore

# Split the combined flat uint16 array back into individual arrays and reshape them
y_uint16_list: List["xp.ndarray"] = [] # type: ignore
current_pos = 0
for i in range(len(float16_arrays)):
size = original_sizes[i]
shape = original_shapes[i]

if size == 0:
# Create an empty uint16 array with the original shape
y_uint16_list.append(xp.array([], dtype=xp.uint16).reshape(shape))
else:
segment = y_uint16_combined_flat[current_pos : current_pos + size]
y_uint16_list.append(segment.reshape(shape))
current_pos += size

if current_pos != y_uint16_combined_flat.size:
# This should not happen if logic is correct, but good for sanity check
raise AssertionError(
"Mismatch in processed elements during splitting of combined array."
)

return y_uint16_list, uint16_to_float16_lookup


def reconstruction_by_dilation(
seed: ArrayLike, mask: ArrayLike, iterations: int
Expand All @@ -34,14 +164,29 @@ def reconstruction_by_dilation(
-------
Image reconstructed by dilation.
"""
ndi = import_module("scipy", "ndimage")
ndi = import_module("scipy", "ndimage", seed)

import numpy as np

seed = np.minimum(seed, mask, out=seed) # just making sure
cupy_used = np != xp and not isinstance(seed, np.ndarray)

lut = None
# quick-fix for the issue https://github.com/cupy/cupy/issues/9122
if cupy_used and seed.dtype == xp.float16:
arrs, lut = discretize_multiple_float16_to_uint16([seed, mask])
seed = arrs[0]
mask = arrs[1]

seed = np.minimum(seed, mask, out=seed)

for _ in range(iterations):
seed = ndi.grey_dilation(seed, size=3, output=seed, mode="constant")
seed = np.minimum(seed, mask, out=seed)

if lut is not None:
# convert back to float16
seed = xp.take(lut, seed)

return seed


Expand Down