Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,18 @@ def raw_data(data_raw_file):
return np.float32(np.copy(data_raw_file["data"]))


@pytest.fixture
def raw_data_cupy(raw_data):
return cp.asarray(raw_data)


@pytest.fixture
def flats(data_raw_file):
return np.float32(np.copy(data_raw_file["flats"]))


@pytest.fixture
def darks(
data_raw_file,
):
def darks(data_raw_file):
return np.float32(np.copy(data_raw_file["darks"]))


Expand Down
33 changes: 33 additions & 0 deletions tests/test_RecToolsIRCuPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,39 @@ def test_FISTA_cp_lc_known_3D(data_cupy, angles, ensure_clean_memory):
assert Iter_rec.shape == (128, 160, 160)


def test_FISTA_SWLS_cp_3D(data_cupy, raw_data_cupy, angles, ensure_clean_memory):
detX = cp.shape(data_cupy)[2]
detY = cp.shape(data_cupy)[1]
N_size = detX
RecTools = RecToolsIRCuPy(
DetectorsDimH=detX, # Horizontal detector dimension
DetectorsDimH_pad=0, # Padding size of horizontal detector
DetectorsDimV=detY, # Vertical detector dimension (3D case)
CenterRotOffset=0.0, # Center of Rotation scalar or a vector
AnglesVec=angles, # A vector of projection angles in radians
ObjSize=N_size, # Reconstructed object dimensions (scalar)
device_projector=0, # define the device
)

_data_ = {
"data_fidelity": "SWLS",
"projection_data": data_cupy,
"projection_raw_data": raw_data_cupy,
"beta_SWLS": 1.0,
"data_axes_labels_order": ["angles", "detY", "detX"],
} # data dictionary
# calculate Lipschitz constant
lc = RecTools.powermethod(_data_)
_algorithm_ = {"iterations": 10, "lipschitz_const": lc}
Iter_rec = RecTools.FISTA(_data_, _algorithm_)

Iter_rec = Iter_rec.get()
assert_allclose(np.min(Iter_rec), -0.0061533, rtol=1e-04)
assert_allclose(np.max(Iter_rec), 0.008243, rtol=1e-04)
assert Iter_rec.dtype == np.float32
assert Iter_rec.shape == (128, 160, 160)


def test_FISTA_cp_3D(data_cupy, angles, ensure_clean_memory):
detX = cp.shape(data_cupy)[2]
detY = cp.shape(data_cupy)[1]
Expand Down
24 changes: 24 additions & 0 deletions tomobar/cuda_kernels/stripe_weighted_least_squares.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <cuda_fp16.h>

template <typename T>
__device__ __forceinline__ void stripe_weighted_least_squares(T *res, T *weights, T *weights_mul_res, T *weights_dot_res, T *weight_sum, int dimX, int dimY, int dimZ)
{
const long tx = blockDim.x * blockIdx.x + threadIdx.x;
const long ty = blockDim.y * blockIdx.y + threadIdx.y;
const long tz = blockDim.z * blockIdx.z + threadIdx.z;

if (tx >= dimX || ty >= dimY || tz >= dimZ)
{
return;
}

const long long index = static_cast<long long>(tz) * dimY * dimX + static_cast<long long>(ty) * dimX + static_cast<long long>(tx);
const long long collapsed_projection_index = tz * dimX + tx;

res[index] = weights_mul_res[index] - 1.0 / weight_sum[collapsed_projection_index] * weights_dot_res[collapsed_projection_index] * weights[index];
}

extern "C" __global__ void stripe_weighted_least_squares_float(float *res, float *weights, float *weights_mul_res, float *weights_dot_res, float *weight_sum, int dimX, int dimY, int dimZ)
{
stripe_weighted_least_squares(res, weights, weights_mul_res, weights_dot_res, weight_sum, dimX, dimY, dimZ);
}
31 changes: 26 additions & 5 deletions tomobar/data_fidelities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import cupy as cp
from typing import Optional
from tomobar.cuda_kernels import load_cuda_module


def grad_data_term(
Expand All @@ -10,8 +11,8 @@ def grad_data_term(
b: cp.ndarray,
use_os: bool,
sub_ind: int,
indVec: Optional[cp.ndarray] = None,
w: Optional[cp.ndarray] = None,
w_sum: Optional[cp.ndarray] = None,
) -> cp.ndarray:
"""Calculation of the gradient of the data fidelity term
Args:
Expand All @@ -25,16 +26,36 @@ def grad_data_term(
Returns:
cp.ndarray: gradient of the data fidelity as a 3D CuPy array.
"""
half_precision = False
kernel_name = (
f"stripe_weighted_least_squares_{'half' if half_precision else 'float'}"
)
module = load_cuda_module("stripe_weighted_least_squares")
stripe_weighted_least_squares = module.get_function(kernel_name)

if self.data_fidelity in ["LS", "PWLS"]:
# Least-Squares (LS)
res = self._Ax(x, sub_ind, use_os) - b
if w is not None:
# Penalised-Weighted least squares
if use_os:
cp.multiply(res, w[:, indVec, :], out=res)
else:
cp.multiply(res, w, out=res)
cp.multiply(res, w, out=res)
elif self.data_fidelity == "KL":
# Kullback-Leibler term. Note that b in that case should be given as pre-log data (raw)
res = 1 - b / cp.clip(self._Ax(x, sub_ind, use_os), 1e-8, None)
elif self.data_fidelity == "SWLS":
res = self._Ax(x, sub_ind, use_os) - b
weights_mul_res = cp.multiply(w, res)
weights_dot_res = cp.sum(weights_mul_res, axis=1)

dz, dy, dx = res.shape
block_dims = (128, 1, 1)
grid_dims = tuple(
(res.T.shape[i] + block_dims[i] - 1) // block_dims[i] for i in range(3)
)
stripe_weighted_least_squares(
grid_dims,
block_dims,
(res, w, weights_mul_res, weights_dot_res, w_sum, dx, dy, dz),
)

return self._Atb(res, sub_ind, use_os)
67 changes: 57 additions & 10 deletions tomobar/methodsIR_CuPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ def __common_initialisation(
self.Atools.detectors_x_pad,
cupyrun=True,
)
_data_upd_["projection_raw_data"] = _apply_horiz_detector_padding(
_data_upd_["projection_raw_data"],
self.Atools.detectors_x_pad,
cupyrun=True,
)

if _algorithm_upd_.get("lipschitz_const") is None:
_algorithm_upd_["lipschitz_const"] = self.powermethod(_data_upd_)
Expand All @@ -389,14 +394,27 @@ def __common_initialisation(

use_os = self.OS_number > 1

if _data_["data_fidelity"] in ["PWLS"]:
w = cp.asarray(_data_upd_["projection_data"]) # weights for PWLS model
if _data_["data_fidelity"] in ["PWLS", "SWLS"]:
w = cp.asarray(_data_upd_["projection_raw_data"]) # weights for PWLS model
w = cp.maximum(w, 1e-6)
w /= w.max()
else:
w = None

return (_data_upd_, _algorithm_upd_, _regularisation_upd_, x0, w, use_os)
if _data_["data_fidelity"] in ["SWLS"]:
beta_SWLS = _data_upd_["beta_SWLS"]
else:
beta_SWLS = None

return (
_data_upd_,
_algorithm_upd_,
_regularisation_upd_,
x0,
w,
beta_SWLS,
use_os,
)

def FISTA(
self,
Expand Down Expand Up @@ -430,34 +448,63 @@ def FISTA(
_regularisation_upd_,
x0,
w,
beta_SWLS,
use_os,
) = self.__common_initialisation(
_data_, _algorithm_, _regularisation_, method_run="FISTA"
)

L_const_inv = 1.0 / _algorithm_upd_["lipschitz_const"]

proj_data = _data_upd_["projection_data"]
indVec = None
t = cp.float32(1.0)
X_t = cp.copy(x0)
X = cp.copy(x0)

if use_os:
proj_data = [None] * self.OS_number
weights = [None] * self.OS_number
weight_sums = [None] * self.OS_number

for sub_ind in range(self.OS_number):
# select a specific set of indeces for the subset (OS)
indVec = self.Atools.newInd_Vec[sub_ind, :]
if indVec[self.Atools.NumbProjBins - 1] == 0:
indVec = indVec[:-1] # shrink vector size

proj_data[sub_ind] = _data_upd_["projection_data"][:, indVec, :]
weights[sub_ind] = None if w is None else w[:, indVec, :]

weight_subset = weights[sub_ind]
weight_sums[sub_ind] = (
None
if weight_subset is None
else (cp.sum(weight_subset, axis=1) + beta_SWLS)
)
else:
proj_data_subset = _data_upd_["projection_data"]
weight_subset = w
weight_subset_sum = cp.sum(weight_subset, axis=1) + beta_SWLS

# FISTA iterations
for _ in range(_algorithm_upd_["iterations"]):
# loop over subsets (OS)
for sub_ind in range(self.OS_number):
X_old = X
t_old = t
if use_os:
# select a specific set of indeces for the subset (OS)
indVec = self.Atools.newInd_Vec[sub_ind, :]
if indVec[self.Atools.NumbProjBins - 1] == 0:
indVec = indVec[:-1] # shrink vector size
proj_data = _data_upd_["projection_data"][:, indVec, :]
proj_data_subset = proj_data[sub_ind]
weight_subset = weights[sub_ind]
weight_subset_sum = weight_sums[sub_ind]

grad_data = grad_data_term(
self, X_t, proj_data, use_os, sub_ind, indVec, w
self,
X_t,
proj_data_subset,
use_os,
sub_ind,
weight_subset,
weight_subset_sum,
)

X = X_t - L_const_inv * grad_data
Expand Down
10 changes: 9 additions & 1 deletion tomobar/supp/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,25 @@ def dicts_check(
_data_["data_axes_labels_order"],
correct_labels_order,
)
_data_["projection_raw_data"] = _data_dims_swapper(
_data_["projection_raw_data"],
_data_["data_axes_labels_order"],
correct_labels_order,
)
# we need to reset the swap option here as the data already been modified so we don't swap it again in the method itself
_data_["data_axes_labels_order"] = None

if data2dinput:
_data_["projection_data"] = cp.expand_dims(
_data_["projection_data"], axis=0
)
_data_["projection_raw_data"] = cp.expand_dims(
_data_["projection_raw_data"], axis=0
)

if _data_.get("data_fidelity") is None:
_data_["data_fidelity"] = "LS"
if _data_["data_fidelity"] not in {"LS", "PWLS", "KL"}:
if _data_["data_fidelity"] not in {"LS", "PWLS", "KL", "SWLS"}:
raise ValueError(
"_data_['data_fidelity'] should be provided as 'LS', 'PWLS', 'KL'."
)
Expand Down
Loading