Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions python/prewarm/src/prewarm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@
# Pattern to strip version specifiers from dependency strings.
_VERSION_SPEC_RE = re.compile(r"[<>=!~;\[\]].*$")

# Well-known pip-package → import-name mappings for packages where the
# import name doesn't match the normalized package name. This is a
# best-effort table for common data-science and notebook packages;
# it doesn't need to be exhaustive — unmapped packages fall through
# to the hyphen→underscore heuristic which works for most packages.
_KNOWN_IMPORT_NAMES: dict[str, str] = {
"pillow": "PIL",
"scikit-learn": "sklearn",
"scikit-image": "skimage",
"python-dateutil": "dateutil",
"pyyaml": "yaml",
"beautifulsoup4": "bs4",
"opencv-python": "cv2",
"opencv-python-headless": "cv2",
"attrs": "attr",
"python-dotenv": "dotenv",
"protobuf": "google.protobuf",
"google-cloud-storage": "google.cloud.storage",
"google-cloud-bigquery": "google.cloud.bigquery",
"psycopg2-binary": "psycopg2",
}


def warm(
modules: list[str],
Expand Down Expand Up @@ -88,18 +110,28 @@ def warm(
def normalize_module_name(spec: str) -> str | None:
"""Convert a dependency specifier to a Python import name.

Strips version specifiers (``>=``, ``==``, etc.), extras (``[extra]``),
and replaces hyphens with underscores. Returns ``None`` if the result
is not a valid Python identifier.
First checks ``_KNOWN_IMPORT_NAMES`` for well-known packages where
the import name differs from the package name. Falls back to
stripping version specifiers and replacing hyphens with underscores.

>>> normalize_module_name("numpy>=1.24")
'numpy'
>>> normalize_module_name("scikit-learn>=1.0")
'scikit_learn'
'sklearn'
>>> normalize_module_name("Pillow")
'PIL'
>>> normalize_module_name("")
"""
name = _VERSION_SPEC_RE.sub("", spec).strip().replace("-", "_")
if not name or not name.isidentifier():
pkg = _VERSION_SPEC_RE.sub("", spec).strip()
if not pkg:
return None
# Check known mappings (case-insensitive lookup on the raw package name)
known = _KNOWN_IMPORT_NAMES.get(pkg.lower())
if known:
return known
# Fallback: hyphen → underscore heuristic
name = pkg.replace("-", "_")
if not name.isidentifier():
return None
return name

Expand Down
20 changes: 17 additions & 3 deletions python/prewarm/tests/test_prewarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_collect_modules_normalizes_specs():

result = _collect_modules(["numpy>=1.24", "scikit-learn>=1.0"], include_conda=False)
assert "numpy" in result
assert "scikit_learn" in result
assert "sklearn" in result
assert "numpy>=1.24" not in result


Expand All @@ -63,10 +63,24 @@ def test_normalize_module_name():
from prewarm import normalize_module_name

assert normalize_module_name("numpy>=1.24") == "numpy"
assert normalize_module_name("scikit-learn>=1.0") == "scikit_learn"
assert normalize_module_name("pandas") == "pandas"
assert normalize_module_name("Pillow[extra]") == "Pillow"
assert normalize_module_name("") is None
# Fallback heuristic: hyphen → underscore
assert normalize_module_name("some-package") == "some_package"


def test_normalize_known_packages():
"""Well-known packages map to their correct import names."""
from prewarm import normalize_module_name

assert normalize_module_name("scikit-learn>=1.0") == "sklearn"
assert normalize_module_name("Pillow") == "PIL"
assert normalize_module_name("pillow[extra]") == "PIL"
assert normalize_module_name("pyyaml") == "yaml"
assert normalize_module_name("beautifulsoup4") == "bs4"
assert normalize_module_name("opencv-python") == "cv2"
assert normalize_module_name("python-dateutil") == "dateutil"
assert normalize_module_name("psycopg2-binary") == "psycopg2"


def test_build_warmup_script_critical_imports_not_wrapped():
Expand Down
Loading