Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sieval/datasets/downloaders/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""SourceHandler Protocol + scheme registry. v0.1 ships ``hf:`` and ``url:``.
"""SourceHandler Protocol + scheme registry. v0.2 ships ``hf:``, ``url:``, ``local:``.

AI-Generated Code - Claude Haiku 4.5 (Anthropic)
"""
Expand Down Expand Up @@ -40,9 +40,10 @@ def _ensure_builtin_handlers() -> None:
if _builtin_registered:
return
from sieval.datasets.downloaders.hf import HFHandler
from sieval.datasets.downloaders.local import LocalHandler
from sieval.datasets.downloaders.url import URLHandler

for handler in (HFHandler(), URLHandler()):
for handler in (HFHandler(), URLHandler(), LocalHandler()):
if handler.scheme not in _HANDLERS:
register_handler(handler)
_builtin_registered = True
Expand Down
81 changes: 81 additions & 0 deletions sieval/datasets/downloaders/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""local scheme handler: stage a package-bundled file into ``dest_root/<name>/``.

For datasets whose corpus is generated once and committed inside the package
(under ``sieval/datasets/_data/``) rather than fetched from a remote. ``download``
copies the bundled file into the same ``{dest_root}/<dataset_name>/`` layout the
url/hf handlers use, so the runtime ``load(name_or_path)`` path is identical.

AI-Generated Code - Claude Opus 4.8 (1M context) (Anthropic)
"""

import shutil
from importlib.resources import files
from pathlib import Path
from posixpath import normpath

from sieval.core.datasets.meta import url_path_basename

# Bundled-data root inside the package; `local:<relpath>` resolves under here.
_DATA_ANCHOR = "sieval.datasets"
_DATA_SUBDIR = "_data"


class LocalHandler:
scheme = "local"

def download(
self,
source: str,
dest_root: Path,
dataset_name: str,
force: bool,
) -> None:
relpath = self._strip_scheme(source)
bundled = self._bundled_path(relpath)
target_dir = dest_root / dataset_name
target_dir.mkdir(parents=True, exist_ok=True)
target = target_dir / _basename(relpath)
if target.exists() and not force:
return
shutil.copyfile(bundled, target)

def is_downloaded(
self,
source: str,
dest_root: Path,
dataset_name: str,
) -> bool:
relpath = self._strip_scheme(source)
return (dest_root / dataset_name / _basename(relpath)).exists()

@staticmethod
def _strip_scheme(source: str) -> str:
if not source.startswith("local:"):
raise ValueError(f"Expected local: scheme, got {source!r}")
return source[len("local:") :]

@staticmethod
def _bundled_path(relpath: str) -> Path:
"""Resolve *relpath* under the package data root, rejecting traversal.

``local:`` must only ever read files committed inside the package, so an
absolute path or a ``..`` segment that would escape ``_data/`` is a hard
error rather than a silently-resolved path.
"""
if (
not relpath
or relpath.startswith("/")
or ".." in relpath.split("/")
or normpath(relpath) != relpath
):
raise ValueError(
f"local: path must be a normalized, package-relative path, "
f"got {relpath!r}"
)
return Path(str(files(_DATA_ANCHOR).joinpath(_DATA_SUBDIR, relpath)))


def _basename(relpath: str) -> str:
"""Filename the bundled file lands under; shares the url-handler primitive
so the on-disk name matches the ``url:`` convention."""
return url_path_basename(relpath) or "download"
11 changes: 7 additions & 4 deletions sieval/datasets/downloaders/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ def download(
with httpx.stream("GET", url, timeout=_TIMEOUT, follow_redirects=True) as r:
r.raise_for_status()
# Catches 2xx responses that dropped the connection mid-stream.
# Compare against on-the-wire bytes, not bytes written: when the
# server sends Content-Encoding (e.g. gzip), Content-Length is the
# compressed size while iter_bytes() yields the larger decompressed
# body, so comparing written bytes would falsely flag truncation.
expected = _parse_content_length(r.headers.get("content-length"))
written = 0
with tmp.open("wb") as f:
for chunk in r.iter_bytes(chunk_size=1 << 16):
f.write(chunk)
written += len(chunk)
if expected is not None and written != expected:
received = r.num_bytes_downloaded
if expected is not None and received != expected:
raise RuntimeError(
f"truncated download from {url}: "
f"Content-Length={expected} but wrote {written} bytes"
f"Content-Length={expected} but received {received} bytes"
)
tmp.replace(target)
except BaseException:
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/datasets/downloaders/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for the local: source handler.

AI-Generated Code - Claude Opus 4.8 (1M context) (Anthropic)
"""

from unittest.mock import patch

import pytest

from sieval.datasets.downloaders.local import LocalHandler, _basename


def test_scheme():
assert LocalHandler().scheme == "local"


def test_strip_scheme_rejects_wrong_scheme():
with pytest.raises(ValueError, match="Expected local: scheme"):
LocalHandler._strip_scheme("url:https://example.com/foo.json")


def test_basename():
assert _basename("pg/PaulGrahamEssays.json.gz") == "PaulGrahamEssays.json.gz"
assert _basename("trailing/") == "download"


@pytest.mark.parametrize(
"bad",
["", "/abs/path.json", "../escape.json", "a/../../b.json", "a/./b.json"],
)
def test_bundled_path_rejects_traversal(bad):
"""`local:` may only read normalized, package-relative paths — an absolute
path or a `..` segment that escapes the bundled `_data/` root is a hard
error, never a silently-resolved path."""
with pytest.raises(ValueError, match="package-relative"):
LocalHandler._bundled_path(bad)


def test_download_copies_to_basename(tmp_path):
"""Layout: <dest>/<dataset_name>/<basename>, copied from the bundled file."""
src = tmp_path / "bundled.json"
src.write_text("payload")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=False,
)
target = tmp_path / "pg" / "bundled.json"
assert target.read_text() == "payload"


def test_download_skips_when_target_exists(tmp_path):
src = tmp_path / "bundled.json"
src.write_text("fresh")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("cached")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=False,
)
assert (target_dir / "bundled.json").read_text() == "cached"


def test_download_force_recopies(tmp_path):
src = tmp_path / "bundled.json"
src.write_text("fresh")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("cached")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=True,
)
assert (target_dir / "bundled.json").read_text() == "fresh"


def test_is_downloaded(tmp_path):
h = LocalHandler()
assert not h.is_downloaded("local:pg/bundled.json", tmp_path, "pg")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("x")
assert h.is_downloaded("local:pg/bundled.json", tmp_path, "pg")
31 changes: 31 additions & 0 deletions tests/unit/datasets/downloaders/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_download_uses_per_phase_timeout(tmp_path):
mock_resp = MagicMock()
mock_resp.iter_bytes.return_value = [b"x"]
mock_resp.raise_for_status.return_value = None
mock_resp.num_bytes_downloaded = 1
mock_stream.return_value.__enter__.return_value = mock_resp
h.download("url:https://example.com/foo.csv", tmp_path, "foo", force=False)
timeout_arg = mock_stream.call_args.kwargs.get("timeout")
Expand All @@ -44,6 +45,7 @@ def test_download_writes_file_at_basename(tmp_path):
mock_resp.iter_bytes.return_value = [b"hello"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "5"}
mock_resp.num_bytes_downloaded = 5
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.csv",
Expand Down Expand Up @@ -81,6 +83,7 @@ def test_download_force_redownloads(tmp_path):
mock_resp.iter_bytes.return_value = [b"fresh"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "5"}
mock_resp.num_bytes_downloaded = 5
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.csv",
Expand All @@ -103,6 +106,7 @@ def test_download_rejects_truncated_stream(tmp_path):
mock_resp.iter_bytes.return_value = [b"abc"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "100"}
mock_resp.num_bytes_downloaded = 3 # connection died after 3 raw bytes
mock_stream.return_value.__enter__.return_value = mock_resp
with pytest.raises(RuntimeError, match="truncated download"):
h.download(
Expand All @@ -116,6 +120,33 @@ def test_download_rejects_truncated_stream(tmp_path):
assert not (tmp_path / "foo" / "foo.csv.partial").exists()


def test_download_accepts_compressed_response(tmp_path):
"""Server sends Content-Encoding (e.g. gzip): Content-Length is the
compressed size while iter_bytes() yields the larger decompressed body.
The check must compare on-the-wire bytes (num_bytes_downloaded) against
Content-Length, not the decompressed bytes written — otherwise every
compressed download falsely trips the truncation guard.

Regression: SQuAD's train-v2.0.json is gzip-served; Content-Length=9551051
but the decoded body is ~42MB, which the old written-bytes check rejected."""
h = URLHandler()
with patch("sieval.datasets.downloaders.url.httpx.stream") as mock_stream:
mock_resp = MagicMock()
# Decompressed body is far larger than the compressed Content-Length.
mock_resp.iter_bytes.return_value = [b"x" * 42, b"y" * 42]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "9", "content-encoding": "gzip"}
mock_resp.num_bytes_downloaded = 9 # raw compressed bytes on the wire
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.json",
dest_root=tmp_path,
dataset_name="foo",
force=False,
)
assert (tmp_path / "foo" / "foo.json").read_bytes() == b"x" * 42 + b"y" * 42


def test_download_accepts_missing_content_length(tmp_path):
"""Chunked transfer-encoded responses often omit Content-Length. The
truncation check must degrade gracefully rather than reject every
Expand Down
Loading