Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 92 additions & 97 deletions src/backprojector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2140,7 +2140,7 @@ void BackProjector::symmetrise(int nr_helical_asu, RFLOAT helical_twist, RFLOAT
enforceHermitianSymmetry();

// Then apply helical and point group symmetry (order irrelevant?)
applyHelicalSymmetry(nr_helical_asu, helical_twist, helical_rise);
applyHelicalSymmetry(nr_helical_asu, helical_twist, helical_rise, threads);

applyPointGroupSymmetry(threads);
}
Expand All @@ -2164,118 +2164,105 @@ void BackProjector::enforceHermitianSymmetry()
}
}

void BackProjector::applyHelicalSymmetry(int nr_helical_asu, RFLOAT helical_twist, RFLOAT helical_rise)
void BackProjector::applyHelicalSymmetry(int nr_helical_asu, RFLOAT helical_twist, RFLOAT helical_rise, int threads)
{
if ( (nr_helical_asu < 2) || (ref_dim != 3) )
return;

int rmax2 = ROUND(r_max * padding_factor) * ROUND(r_max * padding_factor);

Matrix2D<RFLOAT> R(4, 4); // A matrix from the list
MultidimArray<RFLOAT> sum_weight;
MultidimArray<Complex > sum_data;
RFLOAT x, y, z, fx, fy, fz, xp, yp, zp, r2;
bool is_neg_x;
int x0, x1, y0, y1, z0, z1;
Complex d000, d001, d010, d011, d100, d101, d110, d111;
Complex dx00, dx01, dx10, dx11, dxy0, dxy1, ddd;
RFLOAT dd000, dd001, dd010, dd011, dd100, dd101, dd110, dd111;
RFLOAT ddx00, ddx01, ddx10, ddx11, ddxy0, ddxy1;
MultidimArray<RFLOAT> sum_weight(weight);
MultidimArray<Complex > sum_data(data);

// First symmetry operator (not stored in SL) is the identity matrix
sum_weight = weight;
sum_data = data;
int h_min = -nr_helical_asu/2;
int h_max = -h_min + nr_helical_asu%2;

Matrix2D<RFLOAT> R(4, 4); // A matrix from the list

for (int hh = h_min; hh < h_max; hh++)
{
if (hh != 0) // h==0 is done before the for loop (where sum_data = data)
if (hh == 0) continue; // h==0 is done before the for loop (where sum_data = data)

RFLOAT rot_ang = hh * (-helical_twist);
rotation3DMatrix(rot_ang, 'Z', R);
R.setSmallValuesToZero(); // TODO: invert rotation matrix?

//Only the positive half of the z-range has to be calculated.
//The other half is calculated by the x<0 part of the x-loop.
#pragma omp parallel for schedule(dynamic) num_threads(threads) default(none) shared(data,weight,sum_data,sum_weight,rmax2,helical_twist,helical_rise,ori_size,padding_factor) firstprivate(hh,R)
for (long int k=STARTINGZ(sum_weight); k<=FINISHINGZ(sum_weight); k++)
{
RFLOAT rot_ang = hh * (-helical_twist);
rotation3DMatrix(rot_ang, 'Z', R);
R.setSmallValuesToZero(); // TODO: invert rotation matrix?
// Allocate minimal 2D slices per thread instead of full 3D arrays
MultidimArray<RFLOAT> slice_weight(YSIZE(data), XSIZE(data));
MultidimArray<Complex> slice_data(YSIZE(data), XSIZE(data));
slice_weight.xdim = sum_weight.xdim;
slice_weight.ydim = sum_weight.ydim;
slice_weight.yxdim = sum_weight.yxdim;
slice_weight.xinit = sum_weight.xinit;
slice_weight.yinit = sum_weight.yinit;
slice_data.xdim = sum_data.xdim;
slice_data.ydim = sum_data.ydim;
slice_data.yxdim = sum_data.yxdim;
slice_data.xinit = sum_data.xinit;
slice_data.yinit = sum_data.yinit;

slice_weight.initZeros(weight.ydim,weight.xdim);
slice_data.initZeros(data.ydim,data.xdim);

// Loop over all points in the output (i.e. rotated, or summed) array
FOR_ALL_ELEMENTS_IN_ARRAY3D(sum_weight)
for (long int i=STARTINGY(sum_weight); i<=FINISHINGY(sum_weight); i++)
for (long int j=STARTINGX(sum_weight); j<=FINISHINGX(sum_weight); j++)
{
x = (RFLOAT)j; // STARTINGX(sum_weight) is zero!
y = (RFLOAT)i;
z = (RFLOAT)k;
r2 = x*x + y*y + z*z;
RFLOAT x = (RFLOAT)j; // STARTINGX(sum_weight) is zero!
RFLOAT y = (RFLOAT)i;
RFLOAT z = (RFLOAT)k;
RFLOAT r2 = x*x + y*y + z*z;

if (r2 <= rmax2)
{
// coords_output(x,y) = A * coords_input (xp,yp)
xp = x * R(0, 0) + y * R(0, 1) + z * R(0, 2);
yp = x * R(1, 0) + y * R(1, 1) + z * R(1, 2);
zp = x * R(2, 0) + y * R(2, 1) + z * R(2, 2);
RFLOAT xp = x * R(0, 0) + y * R(0, 1);
RFLOAT yp = x * R(1, 0) + y * R(1, 1);
RFLOAT zp = z; // Z remains unchanged

// Only asymmetric half is stored
if (xp < 0)
bool is_neg_x = xp < 0;
if (is_neg_x)
{
// Get complex conjugated hermitian symmetry pair
xp = -xp;
yp = -yp;
zp = -zp;
is_neg_x = true;
}
else
{
is_neg_x = false;
}

// Trilinear interpolation (with physical coords)
// Subtract STARTINGY and STARTINGZ to accelerate access to data (STARTINGX=0)
// In that way use DIRECT_A3D_ELEM, rather than A3D_ELEM
x0 = FLOOR(xp);
fx = xp - x0;
x1 = x0 + 1;
int x0 = FLOOR(xp);
RFLOAT fx = xp - x0;
int x1 = x0 + 1;

y0 = FLOOR(yp);
fy = yp - y0;
y0 -= STARTINGY(data);
y1 = y0 + 1;
int y0 = FLOOR(yp);
RFLOAT fy = yp - y0;
y0 -= STARTINGY(data);
int y1 = y0 + 1;

z0 = FLOOR(zp);
fz = zp - z0;
int z0 = FLOOR(zp);
z0 -= STARTINGZ(data);
z1 = z0 + 1;

#ifdef CHECK_SIZE
if (x0 < 0 || y0 < 0 || z0 < 0 ||
x1 < 0 || y1 < 0 || z1 < 0 ||
x0 >= XSIZE(data) || y0 >= YSIZE(data) || z0 >= ZSIZE(data) ||
x1 >= XSIZE(data) || y1 >= YSIZE(data) || z1 >= ZSIZE(data) )
{
std::cerr << " x0= " << x0 << " y0= " << y0 << " z0= " << z0 << std::endl;
std::cerr << " x1= " << x1 << " y1= " << y1 << " z1= " << z1 << std::endl;
data.printShape();
REPORT_ERROR("BackProjector::applyPointGroupSymmetry: checksize!!!");
}
#endif
// First interpolate (complex) data
d000 = DIRECT_A3D_ELEM(data, z0, y0, x0);
d001 = DIRECT_A3D_ELEM(data, z0, y0, x1);
d010 = DIRECT_A3D_ELEM(data, z0, y1, x0);
d011 = DIRECT_A3D_ELEM(data, z0, y1, x1);
d100 = DIRECT_A3D_ELEM(data, z1, y0, x0);
d101 = DIRECT_A3D_ELEM(data, z1, y0, x1);
d110 = DIRECT_A3D_ELEM(data, z1, y1, x0);
d111 = DIRECT_A3D_ELEM(data, z1, y1, x1);

dx00 = LIN_INTERP(fx, d000, d001);
dx01 = LIN_INTERP(fx, d100, d101);
dx10 = LIN_INTERP(fx, d010, d011);
dx11 = LIN_INTERP(fx, d110, d111);
dxy0 = LIN_INTERP(fy, dx00, dx10);
dxy1 = LIN_INTERP(fy, dx01, dx11);
Complex d00 = DIRECT_A3D_ELEM(data, z0, y0, x0);
Complex d01 = DIRECT_A3D_ELEM(data, z0, y0, x1);
Complex d10 = DIRECT_A3D_ELEM(data, z0, y1, x0);
Complex d11 = DIRECT_A3D_ELEM(data, z0, y1, x1);

Complex dx00 = LIN_INTERP(fx ,d00, d01);
Complex dx10 = LIN_INTERP(fx ,d10, d11);

// Take complex conjugated for half with negative x
ddd = LIN_INTERP(fz, dxy0, dxy1);
Complex ddd = LIN_INTERP(fy, dx00, dx10);

if (is_neg_x)
ddd = conj(ddd);

// Also apply a phase shift for helical translation along Z
if (ABS(helical_rise) > 0.)
{
RFLOAT zshift = hh * helical_rise;
Expand All @@ -2291,32 +2278,40 @@ void BackProjector::applyHelicalSymmetry(int nr_helical_asu, RFLOAT helical_twis
ddd = Complex(ac - bd, ab_cd - ac - bd);
}
// Accumulated sum of the data term
A3D_ELEM(sum_data, k, i, j) += ddd;
A2D_ELEM(slice_data,i,j) += ddd;

// Then interpolate (real) weight
dd000 = DIRECT_A3D_ELEM(weight, z0, y0, x0);
dd001 = DIRECT_A3D_ELEM(weight, z0, y0, x1);
dd010 = DIRECT_A3D_ELEM(weight, z0, y1, x0);
dd011 = DIRECT_A3D_ELEM(weight, z0, y1, x1);
dd100 = DIRECT_A3D_ELEM(weight, z1, y0, x0);
dd101 = DIRECT_A3D_ELEM(weight, z1, y0, x1);
dd110 = DIRECT_A3D_ELEM(weight, z1, y1, x0);
dd111 = DIRECT_A3D_ELEM(weight, z1, y1, x1);

ddx00 = LIN_INTERP(fx, dd000, dd001);
ddx01 = LIN_INTERP(fx, dd100, dd101);
ddx10 = LIN_INTERP(fx, dd010, dd011);
ddx11 = LIN_INTERP(fx, dd110, dd111);
ddxy0 = LIN_INTERP(fy, ddx00, ddx10);
ddxy1 = LIN_INTERP(fy, ddx01, ddx11);

A3D_ELEM(sum_weight, k, i, j) += LIN_INTERP(fz, ddxy0, ddxy1);

} // end if r2 <= rmax2
} // end loop over all elements of sum_weight
} // end if hh!=0
} // end loop over hh
RFLOAT dd00 = DIRECT_A3D_ELEM(weight, z0, y0, x0);
RFLOAT dd01 = DIRECT_A3D_ELEM(weight, z0, y0, x1);
RFLOAT dd10 = DIRECT_A3D_ELEM(weight, z0, y1, x0);
RFLOAT dd11 = DIRECT_A3D_ELEM(weight, z0, y1, x1);

RFLOAT ddx00 = LIN_INTERP(fx, dd00, dd01);
RFLOAT ddx10 = LIN_INTERP(fx, dd10, dd11);

A2D_ELEM(slice_weight,i,j) += LIN_INTERP(fy, ddx00, ddx10);
}
}
#pragma omp critical
{
for (long int i=STARTINGY(sum_weight); i<=FINISHINGY(sum_weight); i++)
for (long int j=STARTINGX(sum_weight); j<=FINISHINGX(sum_weight); j++)
{
RFLOAT x = (RFLOAT)j; // STARTINGX(sum_weight) is zero!
RFLOAT y = (RFLOAT)i;
RFLOAT z = (RFLOAT)k;
RFLOAT r2 = x*x + y*y + z*z;
if (r2 <= rmax2)
{
A3D_ELEM(sum_data,k,i,j) += A2D_ELEM(slice_data,i,j);
A3D_ELEM(sum_weight,k,i,j) += A2D_ELEM(slice_weight,i,j);
}
}
}
}
}

// Update original arrays
data = sum_data;
weight = sum_weight;
}
Expand Down
2 changes: 1 addition & 1 deletion src/backprojector.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ class BackProjector: public Projector

/* Applies helical symmetry. Note that helical_rise is in PIXELS here, as BackProjector doesn't know angpix
*/
void applyHelicalSymmetry(int nr_helical_asu = 1, RFLOAT helical_twist = 0., RFLOAT helical_rise = 0.);
void applyHelicalSymmetry(int nr_helical_asu = 1, RFLOAT helical_twist = 0., RFLOAT helical_rise = 0., int threads = 1);

/* Applies the symmetry from the SymList object to the weight and the data array
*/
Expand Down
4 changes: 2 additions & 2 deletions src/ml_optimiser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4743,7 +4743,7 @@ void MlOptimiser::symmetriseReconstructions()
wsum_model.BPref[ith_recons].applyHelicalSymmetry(
mymodel.helical_nr_asu,
mymodel.helical_twist[ith_recons],
mymodel.helical_rise[ith_recons] / mymodel.pixel_size);
mymodel.helical_rise[ith_recons] / mymodel.pixel_size, nr_threads);

if (fn_multi_sym.size() > ith_recons) // Always false if size=0
{
Expand All @@ -4767,7 +4767,7 @@ void MlOptimiser::symmetriseReconstructions()
wsum_model.BPref[iclass_half].applyHelicalSymmetry(
mymodel.helical_nr_asu,
mymodel.helical_twist[ith_recons],
mymodel.helical_rise[ith_recons] / mymodel.pixel_size);
mymodel.helical_rise[ith_recons] / mymodel.pixel_size, nr_threads);

if (fn_multi_sym.size() > ith_recons) // Always false if size=0
{
Expand Down