Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions DyneTrion/inference_DyneTrion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
from omegaconf import OmegaConf
from torch.utils import data
from typing import Dict
import concurrent.futures

from src.data import DyneTrion_data_loader_dynamic
from src.data import utils as du
import DyneTrion.train_DyneTrion as train_DyneTrion





class Evaluator:
def __init__(
self,
Expand All @@ -44,10 +41,8 @@ def __init__(
self._exp_conf = conf.experiment

# Set-up GPU
if torch.cuda.is_available():
self.device = 'cuda:0'
else:
self.device = 'cpu'
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self._conf.experiment.device = self.device
self._log.info(f'Using device: {self.device}')
# model weight
self._weights_path = self._eval_conf.weights_path
Expand All @@ -57,8 +52,6 @@ def __init__(
self._log.info(f'Saving results to {self._output_dir}')
# Load models and experiment
self._load_ckpt(conf_overrides)




def _load_ckpt(self, conf_overrides):
Expand Down Expand Up @@ -86,6 +79,9 @@ def _load_ckpt(self, conf_overrides):

self.model = self.model.to(self.device)
self.model.eval()

self.model = torch.compile(self.model, dynamic=True)

self.diffuser = self.exp.diffuser

self._log.info(f"Loading model Successfully from {self._weights_path}!!!")
Expand All @@ -107,8 +103,42 @@ def start_evaluation(self):
# we need to call the MD simulation to get the data
# maybe add some func in the dateset class


print("Preparing data and starting warmup...")
test_dataset = self.create_dataset(is_random=self._conf.eval.random_sample)

num_to_run = len(test_dataset)
print(f"Total proteins scheduled for inference: {num_to_run}")

executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
# max_seq_len
current_batch_df = test_dataset.csv.iloc[:num_to_run]
max_idx_in_batch = current_batch_df['seq_len'].idxmax()
max_len = current_batch_df['seq_len'].max()
pdb_id_of_max = current_batch_df.loc[max_idx_in_batch, 'pdb_id']
relative_idx = current_batch_df.index.get_loc(max_idx_in_batch)

print(f"==== Warmup: Using protein with max sequence length [ID: {pdb_id_of_max}, Length: {max_len}] ====")

# warmup 2 steps
with torch.no_grad():
warmup_feats, _ = test_dataset._get_row(relative_idx)
for k, v in warmup_feats.items():
if torch.is_tensor(v):
warmup_feats[k] = v.to(self.device)

f_time, l_len = warmup_feats['res_mask'].shape
z_rot_all = torch.randn(100, f_time, l_len, 3, device=self.device)
z_trans_all = torch.randn(100, f_time, l_len, 3, device=self.device)

self.exp.inference_fn(warmup_feats,num_t=2,min_t=0.01,aux_traj=True,
z_rot_all=z_rot_all,
z_trans_all=z_trans_all
)
torch.cuda.synchronize()
# torch.cuda.empty_cache()
print("==== [Warmup] Complete. GPU memory successfully allocated. ====")

future = executor.submit(test_dataset._get_row, 0)

eval_dir = self._output_dir
os.makedirs(eval_dir, exist_ok=True)
Expand All @@ -123,10 +153,16 @@ def start_evaluation(self):
self.exp._set_seed(42)
pdb_base_path, ref_base_path = self.exp._prepare_extension_eval_dirs(eval_dir)
extrapolation_time = self.exp._conf.eval.extrapolation_time
for i in range(len(test_dataset)):
valid_feats, pdb_names = test_dataset._get_row(i)


for i in range(num_to_run):
valid_feats, pdb_names = future.result()
seq_len = valid_feats['aatype'].shape[-1]
print(f"\n>>>> [Progress: {i+1}] Processing PDB: {pdb_names} | Length: {seq_len} <<<<")
if i + 1 < num_to_run:
future = executor.submit(test_dataset._get_row, i + 1)
for k,v in valid_feats.items():
valid_feats[k] = v.unsqueeze(0)
valid_feats[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
self.exp._process_one_protein_extrapolation(
extrapolation_time,
valid_feats,
Expand All @@ -135,12 +171,15 @@ def start_evaluation(self):
pdb_base_path,
device=self.device,
noise_scale=self.exp._exp_conf.noise_scale,
executor=executor,
)


executor.shutdown(wait=True)

@hydra.main(version_base=None, config_path="./config", config_name="eval_DyneTrion")
def run(conf: DictConfig) -> None:

torch.set_float32_matmul_precision('high')

# Read model checkpoint.
print('Starting inference')
start_time = time.time()
Expand Down
Loading