Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ dependencies = [

[project.optional-dependencies]
spatial = ["spatial-graph"]
motile = ["motile"]
test = [
"spatial-graph",
"motile",
"pytest>=7.0",
"pytest-cov",
"pytest-html",
Expand Down
2 changes: 2 additions & 0 deletions src/tracksdata/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tracksdata.functional._division import shift_division
from tracksdata.functional._edges import join_node_attrs_to_edges
from tracksdata.functional._labeling import ancestral_connected_edges
from tracksdata.functional._motile import to_motile_graph
from tracksdata.functional._napari import rx_digraph_to_napari_dict, to_napari_format

__all__ = [
Expand All @@ -13,5 +14,6 @@
"join_node_attrs_to_edges",
"rx_digraph_to_napari_dict",
"shift_division",
"to_motile_graph",
"to_napari_format",
]
80 changes: 80 additions & 0 deletions src/tracksdata/functional/_motile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import networkx as nx

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph._base_graph import BaseGraph

if TYPE_CHECKING:
import motile
else:
motile = Any


def to_motile_graph(
graph: BaseGraph,
*,
node_attr_keys: Sequence[str] | None = None,
edge_attr_keys: Sequence[str] | None = None,
frame_attribute: str = DEFAULT_ATTR_KEYS.T,
) -> "motile.TrackGraph":
"""
Convert a tracksdata graph into a [`motile.TrackGraph`](https://funkelab.github.io/motile/).

Node and edge attributes are copied over so they can be used as `motile` costs.
Each node keeps its ``frame_attribute`` (time) value, which `motile` requires.

Parameters
----------
graph : BaseGraph
The graph to convert.
node_attr_keys : Sequence[str] | None
Node attribute keys to copy. If None, all node attributes are copied.
``NODE_ID`` and ``frame_attribute`` are always included.
edge_attr_keys : Sequence[str] | None
Edge attribute keys to copy. If None, all edge attributes are copied.
frame_attribute : str
Node attribute used as the time/frame dimension. Defaults to ``"t"``.

Returns
-------
motile.TrackGraph
A `motile` track graph with the same nodes, edges, and copied attributes.
"""
try:
import motile
except ImportError as e:
raise ImportError(
"`motile` is required to convert a graph to a `motile.TrackGraph`.\n"
"Please install it with `pip install motile`."
) from e

if node_attr_keys is not None:
node_attr_keys = list(dict.fromkeys([DEFAULT_ATTR_KEYS.NODE_ID, frame_attribute, *node_attr_keys]))

nodes_df = graph.node_attrs(attr_keys=node_attr_keys)

if frame_attribute not in nodes_df.columns:
raise ValueError(
f"Frame attribute '{frame_attribute}' not found in the graph node attributes {nodes_df.columns}."
)

nx_graph = nx.DiGraph()
for node_data in nodes_df.iter_rows(named=True):
nx_graph.add_node(node_data[DEFAULT_ATTR_KEYS.NODE_ID], **node_data)

# avoid querying edge columns that may not be registered yet when there are no edges
if graph.num_edges() == 0:
edge_attr_keys = []
elif edge_attr_keys is not None:
edge_attr_keys = list(edge_attr_keys)

edges_df = graph.edge_attrs(attr_keys=edge_attr_keys)
for edge_data in edges_df.iter_rows(named=True):
source = edge_data.pop(DEFAULT_ATTR_KEYS.EDGE_SOURCE)
target = edge_data.pop(DEFAULT_ATTR_KEYS.EDGE_TARGET)
edge_data.pop(DEFAULT_ATTR_KEYS.EDGE_ID, None)
nx_graph.add_edge(source, target, **edge_data)

return motile.TrackGraph(nx_graph, frame_attribute=frame_attribute)
64 changes: 64 additions & 0 deletions src/tracksdata/functional/_test/test_motile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import polars as pl
import pytest

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph import RustWorkXGraph

motile = pytest.importorskip("motile")

from tracksdata.functional import to_motile_graph # noqa: E402


def _build_graph() -> tuple[RustWorkXGraph, list[int], list[int]]:
graph = RustWorkXGraph()
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0})
node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.0})
node2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, "x": 2.0})

edge0 = graph.add_edge(node0, node1, {DEFAULT_ATTR_KEYS.EDGE_DIST: -1.0})
edge1 = graph.add_edge(node1, node2, {DEFAULT_ATTR_KEYS.EDGE_DIST: -1.0})

return graph, [node0, node1, node2], [edge0, edge1]


def test_to_motile_graph() -> None:
graph, node_ids, _ = _build_graph()

track_graph = to_motile_graph(graph)

assert isinstance(track_graph, motile.TrackGraph)
assert set(track_graph.nodes) == set(node_ids)
assert (node_ids[0], node_ids[1]) in track_graph.edges
assert (node_ids[1], node_ids[2]) in track_graph.edges

# frame attribute and node attributes are copied over
assert track_graph.nodes[node_ids[0]][DEFAULT_ATTR_KEYS.T] == 0
assert track_graph.nodes[node_ids[2]]["x"] == 2.0
assert track_graph.get_frames() == (0, 3)

# edge attributes are copied over, ids/source/target are not added as attributes
edge_data = track_graph.edges[(node_ids[0], node_ids[1])]
assert edge_data[DEFAULT_ATTR_KEYS.EDGE_DIST] == -1.0
assert DEFAULT_ATTR_KEYS.EDGE_ID not in edge_data
assert DEFAULT_ATTR_KEYS.EDGE_SOURCE not in edge_data


def test_to_motile_graph_subset_of_attrs() -> None:
graph, node_ids, _ = _build_graph()

track_graph = to_motile_graph(graph, node_attr_keys=[], edge_attr_keys=[])

# only NODE_ID and frame attribute are kept
assert set(track_graph.nodes[node_ids[0]]) == {DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T}
# edge has no copied attributes
assert track_graph.edges[(node_ids[0], node_ids[1])] == {}


def test_to_motile_graph_method() -> None:
graph, node_ids, _ = _build_graph()
track_graph = graph.to_motile_graph()
assert isinstance(track_graph, motile.TrackGraph)
assert set(track_graph.nodes) == set(node_ids)
44 changes: 42 additions & 2 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tracksdata.utils._multiprocessing import multiprocessing_apply

if TYPE_CHECKING:
import motile
from traccuracy import TrackingGraph

from tracksdata.graph.filters._base_filter import BaseFilter
Expand Down Expand Up @@ -2040,14 +2041,20 @@ def _remove_metadata(self, key: str) -> None:
Backend-specific metadata removal implementation without public key validation.
"""

def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph":
def to_traccuracy_graph(
self, array_view_kwargs: dict[str, Any] | None = None, location_keys: list[str] | None = None
) -> "TrackingGraph":
"""
Convert the graph to a `traccuracy.TrackingGraph`.

Parameters
----------
array_view_kwargs : dict[str, Any] | None
Additional keyword arguments to pass to the `GraphArrayView` constructor used to create the segmentation.
location_keys : list[str] | None
The keys of the location attributes to use for the segmentation.
If None, the location keys are inferred from the intersection of the graph node attributes and
the list [DEFAULT_ATTR_KEYS.Z, DEFAULT_ATTR_KEYS.Y, DEFAULT_ATTR_KEYS.X].

Returns
-------
Expand All @@ -2056,7 +2063,40 @@ def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -
"""
from tracksdata.metrics._traccuracy import to_traccuracy_graph

return to_traccuracy_graph(self, array_view_kwargs=array_view_kwargs)
return to_traccuracy_graph(self, array_view_kwargs=array_view_kwargs, location_keys=location_keys)

def to_motile_graph(
self,
*,
node_attr_keys: Sequence[str] | None = None,
edge_attr_keys: Sequence[str] | None = None,
frame_attribute: str = DEFAULT_ATTR_KEYS.T,
) -> "motile.TrackGraph":
"""
Convert the graph to a [`motile.TrackGraph`](https://funkelab.github.io/motile/).

Parameters
----------
node_attr_keys : Sequence[str] | None
Node attribute keys to copy. If None, all node attributes are copied.
edge_attr_keys : Sequence[str] | None
Edge attribute keys to copy. If None, all edge attributes are copied.
frame_attribute : str
Node attribute used as the time/frame dimension. Defaults to ``"t"``.

Returns
-------
motile.TrackGraph
A `motile` track graph.
"""
from tracksdata.functional._motile import to_motile_graph

return to_motile_graph(
self,
node_attr_keys=node_attr_keys,
edge_attr_keys=edge_attr_keys,
frame_attribute=frame_attribute,
)

@abc.abstractmethod
def has_node(self, node_id: int) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,6 +3020,8 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None:

traccuracy_graph = graph_backend.to_traccuracy_graph()

assert traccuracy_graph.location_keys == ["y", "x"]

# trivial matching with itself
ctc_results, _ = run_metrics(
gt_data=traccuracy_graph,
Expand Down
28 changes: 27 additions & 1 deletion src/tracksdata/metrics/_traccuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,28 @@
def to_traccuracy_graph(
graph: BaseGraph,
array_view_kwargs: dict[str, Any] | None = None,
location_keys: list[str] | None = None,
) -> "TrackingGraph":
"""
Convert a tracksdata graph to a traccuracy graph.

Parameters
----------
graph : BaseGraph
The graph to convert.
array_view_kwargs : dict[str, Any] | None
Additional keyword arguments to pass to the `GraphArrayView` constructor used to create the segmentation.
location_keys : list[str] | None
The keys of the location attributes to use for the segmentation.
If None, the location keys are inferred from the intersection of the graph node attributes and
the list [DEFAULT_ATTR_KEYS.Z, DEFAULT_ATTR_KEYS.Y, DEFAULT_ATTR_KEYS.X].

Returns
-------
TrackingGraph
A traccuracy graph.
"""

try:
from traccuracy import TrackingGraph
except ImportError as e:
Expand All @@ -39,4 +60,9 @@ def to_traccuracy_graph(

segmentation = GraphArrayView(graph, attr_key=DEFAULT_ATTR_KEYS.NODE_ID, **array_view_kwargs)

return TrackingGraph(nx_graph, segmentation)
if location_keys is None:
location_keys = [DEFAULT_ATTR_KEYS.Z, DEFAULT_ATTR_KEYS.Y, DEFAULT_ATTR_KEYS.X]
node_attr_keys = graph.node_attr_keys()
location_keys = [key for key in location_keys if key in node_attr_keys]

return TrackingGraph(nx_graph, segmentation, location_keys=location_keys)
Loading