Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ EdgeDistance
^^^^^^^^^^^^
.. autoclass:: EdgeDistance

SymmetricDivision
^^^^^^^^^^^^^^^^^
.. autoclass:: SymmetricDivision

Features
--------

Expand Down
2 changes: 2 additions & 0 deletions motile/costs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .features import Features
from .node_selection import NodeSelection
from .split import Split
from .symmetric_division import SymmetricDivision
from .weight import Weight
from .weights import Weights

Expand All @@ -18,6 +19,7 @@
"Features",
"NodeSelection",
"Split",
"SymmetricDivision",
"Weight",
"Weights",
]
21 changes: 13 additions & 8 deletions motile/costs/edge_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,20 @@ def __init__(

def apply(self, solver: Solver) -> None:
edge_variables = solver.get_variables(EdgeSelected)
for key, index in edge_variables.items():
u, v = cast("tuple[int, int]", key)
pos_u = self.__get_node_position(solver.graph, u)
pos_v = self.__get_node_position(solver.graph, v)

feature = np.linalg.norm(pos_u - pos_v)

solver.add_variable_cost(index, feature, self.weight)
solver.add_variable_cost(index, 1.0, self.constant)
for key, index in edge_variables.items():
if solver.graph.is_hyperedge(key):
solver.add_variable_cost(index, 0.0, self.weight)
solver.add_variable_cost(index, 0.0, self.constant)
else:
u, v = cast("tuple[int, int]", key)
pos_u = self.__get_node_position(solver.graph, u)
pos_v = self.__get_node_position(solver.graph, v)

feature = np.linalg.norm(pos_u - pos_v)

solver.add_variable_cost(index, feature, self.weight)
solver.add_variable_cost(index, 1.0, self.constant)

def __get_node_position(self, graph: nx.DiGraph, node: int) -> np.ndarray:
if isinstance(self.position_attribute, tuple):
Expand Down
67 changes: 67 additions & 0 deletions motile/costs/symmetric_division.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from ..variables import EdgeSelected
from .cost import Cost
from .weight import Weight

if TYPE_CHECKING:
import networkx as nx

from motile.solver import Solver


class SymmetricDivision(Cost):
"""Cost for the distance between a parent and the mean locations of its children.

This cost requires division to be represented as "hyperedges" in the graph.
That is, there is an edge (`a`, (`b`, `c`)) for every possible division with node
`a` as the parent and nodes `b` and `c` as the children.

Args:
position_attribute:
The name of the node attribute that corresponds to the spatial
position. Can also be given as a tuple of multiple coordinates,
e.g., ``('z', 'y', 'x')``.

weight:
The weight to apply to the distance to convert it into a cost.

constant:
A constant cost for each selected division. Default is ``0.0``.
"""

def __init__(
self,
position_attribute: str | tuple[str, ...],
weight: float = 1.0,
constant: float = 0.0,
) -> None:
self.position_attribute = position_attribute
self.weight = Weight(weight)
self.constant = Weight(constant)

def apply(self, solver: Solver) -> None:
edge_variables = solver.get_variables(EdgeSelected)
for key, index in edge_variables.items():
if solver.graph.is_hyperedge(key):
(start,) = key[0]
end1, end2 = key[1]
pos_start = self.__get_node_position(solver.graph, start)
pos_end1 = self.__get_node_position(solver.graph, end1)
pos_end2 = self.__get_node_position(solver.graph, end2)
feature = np.linalg.norm(pos_start - 0.5 * (pos_end1 + pos_end2))
solver.add_variable_cost(index, feature, self.weight)
solver.add_variable_cost(index, 1.0, self.constant)
else:
solver.add_variable_cost(index, 0.0, self.weight)
solver.add_variable_cost(index, 0.0, self.constant)

def __get_node_position(self, graph: nx.DiGraph, node: int) -> np.ndarray:
if isinstance(self.position_attribute, tuple):
return np.array([graph.nodes[node][p] for p in self.position_attribute])
else:
return np.array(graph.nodes[node][self.position_attribute])
67 changes: 67 additions & 0 deletions tests/test_costs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import motile
import networkx as nx
from motile.costs import (
Appear,
Disappear,
EdgeSelection,
NodeSelection,
SymmetricDivision,
)


Expand Down Expand Up @@ -81,3 +83,68 @@ def test_disappear_cost(arlo_graph):
solution_graph = solver.get_selected_subgraph()
assert list(solution_graph.nodes.keys()) == [2, 3, 4, 5, 6]
assert len(solution_graph.edges) == 0


def test_symmetric_division_cost() -> None:
"""Test that symmetric division cost favors divisions with centered children.

Graph structure:
t=0: node 0 at y=10
t=1: node 1 at y=5, node 2 at y=15, node 3 at y=10

The symmetric division (0 -> 1,2) should be chosen because the average
of y=5 and y=15 equals y=10 (the parent's position), resulting in zero cost.
The alternative division (0 -> 1,3) would have higher cost due to asymmetry.
"""
# Create nodes
cells = [
{"id": 0, "t": 0, "y": 10.0},
{"id": 1, "t": 1, "y": 5.0},
{"id": 2, "t": 1, "y": 15.0},
{"id": 3, "t": 1, "y": 10.0},
]

nx_graph = nx.DiGraph()
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])

# Add hyperedge for division (0 -> 1, 2) - the symmetric division
# Create a hypernode to represent this division
nx_graph.add_node(10) # hypernode (no frame attribute)
nx_graph.add_edge(0, 10)
nx_graph.add_edge(10, 1)
nx_graph.add_edge(10, 2)

# Add hyperedge for division (0 -> 1, 3) - asymmetric division
nx_graph.add_node(11) # another hypernode
nx_graph.add_edge(0, 11)
nx_graph.add_edge(11, 1)
nx_graph.add_edge(11, 3)

# Add hyperedge for division (0 -> 2, 3) - also asymmetric
nx_graph.add_node(12) # another hypernode
nx_graph.add_edge(0, 12)
nx_graph.add_edge(12, 2)
nx_graph.add_edge(12, 3)

graph = motile.TrackGraph(nx_graph)

solver = motile.Solver(graph)

# Add only the symmetric division cost
solver.add_cost(EdgeSelection(constant=-1.0))
solver.add_cost(SymmetricDivision(position_attribute="y", weight=1.0))

# Solve and check that the symmetric division is chosen
solver.solve()
solution = solver.get_selected_subgraph().to_nx_graph(flatten_hyperedges=True)

# The solution should select all 4 nodes (0, 1, 2, 3)
assert set(solution.nodes.keys()) == {0, 1, 2, 3}

# Check that the symmetric division hyperedge (through hypernode 10) is selected
# The solution should have edges: 0->1, 0->2
assert solution.has_edge(0, 1)
assert solution.has_edge(0, 2)

# The asymmetric divisions should not be selected
assert not solution.has_edge(0, 3)