Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pyqt = ">=5.15.9,<6"
numpy = "<2.2"
pre-commit = ">=4.1.0,<5"
dask = ">=2025.2.0,<2026"
scikit-learn = ">=1.8.0,<2"

[tool.pixi.feature.test.dependencies]
pytest = "*"
Expand Down
2 changes: 0 additions & 2 deletions scripts/script_neuromast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"/hpc/projects/group.royer/people/teun.huijben/data/Thibault/"
"neuromast_model/instanseg_27527750.pt"
)
instanseg_device = None # 'cuda', 'cpu', or None for auto-detect
# *************************

if __name__ == "__main__":
Expand All @@ -64,5 +63,4 @@
annotation_mapping=annotation_mapping,
flag_allow_adding_instanseg_cell=flag_allow_adding_instanseg_cell,
instanseg_model_path=instanseg_model_path,
instanseg_device=instanseg_device,
)
698 changes: 544 additions & 154 deletions trackedit/TrackEditClass.py

Large diffs are not rendered by default.

81 changes: 79 additions & 2 deletions trackedit/_tests/test_UI_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_trackedit_widgets(
imaging_zarr_file="",
imaging_channel="",
viewer=viewer,
flag_allow_adding_spherical_cell=True, # Enable spherical cell feature for testing
flag_allow_adding_spherical_cell=True,
)

# Get the NavigationWidget directly from TrackEdit instance
Expand Down Expand Up @@ -85,6 +85,7 @@ def test_trackedit_widgets(
check_time_box(time_box)
check_editing(TV, editing_menu)
check_add_spherical_cell(track_edit, editing_menu)
check_split_cell(track_edit, TV)
check_red_flag_box(TV, red_flag_box, time_box)
check_division_box(division_box)
check_annotation(toAnnotateBox)
Expand Down Expand Up @@ -209,7 +210,7 @@ def check_add_spherical_cell(track_edit, editing_menu):

# Call the method directly to add a cell
new_node_id = track_edit.add_spherical_cell_at_position(
position_scaled=position, radius_pixels=10
position_scaled=position, radius_physical=10
)

# Verify a node was created
Expand Down Expand Up @@ -244,6 +245,82 @@ def check_add_spherical_cell(track_edit, editing_menu):
track_edit.tracksviewer.undo()


def check_split_cell(track_edit, TV):
"""Check cell splitting functionality and single-step undo."""
tc = TV.tracks_controller

def graph_nodes():
return set(tc.tracks.graph.nodes())

# --- K-means split ---
node_id = 2000001
TV.selected_nodes.add(node_id, append=False)
nodes_before = graph_nodes()

track_edit.split_cell("K-means")

nodes_after = graph_nodes()
assert node_id not in nodes_after, "Original node should be removed after split"
new_nodes = nodes_after - nodes_before
assert len(new_nodes) == 2, f"Expected 2 new nodes, got {len(new_nodes)}"
assert (
len(nodes_after) == len(nodes_before) + 1
), "Net node count should increase by 1"

# Single undo should restore the original node and remove both new ones
TV.undo()
nodes_after_undo = graph_nodes()
assert node_id in nodes_after_undo, "Original node should be restored after undo"
assert not (new_nodes & nodes_after_undo), "New nodes should be gone after undo"
assert (
nodes_after_undo == nodes_before
), "Graph should match pre-split state after undo"

# --- Watershed (distance) split ---
node_id2 = 2000002
TV.selected_nodes.add(node_id2, append=False)
nodes_before2 = graph_nodes()

track_edit.split_cell("Watershed (distance)")

nodes_after2 = graph_nodes()
assert (
node_id2 not in nodes_after2
), "Original node should be removed after watershed split"
new_nodes2 = nodes_after2 - nodes_before2
assert len(new_nodes2) == 2, f"Expected 2 new nodes, got {len(new_nodes2)}"

TV.undo()
nodes_after_undo2 = graph_nodes()
assert node_id2 in nodes_after_undo2, "Original node should be restored after undo"
assert (
nodes_after_undo2 == nodes_before2
), "Graph should match pre-split state after undo"

# --- Multi-cell split (two cells at once) ---
node_a, node_b = 2000001, 2000002
TV.selected_nodes.add(node_a, append=False)
TV.selected_nodes.add(node_b, append=True)
nodes_before_multi = graph_nodes()

track_edit.split_cell("K-means")

nodes_after_multi = graph_nodes()
assert node_a not in nodes_after_multi, "Node A should be removed"
assert node_b not in nodes_after_multi, "Node B should be removed"
new_nodes_multi = nodes_after_multi - nodes_before_multi
assert (
len(new_nodes_multi) == 4
), f"Expected 4 new nodes from splitting 2 cells, got {len(new_nodes_multi)}"

# All splits are grouped into a single undo step
TV.undo()
nodes_final = graph_nodes()
assert node_a in nodes_final, "Node A should be restored"
assert node_b in nodes_final, "Node B should be restored"
assert nodes_final == nodes_before_multi, "Graph should fully match pre-split state"


def check_red_flag_box(TV, red_flag_box, time_box):
"""Check red flag box functionality"""

Expand Down
23 changes: 23 additions & 0 deletions trackedit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click

from trackedit.utils.crop import crop_database_in_time
from trackedit.utils.geff import convert_geff_to_db


Expand Down Expand Up @@ -36,5 +37,27 @@ def geff_to_db(geff_path: Path, output: Path = None):
convert_geff_to_db(geff_path, output)


@cli.command("crop-db")
@click.argument("source_db", type=click.Path(exists=True, path_type=Path))
@click.option(
"--max-t",
required=True,
type=int,
help="Maximum time frame to include (inclusive).",
)
@click.option(
"--output",
"-o",
type=click.Path(path_type=Path),
help="Output database path (default: <input_stem>_t0-<max_t>.db)",
)
def crop_db(source_db: Path, max_t: int, output: Path = None):
"""Crop an Ultrack SQLite database to the first MAX_T frames."""
if output is None:
output = source_db.parent / f"{source_db.stem}_t0-{max_t}.db"
crop_database_in_time(source_db, output, max_t)
print(f"Cropped database written to {output}")


if __name__ == "__main__":
cli()
14 changes: 14 additions & 0 deletions trackedit/motile_overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from qtpy.QtGui import QColor
from ultrack.core.database import NodeDB, get_node_values

import motile_tracker.data_views.views.tree_view.tree_widget_utils as _tree_widget_utils
from motile_tracker.data_model.actions import ActionGroup, AddEdges, DeleteEdges
from motile_tracker.data_model.solution_tracks import SolutionTracks
from motile_tracker.data_model.tracks_controller import TracksController
Expand All @@ -26,6 +27,19 @@
Edge: TypeAlias = tuple[Node, Node]


# Bug fix: original uses parent_map.get(current_track) which returns None for
# missing keys, and None != 0 is always True — causing an infinite loop when
# all root cells of a division are deleted.
def _patched_find_root(track_id: int, parent_map: dict) -> int:
current_track = track_id
while parent_map.get(current_track, 0) != 0:
current_track = parent_map.get(current_track)
return current_track


_tree_widget_utils.find_root = _patched_find_root


def create_db_add_nodes(DB_handler):
def db_add_nodes(self):
# don't use full old function, because it includes painting pixels in segmentation
Expand Down
6 changes: 3 additions & 3 deletions trackedit/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def run_trackedit(
image_translate: Optional[Tuple[float, ...]] = None,
viewer: Optional[napari.Viewer] = None,
flag_show_hierarchy: bool = True,
flag_allow_adding_spherical_cell: bool = False,
adding_spherical_cell_radius: int = 10,
flag_allow_adding_spherical_cell: bool = True,
adding_spherical_cell_radius: int = 5,
flag_allow_adding_instanseg_cell: bool = False,
instanseg_model_path: Optional[str] = None,
instanseg_device: Optional[str] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def run_trackedit(
viewer: Optional existing napari viewer
flag_show_hierarchy: Show hierarchy in the viewer
flag_allow_adding_spherical_cell: Allow adding spherical cells via button (default: False)
adding_spherical_cell_radius: Radius of spherical cells in pixels (default: 10)
adding_spherical_cell_radius: Radius of spherical cells in physical units (default: 10)
flag_allow_adding_instanseg_cell: Allow adding InstanSeg-segmented cells via button (default: False)
instanseg_model_path: Path to InstanSeg TorchScript model file
(required if flag_allow_adding_instanseg_cell=True)
Expand Down
71 changes: 71 additions & 0 deletions trackedit/utils/crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import shutil
import sqlite3
from pathlib import Path


def crop_database_in_time(source_db: Path, output_db: Path, max_t: int) -> None:
"""Create a time-cropped copy of an ultrack SQLite database.

Copies nodes, links, and overlaps tables from the source database,
keeping only entries up to and including time frame `max_t`.
The gt_nodes and gt_links tables are copied as-is (assumed empty).

Args:
source_db: Path to the source .db file.
output_db: Path where the cropped copy will be written.
max_t: Maximum time frame to include (inclusive). Nodes with t <= max_t
are kept; links and overlaps are filtered to only reference
surviving node IDs.
"""
shutil.copy2(source_db, output_db)

conn = sqlite3.connect(output_db)
try:
# Remove nodes outside time range
conn.execute("DELETE FROM nodes WHERE t > ?", (max_t,))

# Collect surviving node IDs
surviving_ids = {
row[0] for row in conn.execute("SELECT id FROM nodes").fetchall()
}

# Remove links where either endpoint is gone
all_links = conn.execute(
"SELECT id, source_id, target_id FROM links"
).fetchall()
link_ids_to_delete = [
row[0]
for row in all_links
if row[1] not in surviving_ids or row[2] not in surviving_ids
]
if link_ids_to_delete:
conn.execute(
f"DELETE FROM links WHERE id IN ({','.join('?' * len(link_ids_to_delete))})",
link_ids_to_delete,
)

# Remove overlaps where either node_id or ancestor_id is gone
all_overlaps = conn.execute(
"SELECT rowid, node_id, ancestor_id FROM overlaps"
).fetchall()
overlap_rowids_to_delete = [
row[0]
for row in all_overlaps
if row[1] not in surviving_ids or row[2] not in surviving_ids
]
if overlap_rowids_to_delete:
conn.execute(
f"DELETE FROM overlaps WHERE rowid IN ({','.join('?' * len(overlap_rowids_to_delete))})",
overlap_rowids_to_delete,
)

conn.commit()
conn.execute("VACUUM")
finally:
conn.close()


if __name__ == "__main__":
from trackedit.cli import cli

cli()
16 changes: 8 additions & 8 deletions trackedit/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def calculate_bbox_from_mask(mask):

def create_cell_mask_and_bbox(
position_scaled: np.ndarray,
radius_pixels: float,
radius_physical: float,
ndim: int,
scale: tuple,
data_shape_full: tuple,
Expand All @@ -407,8 +407,8 @@ def create_cell_mask_and_bbox(
----------
position_scaled : array-like
Position in viewer coordinates (scaled)
radius_pixels : float
Radius of the sphere in pixels
radius_physical : float
Radius of the sphere in physical units (same units as scale)
ndim : int
Number of dimensions (3 for 2D+t, 4 for 3D+t)
scale : tuple
Expand All @@ -435,9 +435,9 @@ def create_cell_mask_and_bbox(
]
)
radii = (
radius_pixels / z_scale,
radius_pixels / y_scale,
radius_pixels / x_scale,
radius_physical / z_scale,
radius_physical / y_scale,
radius_physical / x_scale,
)
else:
y_scale, x_scale = scale
Expand All @@ -449,8 +449,8 @@ def create_cell_mask_and_bbox(
]
)
radii = (
radius_pixels / y_scale,
radius_pixels / x_scale,
radius_physical / y_scale,
radius_physical / x_scale,
)

# Calculate bounding box
Expand Down
Loading
Loading