Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/odl/applications/tomo/backends/astra_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,8 @@ def __init__(self, geometry, vol_space, proj_space):
), f"Volume space ({vol_space.impl}) != Projection space ({proj_space.impl})"

if self.geometry.ndim == 3:
if vol_space.impl == 'numpy':
self.transpose_tuple = (1,0,2)
elif vol_space.impl == 'pytorch':
self.transpose_tuple = (1,0)
else:
raise NotImplementedError("Not implemented for another backend")
self.transpose_tuple = (1,0,2) if self.geometry.det_curvature_radius is None else (2, 0, 1)
self.inverse_transpose_tuple = (1,0,2) if self.geometry.det_curvature_radius is None else (1, 2, 0)

self.fp_scaling_factor = astra_cuda_fp_scaling_factor(self.geometry)
self.bp_scaling_factor = astra_cuda_bp_scaling_factor(
Expand Down Expand Up @@ -213,7 +209,7 @@ def _call_forward_real(self, vol_data:DiscretizedSpaceElement, out=None, **kwarg
)
proj_data = out.data[None] if self.proj_ndim == 2 else out.data
if self.geometry.ndim == 3:
proj_data = proj_data.transpose(*self.transpose_tuple)
proj_data = self._proj_space.array_namespace.permute_dims(proj_data, self.transpose_tuple)

else:
proj_data = empty(
Expand Down Expand Up @@ -249,7 +245,7 @@ def _call_forward_real(self, vol_data:DiscretizedSpaceElement, out=None, **kwarg
proj_data = (
proj_data[0]
if self.geometry.ndim == 2
else proj_data.transpose(*self.transpose_tuple)
else self._proj_space.array_namespace.permute_dims(proj_data, self.inverse_transpose_tuple)
)

if out is not None:
Expand Down Expand Up @@ -323,7 +319,7 @@ def _call_backward_real(self, proj_data:DiscretizedSpaceElement, out=None, **kwa
if self.proj_ndim == 2:
proj_data = proj_data.data[None]
elif self.proj_ndim == 3:
proj_data = proj_data.data.transpose(*self.transpose_tuple)
proj_data = self._proj_space.array_namespace.permute_dims(proj_data.data, self.transpose_tuple)
else:
raise NotImplementedError

Expand Down
94 changes: 93 additions & 1 deletion src/odl/applications/tomo/backends/astra_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from odl.core.discr import DiscretizedSpace, DiscretizedSpaceElement
from odl.applications.tomo.geometry import (
DivergentBeamGeometry, Flat1dDetector, Flat2dDetector, Geometry,
ConeBeamGeometry, CylindricalDetector, DivergentBeamGeometry, Flat1dDetector, Flat2dDetector, Geometry,
ParallelBeamGeometry)
from odl.applications.tomo.util.utility import euler_matrix
from odl.core.array_API_support import get_array_and_backend
Expand Down Expand Up @@ -124,6 +124,10 @@
# next release after 1.8.3, see
# https://github.com/astra-toolbox/astra-toolbox/pull/183
'par2d_distance_driven_proj': '>1.8.3',

# Cylidrical detector geometry, see
# https://github.com/astra-toolbox/astra-toolbox/pull/444
'cyl_cone_vec': ">= 2.4.0"
}

ODL_TO_ASTRA_INDEX_PERMUTATIONS = [
Expand Down Expand Up @@ -354,6 +358,75 @@ def astra_conebeam_3d_geom_to_vec(geometry:Geometry):

return vectors

def astra_cyl_conebeam_3d_geom_to_vec(geometry:DivergentBeamGeometry):
"""Create vectors for ASTRA projection geometries from ODL geometry.

The 3D vectors are used to create an ASTRA projection geometry for
cone beam geometries with a cylindrical detector, see ``'cyl_cone_vec'``
in the `ASTRA projection geometry documentation`_.

Each row of the returned vectors corresponds to a single projection
and consists of ::

(srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ, R)

with

- ``src``: the ray source position
- ``d`` : the center of the detector
- ``u`` : tangential direction at center of detector;
the length of u is the arc length of a detector pixel
- ``v`` : the vector from detector pixel ``(0,0)`` to ``(1,0)``
- ``R`` : the radius of the detector cylinder

Parameters
----------
geometry : `Geometry`
ODL projection geometry from which to create the ASTRA geometry.

Returns
-------
vectors : `numpy.ndarray`
Array of shape ``(num_angles, 13)`` containing the vectors.

References
----------
.. _ASTRA projection geometry documentation:
http://www.astra-toolbox.com/docs/geom3d.html#projection-geometries
"""
angles = geometry.angles
vectors = np.zeros((angles.size, 13))

# Source position
vectors[:, 0:3] = geometry.src_position(angles)

# Center of detector in 3D space
# FIXME: This is not correct: det_point_position returns the zero-point of
# the detector, and not the center of the detector. For quarter-pixel-shifted
# detector these two do not coincide.
mid_pt = geometry.det_params.mid_pt
vectors[:, 3:6] = geometry.det_point_position(angles, mid_pt)

# `det_axes` gives shape (N, 2, 3), swap to get (2, N, 3)
det_axes = np.moveaxis(geometry.det_axes(angles), -2, 0)
px_sizes = geometry.det_partition.cell_sides

# `px_sizes[0]` is angular partition; scale by radius to get arc length
# NB: For flat panel detector we swap the u and v axes to get a better
# memory layout. For cylindrical detectors this is (currently) not possible
# since both ODL and Astra have the v direction along the axial direction.
vectors[:, 6:9] = det_axes[0] * px_sizes[0] * geometry.det_curvature_radius
vectors[:, 9:12] = det_axes[1] * px_sizes[1]

# detector curvature radius
vectors[:, 12] = geometry.det_curvature_radius

# ASTRA has (z, y, x) axis convention, in contrast to (x, y, z) in ODL,
# so we need to adapt to this by changing the order.
vectors = vectors[:, [*ODL_TO_ASTRA_INDEX_PERMUTATIONS, 12]]

return vectors

def astra_fanflat_2d_geom_to_conebeam_vec(geometry:Geometry):
""" Create vectors for ASTRA projection geometry.
This is required for the CUDA implementation of fanflat geometry.
Expand Down Expand Up @@ -609,6 +682,23 @@ def astra_projection_geometry(geometry: Geometry, astra_impl: str):
vec = astra_conebeam_3d_geom_to_vec(geometry)
proj_geom = astra.create_proj_geom('cone_vec', det_row_count,
det_col_count, vec)

elif (isinstance(geometry, DivergentBeamGeometry) and
isinstance(geometry.detector, CylindricalDetector) and
geometry.ndim == 3):

if not astra_supports('cyl_cone_vec'):
req_ver = astra_versions_supporting('cyl_cone_vec')
raise NotImplementedError(
f"support for cylindrical detector geometry requires ASTRA {req_ver}"
)
# Do NOT swap detector axes (see astra_cyl_conebeam_3d_geom_to_vec)
det_row_count = geometry.det_partition.shape[1]
det_col_count = geometry.det_partition.shape[0]
vec = astra_cyl_conebeam_3d_geom_to_vec(geometry)
proj_geom = astra.create_proj_geom('cyl_cone_vec', det_row_count,
det_col_count, vec)

else:
raise NotImplementedError(f"unknown ASTRA geometry type {geometry}")

Expand Down Expand Up @@ -750,6 +840,8 @@ def astra_projector(
valid_proj_types = ['linear3d', 'cuda3d']
elif astra_geom in {'cone', 'cone_vec'}:
valid_proj_types = ['linearcone', 'cuda3d']
elif astra_geom in {'cyl_cone_vec'}:
valid_proj_types = ['cuda3d']
else:
raise ValueError(f"invalid geometry type {astra_geom}")

Expand Down
1 change: 1 addition & 0 deletions src/odl/applications/tomo/geometry/conebeam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,7 @@ def __repr__(self):
posargs = [self.motion_partition, self.det_partition]
optargs = [('src_radius', self.src_radius, -1),
('det_radius', self.det_radius, -1),
('det_curvature_radius', self.det_curvature_radius, None),
('pitch', self.pitch, 0)
]

Expand Down