Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d734d07
Eliminate all nrot^2 allocations by compressing ops to block level
fdimaio Apr 8, 2026
6d42ad3
Modify SA to ONLY consider neighbors when substituting
fdimaio Apr 8, 2026
5aeb11d
Update SA. Change code so warp calculates single rot change. reduce…
fdimaio Apr 10, 2026
b4a68f4
Grab bag of changes: a) fix 1H/2H flips on ASN and GLN; b) add a dist…
fdimaio Apr 10, 2026
1c3b20c
Bugfix for chi1 corrections: since they are backbone-depoendent, we c…
fdimaio Apr 13, 2026
2a6c32e
First pass implementation of bump check / background node elimination
aleaverfay Apr 13, 2026
6b13170
linting
fdimaio Apr 14, 2026
808f51d
Small tweaks
aleaverfay Apr 14, 2026
c8a2328
save debugging process
Apr 14, 2026
deca32e
Get bump check working.
aleaverfay Apr 14, 2026
45e5171
More bump check cleanups
Apr 14, 2026
8788dbd
Minimize RPE lifetime
aleaverfay Apr 14, 2026
c6e01fd
Rename var in pose_stack_from_biotite. In SA, only thread 0 does exp…
fdimaio Apr 14, 2026
cb9f346
Filter residues with incomplete backbones
fdimaio Apr 14, 2026
e51ffea
Merge remote-tracking branch 'origin/dimaio/packing_modifications' in…
aleaverfay Apr 15, 2026
1c2acd4
Save test work to move to laptop
aleaverfay Apr 20, 2026
435300f
Progress in getting the sharp edges on bump check sanded down
aleaverfay Apr 20, 2026
c5b7ce9
Merge branch 'master' into aleaverfay/bump_check
aleaverfay Apr 20, 2026
c2d4727
Fix circular dependencies issues; update build_rotamers unit test
aleaverfay Apr 20, 2026
54afd4e
Fix import
aleaverfay Apr 20, 2026
b1656ed
In our 1ubq.pdb, swap location of 1HD2/2HD2 in ASN & 1HE2/2HE2 in GLN
aleaverfay Apr 20, 2026
a93bc09
Merge branch 'master' into aleaverfay/bump_check; bump minimum biotit…
aleaverfay Apr 20, 2026
86fdb98
Restore biotite to 1.2.0 for Python3.10 compatibility
aleaverfay Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tmol/io/pose_stack_from_biotite.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def _filter_supported_atoms_and_connectivity(
)
valid_res[i] = False

valid_atoms = valid_res[
biotite.structure.get_all_residue_positions(biotite_structure)
]
valid_atoms = valid_res[get_all_residue_positions(biotite_structure)]

lower = numpy.roll(valid_res, 1)
lower[0] = True
Expand Down
20 changes: 18 additions & 2 deletions tmol/pack/compiled/annealer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ template <
typename Int>
struct InteractionGraphBuilder {
static auto f(
int const verbose,
int const chunk_size,
int const max_n_block_types,
TView<Int, 1, D> n_rots_for_pose,
TView<Int, 1, D> rot_offset_for_pose,
TView<Int, 2, D> n_rots_for_block,
Expand All @@ -27,10 +29,24 @@ struct InteractionGraphBuilder {
TView<int32_t, 2, D> sparse_inds,
TView<Real, 1, D> sparse_energies)
-> std::tuple<
TPack<Real, 1, D>,
TPack<
int64_t,
1,
tmol::Device::CPU>, // max_n_bump_checked_rotamers_per_pose
TPack<Int, 1, D>, // n_molten_blocks_per_pose
TPack<Int, 1, D>, // n_bc_rots_per_pose
TPack<Int, 1, D>, // bc_rot_offset_for_pose
TPack<Int, 2, D>, // n_bc_rots_for_molten_block
TPack<Int, 2, D>, // bc_rot_offset_for_molten_block
TPack<Int, 1, D>, // molten_block_ind_for_bc_rot
TPack<int64_t, 2, D>, // rotamer_for_nonmolten_block
TPack<int64_t, 1, D>, // bc_rot_to_orig_rot

TPack<Real, 1, D>, // bg/bg energies
TPack<Real, 1, D>, // energy1b
TPack<int64_t, 3, D>,
TPack<int64_t, 1, D>,
TPack<Real, 1, D> >;
TPack<Real, 1, D> >; // energy2b
};

template <tmol::Device D>
Expand Down
1,013 changes: 952 additions & 61 deletions tmol/pack/compiled/compiled.impl.hh

Large diffs are not rendered by default.

49 changes: 43 additions & 6 deletions tmol/pack/compiled/compiled.ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ namespace compiled {
using torch::Tensor;

std::vector<Tensor> build_interaction_graph(
int64_t const verbose,
int64_t const chunk_size,
int64_t const max_n_block_types,
Tensor n_rots_for_pose,
Tensor rot_offset_for_pose,
Tensor n_rots_for_block,
Expand All @@ -33,6 +35,18 @@ std::vector<Tensor> build_interaction_graph(
Tensor sparse_inds,
Tensor sparse_energies) {
nvtx_range_push("pack_build_ig");

at::Tensor max_n_bump_checked_rotamers_per_pose;
at::Tensor n_molten_blocks_per_pose;
at::Tensor n_bc_rots_per_pose;
at::Tensor bc_rot_offset_for_pose;
at::Tensor n_bc_rots_for_molten_block;
at::Tensor bc_rot_offset_for_molten_block;
at::Tensor molten_block_ind_for_bc_rot;
at::Tensor rotamer_for_nonmolten_block;
at::Tensor bc_rot_to_orig_rot;

at::Tensor bg_bg_energies;
at::Tensor energy1b;
at::Tensor chunk_pair_offset_for_block_pair;
at::Tensor chunk_pair_offset;
Expand All @@ -50,7 +64,9 @@ std::vector<Tensor> build_interaction_graph(
Dev,
Real,
Int>::
f(chunk_size,
f(verbose,
chunk_size,
max_n_block_types,
TCAST(n_rots_for_pose),
TCAST(rot_offset_for_pose),
TCAST(n_rots_for_block),
Expand All @@ -60,14 +76,35 @@ std::vector<Tensor> build_interaction_graph(
TCAST(block_ind_for_rot),
TCAST(sparse_inds),
TCAST(sparse_energies));
energy1b = std::get<0>(result).tensor;
chunk_pair_offset_for_block_pair = std::get<1>(result).tensor;
chunk_pair_offset = std::get<2>(result).tensor;
energy2b = std::get<3>(result).tensor;

max_n_bump_checked_rotamers_per_pose = std::get<0>(result).tensor;
n_molten_blocks_per_pose = std::get<1>(result).tensor;
n_bc_rots_per_pose = std::get<2>(result).tensor;
bc_rot_offset_for_pose = std::get<3>(result).tensor;
n_bc_rots_for_molten_block = std::get<4>(result).tensor;
bc_rot_offset_for_molten_block = std::get<5>(result).tensor;
molten_block_ind_for_bc_rot = std::get<6>(result).tensor;
rotamer_for_nonmolten_block = std::get<7>(result).tensor;
bc_rot_to_orig_rot = std::get<8>(result).tensor;
bg_bg_energies = std::get<9>(result).tensor;
energy1b = std::get<10>(result).tensor;
chunk_pair_offset_for_block_pair = std::get<11>(result).tensor;
chunk_pair_offset = std::get<12>(result).tensor;
energy2b = std::get<13>(result).tensor;
}));

std::vector<torch::Tensor> result(
{energy1b,
{max_n_bump_checked_rotamers_per_pose,
n_molten_blocks_per_pose,
n_bc_rots_per_pose,
bc_rot_offset_for_pose,
n_bc_rots_for_molten_block,
bc_rot_offset_for_molten_block,
molten_block_ind_for_bc_rot,
rotamer_for_nonmolten_block,
bc_rot_to_orig_rot,
bg_bg_energies,
energy1b,
chunk_pair_offset_for_block_pair,
chunk_pair_offset,
energy2b});
Expand Down
50 changes: 45 additions & 5 deletions tmol/pack/impose_rotamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
def impose_top_rotamer_assignments(
orig_pose_stack: PoseStack,
rotamer_set: RotamerSet,
assignment: Tensor[torch.int32][:, :, :],
rotamer_for_nonmolten_block: Tensor[torch.int64][:, :],
n_molten_blocks_per_pose: Tensor[torch.int64][:],
bc_rot_offset_for_molten_block: Tensor[torch.int64][:, :],
bc_rot_to_orig_rot: Tensor[torch.int64][:],
bc_assignment: Tensor[torch.int32][:, :, :],
):
"""Impose the lowest-energy rotamer assignemnt to each pose in the original PoseStack."""

Expand All @@ -35,15 +39,51 @@ def impose_top_rotamer_assignments(
n_poses = orig_pose_stack.n_poses
max_n_blocks = orig_pose_stack.max_n_blocks
max_n_atoms_per_block = orig_pose_stack.max_n_atoms
max_n_molten_blocks = bc_assignment.shape[2]
bc_assignment = bc_assignment[:, 0, :]
# n_assignments = bc_assignment.shape[1]
# print("bc_assignment", bc_assignment.shape, bc_assignment.dtype)

# lets figure out how many atoms per pose
# map from the subset of blocks and bump-checked rotamer
# to an assignment in the original rotamer-set indexing
assignment = torch.full(
(n_poses, max_n_blocks), -1, dtype=torch.int64, device=device
)
is_nonmolten_block = rotamer_for_nonmolten_block != -1
assignment[is_nonmolten_block] = rotamer_for_nonmolten_block[is_nonmolten_block]

is_real_block = orig_pose_stack.block_type_ind64 != -1
is_molten_block = torch.logical_and(
rotamer_for_nonmolten_block == -1, is_real_block
)
# print("is_molten_block.unsqueeze(1).expand(-1, n_assignments, -1)", is_molten_block.unsqueeze(1).expand(-1, n_assignments, -1).shape)
# print("bc_assignment[is_molten_block.unsqueeze(1).expand(-1, n_assignments, -1)]", bc_assignment[is_molten_block.unsqueeze(1).expand(-1, n_assignments, -1)].shape)
# print("assignment[is_molten_block, :]", assignment[is_molten_block, :].shape)

molten_block_arange = (
torch.arange(max_n_molten_blocks, dtype=torch.int64, device=device)
.unsqueeze(0)
.expand(n_poses, -1)
)
is_real_molten_block = molten_block_arange < n_molten_blocks_per_pose.view(-1, 1)
bc_assignment_global = (
bc_assignment[is_real_molten_block]
+ bc_rot_offset_for_molten_block[is_real_molten_block]
)

assignment[is_molten_block] = bc_rot_to_orig_rot[bc_assignment_global].reshape(-1)

# print("assignment", assignment.shape)
# print("assignment", assignment)

# lets figure out how many atoms per pose
new_block_type_ind64 = torch.full(
(n_poses, max_n_blocks), -1, dtype=torch.int64, device=device
)
new_rot_for_block64 = (
assignment[:, 0, :].to(torch.int64) + rotamer_set.rot_offset_for_block
)
# new_rot_for_block64 = (
# assignment[:, 0, :].to(torch.int64) + rotamer_set.rot_offset_for_block
# )
new_rot_for_block64 = assignment

is_real_block = orig_pose_stack.block_type_ind64 != -1

Expand Down
138 changes: 94 additions & 44 deletions tmol/pack/pack_rotamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tmol.pose.pose_stack import PoseStack
from tmol.score.score_function import ScoreFunction

from tmol.pack.compiled.compiled import build_interaction_graph
from tmol.pack.packer_task import PackerTask
from tmol.pack.rotamer.build_rotamers import build_rotamers
from tmol.pack.datatypes import PackerEnergyTables
Expand All @@ -19,6 +18,7 @@ def pack_rotamers(
verbose=False,
**sa_params,
):

if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
Expand All @@ -29,49 +29,16 @@ def pack_rotamers(
torch.cuda.synchronize()
end_time1 = time.perf_counter()

rotamer_scoring_module = sfxn.render_rotamer_scoring_module(pose_stack, rotamer_set)

energies = rotamer_scoring_module(rotamer_set.coords)
energies = energies.coalesce()

if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time2 = time.perf_counter()

chunk_size = 16

energy1b, chunk_pair_offset_for_block_pair, chunk_pair_offset, energy2b = (
build_interaction_graph(
chunk_size,
rotamer_set.n_rots_for_pose,
rotamer_set.rot_offset_for_pose,
rotamer_set.n_rots_for_block,
rotamer_set.rot_offset_for_block,
rotamer_set.pose_for_rot,
rotamer_set.block_type_ind_for_rot,
rotamer_set.block_ind_for_rot,
energies.indices().to(torch.int32),
energies.values(),
)
)
if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time3 = time.perf_counter()
(
packer_energy_tables,
rotamer_for_nonmolten_block,
n_molten_blocks_per_pose,
bc_rot_offset_for_molten_block,
bc_rot_to_orig_rot,
end_time2,
end_time3,
) = _calculate_packer_energies(pose_stack, sfxn, rotamer_set, verbose=verbose)

packer_energy_tables = PackerEnergyTables(
max_n_rotamers_per_pose=rotamer_set.max_n_rots_per_pose,
pose_n_res=pose_stack.n_res_per_pose,
pose_n_rotamers=rotamer_set.n_rots_for_pose,
pose_rotamer_offset=rotamer_set.rot_offset_for_pose,
nrotamers_for_res=rotamer_set.n_rots_for_block,
oneb_offsets=rotamer_set.rot_offset_for_block,
res_for_rot=rotamer_set.block_ind_for_rot,
chunk_size=chunk_size,
chunk_offset_offsets=chunk_pair_offset_for_block_pair,
chunk_offsets=chunk_pair_offset,
energy1b=energy1b,
energy2b=energy2b,
)
if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time4 = time.perf_counter()
Expand All @@ -82,8 +49,17 @@ def pack_rotamers(
if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time5 = time.perf_counter()

# print("rotamer_for_nonmolten_block", rotamer_for_nonmolten_block.dtype)

new_pose_stack = impose_top_rotamer_assignments(
pose_stack, rotamer_set, rotamer_assignments
pose_stack,
rotamer_set,
rotamer_for_nonmolten_block,
n_molten_blocks_per_pose,
bc_rot_offset_for_molten_block,
bc_rot_to_orig_rot,
rotamer_assignments,
)
if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
Expand All @@ -98,3 +74,77 @@ def pack_rotamers(
)

return new_pose_stack


def _calculate_packer_energies(pose_stack, sfxn, rotamer_set, verbose=False):
from tmol.pack.compiled.compiled import build_interaction_graph

pbt = pose_stack.packed_block_types
rotamer_scoring_module = sfxn.render_rotamer_scoring_module(pose_stack, rotamer_set)

energies = rotamer_scoring_module(rotamer_set.coords)
energies = energies.coalesce()

if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time2 = time.perf_counter()

chunk_size = 16

(
max_n_bump_checked_rotamers_per_pose_tensor,
n_molten_blocks_per_pose,
n_bc_rots_per_pose,
bc_rot_offset_for_pose,
n_bc_rots_for_molten_block,
bc_rot_offset_for_molten_block,
molten_block_ind_for_bc_rot,
rotamer_for_nonmolten_block,
bc_rot_to_orig_rot,
bg_bg_energies,
energy1b,
chunk_pair_offset_for_block_pair,
chunk_pair_offset,
energy2b,
) = build_interaction_graph(
verbose,
chunk_size,
pbt.n_types,
rotamer_set.n_rots_for_pose,
rotamer_set.rot_offset_for_pose,
rotamer_set.n_rots_for_block,
rotamer_set.rot_offset_for_block,
rotamer_set.pose_for_rot,
rotamer_set.block_type_ind_for_rot,
rotamer_set.block_ind_for_rot,
energies.indices().to(torch.int32),
energies.values(),
)
if verbose and torch.cuda.is_available():
torch.cuda.synchronize()
end_time3 = time.perf_counter()

packer_energy_tables = PackerEnergyTables(
max_n_rotamers_per_pose=max_n_bump_checked_rotamers_per_pose_tensor.item(),
pose_n_res=n_molten_blocks_per_pose,
pose_n_rotamers=n_bc_rots_per_pose,
pose_rotamer_offset=bc_rot_offset_for_pose,
nrotamers_for_res=n_bc_rots_for_molten_block,
oneb_offsets=bc_rot_offset_for_molten_block,
res_for_rot=molten_block_ind_for_bc_rot,
chunk_size=chunk_size,
chunk_offset_offsets=chunk_pair_offset_for_block_pair,
chunk_offsets=chunk_pair_offset,
energy1b=energy1b,
energy2b=energy2b,
)

return (
packer_energy_tables,
rotamer_for_nonmolten_block,
n_molten_blocks_per_pose,
bc_rot_offset_for_molten_block,
bc_rot_to_orig_rot,
end_time2,
end_time3,
)
Loading
Loading