Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/tools/test_generate_cmake_presets_cudacxx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json

import tools.generate_cmake_presets as presets


def test_generate_presets_prefers_cudacxx(monkeypatch, tmp_path):
output_path = tmp_path / "CMakeUserPresets.json"
cudacxx = tmp_path / "cu-bridge" / "bin" / "nvcc"

monkeypatch.setenv("CUDACXX", str(cudacxx))
monkeypatch.setattr(presets, "CUDA_HOME", str(tmp_path / "cuda"))
monkeypatch.setattr(presets.os.path, "exists", lambda path: True)
monkeypatch.setattr(presets, "which", lambda name: None)
monkeypatch.setattr(presets, "get_cpu_cores", lambda: 8)

presets.generate_presets(output_path=str(output_path), force_overwrite=True)

data = json.loads(output_path.read_text())
compiler = data["configurePresets"][0]["cacheVariables"]["CMAKE_CUDA_COMPILER"]
assert compiler == str(cudacxx)
7 changes: 6 additions & 1 deletion tools/generate_cmake_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False)

# Detect NVCC
nvcc_path = None
if CUDA_HOME:
cudacxx = os.environ.get("CUDACXX")
if cudacxx and os.path.exists(cudacxx):
nvcc_path = cudacxx
print(f"Found CUDA compiler via CUDACXX: {nvcc_path}")
Comment on lines +48 to +51

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the CUDACXX environment variable is set to a compiler command name (e.g., nvcc or clang++) rather than an absolute path, os.path.exists(cudacxx) will return False. This causes the script to ignore the user's specified compiler and fall back to detecting nvcc elsewhere, which is incorrect. Resolving the compiler using which(cudacxx) when it is not a direct path ensures that command names in the system PATH are correctly detected and respected.

Suggested change
cudacxx = os.environ.get("CUDACXX")
if cudacxx and os.path.exists(cudacxx):
nvcc_path = cudacxx
print(f"Found CUDA compiler via CUDACXX: {nvcc_path}")
cudacxx = os.environ.get("CUDACXX")
if cudacxx:
resolved_cudacxx = cudacxx if os.path.exists(cudacxx) else which(cudacxx)
if resolved_cudacxx:
nvcc_path = resolved_cudacxx
print(f"Found CUDA compiler via CUDACXX: {nvcc_path}")


if not nvcc_path and CUDA_HOME:
prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc")
if os.path.exists(prospective_path):
nvcc_path = prospective_path
Expand Down