Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
Changelog
==========

Unreleased
----------------------------
+ Refactor SLURM account, partition, qos handling logic for cleaner code
+ Increase requested Slurm memory by 1.5x after out-of-memory task failures
and log when retry memory is increased.
+ Make ``dynamic_partition`` rules with missing runtime attributes fail to
match, allowing optional attributes such as ``gpuCount`` to select rules.

version 0.6.0
----------------------------
+ Add support for ``slurm_qos`` and ``slurm_qos_gpu`` runtime options.
Expand Down
17 changes: 17 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,20 @@ modify the partition, the comment, or something else.
"args": "--comment default"
}
]

Rules with conditions only match when all referenced runtime attributes are
present. This allows rules to switch arguments based on optional runtime values
such as ``gpuCount``:

.. code-block:: ini

[slurm]
dynamic_partition = [
{
"gpuCount__ge": 1,
"args": "--partition gpu --account gpu_account --qos gpu_qos"
},
{
"args": "--partition cpu --account cpu_account --qos cpu_qos"
}
]
97 changes: 70 additions & 27 deletions src/miniwdl_slurm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import subprocess
import sys
from contextlib import ExitStack
from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Optional, Set, Union

from WDL import Type, Value
from WDL.runtime import config
Expand All @@ -33,6 +33,13 @@


class SlurmSingularity(SingularityContainer):
OOM_EXIT_CODES: Set[int] = {137, 253}

def __init__(self, cfg: config.Loader, run_id: str, host_dir: str) -> None:
super().__init__(cfg, run_id, host_dir)
self._last_exit_code: Optional[int] = None
self._oom_retry_count = 0

@classmethod
def global_init(cls, cfg: config.Loader, logger: logging.Logger) -> None:
# Set resources to maxsize. The base class (_SubProcessScheduler)
Expand Down Expand Up @@ -120,7 +127,7 @@ def process_runtime(self,
Type.String()).value
self.runtime_values["slurm_constraint"] = slurm_constraint

def _slurm_invocation(self):
def _slurm_invocation(self, logger: logging.Logger):
# We use sbatch --wait as this makes the submitted job behave like a
# local job.
# Using sbatch also makes sure the resources are requested, even
Expand Down Expand Up @@ -148,33 +155,29 @@ def _slurm_invocation(self):
# If no gpuType is given, use the default GPU type.
sbatch_args.extend(["--gres", f"gpu:{gpuCount}"])

account = self.runtime_values.get("slurm_account", None)
account_gpu = self.runtime_values.get("slurm_account_gpu", None)
if gpuCount is not None and account_gpu is not None:
sbatch_args.extend(["--account", account_gpu])
elif account is not None:
sbatch_args.extend(["--account", account])
if gpuCount is not None:
partition = self.runtime_values.get("slurm_partition_gpu", None)
qos = self.runtime_values.get("slurm_qos_gpu", None)
account = self.runtime_values.get("slurm_account_gpu", None)
else:
partition = self.runtime_values.get("slurm_partition", None)
qos = self.runtime_values.get("slurm_qos", None)
account = self.runtime_values.get("slurm_account", None)

partition = self.runtime_values.get("slurm_partition", None)
partition_gpu = self.runtime_values.get("slurm_partition_gpu", None)
if gpuCount is not None and partition_gpu is not None:
sbatch_args.extend(["--partition", partition_gpu])
elif partition is not None:
if partition is not None:
sbatch_args.extend(["--partition", partition])

qos = self.runtime_values.get("slurm_qos", None)
qos_gpu = self.runtime_values.get("slurm_qos_gpu", None)
if gpuCount is not None and qos_gpu is not None:
sbatch_args.extend(["--qos", qos_gpu])
elif qos is not None:
if account is not None:
sbatch_args.extend(["--account", account])
if qos is not None:
sbatch_args.extend(["--qos", qos])

cpu = self.runtime_values.get("cpu", None)
if cpu is not None:
sbatch_args.extend(["--cpus-per-task", str(cpu)])

memory = self.runtime_values.get("memory_reservation", None)
memory = self._retry_adjusted_memory_reservation()
if memory is not None:
self._log_retry_memory_adjustment(logger, memory)
# Round to the nearest megabyte.
sbatch_args.extend(["--mem", f"{round(memory / (1024 ** 2))}M"])

Expand Down Expand Up @@ -207,7 +210,7 @@ def _run_invocation(self, logger: logging.Logger, cleanup: ExitStack,
image: str) -> List[str]:
singularity_command = super()._run_invocation(logger, cleanup, image)

slurm_invocation = self._slurm_invocation()
slurm_invocation = self._slurm_invocation(logger)
slurm_invocation.extend(singularity_command)
logger.info("Slurm invocation: " + ' '.join(
shlex.quote(part) for part in slurm_invocation))
Expand All @@ -220,9 +223,12 @@ def _run(self,
) -> int:
# Line copied from base class as value is not publicly exposed.
cli_log_filename = os.path.join(self.host_dir, f"{self.cli_name}.log.txt")
exit_code = None
try:
return super()._run(logger, terminating, command)
exit_code = super()._run(logger, terminating, command)
return exit_code
finally:
self._record_exit_code(exit_code)
if terminating(): # Cancel the job if terminating
with open(cli_log_filename, "rt") as submit_log:
# "job_id" or "job_id;cluster_name" are output with --parsable.
Expand All @@ -237,18 +243,55 @@ def _run(self,
'memory': 'memory_reservation',
}

def _record_exit_code(self, exit_code: Optional[int]) -> None:
self._last_exit_code = exit_code
if exit_code in self.OOM_EXIT_CODES:
self._oom_retry_count += 1
else:
self._oom_retry_count = 0

def _retry_adjusted_memory_reservation(self) -> Union[int, float, None]:
memory = self.runtime_values.get("memory_reservation", None)
if memory is not None and self.try_counter > 1 and self._oom_retry_count:
memory *= 1.5 ** self._oom_retry_count
return memory

def _log_retry_memory_adjustment(self, logger: logging.Logger,
adjusted_memory: Union[int, float]) -> None:
original_memory = self.runtime_values.get("memory_reservation", None)
if original_memory is None or adjusted_memory == original_memory:
return

logger.info(
"Retrying with increased Slurm memory reservation: "
f"try {self.try_counter}; "
f"consecutive out-of-memory exit codes seen {self._oom_retry_count}; "
f"previous exit code {self._last_exit_code}; "
f"{round(original_memory / (1024 ** 2))}M -> "
f"{round(adjusted_memory / (1024 ** 2))}M"
)

def _rule_runtime_value(self, attr_name: str) -> Union[int, float, None]:
if attr_name == "memory_reservation":
return self._retry_adjusted_memory_reservation()

runtime_value = self.runtime_values.get(attr_name)
if isinstance(runtime_value, (int, float)):
return runtime_value
return None

def _rules_match(self, rule: dict[str, Union[int, float]]):
for rule_pair, expected_value in rule.items():
if '__' not in rule_pair:
continue

(attribute, comparator) = rule_pair.split('__', 1)
attr_name = self.attribute_lookup.get(attribute, attribute)
runtime_value = self.runtime_values.get(attr_name)
if runtime_value is not None:
if not self._rule_pair_matches(expected_value, comparator,
runtime_value):
return False
runtime_value = self._rule_runtime_value(attr_name)
if runtime_value is None:
return False
if not self._rule_pair_matches(expected_value, comparator, runtime_value):
return False
return True

def _rule_pair_matches(self, value: Union[int, float], comparator: str,
Expand Down
Loading