Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions python/prewarm/src/prewarm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
import compileall
import contextlib
import importlib
import re

# Critical packages that MUST import successfully for the environment to be
# considered valid. If any of these fail, the warmup script exits non-zero
# so the daemon does not admit a broken environment to the pool.
CRITICAL_MODULES = [
"ipykernel",
"IPython",
]

# Base packages always imported during warmup — these are the core
# notebook runtime dependencies whose first-import is expensive.
# Failures are non-fatal (try/except) since the env is still usable.
BASE_MODULES = [
"ipykernel",
"IPython",
"ipywidgets",
"anywidget",
"nbformat",
Expand All @@ -39,6 +47,9 @@
("ipykernel.comm", "CommManager"),
]

# Pattern to strip version specifiers from dependency strings.
_VERSION_SPEC_RE = re.compile(r"[<>=!~;\[\]].*$")


def warm(
modules: list[str],
Expand Down Expand Up @@ -74,17 +85,40 @@ def warm(
_warm_directly(all_modules, include_conda=include_conda)


def normalize_module_name(spec: str) -> str | None:
"""Convert a dependency specifier to a Python import name.

Strips version specifiers (``>=``, ``==``, etc.), extras (``[extra]``),
and replaces hyphens with underscores. Returns ``None`` if the result
is not a valid Python identifier.

>>> normalize_module_name("numpy>=1.24")
'numpy'
>>> normalize_module_name("scikit-learn>=1.0")
'scikit_learn'
>>> normalize_module_name("")
"""
name = _VERSION_SPEC_RE.sub("", spec).strip().replace("-", "_")
if not name or not name.isidentifier():
return None
return name


def _compile_site_packages(path: str) -> None:
"""Pre-compile all .py files in site-packages to .pyc."""
compileall.compile_dir(path, quiet=2, workers=0)


def _collect_modules(extra: list[str], *, include_conda: bool = False) -> list[str]:
"""Assemble the full module list: base + conda (optional) + user extras."""
"""Assemble the full module list: base + conda (optional) + normalized user extras."""
modules = list(BASE_MODULES)
if include_conda:
modules.extend(CONDA_MODULES)
modules.extend(extra)
# Normalize user-supplied dependency specifiers to import names
for spec in extra:
name = normalize_module_name(spec)
if name:
modules.append(name)
# Deduplicate while preserving order
seen: set[str] = set()
result: list[str] = []
Expand Down Expand Up @@ -143,22 +177,22 @@ def build_warmup_script(
if site_packages:
lines.append(f"compileall.compile_dir({site_packages!r}, quiet=2, workers=0)")

# Phase 2: imports
# Phase 2: critical imports — these MUST succeed or the script exits non-zero,
# preventing a broken environment from being admitted to the pool.
for m in CRITICAL_MODULES:
lines.append(f"import {m}")

# Deep imports that validate the kernel runtime is functional
lines.append("from ipykernel.kernelbase import Kernel")
lines.append("from ipykernel.ipkernel import IPythonKernel")

# Phase 3: non-critical imports — failures are silently skipped
lines.append("import importlib")

all_modules = _collect_modules(extra_modules, include_conda=include_conda)
for m in all_modules:
lines.append(f"try:\n importlib.import_module({m!r})\nexcept Exception:\n pass")

# Deep imports (always — ipykernel classes)
lines.append(
"try:\n"
" from ipykernel.kernelbase import Kernel\n"
" from ipykernel.ipkernel import IPythonKernel\n"
"except Exception:\n"
" pass"
)

if include_conda:
for mod, attr in CONDA_DEEP_IMPORTS:
lines.append(f"try:\n from {mod} import {attr}\nexcept Exception:\n pass")
Expand Down
42 changes: 40 additions & 2 deletions python/prewarm/tests/test_prewarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,46 @@ def test_collect_modules_deduplicates():
"""Duplicate modules should be removed while preserving order."""
from prewarm import _collect_modules

result = _collect_modules(["ipykernel", "numpy"], include_conda=False)
assert result.count("ipykernel") == 1
result = _collect_modules(["ipywidgets", "numpy"], include_conda=False)
assert result.count("ipywidgets") == 1


def test_collect_modules_normalizes_specs():
"""Version specifiers and hyphens are normalized to import names."""
from prewarm import _collect_modules

result = _collect_modules(["numpy>=1.24", "scikit-learn>=1.0"], include_conda=False)
assert "numpy" in result
assert "scikit_learn" in result
assert "numpy>=1.24" not in result


def test_normalize_module_name():
"""normalize_module_name strips specs and converts hyphens."""
from prewarm import normalize_module_name

assert normalize_module_name("numpy>=1.24") == "numpy"
assert normalize_module_name("scikit-learn>=1.0") == "scikit_learn"
assert normalize_module_name("pandas") == "pandas"
assert normalize_module_name("Pillow[extra]") == "Pillow"
assert normalize_module_name("") is None


def test_build_warmup_script_critical_imports_not_wrapped():
"""Critical imports (ipykernel, IPython) must NOT be in try/except."""
from prewarm import build_warmup_script

script = build_warmup_script([], include_conda=False)
# Critical imports should be bare import statements
assert "import ipykernel" in script
assert "import IPython" in script
# They should NOT be wrapped in try/except
lines = script.split("\n")
for i, line in enumerate(lines):
if line.strip() == "import ipykernel":
assert i == 0 or "try" not in lines[i - 1]
if line.strip() == "import IPython":
assert i == 0 or "try" not in lines[i - 1]


def test_build_warmup_script_basic():
Expand Down
Loading