Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,4 @@ if(TMOL_BUILD_TESTS)
)
endif()

# --- TORCH_LIBRARY test extension (not pybind) ---
add_library(test_custom_op MODULE
tmol/tests/utility/torchscript/custom_op.cpp
)
set_target_properties(test_custom_op PROPERTIES
PREFIX ""
OUTPUT_NAME "_custom_op"
SUFFIX "${_PY_EXT_SUFFIX}"
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/tmol/tests/utility/torchscript"
)
target_compile_definitions(test_custom_op PRIVATE TORCH_EXTENSION_NAME=_custom_op)
tmol_set_cpp_flags(test_custom_op)
install(TARGETS test_custom_op DESTINATION "tmol/tests/utility/torchscript")

endif()
28 changes: 19 additions & 9 deletions tmol/io/pose_stack_from_biotite.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def pose_stack_from_biotite(
torch_device: torch.device,
param_db: ParameterDatabase | None = None,
missing_density_distance_threshold: float = 2.4,
no_optH: bool = False,
**kwargs: object,
) -> PoseStack | tuple[PoseStack, dict]:
"""Build a PoseStack from the output generated by Biotite.
Expand All @@ -104,6 +105,12 @@ def pose_stack_from_biotite(
Adjacent residues whose closest inter-atom distance exceeds this
value are treated as disconnected (upper/lower connects broken).
Set to 0 to disable. Default is 2.4.
no_optH: When False (default), all residues with complete heavy atoms
are packed with OptHSampler to place and optimize hydrogen positions
and NHQ flips, while residues with missing heavy atoms are rebuilt
with DunbrackChiSampler. When True, only missing heavy-atom
sidechains are rebuilt with Dunbrack; hydrogens are left at the
kinematically ideal positions produced during pose construction.
**kwargs: Additional arguments passed to pose_stack_from_canonical_form.

Returns:
Expand All @@ -112,9 +119,7 @@ def pose_stack_from_biotite(
are returned as dictionary in the second value of a tuple.
"""
from tmol.io.pose_stack_construction import pose_stack_from_canonical_form
from tmol.pack.build_missing_sidechains import (
build_missing_sidechains_with_missing_atoms,
)
from tmol.pack.build_missing_sidechains import build_missing_sidechains
from tmol.pack.rotamer.dunbrack.dunbrack_chi_sampler import (
create_dunbrack_sampler_from_database,
)
Expand All @@ -137,21 +142,26 @@ def pose_stack_from_biotite(
pose_stack, opt_return_vals = result
block_has_missing_atoms = opt_return_vals["block_has_missing_atoms"]

if block_has_missing_atoms is not None and torch.any(block_has_missing_atoms):
needs_packing = block_has_missing_atoms is not None and (
torch.any(block_has_missing_atoms) or not no_optH
)
if needs_packing:
db = context.parameter_database

sfxn = beta2016_score_function(torch_device, param_db=db)
dunbrack_sampler = create_dunbrack_sampler_from_database(db, torch_device)

logger.info(
"%i missing sidechains", torch.count_nonzero(block_has_missing_atoms)
)
pose_stack = build_missing_sidechains_with_missing_atoms(
if torch.any(block_has_missing_atoms):
logger.info(
"%i blocks with missing heavy atoms",
torch.count_nonzero(block_has_missing_atoms),
)
pose_stack = build_missing_sidechains(
pose_stack,
sfxn,
dunbrack_sampler,
block_has_missing_atoms,
context.restype_set,
no_optH=no_optH,
)

# This code tries to faithfully return what the caller expects based on the optional
Expand Down
103 changes: 69 additions & 34 deletions tmol/pack/build_missing_sidechains.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,96 @@
from tmol.pack.packer_task import PackerTask, PackerPalette
from tmol.pack.rotamer.dunbrack.dunbrack_chi_sampler import DunbrackChiSampler
from tmol.pack.rotamer.fixed_aa_chi_sampler import FixedAAChiSampler

# from tmol.pack.rotamer.include_current_sampler import IncludeCurrentSampler
from tmol.pack.pack_rotamers import pack_rotamers


def build_missing_sidechains_with_missing_atoms(
def build_missing_sidechains(
pose_stack: PoseStack,
sfxn: ScoreFunction,
dunbrack_sampler: DunbrackChiSampler,
block_has_missing_atoms: Tensor[torch.bool][:, :],
rts,
no_optH: bool = False,
) -> PoseStack:
"""Build missing sidechains using the packer with explicit missing atoms information.
"""Build missing sidechains and place hydrogens using per-block sampler assignment.

Assigns samplers on a per-block basis in a single packing run:

- Blocks with missing non-leaf (heavy) atoms: DunbrackChiSampler +
FixedAAChiSampler. The input conformation is not included as a rotamer
because the sidechain is incomplete.
- All other real blocks (leaf-only or no missing atoms): OptHSampler, which
keeps heavy atoms fixed and samples proton chi angles and NHQ flips.
FallbackSampler (always present by default) covers residue types that
OptH does not handle (ALA, GLY, etc.).

When no_optH=True the old behavior is preserved: only Dunbrack runs for
blocks with missing heavy atoms; all other blocks are frozen.

This function examines the block_has_missing_atoms tensor to determine
which blocks have missing sidechain atoms. For blocks with missing atoms,
it adds the DunbrackChiSampler and FixedAAChiSampler to the PackerTask.
For blocks without missing atoms, it adds the IncludeCurrentSampler to
preserve existing sidechains. Then it calls pack_rotamers to build the
missing sidechains.
Note: IncludeCurrentSampler is intentionally not used. For Dunbrack
blocks the native conformation is broken and must not appear as a rotamer.
For OptH blocks, OptH includes native as rotamer-0 for NHQ residues and
FallbackSampler covers the rest.

Args:
pose_stack: The pose stack containing the structures to process
sfxn: The score function to use for packing (typically beta2016)
dunbrack_sampler: The DunbrackChiSampler configured with the default database
block_has_missing_atoms: Boolean tensor indicating which blocks have missing atoms
Shape: [n_poses, max_n_blocks]
pose_stack: The pose stack to process.
sfxn: Score function used for packing.
dunbrack_sampler: DunbrackChiSampler configured from the parameter DB.
block_has_missing_atoms: Boolean tensor [n_poses, max_n_blocks]; True
for blocks that have missing non-leaf (heavy) atoms.
rts: ResidueTypeSet (unused directly; kept for API compatibility).
no_optH: When True, skip OptH and preserve old Dunbrack-only behavior.

Returns:
PoseStack: A new pose stack with missing sidechains built
PoseStack with missing sidechains built and (by default) hydrogens
placed and optimized.
"""
from tmol.pack.rotamer.opth_sampler import OptHSampler

restype_set = pose_stack.packed_block_types.restype_set

# Create a PackerPalette and PackerTask
palette = PackerPalette(restype_set)
task = PackerTask(pose_stack, palette)
task.set_include_current()
task.restrict_to_repacking() # no design
task.restrict_to_repacking()

fixed_sampler = FixedAAChiSampler()
opth_sampler = None if no_optH else OptHSampler()

# Iterate through the block level tasks and either disable packing if the sidechain already
# exists, or else make sure we dont try and load the current sidechain with missing atoms
for pose_ind in range(block_has_missing_atoms.size(0)):
for block_ind in range(block_has_missing_atoms.size(1)):
if pose_stack.is_real_block(pose_ind, block_ind):
has_missing = block_has_missing_atoms[pose_ind, block_ind]
if has_missing:
task.blts[pose_ind][block_ind].include_current = False
else:
task.blts[pose_ind][block_ind].disable_packing()

# Add the samplers
fixed_sampler = FixedAAChiSampler()
task.add_conformer_sampler(dunbrack_sampler)
task.add_conformer_sampler(fixed_sampler)
if not pose_stack.is_real_block(pose_ind, block_ind):
continue
blt = task.blts[pose_ind][block_ind]
if block_has_missing_atoms[pose_ind, block_ind]:
# Missing heavy atoms: rebuild sidechain from Dunbrack library.
# Do not include the broken input conformation as a rotamer.
blt.add_conformer_sampler(dunbrack_sampler)
blt.add_conformer_sampler(fixed_sampler)
elif no_optH:
blt.disable_packing()
else:
# Complete heavy atoms: optimize proton placement with OptH.
blt.add_conformer_sampler(opth_sampler)

# Call pack_rotamers to build the missing sidechains
return pack_rotamers(pose_stack, sfxn, task, verbose=False)


def build_missing_sidechains_with_missing_atoms(
pose_stack: PoseStack,
sfxn: ScoreFunction,
dunbrack_sampler: DunbrackChiSampler,
block_has_missing_atoms: Tensor[torch.bool][:, :],
rts,
) -> PoseStack:
"""Backward-compatible wrapper around build_missing_sidechains.

Calls build_missing_sidechains with no_optH=True, preserving the original
behavior: Dunbrack for blocks with missing heavy atoms, all others frozen.
"""
return build_missing_sidechains(
pose_stack,
sfxn,
dunbrack_sampler,
block_has_missing_atoms,
rts,
no_optH=True,
)
13 changes: 7 additions & 6 deletions tmol/pack/packer_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ def default_conformer_samplers(self, block_type):

Each block must have coordinates represented in the tensor with the other
rotamers, and the easiest way to do that is to create a rotamer with the
DOFs of the input conformation. The IncludeCurrentSampler copies these
DOFs from the inverse-folded coordinates of the starting Pose's blocks.
DOFs of the input conformation. The FallbackSampler copies these DOFs
from the inverse-folded coordinates of the starting Pose's blocks, but
only for positions where no other sampler provides rotamers (e.g. residue
types not covered by DunbrackChiSampler). Positions with at least one
other sampler are left to that sampler exclusively.
Future versions of PackerPalette have the option to override this method.
"""
from tmol.pack.rotamer.include_current_sampler import (
IncludeCurrentSampler,
)
from tmol.pack.rotamer.fallback_sampler import FallbackSampler

return [IncludeCurrentSampler()]
return [FallbackSampler()]


# TO DO: BLT should hold "considered block types" as a boolean vector
Expand Down
128 changes: 128 additions & 0 deletions tmol/pack/rotamer/fallback_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import numpy
import torch
import attr

from typing import Tuple

from tmol.types.torch import Tensor
from tmol.types.functional import validate_args

from tmol.chemical.restypes import RefinedResidueType
from tmol.pose.packed_block_types import PackedBlockTypes
from tmol.pose.pose_stack import PoseStack
from tmol.kinematics.datatypes import KinForest
from tmol.pack.rotamer.conformer_sampler import ConformerSampler
from tmol.pack.rotamer.include_current_sampler import (
create_full_dof_inds_to_copy_from_orig_to_rotamers_for_include_current_sampler,
)


@attr.s(auto_attribs=True, frozen=True)
class FallbackSampler(ConformerSampler):
"""Include the input conformation as a rotamer only for positions that have
no rotamers from any other sampler.

This is the default sampler in PackerPalette. Unlike IncludeCurrentSampler,
it does not unconditionally add a rotamer for every position; instead it
activates only when every other sampler in the block-level task returns
False from defines_rotamers_for_rt for the original block type, ensuring
that positions covered by, e.g., DunbrackChiSampler do not accumulate an
extra current-conformation rotamer.

The disable_packing case (all block types disallowed) is also handled: a
rotamer from the input conformation is always produced so the packer has
something to represent for fixed residues.
"""

@classmethod
def sampler_name(cls):
return "FallbackSampler"

@validate_args
def annotate_residue_type(self, rt: RefinedResidueType):
pass

@validate_args
def annotate_packed_block_types(self, packed_block_types: PackedBlockTypes):
pass

@validate_args
def defines_rotamers_for_rt(self, rt: RefinedResidueType):
return True

@validate_args
def first_sc_atoms_for_rt(self, rt: RefinedResidueType) -> Tuple[str, ...]:
return (rt.default_jump_connection_atom,)

def create_samples_for_poses(
self,
pose_stack: PoseStack,
task: "PackerTask", # noqa: F821
) -> Tuple[ # noqa F821
Tensor[torch.int32][:], # n_rots_for_gbt
Tensor[torch.int32][:], # gbt_for_rotamer
dict,
]:
n_rots_for_gbt_list = [
(
1
if bt is blt.original_block_type
and (
not numpy.any(blt.block_type_allowed)
or not any(
s
for s in blt.conformer_samplers
if not isinstance(s, FallbackSampler)
and s.defines_rotamers_for_rt(bt)
)
)
else 0
)
for one_pose_blts in task.blts
for blt in one_pose_blts
for bt in blt.considered_block_types
]
n_rots_for_gbt = torch.tensor(
n_rots_for_gbt_list, dtype=torch.int32, device=pose_stack.device
)
gbt_for_rotamer = torch.nonzero(n_rots_for_gbt, as_tuple=True)[0]
return (n_rots_for_gbt, gbt_for_rotamer, {})

def fill_dofs_for_samples(
self,
pose_stack: PoseStack,
task: "PackerTask", # noqa: F821
orig_kinforest: KinForest,
orig_dofs_kto: Tensor[torch.float32][:, 9],
gbt_for_conformer: Tensor[torch.int64][:],
block_type_ind_for_conformer: Tensor[torch.int64][:],
n_dof_atoms_offset_for_conformer: Tensor[torch.int64][:],
conformer_built_by_sampler: Tensor[torch.bool][:],
conf_inds_for_sampler: Tensor[torch.int64][:],
sampler_n_rots_for_gbt: Tensor[torch.int32][:],
sampler_gbt_for_rotamer: Tensor[torch.int32][:],
sample_dict: dict,
conf_dofs_kto: Tensor[torch.float32][:, 9],
):
n_rots = sampler_gbt_for_rotamer.shape[0]
if n_rots == 0:
return

if torch.cuda.is_available():
torch.cuda.synchronize()
dst, src = (
create_full_dof_inds_to_copy_from_orig_to_rotamers_for_include_current_sampler(
pose_stack,
task,
gbt_for_conformer,
block_type_ind_for_conformer,
conf_inds_for_sampler,
sampler_n_rots_for_gbt,
sampler_gbt_for_rotamer,
n_dof_atoms_offset_for_conformer,
)
)

conf_dofs_kto[dst + 1, :] = orig_dofs_kto[src + 1, :]
if torch.cuda.is_available():
torch.cuda.synchronize()
1 change: 0 additions & 1 deletion tmol/pack/rotamer/fixed_aa_chi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def sample_chi_for_poses(
for one_pose_blts in task.blts
for blt in one_pose_blts
for i, bt in enumerate(blt.considered_block_types)
if self in blt.conformer_samplers
],
dtype=object,
)
Expand Down
Loading
Loading