Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/tools/test_generate_cmake_presets_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

import tools.generate_cmake_presets as presets


def test_generate_presets_accepts_parallelism_overrides(monkeypatch, tmp_path):
output_path = tmp_path / "CMakeUserPresets.json"

monkeypatch.setattr(presets, "CUDA_HOME", str(tmp_path))
monkeypatch.setattr(presets.os.path, "exists", lambda path: True)
monkeypatch.setattr(presets, "which", lambda name: None)
monkeypatch.setattr(presets, "get_cpu_cores", lambda: 64)

presets.generate_presets(
output_path=str(output_path),
force_overwrite=True,
cmake_jobs=3,
nvcc_threads=2,
)

data = json.loads(output_path.read_text())
cache = data["configurePresets"][0]["cacheVariables"]
assert cache["NVCC_THREADS"] == "2"
assert data["buildPresets"][0]["jobs"] == 3
33 changes: 29 additions & 4 deletions tools/generate_cmake_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def get_cpu_cores():
return multiprocessing.cpu_count()


def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False):
def generate_presets(
output_path="CMakeUserPresets.json",
force_overwrite=False,
cmake_jobs=None,
nvcc_threads=None,
):
"""Generates the CMakeUserPresets.json file."""

print("Attempting to detect your system configuration...")
Expand Down Expand Up @@ -74,8 +79,14 @@ def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False)

# Get CPU cores
cpu_cores = get_cpu_cores()
nvcc_threads = min(4, cpu_cores)
cmake_jobs = max(1, cpu_cores // nvcc_threads)
if nvcc_threads is None:
nvcc_threads = min(4, cpu_cores)
elif nvcc_threads < 1:
raise ValueError("nvcc_threads must be at least 1")
if cmake_jobs is None:
cmake_jobs = max(1, cpu_cores // nvcc_threads)
elif cmake_jobs < 1:
raise ValueError("cmake_jobs must be at least 1")
Comment on lines +89 to +96

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If get_cpu_cores() returns None (which multiprocessing.cpu_count() can do on some platforms) or returns 0 (which can happen in certain containerized or restricted environments), the current logic will raise a TypeError or a ZeroDivisionError.

Specifically, if cpu_cores is 0, nvcc_threads is set to 0 via min(4, 0). Then, calculating cmake_jobs using cpu_cores // nvcc_threads results in 0 // 0, which raises a ZeroDivisionError.

We can make this logic robust against both None and 0 values by defaulting cpu_cores to at least 1 and ensuring nvcc_threads is at least 1.

Suggested change
if nvcc_threads is None:
nvcc_threads = min(4, cpu_cores)
elif nvcc_threads < 1:
raise ValueError("nvcc_threads must be at least 1")
if cmake_jobs is None:
cmake_jobs = max(1, cpu_cores // nvcc_threads)
elif cmake_jobs < 1:
raise ValueError("cmake_jobs must be at least 1")
if nvcc_threads is None:
nvcc_threads = max(1, min(4, cpu_cores or 1))
elif nvcc_threads < 1:
raise ValueError("nvcc_threads must be at least 1")
if cmake_jobs is None:
cmake_jobs = max(1, (cpu_cores or 1) // nvcc_threads)
elif cmake_jobs < 1:
raise ValueError("cmake_jobs must be at least 1")

print(
f"Detected {cpu_cores} CPU cores. "
f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}."
Expand Down Expand Up @@ -171,11 +182,25 @@ def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cmake-jobs",
type=int,
help="Override build preset jobs instead of deriving it from CPU cores",
)
parser.add_argument(
"--nvcc-threads",
type=int,
help="Override NVCC_THREADS instead of deriving it from CPU cores",
)
parser.add_argument(
"--force-overwrite",
action="store_true",
help="Force overwrite existing CMakeUserPresets.json without prompting",
)

args = parser.parse_args()
generate_presets(force_overwrite=args.force_overwrite)
generate_presets(
force_overwrite=args.force_overwrite,
cmake_jobs=args.cmake_jobs,
nvcc_threads=args.nvcc_threads,
)