Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tmol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tmol.chemical.restypes import one2three, three2one # noqa: F401
from tmol.pose.packed_block_types import PackedBlockTypes # noqa: F401
from tmol.pose.pose_stack import PoseStack # noqa: F401
from tmol.pose.constraint_set import ConstraintSet # noqa: F401

from tmol.kinematics.fold_forest import FoldForest, EdgeType # noqa: F401
from tmol.kinematics.datatypes import KinematicModuleData # noqa: F401
Expand Down Expand Up @@ -39,7 +40,9 @@
from tmol.score import beta2016_score_function # noqa: F401
from tmol.score.score_function import ScoreFunction # noqa: F401
from tmol.score.score_types import ScoreType # noqa: F401

from tmol.score.constraint.constraint_energy_term import (
ConstraintEnergyTerm,
) # noqa: F401

from tmol.optimization.kin_min import build_kinforest_network, run_kin_min # noqa: F401

Expand Down
2 changes: 1 addition & 1 deletion tmol/io/details/his_taut_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def resolve_his_tautomerization(
)

return (
torch.tensor(his_taut, dtype=torch.int32, device=coords.device),
his_taut.to(dtype=torch.int32, device=coords.device),
res_type_variants,
resolved_coords,
resolved_atom_is_present,
Expand Down
12 changes: 12 additions & 0 deletions tmol/pose/constraint_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def add_constraints_to_all_poses(self, fn, atom_indices, params=None):
self.add_constraints(fn, atom_indices, params, nposes=nposes)

def add_constraints(self, fn, atom_indices, params=None, nposes=0):
"""Add multiple constraints all using the same functional form.

fn: the constraint function to use. It should take two arguments
1. the set of atom coordinates as an [n_csts x n_atoms_per_cst x 3] tensor
2. the set of parameters as an [n_csts x n_param_vals_per_cst] tensor
atom_indices: [n_csts x n_atoms_per_cst x 3]
Atom indices should be given as tuples of (pose_index, block_index, atom_within_block_index)
All of the indices for a single constraint must live in the same pose.
params: [n_csts x n_param_vals_per_cst]
Parameters for each of the constraints; e.g. x0 and k for a harmonic constraint.
"""

def find_or_insert(value, lst):
if value in lst:
return lst.index(value)
Expand Down
9 changes: 8 additions & 1 deletion tmol/score/constraint/constraint_energy_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tmol.score.constraint.constraint_whole_pose_module import (
ConstraintWholePoseScoringModule,
)
from tmol.score.constraint.potentials.compiled import get_torsion_angle

from tmol.chemical.restypes import RefinedResidueType
from tmol.pose.packed_block_types import PackedBlockTypes
Expand All @@ -33,6 +32,8 @@ def n_bodies(self):

@classmethod
def get_torsion_angle_test(cls, tensor):
from tmol.score.constraint.potentials.compiled import get_torsion_angle

return get_torsion_angle(tensor)

@classmethod
Expand All @@ -42,6 +43,12 @@ def harmonic(cls, atoms, params):
dist = torch.linalg.norm(atoms1 - atoms2, dim=-1)
return (dist - params[:, 0]) ** 2

@classmethod
def harmonic_coord_constraint(cls, atoms, params):
"""Harmonic penalty for the coordinates deviating from some set of target coordinates"""
dist = torch.linalg.norm(atoms[:, 0, :] - params, dim=-1)
return dist**2

@classmethod
def bounded(cls, atoms, params):
lb = params[:, 0]
Expand Down
27 changes: 18 additions & 9 deletions tmol/score/dunbrack/dunbrack_energy_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def setup_block_type(self, block_type: RefinedResidueType):
else self.global_params.scoring_db_aux.semirotameric_tableset_offsets[
s_inds[s_inds != -1]
][0]
.cpu()
.numpy()
)

empty_tor = numpy.full((4, 3), -1, dtype=numpy.int32)
Expand All @@ -117,25 +119,27 @@ def setup_block_type(self, block_type: RefinedResidueType):

dih_uaids = numpy.array([phi_uaids] + [psi_uaids] + chis)

n_chi = self.global_params.scoring_db_aux.nchi_for_table_set[rotamer_table_set]
n_chi = self.global_params.scoring_db_aux.nchi_for_table_set[
rotamer_table_set
].item()
n_rotameric_chi = n_chi - (1 if semirotameric else 0)
n_dihedrals = n_chi + 2

probability_table_offset = (
probability_table_offset = int(
self.global_params.scoring_db_aux.rotameric_prob_tableset_offsets[
rotameric_index
]
].item()
)

mean_table_offset = (
mean_table_offset = int(
self.global_params.scoring_db_aux.rotameric_meansdev_tableset_offsets[
rotamer_table_set
]
].item()
)
rotamer_index_to_table_index_offset = (
rotamer_index_to_table_index_offset = int(
self.global_params.scoring_db_aux.rotameric_chi_ri2ti_offsets[
rotamer_table_set
]
].item()
)

dunbrack_attrs = DunbrackBlockAttrs(
Expand Down Expand Up @@ -171,6 +175,7 @@ def setup_packed_block_types(self, packed_block_types: PackedBlockTypes):
packed_block_types.active_block_types,
device=self.device,
)

packed_data = [
pack(lambda f: getattr(f.dunbrack_attrs, field.name))
for field in dataclasses.fields(DunbrackBlockAttrs)
Expand All @@ -190,7 +195,7 @@ def pack_data_keyed_on_block_type(
cur = numpy.shape(bt_data)
if max_size is None:
max_size = cur
dtype = bt_data.dtype
dtype = bt_data.dtype if not isinstance(bt_data, int) else int
max_size = numpy.maximum(max_size, cur)

n_block_types = (len(active_block_types),)
Expand All @@ -215,7 +220,11 @@ def dim_slices(dim):
bt_data = field_getter(bt)
if bt_data is None:
continue
slices = [i] + [*map(dim_slices, bt_data.shape)]
slices = [i] + (
[*map(dim_slices, bt_data.shape)]
if not isinstance(bt_data, int)
else []
)
tensor[slices] = torch.tensor(
bt_data, dtype=dtype_conversion[dtype], device=device
)
Expand Down
29 changes: 24 additions & 5 deletions tmol/score/score_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, param_db: ParameterDatabase, device: torch.device):
self._all_terms_unordered = []
self._all_terms_out_of_date = False

self._all_score_types = []

self._one_body_terms = []
self._one_body_terms_unordered = []
self._one_body_terms_out_of_date = False
Expand Down Expand Up @@ -89,28 +91,43 @@ def all_terms(self):
Do not modify this list directly
"""
if self._all_terms_out_of_date:
self._all_terms = self.get_sorted_terms(self._all_terms_unordered)
self._all_terms, self._all_score_types = self.get_sorted_terms(
self._all_terms_unordered
)
self._all_terms_out_of_date = False

return self._all_terms

def all_score_types(self):
if self._all_terms_out_of_date:
self._all_terms, self._all_score_types = self.get_sorted_terms(
self._all_terms_unordered
)
self._all_terms_out_of_date = False

return self._all_score_types

def one_body_terms(self):
if self._one_body_terms_out_of_date:
self._one_body_terms = self.get_sorted_terms(self._one_body_terms_unordered)
self._one_body_terms, _ = self.get_sorted_terms(
self._one_body_terms_unordered
)
self._one_body_terms_out_of_date = False

return self._one_body_terms

def two_body_terms(self):
if self._two_body_terms_out_of_date:
self._two_body_terms = self.get_sorted_terms(self._two_body_terms_unordered)
self._two_body_terms, _ = self.get_sorted_terms(
self._two_body_terms_unordered
)
self._two_body_terms_out_of_date = False

return self._two_body_terms

def multi_body_terms(self):
if self._multi_body_terms_out_of_date:
self._multi_body_terms = self.get_sorted_terms(
self._multi_body_terms, _ = self.get_sorted_terms(
self._multi_body_terms_unordered
)
self._multi_body_terms_out_of_date = False
Expand Down Expand Up @@ -177,6 +194,7 @@ def weights_tensor(self):
@staticmethod
def get_sorted_terms(term_list):
sorted_term_list = []
sorted_score_type_list = []
term_covered = [False] * ScoreType.n_score_types.value
terms_by_st = [None] * ScoreType.n_score_types.value
for term in term_list:
Expand All @@ -195,7 +213,8 @@ def get_sorted_terms(term_list):
sorted_term_list.append(term)
for term_st in term.score_types():
term_covered[term_st.value] = True
return sorted_term_list
sorted_score_type_list.append(term_st)
return sorted_term_list, sorted_score_type_list


class WholePoseScoringModule:
Expand Down
1 change: 1 addition & 0 deletions tmol/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
big_pdb,
water_box_pdb,
ubq_pdb,
kin_minimized_ubq_pdb,
disulfide_pdb,
systems_bysize,
pertuzumab_pdb,
Expand Down
5 changes: 5 additions & 0 deletions tmol/tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def ubq_pdb():
return pdb.data["1ubq"]


@pytest.fixture(scope="session")
def kin_minimized_ubq_pdb():
return pdb.data["kin_minimized_1ubq"]


@pytest.fixture(scope="session")
def disulfide_pdb():
return pdb.data["3plc"]
Expand Down
Loading