Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,30 @@ def is_new_package(self) -> bool:
def summary(self) -> str:
"""Human-readable summary of the change."""
if self.is_new_package:
return f"NEW PACKAGE: {self.new_critical} critical, {self.new_high} high"
base = f"NEW PACKAGE: {self.new_critical} critical, {self.new_high} high"

# Add info about transitive dependencies with high/critical vulns
transitive_vulns = [v for v in self.new_vulnerabilities if not v.is_direct]
if transitive_vulns:
# Group by package
by_package = {}
for v in transitive_vulns:
if v.severity.lower() in ["critical", "high"]:
pkg = v.package_name
if pkg not in by_package:
by_package[pkg] = {"critical": 0, "high": 0}
by_package[pkg][v.severity.lower()] += 1

if by_package:
transitive_parts = []
for pkg, counts in sorted(by_package.items()):
if counts["critical"] > 0 or counts["high"] > 0:
transitive_parts.append(f"{pkg} ({counts['critical']}C/{counts['high']}H)")

if transitive_parts:
base += f" (from: {', '.join(transitive_parts)})"

return base

crit_delta = self.new_critical - self.old_critical
high_delta = self.new_high - self.old_high
Expand Down Expand Up @@ -246,6 +269,7 @@ def evaluate_package_change(
Evaluate security impact of a single package version change.

Compares vulnerability counts between old and new versions.
Includes transitive dependencies for both new and updated packages.

Args:
package_change: The package change from git diff
Expand All @@ -256,9 +280,9 @@ def evaluate_package_change(
"""
pkg_name = package_change.name

# Get vulnerabilities for new version from current scan
new_vulns = current_scan.get_vulns_for_package(pkg_name)
new_counts = current_scan.count_severity_for_package(pkg_name)
# Get vulnerabilities for new version INCLUDING transitive dependencies
new_vulns = current_scan.get_vulns_for_package_tree(pkg_name)
new_counts = current_scan.count_severity_for_package_tree(pkg_name)

delta = PackageSecurityDelta(
package_name=pkg_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,32 @@ def count_severity_for_package(self, package_name: str) -> Dict[str, int]:
"medium": sum(1 for v in vulns if v.severity.lower() == "medium"),
"low": sum(1 for v in vulns if v.severity.lower() == "low"),
}

def get_vulns_for_package_tree(self, package_name: str) -> List[DependencyVulnerability]:
"""
Get all vulnerabilities for a package and its transitive dependencies.

Includes vulnerabilities where the package appears anywhere in the dependency path.
For example, if package_name is "express", this returns:
- Direct express vulnerabilities
- Vulnerabilities in qs (if express depends on qs)
- Vulnerabilities in fresh (if express depends on fresh)
etc.
"""
return [
v for v in self.vulnerabilities
if any(package_name in dep for dep in v.dependency_path)
]

def count_severity_for_package_tree(self, package_name: str) -> Dict[str, int]:
"""Count vulnerabilities by severity for a package and its transitive dependencies."""
vulns = self.get_vulns_for_package_tree(package_name)
return {
"critical": sum(1 for v in vulns if v.severity.lower() == "critical"),
"high": sum(1 for v in vulns if v.severity.lower() == "high"),
"medium": sum(1 for v in vulns if v.severity.lower() == "medium"),
"low": sum(1 for v in vulns if v.severity.lower() == "low"),
}


def run_snyk_cli(args: List[str], timeout: int = 300) -> tuple[int, str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,40 @@
LOG_DIR = os.environ.get("SNYK_HOOK_LOG_DIR", "/tmp")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious: where is SNYK_HOOK_LOG_DIR set?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not set anywhere explicitly right now, so presumably this line will always set LOG_DIR to /tmp unless there is an environment variable set for SNYK_HOOK_LOG_DIR by the user. I think having this as the default is fine since it'll fallback to tmp



# =============================================================================
# PATH VALIDATION
# =============================================================================

def safe_open_within_base(file_path: Path, base_dir: Path, mode: str = 'r'):
"""
Safely open a file after validating it's within the base directory.

Prevents path traversal attacks by ensuring the resolved path
stays within the allowed base directory.

Args:
file_path: The path to the file to open
base_dir: The base directory that file_path must be within
mode: File open mode (default: 'r')

Returns:
File handle

Raises:
ValueError: If path would escape the base directory
"""
resolved_path = file_path.resolve()
resolved_base = base_dir.resolve()

try:
resolved_path.relative_to(resolved_base)
except ValueError:
raise ValueError(f"Path traversal detected: {file_path} resolves outside of {resolved_base}")

# Path is validated via relative_to() check above - safe to open
return open(resolved_path, mode) # noqa: SIM115


# =============================================================================
# LOGGING
# =============================================================================
Expand Down Expand Up @@ -257,33 +291,86 @@ def scan_sca(workspace: str) -> Dict[str, Any]:
log(f"SCA scan failed: {result.error_message}")
return {"status": "error", "error": result.error_message}

# Group vulnerabilities by package
# Group vulnerabilities by TOP-LEVEL package (first in dependency path)
# This captures the full dependency tree for each top-level package
packages: Dict[str, Dict[str, Any]] = {}

for v in result.vulnerabilities:
pkg = v.package_name
version = v.installed_version

key = f"{pkg}@{version}"
if key not in packages:
packages[key] = {
"package": pkg,
"version": version,
"vulnerabilities": [],
"severity_counts": {"critical": 0, "high": 0, "medium": 0, "low": 0}
}

packages[key]["vulnerabilities"].append({
"id": v.id,
"title": v.title,
"severity": v.severity,
"fixed_version": v.fixed_version,
"cve": v.cve
})
# First, initialize entries for all top-level packages from package.json
# This ensures we cache even packages with 0 vulnerabilities
try:
workspace_path = Path(workspace).resolve()
pkg_json_path = workspace_path / "package.json"

sev = v.severity.lower()
if sev in packages[key]["severity_counts"]:
packages[key]["severity_counts"][sev] += 1
if pkg_json_path.exists():
with safe_open_within_base(pkg_json_path, workspace_path) as f:
pkg_json = json.load(f)
deps = pkg_json.get("dependencies", {})

# Get actual installed versions from node_modules
for pkg_name in deps:
try:
# Validate package name doesn't contain path traversal
if ".." in pkg_name or pkg_name.startswith("/"):
log(f"Skipping suspicious package name: {pkg_name}")
continue

pkg_path = workspace_path / "node_modules" / pkg_name / "package.json"

if pkg_path.exists():
with safe_open_within_base(pkg_path, workspace_path) as pf:
pkg_info = json.load(pf)
version = pkg_info.get("version")
if version:
key = f"{pkg_name}@{version}"
packages[key] = {
"package": pkg_name,
"version": version,
"vulnerabilities": [],
"severity_counts": {"critical": 0, "high": 0, "medium": 0, "low": 0}
}
except ValueError as e:
log(f"Path validation error for {pkg_name}: {e}")
except Exception:
pass
except ValueError as e:
log(f"Path validation error: {e}")
except Exception:
pass

# Now add vulnerabilities to the appropriate packages
for v in result.vulnerabilities:
# Find the top-level package from dependency path
if len(v.dependency_path) >= 2:
top_level = v.dependency_path[1] # e.g., 'express@4.14.1'

if '@' in top_level:
parts = top_level.rsplit('@', 1)
if len(parts) == 2:
pkg, version = parts
key = f"{pkg}@{version}"

if key not in packages:
packages[key] = {
"package": pkg,
"version": version,
"vulnerabilities": [],
"severity_counts": {"critical": 0, "high": 0, "medium": 0, "low": 0}
}

packages[key]["vulnerabilities"].append({
"id": v.id,
"title": v.title,
"severity": v.severity,
"fixed_version": v.fixed_version,
"cve": v.cve,
"package_name": v.package_name,
"installed_version": v.installed_version,
"dependency_path": v.dependency_path
})

sev = v.severity.lower()
if sev in packages[key]["severity_counts"]:
packages[key]["severity_counts"][sev] += 1

# Cache each package
total_vulns = 0
Expand Down Expand Up @@ -365,4 +452,3 @@ def main() -> None:
import traceback
log(traceback.format_exc())
sys.exit(1)

Original file line number Diff line number Diff line change
Expand Up @@ -417,18 +417,19 @@ def get_sca_with_cache(
print_cache_status(True, f"{pkg_name}@{version}")

# Convert cached vulns back to DependencyVulnerability objects
# The cache now includes the full dependency tree
for v in cached.vulnerabilities:
all_vulns.append(DependencyVulnerability(
id=v.get("id", "unknown"),
title=v.get("title", "Unknown"),
severity=v.get("severity", "medium"),
package_name=pkg_name,
installed_version=version,
package_name=v.get("package_name", pkg_name),
installed_version=v.get("installed_version", version),
fixed_version=v.get("fixed_version"),
cve=v.get("cve"),
cvss_score=None,
is_direct=True,
dependency_path=[]
is_direct=v.get("package_name") == pkg_name,
dependency_path=v.get("dependency_path", [])
))
else:
cache_misses += 1
Expand All @@ -448,23 +449,49 @@ def get_sca_with_cache(
if v.id not in cached_ids:
all_vulns.append(v)

# Cache results per package
packages_cached: Dict[str, List[Dict]] = {}
# Cache results per TOP-LEVEL package with their full dependency trees
# Group vulnerabilities by the first package in the dependency path (the top-level package)
packages_cached: Dict[str, Dict] = {}

for v in fresh_result.vulnerabilities:
key = f"{v.package_name}@{v.installed_version}"
if key not in packages_cached:
packages_cached[key] = []
packages_cached[key].append({
"id": v.id,
"title": v.title,
"severity": v.severity,
"fixed_version": v.fixed_version,
"cve": v.cve
})
# Find the top-level package (first package after project name in dependency path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non blocknig.

The pre-commit is getting kinda big. Maybe we should move this chunk into one of the existing files or make a new file for it

if len(v.dependency_path) >= 2:
# dependency_path looks like: ['snyk-test-repo@1.0.0', 'express@4.14.1', ...]
top_level = v.dependency_path[1] # e.g., 'express@4.14.1'

# Extract package name and version
if '@' in top_level:
parts = top_level.rsplit('@', 1)
if len(parts) == 2:
top_pkg_name, top_pkg_version = parts
key = f"{top_pkg_name}@{top_pkg_version}"

if key not in packages_cached:
packages_cached[key] = {
"package": top_pkg_name,
"version": top_pkg_version,
"vulnerabilities": []
}

packages_cached[key]["vulnerabilities"].append({
"id": v.id,
"title": v.title,
"severity": v.severity,
"fixed_version": v.fixed_version,
"cve": v.cve,
"package_name": v.package_name,
"installed_version": v.installed_version,
"dependency_path": v.dependency_path
})

# Cache each top-level package with its full tree
if DEBUG:
print(f"Caching {len(packages_cached)} top-level packages")
for key in packages_cached:
print(f" {key}: {len(packages_cached[key]['vulnerabilities'])} vulnerabilities")

for key, vulns in packages_cached.items():
pkg, ver = key.rsplit("@", 1)
cache.set_sca_result(pkg, ver, vulns)
for key, data in packages_cached.items():
cache.set_sca_result(data["package"], data["version"], data["vulnerabilities"])
else:
print_warning(f"SCA scan failed: {fresh_result.error_message}")

Expand Down