Skip to content

Commit 9a8aae0

Browse files
committed
updated tests and removed redundant 1 heavy atom code
1 parent f5d5431 commit 9a8aae0

2 files changed

Lines changed: 33 additions & 34 deletions

File tree

CodeEntropy/levels/axes.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def get_residue_axes(self, data_container, index: int, residue=None):
123123
# No UAS are bonded to other residues
124124
# Use a custom principal axes, from a MOI tensor that uses positions of
125125
# heavy atoms only, but including masses of heavy atom + bonded H.
126+
126127
moi_tensor = self.get_moment_of_inertia_tensor(
127128
center_of_mass=np.array(residue.center_of_mass()),
128129
positions=uas.positions,
@@ -225,7 +226,6 @@ def get_UA_axes(self, data_container, index: int, res_position):
225226
"""
226227
index = int(index) # bead index
227228
heavy_atoms = data_container.atoms.select_atoms("mass 2 to 999")
228-
229229
# use the same customPI trans axes as the residue level
230230
if len(heavy_atoms) > 1:
231231
if len(data_container.residues) == 1:
@@ -289,15 +289,6 @@ def get_UA_axes(self, data_container, index: int, res_position):
289289
trans_center = np.array(backbone.center_of_mass())
290290
trans_axes = self.get_residue_custom_axes(edges, trans_center)
291291

292-
else:
293-
# only one heavy atom or hydrogen molecule
294-
make_whole(data_container.atoms)
295-
residue = data_container
296-
# trans_center is center of mass
297-
trans_center = np.array(data_container.center_of_mass())
298-
trans_axes = data_container.atoms.principal_axes()
299-
300-
if len(heavy_atoms) > 1:
301292
residue_heavy_atoms = residue.atoms.select_atoms("mass 2 to 999")
302293
# look for heavy atoms in residue of interest
303294
heavy_atom_indices = []
@@ -307,19 +298,30 @@ def get_UA_axes(self, data_container, index: int, res_position):
307298
# where n is the bead index
308299
heavy_atom_index = heavy_atom_indices[index]
309300
heavy_atom = residue.atoms.select_atoms(f"index {heavy_atom_index}")
301+
rot_center = heavy_atom.positions[0]
302+
rot_axes, moment_of_inertia = self.get_bonded_axes(
303+
system=data_container,
304+
atom=heavy_atom[0],
305+
dimensions=data_container.dimensions[:3],
306+
)
307+
310308
else:
311-
# only the one heavy atom
309+
# 1 heavy atom in the data_container
312310
heavy_atom = heavy_atoms[0]
311+
# trans and rot centres are centre of mass
312+
rot_center = data_container.center_of_mass()
313+
rot_axes, moment_of_inertia = self.get_bonded_axes(
314+
system=data_container,
315+
atom=heavy_atom[0],
316+
dimensions=data_container.dimensions[:3],
317+
)
318+
trans_center = rot_center
319+
# principal axes
320+
trans_axes = rot_axes
321+
313322
if trans_axes is None:
314323
raise ValueError("Unable to compute translation axes for UA bead.")
315324

316-
rot_center = heavy_atom.positions[0]
317-
rot_axes, moment_of_inertia = self.get_bonded_axes(
318-
system=data_container,
319-
atom=heavy_atom[0],
320-
dimensions=data_container.dimensions[:3],
321-
)
322-
323325
if rot_axes is None or moment_of_inertia is None:
324326
raise ValueError("Unable to compute bonded axes for UA bead.")
325327

@@ -765,7 +767,6 @@ def get_moment_of_inertia_tensor(
765767
"""
766768
r = self.get_vector(center_of_mass, positions, dimensions)
767769
r2 = np.sum(r**2, axis=1)
768-
769770
masses_arr = np.asarray(list(masses), dtype=float)
770771
moment_of_inertia_tensor = np.eye(3) * np.sum(masses_arr * r2)
771772
moment_of_inertia_tensor -= np.einsum("i,ij,ik->jk", masses_arr, r, r)
@@ -797,6 +798,7 @@ def get_custom_principal_axes(
797798
- principal_axes: (3, 3) principal axes (rows).
798799
- moment_of_inertia: (3,) principal moments.
799800
"""
801+
800802
eigenvalues, eigenvectors = np.linalg.eig(moment_of_inertia_tensor)
801803
order = np.abs(eigenvalues).argsort()[::-1] # descending order
802804
transposed = np.transpose(eigenvectors) # columns -> rows

tests/unit/CodeEntropy/levels/test_axes.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,15 @@ def test_get_UA_axes_uses_principal_axes_when_single_heavy(monkeypatch):
166166
u = MagicMock()
167167
u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90])
168168
u.atoms.principal_axes.return_value = np.eye(3)
169+
u.center_of_mass.return_value = np.array([[4.0, 0.0, 0.0]])
169170

170171
# heavy_atoms length <= 1 => principal_axes path
171-
heavy_atom = MagicMock(index=5)
172+
heavy_atom = MagicMock(index=5, position=np.array([4.0, 0.0, 0.0]))
172173
heavy_atoms = [heavy_atom]
173174

174175
def _sel(q):
175-
if q == "prop mass > 1.1":
176+
if q == "mass 2 to 999":
177+
# return heavy atoms group
176178
return heavy_atoms
177179
if q.startswith("index "):
178180
# return atom group with positions
@@ -181,6 +183,7 @@ def _sel(q):
181183
ag.__getitem__.return_value = MagicMock(
182184
mass=12.0, position=np.array([4.0, 0.0, 0.0]), index=5
183185
)
186+
184187
return ag
185188
return []
186189

@@ -513,18 +516,16 @@ def _select_atoms(q):
513516
assert np.allclose(moi, np.array([3.0, 2.0, 1.0]))
514517

515518

516-
def test_get_residue_axes_with_bonds_vanilla_path(monkeypatch):
519+
def test_get_residue_axes_vanilla_path(monkeypatch):
517520
ax = AxesCalculator()
518-
519521
residue = MagicMock()
520522
residue.__len__.return_value = 1
521-
residue.atoms.principal_axes.return_value = np.eye(3) * 2
522523
residue.atoms.center_of_mass.return_value = np.array([1.0, 2.0, 3.0])
523524
residue.center_of_mass.return_value = np.array([1.0, 2.0, 3.0])
525+
residue.select_atoms.return_value = MagicMock(positions=np.zeros((1, 3)))
524526

525527
u = MagicMock()
526528
u.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90])
527-
u.atoms.principal_axes.return_value = np.eye(3) * 2
528529

529530
def _select_atoms(q):
530531
if q.startswith("(resindex"):
@@ -537,7 +538,9 @@ def _select_atoms(q):
537538

538539
monkeypatch.setattr("CodeEntropy.levels.axes.make_whole", lambda _ag: None)
539540
monkeypatch.setattr(
540-
ax, "get_vanilla_axes", lambda mol: (np.eye(3) * 2, np.array([9.0, 8.0, 7.0]))
541+
ax,
542+
"get_custom_principal_axes",
543+
lambda mol: (np.eye(3) * 2, np.array([9.0, 8.0, 7.0])),
541544
)
542545

543546
trans, rot, center, moi = ax.get_residue_axes(u, index=10, residue=residue)
@@ -647,19 +650,13 @@ def center_of_mass(self, *args, **kwargs):
647650
def __getitem__(self, idx):
648651
return system_atom
649652

650-
def _select_atoms(q):
651-
if q == "prop mass > 1.1":
652-
return heavy_atoms
653-
if q.startswith("index "):
654-
return heavy_atom_selection
655-
return _FakeAtomGroup([])
656-
657653
data_container = MagicMock()
658654
data_container.atoms = _Atoms()
659655
data_container.dimensions = np.array([10.0, 10.0, 10.0, 90, 90, 90], dtype=float)
656+
_FakeAtomGroup.atoms = heavy_atom_selection
660657

661658
def _select_atoms(q):
662-
if q == "prop mass > 1.1":
659+
if q == "mass 2 to 999":
663660
return heavy_atoms
664661
if q.startswith("index "):
665662
return heavy_atom_selection

0 commit comments

Comments
 (0)