Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 62 additions & 58 deletions trackedit/TrackEditClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,8 @@ def add_instanseg_cell_at_position(self, viewer, current_time):
volumes.append(np.asarray(ch_data))
image_volume = np.stack(volumes, axis=0) # (C, Z, Y, X) or (C, Y, X)

# Get view direction for ray casting
# Get view direction for ray casting (unavailable in napari's 2D canvas mode)
vd_world = getattr(viewer.cursor, "_view_direction", None)
if vd_world is None or np.allclose(vd_world, 0):
show_warning(
"Could not get view direction from napari cursor. Try clicking again."
)
return None

# Get cursor position in world coordinates
cursor_pos_world = np.asarray(viewer.cursor.position, dtype=float)
Expand All @@ -713,64 +708,73 @@ def add_instanseg_cell_at_position(self, viewer, current_time):
[self.databasehandler.y_scale, self.databasehandler.x_scale]
)

# Convert view direction and cursor position to data coordinates
vd_world = np.asarray(vd_world, dtype=float)
# View direction is also 4D (time + spatial), extract only spatial components
vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world
vd_data = vd_spatial_world / scale_array
norm = np.linalg.norm(vd_data)
if norm == 0:
show_warning("Invalid view direction. Try clicking again.")
return None
vd_data /= norm

origin_data = cursor_spatial_world / scale_array

# image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim
spatial_shape = image_volume.shape[1:]
nuclear_volume = image_volume[0] # use nuclear channel (idx 0) for ray casting

# Cast ray through volume in both directions
diag = int(np.sqrt(sum(s**2 for s in spatial_shape)))
t_values = np.arange(-diag, diag + 1, dtype=float)
ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :]
ray_voxels = np.round(ray_points).astype(int)

# Keep only voxels inside the volume
if self.databasehandler.ndim == 4:
valid = (
(ray_voxels[:, 0] >= 0)
& (ray_voxels[:, 0] < spatial_shape[0])
& (ray_voxels[:, 1] >= 0)
& (ray_voxels[:, 1] < spatial_shape[1])
& (ray_voxels[:, 2] >= 0)
& (ray_voxels[:, 2] < spatial_shape[2])
)
if vd_world is None or np.allclose(vd_world, 0):
# 2D canvas mode: no view direction available, use cursor position directly
position_data = tuple(np.round(origin_data).astype(float))
else:
valid = (
(ray_voxels[:, 0] >= 0)
& (ray_voxels[:, 0] < spatial_shape[0])
& (ray_voxels[:, 1] >= 0)
& (ray_voxels[:, 1] < spatial_shape[1])
)
# 3D canvas mode: cast ray and find brightest voxel along it
# Convert view direction to data coordinates
vd_world = np.asarray(vd_world, dtype=float)
# View direction is also 4D (time + spatial), extract only spatial components
vd_spatial_world = vd_world[1:] if len(vd_world) == 4 else vd_world
vd_data = vd_spatial_world / scale_array
norm = np.linalg.norm(vd_data)
if norm == 0:
show_warning("Invalid view direction. Try clicking again.")
return None
vd_data /= norm

# image_volume is (C, Z, Y, X) or (C, Y, X); spatial shape excludes channel dim
spatial_shape = image_volume.shape[1:]
nuclear_volume = image_volume[
0
] # use nuclear channel (idx 0) for ray casting

# Cast ray through volume in both directions
diag = int(np.sqrt(sum(s**2 for s in spatial_shape)))
t_values = np.arange(-diag, diag + 1, dtype=float)
ray_points = origin_data[None, :] + t_values[:, None] * vd_data[None, :]
ray_voxels = np.round(ray_points).astype(int)

# Keep only voxels inside the volume
if self.databasehandler.ndim == 4:
valid = (
(ray_voxels[:, 0] >= 0)
& (ray_voxels[:, 0] < spatial_shape[0])
& (ray_voxels[:, 1] >= 0)
& (ray_voxels[:, 1] < spatial_shape[1])
& (ray_voxels[:, 2] >= 0)
& (ray_voxels[:, 2] < spatial_shape[2])
)
else:
valid = (
(ray_voxels[:, 0] >= 0)
& (ray_voxels[:, 0] < spatial_shape[0])
& (ray_voxels[:, 1] >= 0)
& (ray_voxels[:, 1] < spatial_shape[1])
)

ray_voxels = ray_voxels[valid]
if len(ray_voxels) == 0:
show_warning("Ray does not intersect volume. Try clicking inside the data.")
return None
ray_voxels = ray_voxels[valid]
if len(ray_voxels) == 0:
show_warning(
"Ray does not intersect volume. Try clicking inside the data."
)
return None

# Deduplicate and find voxel with maximum intensity in nuclear channel
ray_voxels = np.unique(ray_voxels, axis=0)
if self.databasehandler.ndim == 4:
intensities = nuclear_volume[
ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2]
]
else:
intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]]
# Deduplicate and find voxel with maximum intensity in nuclear channel
ray_voxels = np.unique(ray_voxels, axis=0)
if self.databasehandler.ndim == 4:
intensities = nuclear_volume[
ray_voxels[:, 0], ray_voxels[:, 1], ray_voxels[:, 2]
]
else:
intensities = nuclear_volume[ray_voxels[:, 0], ray_voxels[:, 1]]

best_idx = int(np.argmax(intensities))
best_voxel = ray_voxels[best_idx]
position_data = tuple(best_voxel.astype(float))
best_idx = int(np.argmax(intensities))
best_voxel = ray_voxels[best_idx]
position_data = tuple(best_voxel.astype(float))

# Prepare scale tuple and handle 2D vs 3D
if self.databasehandler.ndim == 4:
Expand Down
62 changes: 62 additions & 0 deletions trackedit/motile_overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,68 @@ def fixed_click(layer, event):
# TrackLabels.get_status = get_status


def create_center_view(DB_handler):
"""Override TracksLayerGroup.center_view to convert world positions to data
coordinates when setting viewer.dims.current_step.

Positions in the trackgraph are stored in world (scaled) units, but
current_step expects voxel indices, so we divide by the layer scale.
"""

def center_view_with_scale(self, node):
if self.seg_layer is None or self.seg_layer.mode == "pan_zoom":
location = self.tracks.get_positions([node], incl_time=True)[0].tolist()
assert (
len(location) == self.viewer.dims.ndim
), f"Location {location} does not match viewer number of dims {self.viewer.dims.ndim}"

# Build per-dimension scale: dim 0 is time (unscaled)
if DB_handler.ndim == 4: # (t, z, y, x)
scale_by_dim = {
0: 1.0,
1: DB_handler.z_scale,
2: DB_handler.y_scale,
3: DB_handler.x_scale,
}
else: # ndim == 3: (t, y, x)
scale_by_dim = {
0: 1.0,
1: DB_handler.y_scale,
2: DB_handler.x_scale,
}

step = list(self.viewer.dims.current_step)
for dim in self.viewer.dims.not_displayed:
scale = scale_by_dim.get(dim, 1.0)
step[dim] = int(location[dim] / scale + 0.5)
self.viewer.dims.current_step = step

# Camera centering uses world coordinates — no scale conversion needed
example_layer = self.points_layer
corner_coordinates = example_layer.corner_pixels
dims_displayed = self.viewer.dims.displayed
x_dim = dims_displayed[-1]
y_dim = dims_displayed[-2]

_min_x = corner_coordinates[0][x_dim]
_max_x = corner_coordinates[1][x_dim]
_min_y = corner_coordinates[0][y_dim]
_max_y = corner_coordinates[1][y_dim]

if not (
(location[x_dim] > _min_x and location[x_dim] < _max_x)
and (location[y_dim] > _min_y and location[y_dim] < _max_y)
):
camera_center = self.viewer.camera.center
self.viewer.camera.center = (
camera_center[0],
location[y_dim],
location[x_dim],
)

return center_view_with_scale


# --- Custom keybindings ---


Expand Down
3 changes: 3 additions & 0 deletions trackedit/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
DeleteNodes,
)
from motile_tracker.data_views import TracksViewer
from motile_tracker.data_views.views.layers.tracks_layer_group import TracksLayerGroup
from trackedit.DatabaseHandler import DatabaseHandler
from trackedit.motile_overwrites import (
create_center_view,
create_db_add_edges,
create_db_add_nodes,
create_db_delete_edges,
Expand Down Expand Up @@ -132,6 +134,7 @@ def run_trackedit(
DeleteEdges._apply = create_db_delete_edges(DB_handler)
AddEdges._apply = create_db_add_edges(DB_handler)
AddNodes._apply = create_db_add_nodes(DB_handler)
TracksLayerGroup.center_view = create_center_view(DB_handler)
TracksViewer._refresh = create_tracks_viewer_and_segments_refresh(
layer_name=layer_name
)
Expand Down
Loading