Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions cellfinder/core/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
CuboidArrayDataset,
CuboidBatchSampler,
)
from cellfinder.core.classify.tools import get_model
from cellfinder.core.classify.tools import get_model, model_input_channels
from cellfinder.core.tools.tools import deprecate_positional_args
from cellfinder.core.train.train_yaml import depth_type, models

Expand All @@ -25,7 +25,7 @@ def main(
*,
points: List[Cell],
signal_array: types.array,
background_array: types.array,
background_array: Optional[types.array],
n_free_cpus: int,
voxel_sizes: Tuple[float, float, float],
network_voxel_sizes: Tuple[float, float, float],
Expand All @@ -48,8 +48,10 @@ def main(
The potential cells to classify.
signal_array : numpy.ndarray or dask array
3D array representing the signal data in z, y, x order.
background_array : numpy.ndarray or dask array
3D array representing the signal data in z, y, x order.
background_array : numpy.ndarray or dask array, optional
3D array representing the background data in z, y, x order. If
``None``, a single-channel (signal-only) cube is built and a
single-channel model must be used.
n_free_cpus : int
How many CPU cores to leave free.
voxel_sizes : 3-tuple of floats
Expand Down Expand Up @@ -94,7 +96,7 @@ def main(
"""
if signal_array.ndim != 3:
raise IOError("Signal data must be 3D")
if background_array.ndim != 3:
if background_array is not None and background_array.ndim != 3:
raise IOError("Background data must be 3D")

# Too many workers doesn't increase speed, and uses huge amounts of RAM
Expand Down Expand Up @@ -138,13 +140,32 @@ def main(
model_weights = trained_model
trained_model = None

if (
trained_model is None
and model_weights
and (Path(model_weights).suffix == ".keras")
):
trained_model = model_weights
model_weights = None

model = get_model(
existing_model=trained_model,
model_weights=model_weights,
network_depth=models[network_depth],
inference=True,
num_channels=dataset.num_channels,
)

model_channels = model_input_channels(model)
if model_channels != dataset.num_channels:
raise ValueError(
f"The classification model expects {model_channels}-channel "
f"input but {dataset.num_channels} channel(s) were provided. "
f"Use a `trained_model` whose channel count matches the data, "
f"or provide data matching the model (signal only for 1, "
f"signal + background for 2)."
)

logger.info("Running inference")
if workers:
dataset.start_dataset_thread(workers)
Expand Down
10 changes: 10 additions & 0 deletions cellfinder/core/classify/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
from cellfinder.core.classify.resnet import build_model, layer_type


def model_input_channels(model: Model) -> int:
"""The number of input channels the model expects."""
return tuple(model.inputs[0].shape)[-1]


def get_model(
existing_model: Optional[os.PathLike] = None,
model_weights: Optional[os.PathLike] = None,
network_depth: Optional[layer_type] = None,
learning_rate: float = 0.0001,
inference: bool = False,
continue_training: bool = False,
num_channels: int = 2,
) -> Model:
"""Returns the correct model based on the arguments passed
:param existing_model: An existing, trained model. This is returned if it
Expand All @@ -28,6 +34,9 @@ def get_model(
by using the default one
:param continue_training: If True, will ensure that a trained model
exists. E.g. by using the default one
:param num_channels: Number of input channels for a freshly built model.
``2`` for the standard signal+background model, ``1`` for a single-channel
(signal-only) model. Ignored when ``existing_model`` is loaded.
:return: A keras model

"""
Expand All @@ -37,6 +46,7 @@ def get_model(
else:
logger.debug(f"Creating a new instance of model: {network_depth}")
model = build_model(
shape=(50, 50, 20, num_channels),
network_depth=network_depth,
learning_rate=learning_rate,
)
Expand Down
13 changes: 11 additions & 2 deletions cellfinder/core/download/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,28 @@


MODEL_URL = "https://gin.g-node.org/cellfinder/models/raw/master"
HF_1CH_URL = "https://huggingface.co/brainglobe/cellfinder_single_channel_default/resolve/main" # noqa: E501

model_filenames = {
"resnet50_tv": "resnet50_tv.h5",
"resnet50_all": "resnet50_weights.h5",
"resnet50_1ch": "resnet50_single_channel.keras",
}

model_urls = {
"resnet50_tv": f"{MODEL_URL}/resnet50_tv.h5",
"resnet50_all": f"{MODEL_URL}/resnet50_weights.h5",
"resnet50_1ch": f"{HF_1CH_URL}/resnet50_single_channel.keras",
}

model_hashes = {
"resnet50_tv": "63d36af456640590ba6c896dc519f9f29861015084f4c40777a54c18c1fc4edd", # noqa: E501
"resnet50_all": None,
"resnet50_1ch": "4c0af5e916195603266fc18686a84e7156683cbd6e91b27385e9d6e0b5ef5a55", # noqa: E501
}


model_type = Literal["resnet50_tv", "resnet50_all"]
model_type = Literal["resnet50_tv", "resnet50_all", "resnet50_1ch"]


def download_models(
Expand All @@ -55,7 +64,7 @@ def download_models(
download_path = Path(download_path)
filename = model_filenames[model_name]
model_path = pooch.retrieve(
url=f"{MODEL_URL}/{filename}",
url=model_urls[model_name],
known_hash=model_hashes[model_name],
path=download_path,
fname=filename,
Expand Down
10 changes: 7 additions & 3 deletions cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def main(
*,
signal_array: types.array,
background_array: types.array,
background_array: Optional[types.array],
voxel_sizes: Tuple[float, float, float],
start_plane: int = 0,
end_plane: int = -1,
Expand Down Expand Up @@ -58,8 +58,12 @@ def main(
----------
signal_array : numpy.ndarray or dask array
3D array representing the signal data in z, y, x order.
background_array : numpy.ndarray or dask array
3D array representing the signal data in z, y, x order.
background_array : numpy.ndarray or dask array, optional
3D array representing the background (autofluorescence) data in
z, y, x order. If ``None``, classification runs on the signal
channel alone (single-channel mode). A single-channel
``trained_model`` must then be supplied, since the default pretrained
weights are two-channel.
voxel_sizes : 3-tuple of floats
Size of your voxels in the z, y, and x dimensions (microns).
start_plane : int
Expand Down
6 changes: 4 additions & 2 deletions cellfinder/core/train/train_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,18 @@ def run(
f"in {len(yaml_file)} yaml files"
)

filenames_train, cells_train = make_tiff_lists(tiff_files)
num_channels = len(filenames_train[0])

model = get_model(
existing_model=trained_model,
model_weights=model_weights,
network_depth=models[network_depth],
learning_rate=learning_rate,
continue_training=continue_training,
num_channels=num_channels,
)

filenames_train, cells_train = make_tiff_lists(tiff_files)

n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
n_processes = min(n_processes, max_workers)
if test_fraction > 0:
Expand Down
28 changes: 28 additions & 0 deletions tests/core/test_integration/test_train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import sys

import keras
import pytest
from pytest_mock.plugin import MockerFixture

from cellfinder.core.classify.tools import model_input_channels
from cellfinder.core.train.train_yaml import cli as train_run

data_dir = os.path.join(
Expand All @@ -12,6 +14,9 @@
cell_cubes = os.path.join(data_dir, "cells")
non_cell_cubes = os.path.join(data_dir, "non_cells")
training_yaml_file = os.path.join(data_dir, "training.yaml")
training_yaml_single_channel = os.path.join(
data_dir, "training_single_channel.yaml"
)


EPOCHS = "2"
Expand Down Expand Up @@ -40,6 +45,29 @@ def test_train(tmpdir):
assert os.path.exists(model_file)


@pytest.mark.slow
def test_train_single_channel(tmpdir):
tmpdir = str(tmpdir)

train_args = [
"cellfinder_train",
"-y",
training_yaml_single_channel,
"-o",
tmpdir,
"--epochs",
EPOCHS,
]
sys.argv = train_args
train_run()

model_file = os.path.join(tmpdir, "model.keras")
assert os.path.exists(model_file)

model = keras.models.load_model(model_file)
assert model_input_channels(model) == 1


@pytest.mark.parametrize("lr_schedule", [True, False])
def test_train_lr_schedule(mocker: MockerFixture, tmpdir, lr_schedule):
tmpdir = str(tmpdir)
Expand Down
74 changes: 74 additions & 0 deletions tests/core/test_unit/test_classify/test_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import MagicMock

import numpy as np
import pytest
from brainglobe_utils.cells.cells import Cell

from cellfinder.core.classify import classify


def test_classify_channel_mismatch_raises(synthetic_single_spot, mocker):
"""A model whose channel count differs from the data raises clearly."""
signal_array, _background, c_xyz = synthetic_single_spot
signal_array = signal_array.astype(np.uint16)
points = [Cell(tuple(int(c) for c in c_xyz), Cell.UNKNOWN)]

# a two-channel model fed single-channel (background=None) data
fake_model = MagicMock()
fake_model.inputs = [MagicMock(shape=(None, 50, 50, 20, 2))]
mocker.patch(
"cellfinder.core.classify.classify.get_model",
return_value=fake_model,
)

with pytest.raises(ValueError, match="expects 2-channel input but 1"):
classify.main(
points,
signal_array,
None,
0,
(5, 1, 1),
(5, 1, 1),
1,
50,
50,
20,
None,
None,
"50",
)


def test_classify_keras_weights_loaded_as_model(synthetic_single_spot, mocker):
"""A ``.keras`` path passed as weights is loaded as a full model."""
signal_array, _background, c_xyz = synthetic_single_spot
signal_array = signal_array.astype(np.uint16)
points = [Cell(tuple(int(c) for c in c_xyz), Cell.UNKNOWN)]

fake_model = MagicMock()
fake_model.inputs = [MagicMock(shape=(None, 50, 50, 20, 2))]
get_model = mocker.patch(
"cellfinder.core.classify.classify.get_model",
return_value=fake_model,
)

with pytest.raises(ValueError, match="expects 2-channel input but 1"):
classify.main(
points,
signal_array,
None,
0,
(5, 1, 1),
(5, 1, 1),
1,
50,
50,
20,
None,
"model.keras",
"50",
)

_, kwargs = get_model.call_args
assert kwargs["existing_model"] == "model.keras"
assert kwargs["model_weights"] is None
8 changes: 8 additions & 0 deletions tests/core/test_unit/test_classify/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ def test_incorrect_weights(mock_build_model):
inference=True,
model_weights="incorrect_weights.h5",
)


@pytest.mark.parametrize("num_channels", [1, 2])
def test_get_model_num_channels(num_channels):
model = tools.get_model(
network_depth="18-layer", num_channels=num_channels
)
assert tools.model_input_channels(model) == num_channels
13 changes: 13 additions & 0 deletions tests/core/test_unit/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from cellfinder.core.download import download


def test_model_registry_consistent():
assert set(download.model_filenames) == set(download.model_hashes)
assert set(download.model_filenames) == set(download.model_urls)


def test_single_channel_model_registered():
assert "resnet50_1ch" in download.model_filenames
assert "resnet50_1ch" in download.model_hashes
assert download.model_filenames["resnet50_1ch"].endswith(".keras")
assert "huggingface.co" in download.model_urls["resnet50_1ch"]
16 changes: 16 additions & 0 deletions tests/core/test_unit/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,19 @@ def test_valid_weights_allows_detection(

mock_prep_weights.assert_called_once()
mock_detect.assert_called_once()


@patch("cellfinder.core.detect.detect.main", return_value=[])
@patch("cellfinder.core.tools.prep.prep_model_weights")
def test_optional_background_allows_detection(
mock_prep_weights, mock_detect, signal_array
):
mock_prep_weights.return_value = "/some/weights.h5"

main(
signal_array=signal_array,
background_array=None,
voxel_sizes=(5, 2, 2),
)

mock_detect.assert_called_once()
11 changes: 11 additions & 0 deletions tests/data/integration/training/training_single_channel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data:
- bg_channel: -1
cell_def: ''
cube_dir: tests/data/integration/training/cells
signal_channel: 0
type: cell
- bg_channel: -1
cell_def: ''
cube_dir: tests/data/integration/training/cells
signal_channel: 0
type: no_cell
Loading