Skip to content
101 changes: 74 additions & 27 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,50 +389,97 @@ def _bbox_to_slices(self, bbox: Any) -> tuple[slice, ...] | None:

return tuple(slice(int(s), int(e)) for s, e in zip(start, stop, strict=True))

def _invalidate_from_attrs(self, attrs: dict) -> None:
def _invalidate_bbox(self, time_values: Sequence[Any], bboxes: Sequence[np.ndarray | None]) -> None:
"""
Invalidate cache region touched by node attributes.
Invalidate the cache regions covered by the given times and bboxes.

Falls back to larger invalidation windows when metadata is incomplete.
"""

time_value = attrs.get(DEFAULT_ATTR_KEYS.T)
if time_value is None:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.T}' key for cache invalidation.")
if DEFAULT_ATTR_KEYS.BBOX not in attrs:
raise ValueError(f"Node attributes must contain '{DEFAULT_ATTR_KEYS.BBOX}' key for cache invalidation.")
``time_values`` and ``bboxes`` are parallel sequences; each ``(time, bbox)``
pair is clipped to the array volume and the matching cache region is dropped.
A bbox that lies outside the array volume invalidates nothing.

try:
time = int(np.asarray(time_value).item())
except (TypeError, ValueError) as e:
raise ValueError(
f"Time attribute value must be a scalar integer, got {time_value} of type {type(time_value)}"
) from e
if not (0 <= time < self.original_shape[0]):
return
A ``GraphArrayView`` requires every node to carry a ``bbox`` attribute, so a
``None`` bbox is a programming error and raises ``ValueError``.
"""
if hasattr(time_values, "to_list"):
time_values = time_values.to_list()

slices = self._bbox_to_slices(attrs[DEFAULT_ATTR_KEYS.BBOX])
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)
for time_value, bbox in zip(time_values, bboxes, strict=True):
try:
time = int(time_value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Time attribute value must be a scalar integer, got {time_value!r} of type {type(time_value)}"
) from e
if not (0 <= time < self.original_shape[0]):
continue

if bbox is None:
raise ValueError(
f"Node at time {time} is missing a '{DEFAULT_ATTR_KEYS.BBOX}' attribute. "
"A GraphArrayView requires every node to have a bbox."
)

slices = self._bbox_to_slices(bbox)
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)

def _on_node_added(
self,
node_ids: list[int],
new_attrs: list[dict],
) -> None:
for attrs in new_attrs:
self._invalidate_from_attrs(attrs)
del node_ids
self._invalidate_bbox(
[attrs[DEFAULT_ATTR_KEYS.T] for attrs in new_attrs],
[attrs.get(DEFAULT_ATTR_KEYS.BBOX) for attrs in new_attrs],
)

def _on_node_removed(self, node_ids: list[int], old_attrs: list[dict]) -> None:
for attrs in old_attrs:
self._invalidate_from_attrs(attrs)
del node_ids
self._invalidate_bbox(
[attrs[DEFAULT_ATTR_KEYS.T] for attrs in old_attrs],
[attrs.get(DEFAULT_ATTR_KEYS.BBOX) for attrs in old_attrs],
)

def _on_node_updated(
self,
node_ids: list[int],
old_attrs: list[dict],
new_attrs: list[dict],
) -> None:
del node_ids
time_values: list[Any] = []
bboxes: list[Any] = []
for old_attr, new_attr in zip(old_attrs, new_attrs, strict=True):
self._invalidate_from_attrs(old_attr)
self._invalidate_from_attrs(new_attr)
old_t = old_attr[DEFAULT_ATTR_KEYS.T]
new_t = new_attr[DEFAULT_ATTR_KEYS.T]
old_bbox = old_attr.get(DEFAULT_ATTR_KEYS.BBOX)
new_bbox = new_attr.get(DEFAULT_ATTR_KEYS.BBOX)

moved = old_t != new_t or not np.array_equal(old_bbox, new_bbox)

if moved:
# Node relocated: clear the stale region and paint the new one.
time_values.extend((old_t, new_t))
bboxes.extend((old_bbox, new_bbox))
elif old_attr.get(self._attr_key) != new_attr.get(self._attr_key) or self._mask_changed(old_attr, new_attr):
time_values.append(new_t)
bboxes.append(new_bbox)

self._invalidate_bbox(time_values, bboxes)

@staticmethod
def _mask_changed(old_attr: dict, new_attr: dict) -> bool:
"""
Whether the painted output changed while the bbox stayed in place.

The rendered region depends on the displayed attribute value and the mask
pixels, so a mask swap with an unchanged bbox still requires invalidation.
"""
old_mask = old_attr.get(DEFAULT_ATTR_KEYS.MASK)
new_mask = new_attr.get(DEFAULT_ATTR_KEYS.MASK)
if old_mask is None and new_mask is None:
return False
elif old_mask is None or new_mask is None:
return True
return old_mask != new_mask
218 changes: 218 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from unittest.mock import MagicMock, patch

import numpy as np
import polars as pl
Expand Down Expand Up @@ -463,6 +464,125 @@ def test_graph_array_view_invalidates_old_and_new_chunks_on_update(graph_backend
assert output[5, 5] == 7


def test_graph_array_view_invalidates_once_when_attr_key_changes_but_bbox_unchanged(
graph_backend: BaseGraph,
) -> None:
"""Updating the displayed attr_key (label) without moving the node should invalidate the region exactly once."""
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1)
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

mock_invalidate = MagicMock(wraps=array_view._invalidate_bbox)
with patch.object(array_view, "_invalidate_bbox", mock_invalidate):
graph_backend.update_node_attrs(
attrs={"label": [7]},
node_ids=[node_id],
)
# bbox unchanged, but the displayed attribute changed — invalidate exactly one region
n_regions = sum(len(call.args[0]) for call in mock_invalidate.call_args_list)
assert n_regions == 1

# The affected chunk must be invalidated
expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[0, 0] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

# After recomputation, the new value should be painted
output = np.asarray(array_view[0])
assert output[1, 1] == 7


def test_graph_array_view_invalidates_twice_when_attr_key_and_bbox_change(graph_backend: BaseGraph) -> None:
"""Updating both the displayed attr_key and the bbox should invalidate old and new regions."""
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1)
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

moved_mask = _make_square_mask(5, 5)
mock_invalidate = MagicMock(wraps=array_view._invalidate_bbox)
with patch.object(array_view, "_invalidate_bbox", mock_invalidate):
graph_backend.update_node_attrs(
attrs={
"label": [7],
DEFAULT_ATTR_KEYS.MASK: [moved_mask],
DEFAULT_ATTR_KEYS.BBOX: [moved_mask.bbox],
},
node_ids=[node_id],
)
# bbox changed — must invalidate both old and new regions
n_regions = sum(len(call.args[0]) for call in mock_invalidate.call_args_list)
assert n_regions == 2

expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[0, 0] = False
expected_ready[1, 1] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

output = np.asarray(array_view[0])
assert output[1, 1] == 0
assert output[5, 5] == 7


def test_graph_array_view_no_invalidation_when_unrelated_attr_changes(graph_backend: BaseGraph) -> None:
"""Updating an attribute the view doesn't display should not invalidate any chunks."""
_add_graph_array_node_attrs(graph_backend)
graph_backend.add_node_attr_key("score", dtype=pl.Float64)

mask = _make_square_mask(1, 1)
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
"score": 0.5,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

mock_invalidate = MagicMock(wraps=array_view._invalidate_bbox)
with patch.object(array_view, "_invalidate_bbox", mock_invalidate):
graph_backend.update_node_attrs(
attrs={"score": [0.9]},
node_ids=[node_id],
)
# Neither bbox nor the displayed attribute changed — no region invalidated
n_regions = sum(len(call.args[0]) for call in mock_invalidate.call_args_list)
assert n_regions == 0

np.testing.assert_array_equal(
array_view._cache._store[0].ready,
np.ones((2, 2), dtype=bool),
)


def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph) -> None:
_add_graph_array_node_attrs(graph_backend)

Expand Down Expand Up @@ -497,3 +617,101 @@ def test_graph_array_view_invalidates_chunk_on_remove(graph_backend: BaseGraph)
output = np.asarray(array_view[0])
assert output[1, 1] == 1
assert output[5, 5] == 0


def test_graph_array_view_raises_when_bbox_missing(graph_backend: BaseGraph) -> None:
"""A GraphArrayView requires every node to have a bbox, so a missing bbox must raise."""
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1)
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

# An event whose attrs lack a bbox key is a programming error.
with pytest.raises(ValueError, match=DEFAULT_ATTR_KEYS.BBOX):
array_view._on_node_added([999], [{DEFAULT_ATTR_KEYS.T: 0, "label": 5}])


def test_graph_array_view_invalidates_once_when_mask_changes_but_bbox_unchanged(graph_backend: BaseGraph) -> None:
"""Swapping the mask pixels without moving the bbox must invalidate the region exactly once."""
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1) # bbox [1, 1, 3, 3], fully filled
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

# Same bbox [1, 1, 3, 3], but only the diagonal pixels are set.
new_mask = Mask(np.array([[True, False], [False, True]], dtype=bool), bbox=np.array([1, 1, 3, 3]))
mock_invalidate = MagicMock(wraps=array_view._invalidate_bbox)
with patch.object(array_view, "_invalidate_bbox", mock_invalidate):
graph_backend.update_node_attrs(
attrs={DEFAULT_ATTR_KEYS.MASK: [new_mask]},
node_ids=[node_id],
)
# bbox and label unchanged, but the mask pixels changed — invalidate exactly one region
n_regions = sum(len(call.args[0]) for call in mock_invalidate.call_args_list)
assert n_regions == 1

expected_ready = np.ones((2, 2), dtype=bool)
expected_ready[0, 0] = False
np.testing.assert_array_equal(array_view._cache._store[0].ready, expected_ready)

# After recomputation only the diagonal pixels carry the label.
output = np.asarray(array_view[0])
assert output[1, 1] == 1
assert output[2, 2] == 1
assert output[1, 2] == 0
assert output[2, 1] == 0


def test_graph_array_view_no_invalidation_when_mask_unchanged(graph_backend: BaseGraph) -> None:
"""Re-supplying an identical mask (same bbox and pixels) must not invalidate anything."""
_add_graph_array_node_attrs(graph_backend)

mask = _make_square_mask(1, 1)
node_id = graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
"label": 1,
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)

array_view = GraphArrayView(graph=graph_backend, shape=(2, 8, 8), attr_key="label", chunk_shape=(4, 4))
_ = np.asarray(array_view[0])
np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))

# A fresh Mask object with identical bbox and pixels: no rendered change.
same_mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([1, 1, 3, 3]))
mock_invalidate = MagicMock(wraps=array_view._invalidate_bbox)
with patch.object(array_view, "_invalidate_bbox", mock_invalidate):
graph_backend.update_node_attrs(
attrs={DEFAULT_ATTR_KEYS.MASK: [same_mask]},
node_ids=[node_id],
)
# Nothing affecting the rendered output changed — no region invalidated
n_regions = sum(len(call.args[0]) for call in mock_invalidate.call_args_list)
assert n_regions == 0

np.testing.assert_array_equal(array_view._cache._store[0].ready, np.ones((2, 2), dtype=bool))