Skip to content

Commit 566cfa7

Browse files
authored
Merge pull request #354 from CCPBioSim/306-dask-parallel-implementation
Add configurable Dask frame execution with SLURM-backed HPC support
2 parents fd68aaa + 6ace042 commit 566cfa7

17 files changed

Lines changed: 3120 additions & 780 deletions

CodeEntropy/config/argparse.py

Lines changed: 232 additions & 66 deletions
Large diffs are not rendered by default.

CodeEntropy/config/runtime.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from rich.text import Text
3434

3535
from CodeEntropy.config.argparse import ConfigResolver
36+
from CodeEntropy.core.dask_clusters import HPCDaskManager
3637
from CodeEntropy.core.logging import LoggingConfig
3738
from CodeEntropy.entropy.workflow import EntropyWorkflow
3839
from CodeEntropy.levels.dihedrals import ConformationStateBuilder
@@ -223,8 +224,9 @@ def run_entropy_workflow(self) -> None:
223224
224225
This method:
225226
- Sets up logging and prints the splash screen
226-
- Loads YAML config from CWD and parses CLI args
227+
- Loads YAML configuration from CWD and parses CLI args
227228
- Merges args with YAML per-run config
229+
- Optionally submits a master SLURM job and exits
228230
- Builds the MDAnalysis Universe (with optional force merging)
229231
- Validates user parameters
230232
- Constructs dependencies and executes EntropyWorkflow
@@ -256,6 +258,16 @@ def run_entropy_workflow(self) -> None:
256258

257259
args = self._config_manager.resolve(args, run_config)
258260

261+
if getattr(args, "submit", False):
262+
if os.environ.get("CODEENTROPY_SUBMITTED_JOB") == "1":
263+
run_logger.info(
264+
"Already running inside submitted SLURM job; "
265+
"continuing workflow."
266+
)
267+
else:
268+
HPCDaskManager(args).submit_master()
269+
return
270+
259271
log_level = (
260272
logging.DEBUG if getattr(args, "verbose", False) else logging.INFO
261273
)
@@ -298,6 +310,7 @@ def run_entropy_workflow(self) -> None:
298310
except Exception:
299311
logger.error("Run arguments at failure could not be serialized")
300312

313+
logger.exception("Fatal error during entropy calculation")
301314
raise RuntimeError("CodeEntropyRunner encountered an error") from exc
302315

303316
@staticmethod

CodeEntropy/core/dask_clusters.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""
2+
Helpers for setting up Dask clusters on HPC using SLURM.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import logging
8+
import os
9+
import shlex
10+
import subprocess
11+
import sys
12+
13+
import psutil
14+
from dask.distributed import Client
15+
from dask_jobqueue import SLURMCluster
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class HPCDaskManager:
21+
"""
22+
Manage SLURM-backed Dask clusters and submission utilities for HPC environments.
23+
"""
24+
25+
def __init__(self, args):
26+
"""
27+
Initialise HPCDaskManager with runtime arguments.
28+
29+
Args:
30+
args: Parsed CLI arguments containing HPC and conda configuration.
31+
"""
32+
self.args = args
33+
34+
def _conda_env(self) -> str:
35+
"""Determine the activated conda/mamba environment."""
36+
try:
37+
return os.environ["CONDA_DEFAULT_ENV"]
38+
except KeyError:
39+
logger.error("Please activate your conda/mamba environment.")
40+
raise SystemExit(1) from None
41+
42+
def _conda_exec(self) -> str:
43+
"""Determine whether conda or mamba should be used for activation."""
44+
if os.environ.get("MAMBA_EXE"):
45+
return "mamba"
46+
47+
if os.environ.get("CONDA_EXE"):
48+
return "conda"
49+
50+
logger.error(
51+
"Cannot determine your conda executable. "
52+
"Please make sure conda or mamba has been initialised."
53+
)
54+
raise SystemExit(1)
55+
56+
def _conda_path(self) -> str:
57+
"""Determine the path to the conda executable used for shell initialisation."""
58+
conda_exe = os.environ.get("CONDA_EXE")
59+
60+
if conda_exe:
61+
return conda_exe
62+
63+
logger.error("Please make sure conda is set up correctly.")
64+
raise SystemExit(1)
65+
66+
def resolve_conda_settings(self) -> None:
67+
"""
68+
Fill missing conda/mamba settings from the active environment.
69+
70+
Explicit user-provided values are preserved. Auto-detection is only used
71+
when values are missing.
72+
"""
73+
args = self.args
74+
75+
if not getattr(args, "conda_env", None):
76+
args.conda_env = self._conda_env()
77+
78+
if not getattr(args, "conda_exec", None):
79+
args.conda_exec = self._conda_exec()
80+
81+
if not getattr(args, "conda_path", None) or args.conda_path == "conda":
82+
args.conda_path = self._conda_path()
83+
84+
def check_slurm_env(self) -> None:
85+
"""
86+
Remove inherited SLURM environment variables that can break nested srun calls.
87+
88+
This is important when the master CodeEntropy process itself is already
89+
running inside a SLURM allocation and then launches Dask worker jobs.
90+
"""
91+
for variable in (
92+
"SLURM_CPU_BIND",
93+
"SLURM_MEM_PER_CPU",
94+
"SLURM_MEM_PER_GPU",
95+
"SLURM_MEM_PER_NODE",
96+
):
97+
os.environ.pop(variable, None)
98+
99+
def system_network_interface(self) -> str:
100+
"""
101+
Get the best candidate for the HPC network interface.
102+
103+
This deliberately follows the WaterEntropy-style behaviour and only
104+
selects from known HPC-safe interfaces. It avoids selecting arbitrary
105+
interfaces such as eno1, which may exist on the master node but not on
106+
worker nodes.
107+
"""
108+
hpc_nics = ["bond0", "ib0", "hsn0", "eth0"]
109+
interfaces = list(psutil.net_if_addrs().keys())
110+
111+
for iface in hpc_nics:
112+
if iface in interfaces:
113+
return iface
114+
115+
raise RuntimeError(
116+
"Could not find a known HPC network interface. "
117+
f"Available interfaces: {interfaces}. "
118+
"Expected one of: bond0, ib0, hsn0, eth0."
119+
)
120+
121+
def slurm_directives(self) -> tuple[list[str], list[str]]:
122+
"""
123+
Process additional SLURM directives and directives to skip.
124+
125+
Returns:
126+
Tuple containing extra directives and skipped directives.
127+
"""
128+
args = self.args
129+
extra: list[str] = []
130+
131+
if args.hpc_account:
132+
extra.append(f"--account={args.hpc_account}")
133+
if args.hpc_qos:
134+
extra.append(f"--qos={args.hpc_qos}")
135+
if args.hpc_constraint:
136+
extra.append(f"--constraint={args.hpc_constraint}")
137+
138+
skip = ["--mem"]
139+
140+
return extra, skip
141+
142+
def slurm_prologues(self) -> list[str]:
143+
"""
144+
Build environment setup commands for the SLURM worker job script.
145+
146+
Returns:
147+
List of shell commands executed before the Dask worker starts.
148+
"""
149+
args = self.args
150+
prologue: list[str] = []
151+
152+
for module_name in getattr(args, "hpc_modules", None) or []:
153+
prologue.append(f"module load {module_name}")
154+
155+
prologue.append("unset SLURM_MEM_PER_CPU")
156+
prologue.append("unset SLURM_MEM_PER_GPU")
157+
prologue.append("unset SLURM_MEM_PER_NODE")
158+
prologue.append("unset SLURM_CPU_BIND")
159+
160+
prologue.append(f'eval "$({args.conda_path} shell.bash hook)"')
161+
162+
if args.conda_exec == "mamba":
163+
prologue.append(f'eval "$({args.conda_exec} shell hook --shell bash)"')
164+
165+
prologue.append(f"{args.conda_exec} activate {args.conda_env}")
166+
prologue.append("export SLURM_CPU_FREQ_REQ=2250000")
167+
168+
return prologue
169+
170+
def configure_cluster(self) -> Client:
171+
"""
172+
Configure a SLURM-backed Dask cluster.
173+
174+
Returns:
175+
Dask distributed client connected to the SLURMCluster.
176+
"""
177+
args = self.args
178+
179+
self.resolve_conda_settings()
180+
181+
extra, skip = self.slurm_directives()
182+
prologue = self.slurm_prologues()
183+
iface = self.system_network_interface()
184+
185+
self.check_slurm_env()
186+
187+
cluster = SLURMCluster(
188+
cores=args.hpc_cores,
189+
processes=args.hpc_processes,
190+
memory=args.hpc_memory,
191+
queue=args.hpc_queue,
192+
job_directives_skip=skip,
193+
job_extra_directives=extra,
194+
python="srun python",
195+
walltime=args.hpc_walltime,
196+
shebang="#!/bin/bash --login",
197+
local_directory="$PWD",
198+
interface=iface,
199+
job_script_prologue=prologue,
200+
)
201+
202+
cluster.scale(jobs=args.hpc_nodes)
203+
204+
client = Client(cluster)
205+
206+
with open("dask-cluster-submit.sh", "w", encoding="utf-8") as f:
207+
f.write(cluster.job_script())
208+
209+
return client
210+
211+
def submit_master(self) -> None:
212+
"""
213+
Submit a SLURM job that runs the master CodeEntropy process.
214+
215+
This generates a temporary SLURM script and submits it via sbatch.
216+
"""
217+
self.resolve_conda_settings()
218+
219+
cli = list(sys.argv[1:])
220+
221+
if "--submit" in cli:
222+
idx = cli.index("--submit")
223+
cli.pop(idx)
224+
225+
if idx < len(cli) and str(cli[idx]).lower() in {"true", "false"}:
226+
cli.pop(idx)
227+
228+
script_name = "CodeEntropy-master-submit.sh"
229+
230+
with open(script_name, "w", encoding="utf-8") as f:
231+
f.write("#!/bin/bash --login\n\n")
232+
f.write("#SBATCH --job-name=codeentropy-master\n")
233+
f.write("#SBATCH --nodes=1\n")
234+
f.write("#SBATCH --ntasks=1\n")
235+
f.write("#SBATCH --cpus-per-task=2\n")
236+
f.write(f"#SBATCH --time={self.args.hpc_walltime}\n")
237+
f.write(f"#SBATCH --partition={self.args.hpc_queue}\n")
238+
f.write("#SBATCH --output=CodeEntropy-master-%j.out\n")
239+
f.write("#SBATCH --error=CodeEntropy-master-%j.err\n")
240+
241+
if self.args.hpc_account:
242+
f.write(f"#SBATCH --account={self.args.hpc_account}\n")
243+
244+
if self.args.hpc_qos:
245+
f.write(f"#SBATCH --qos={self.args.hpc_qos}\n")
246+
247+
if self.args.hpc_constraint:
248+
f.write(f"#SBATCH --constraint={self.args.hpc_constraint}\n")
249+
250+
f.write("\n")
251+
252+
for module_name in getattr(self.args, "hpc_modules", None) or []:
253+
f.write(f"module load {module_name}\n")
254+
255+
f.write("unset SLURM_MEM_PER_CPU\n")
256+
f.write("unset SLURM_MEM_PER_GPU\n")
257+
f.write("unset SLURM_MEM_PER_NODE\n")
258+
f.write("unset SLURM_CPU_BIND\n")
259+
260+
f.write(f'eval "$({self.args.conda_path} shell.bash hook)"\n')
261+
262+
if self.args.conda_exec == "mamba":
263+
f.write(f'eval "$({self.args.conda_exec} shell hook --shell bash)"\n')
264+
265+
f.write(f"{self.args.conda_exec} activate {self.args.conda_env}\n")
266+
f.write("export SLURM_CPU_FREQ_REQ=2250000\n")
267+
f.write("export CODEENTROPY_SUBMITTED_JOB=1\n\n")
268+
269+
command = " ".join(["srun", "CodeEntropy", shlex.join(cli)])
270+
f.write(f"{command}\n")
271+
272+
self.check_slurm_env()
273+
274+
try:
275+
result = subprocess.check_output(["sbatch", script_name])
276+
print(result.decode("utf-8"))
277+
except subprocess.CalledProcessError as exc:
278+
print(exc.output.decode("utf-8", errors="replace"))
279+
raise

CodeEntropy/entropy/workflow.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pandas as pd
2424

25+
from CodeEntropy.core.dask_clusters import HPCDaskManager
2526
from CodeEntropy.core.logging import LoggingConfig
2627
from CodeEntropy.entropy.graph import EntropyGraph
2728
from CodeEntropy.entropy.water import WaterEntropy
@@ -116,13 +117,61 @@ def execute(self) -> None:
116117
frame_selection=frame_selection,
117118
)
118119

119-
with self._reporter.progress(transient=False) as p:
120-
self._run_level_dag(shared_data, progress=p)
121-
self._run_entropy_graph(shared_data, progress=p)
120+
self._configure_parallel_frame_execution(shared_data)
121+
122+
try:
123+
with self._reporter.progress(transient=False) as p:
124+
self._run_level_dag(shared_data, progress=p)
125+
self._run_entropy_graph(shared_data, progress=p)
126+
finally:
127+
client = shared_data.get("dask_client")
128+
if client is not None:
129+
client.close()
122130

123131
self._finalize_molecule_results()
124132
self._reporter.log_tables()
125133

134+
def _configure_parallel_frame_execution(self, shared_data: SharedData) -> None:
135+
"""Attach a Dask client to shared_data if parallel frames are requested.
136+
137+
Supports:
138+
- Local Dask via --parallel_frames true / --use_dask true
139+
- SLURM-backed Dask via --hpc true
140+
"""
141+
use_parallel = bool(
142+
getattr(self._args, "parallel_frames", False)
143+
or getattr(self._args, "use_dask", False)
144+
or getattr(self._args, "hpc", False)
145+
)
146+
147+
if not use_parallel:
148+
return
149+
150+
if "dask_client" in shared_data:
151+
shared_data["parallel_frames"] = True
152+
return
153+
154+
if getattr(self._args, "hpc", False):
155+
client = HPCDaskManager(self._args).configure_cluster()
156+
shared_data["dask_client"] = client
157+
shared_data["parallel_frames"] = True
158+
return
159+
160+
try:
161+
from dask.distributed import Client
162+
except ImportError as exc:
163+
raise RuntimeError(
164+
"Parallel frame execution was requested, but dask.distributed "
165+
"is not installed."
166+
) from exc
167+
168+
shared_data["dask_client"] = Client(
169+
processes=True,
170+
n_workers=getattr(self._args, "dask_workers", None),
171+
threads_per_worker=getattr(self._args, "dask_threads_per_worker", 1),
172+
)
173+
shared_data["parallel_frames"] = True
174+
126175
def _build_frame_selection(self) -> FrameSelection:
127176
"""Build the workflow frame selection.
128177

0 commit comments

Comments
 (0)