Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions sieval/core/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _validate(
f"{_VALID_SCHEMES} (e.g. 'hf:org/name')"
)

_validate_url_basenames_unique(name, source)
_validate_staged_basenames_unique(name, source)


def url_path_basename(url: str) -> str:
Expand All @@ -187,20 +187,27 @@ def url_path_basename(url: str) -> str:
return urlparse(url).path.rsplit("/", 1)[-1]


def _validate_url_basenames_unique(name: str, source: tuple[str, ...]) -> None:
"""Reject duplicate basenames among url: sources in one dataset — two URLs
sharing a basename would overwrite each other at ``<dest>/<name>/<basename>``."""
_STAGED_SCHEMES = ("url:", "local:")


def _validate_staged_basenames_unique(name: str, source: tuple[str, ...]) -> None:
"""Reject duplicate basenames among sources that stage to a flat file in one
dataset — both url: and local: land at ``<dest>/<name>/<basename>``, so two
sources (url/url, local/local, or url/local) sharing a basename would
silently overwrite each other."""
basenames = [
url_path_basename(src[len("url:") :])
url_path_basename(src[len(scheme) :])
for src in source
if src.startswith("url:")
for scheme in _STAGED_SCHEMES
if src.startswith(scheme)
]
counter = Counter(basenames)
duplicates = {b for b, count in counter.items() if count > 1}
if duplicates:
raise ValueError(
f"url: sources in dataset {name!r} have colliding basenames: "
f"{sorted(duplicates)}; each URL must produce a unique on-disk filename"
f"url:/local: sources in dataset {name!r} have colliding basenames: "
f"{sorted(duplicates)}; each staged source must produce a unique "
f"on-disk filename"
)


Expand Down
5 changes: 3 additions & 2 deletions sieval/datasets/downloaders/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""SourceHandler Protocol + scheme registry. v0.1 ships ``hf:`` and ``url:``.
"""SourceHandler Protocol + scheme registry. v0.2 ships ``hf:``, ``url:``, ``local:``.

AI-Generated Code - Claude Haiku 4.5 (Anthropic)
"""
Expand Down Expand Up @@ -40,9 +40,10 @@ def _ensure_builtin_handlers() -> None:
if _builtin_registered:
return
from sieval.datasets.downloaders.hf import HFHandler
from sieval.datasets.downloaders.local import LocalHandler
from sieval.datasets.downloaders.url import URLHandler

for handler in (HFHandler(), URLHandler()):
for handler in (HFHandler(), URLHandler(), LocalHandler()):
if handler.scheme not in _HANDLERS:
register_handler(handler)
_builtin_registered = True
Expand Down
87 changes: 87 additions & 0 deletions sieval/datasets/downloaders/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""local scheme handler: stage a package-bundled file into ``dest_root/<name>/``.

For datasets whose corpus is generated once and committed inside the package
(under ``sieval/datasets/_data/``) rather than fetched from a remote. ``download``
copies the bundled file into the same ``{dest_root}/<dataset_name>/`` layout the
url/hf handlers use, so the runtime ``load(name_or_path)`` path is identical.

AI-Generated Code - Claude Opus 4.8 (1M context) (Anthropic)
"""

import shutil
from importlib.resources import files
from pathlib import Path
from posixpath import normpath

from sieval.core.datasets.meta import url_path_basename

# Bundled-data root inside the package; `local:<relpath>` resolves under here.
_DATA_ANCHOR = "sieval.datasets"
_DATA_SUBDIR = "_data"


class LocalHandler:
scheme = "local"

def download(
self,
source: str,
dest_root: Path,
dataset_name: str,
force: bool,
) -> None:
relpath = self._strip_scheme(source)
bundled = self._bundled_path(relpath)
target_dir = dest_root / dataset_name
target_dir.mkdir(parents=True, exist_ok=True)
target = target_dir / _basename(relpath)
if target.exists() and not force:
return
tmp = target.with_name(target.name + ".partial")
try:
shutil.copyfile(bundled, tmp)
tmp.replace(target)
except BaseException:
tmp.unlink(missing_ok=True)
raise

def is_downloaded(
self,
source: str,
dest_root: Path,
dataset_name: str,
) -> bool:
relpath = self._strip_scheme(source)
return (dest_root / dataset_name / _basename(relpath)).exists()

@staticmethod
def _strip_scheme(source: str) -> str:
if not source.startswith("local:"):
raise ValueError(f"Expected local: scheme, got {source!r}")
return source[len("local:") :]

@staticmethod
def _bundled_path(relpath: str) -> Path:
"""Resolve *relpath* under the package data root, rejecting traversal.

``local:`` must only ever read files committed inside the package, so an
absolute path or a ``..`` segment that would escape ``_data/`` is a hard
error rather than a silently-resolved path.
"""
if (
not relpath
or relpath.startswith("/")
or ".." in relpath.split("/")
or normpath(relpath) != relpath
):
raise ValueError(
f"local: path must be a normalized, package-relative path, "
f"got {relpath!r}"
)
return Path(str(files(_DATA_ANCHOR).joinpath(_DATA_SUBDIR, relpath)))


def _basename(relpath: str) -> str:
"""Filename the bundled file lands under; shares the url-handler primitive
so the on-disk name matches the ``url:`` convention."""
return url_path_basename(relpath) or "download"
11 changes: 7 additions & 4 deletions sieval/datasets/downloaders/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ def download(
with httpx.stream("GET", url, timeout=_TIMEOUT, follow_redirects=True) as r:
r.raise_for_status()
# Catches 2xx responses that dropped the connection mid-stream.
# Compare against on-the-wire bytes, not bytes written: when the
# server sends Content-Encoding (e.g. gzip), Content-Length is the
# compressed size while iter_bytes() yields the larger decompressed
# body, so comparing written bytes would falsely flag truncation.
expected = _parse_content_length(r.headers.get("content-length"))
written = 0
with tmp.open("wb") as f:
for chunk in r.iter_bytes(chunk_size=1 << 16):
f.write(chunk)
written += len(chunk)
if expected is not None and written != expected:
received = r.num_bytes_downloaded
if expected is not None and received != expected:
raise RuntimeError(
f"truncated download from {url}: "
f"Content-Length={expected} but wrote {written} bytes"
f"Content-Length={expected} but received {received} bytes"
)
tmp.replace(target)
except BaseException:
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/core/datasets/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,54 @@ def load(self, name_or_path, **kwargs):
raise NotImplementedError


def test_sieval_dataset_rejects_colliding_local_basenames():
"""Two local: sources in the same dataset with the same basename stage to
the same <dest>/<name>/<basename> and overwrite each other. Reject."""

class LocalClashSample(TypedDict):
x: str

with pytest.raises(ValueError, match="colliding basenames"):

@sieval_dataset(
name="local_clash_test",
display_name="Local Clash Test",
description="x",
source=(
"local:a/data.csv",
"local:b/data.csv",
),
categories=(Category(Level1Category.LOGIC, "BasicLogic"),),
)
class LocalClashDataset(Dataset[LocalClashSample]):
def load(self, name_or_path, **kwargs):
raise NotImplementedError


def test_sieval_dataset_rejects_url_local_basename_collision():
"""A url: and a local: source staging to the same basename overwrite each
other under <dest>/<name>/; the guard spans both staged schemes."""

class CrossSample(TypedDict):
x: str

with pytest.raises(ValueError, match="colliding basenames"):

@sieval_dataset(
name="cross_clash_test",
display_name="Cross Clash Test",
description="x",
source=(
"url:https://a.example.com/data.csv",
"local:cross_clash/data.csv",
),
categories=(Category(Level1Category.LOGIC, "BasicLogic"),),
)
class CrossClashDataset(Dataset[CrossSample]):
def load(self, name_or_path, **kwargs):
raise NotImplementedError


def test_sieval_dataset_accepts_different_url_basenames():
"""Different basenames should not trigger the validation."""

Expand Down
95 changes: 95 additions & 0 deletions tests/unit/datasets/downloaders/test_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for the local: source handler.

AI-Generated Code - Claude Opus 4.8 (1M context) (Anthropic)
"""

from unittest.mock import patch

import pytest

from sieval.datasets.downloaders.local import LocalHandler, _basename


def test_scheme():
assert LocalHandler().scheme == "local"


def test_strip_scheme_rejects_wrong_scheme():
with pytest.raises(ValueError, match="Expected local: scheme"):
LocalHandler._strip_scheme("url:https://example.com/foo.json")


def test_basename():
assert _basename("pg/PaulGrahamEssays.json.gz") == "PaulGrahamEssays.json.gz"
assert _basename("trailing/") == "download"


@pytest.mark.parametrize(
"bad",
["", "/abs/path.json", "../escape.json", "a/../../b.json", "a/./b.json"],
)
def test_bundled_path_rejects_traversal(bad):
"""`local:` may only read normalized, package-relative paths — an absolute
path or a `..` segment that escapes the bundled `_data/` root is a hard
error, never a silently-resolved path."""
with pytest.raises(ValueError, match="package-relative"):
LocalHandler._bundled_path(bad)


def test_download_copies_to_basename(tmp_path):
"""Layout: <dest>/<dataset_name>/<basename>, copied from the bundled file."""
src = tmp_path / "bundled.json"
src.write_text("payload")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=False,
)
target = tmp_path / "pg" / "bundled.json"
assert target.read_text() == "payload"


def test_download_skips_when_target_exists(tmp_path):
src = tmp_path / "bundled.json"
src.write_text("fresh")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("cached")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=False,
)
assert (target_dir / "bundled.json").read_text() == "cached"


def test_download_force_recopies(tmp_path):
src = tmp_path / "bundled.json"
src.write_text("fresh")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("cached")
h = LocalHandler()
with patch.object(LocalHandler, "_bundled_path", return_value=src):
h.download(
"local:pg/bundled.json",
dest_root=tmp_path,
dataset_name="pg",
force=True,
)
assert (target_dir / "bundled.json").read_text() == "fresh"


def test_is_downloaded(tmp_path):
h = LocalHandler()
assert not h.is_downloaded("local:pg/bundled.json", tmp_path, "pg")
target_dir = tmp_path / "pg"
target_dir.mkdir()
(target_dir / "bundled.json").write_text("x")
assert h.is_downloaded("local:pg/bundled.json", tmp_path, "pg")
31 changes: 31 additions & 0 deletions tests/unit/datasets/downloaders/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_download_uses_per_phase_timeout(tmp_path):
mock_resp = MagicMock()
mock_resp.iter_bytes.return_value = [b"x"]
mock_resp.raise_for_status.return_value = None
mock_resp.num_bytes_downloaded = 1
mock_stream.return_value.__enter__.return_value = mock_resp
h.download("url:https://example.com/foo.csv", tmp_path, "foo", force=False)
timeout_arg = mock_stream.call_args.kwargs.get("timeout")
Expand All @@ -44,6 +45,7 @@ def test_download_writes_file_at_basename(tmp_path):
mock_resp.iter_bytes.return_value = [b"hello"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "5"}
mock_resp.num_bytes_downloaded = 5
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.csv",
Expand Down Expand Up @@ -81,6 +83,7 @@ def test_download_force_redownloads(tmp_path):
mock_resp.iter_bytes.return_value = [b"fresh"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "5"}
mock_resp.num_bytes_downloaded = 5
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.csv",
Expand All @@ -103,6 +106,7 @@ def test_download_rejects_truncated_stream(tmp_path):
mock_resp.iter_bytes.return_value = [b"abc"]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "100"}
mock_resp.num_bytes_downloaded = 3 # connection died after 3 raw bytes
mock_stream.return_value.__enter__.return_value = mock_resp
with pytest.raises(RuntimeError, match="truncated download"):
h.download(
Expand All @@ -116,6 +120,33 @@ def test_download_rejects_truncated_stream(tmp_path):
assert not (tmp_path / "foo" / "foo.csv.partial").exists()


def test_download_accepts_compressed_response(tmp_path):
"""Server sends Content-Encoding (e.g. gzip): Content-Length is the
compressed size while iter_bytes() yields the larger decompressed body.
The check must compare on-the-wire bytes (num_bytes_downloaded) against
Content-Length, not the decompressed bytes written — otherwise every
compressed download falsely trips the truncation guard.

Regression: SQuAD's train-v2.0.json is gzip-served; Content-Length=9551051
but the decoded body is ~42MB, which the old written-bytes check rejected."""
h = URLHandler()
with patch("sieval.datasets.downloaders.url.httpx.stream") as mock_stream:
mock_resp = MagicMock()
# Decompressed body is far larger than the compressed Content-Length.
mock_resp.iter_bytes.return_value = [b"x" * 42, b"y" * 42]
mock_resp.raise_for_status.return_value = None
mock_resp.headers = {"content-length": "9", "content-encoding": "gzip"}
mock_resp.num_bytes_downloaded = 9 # raw compressed bytes on the wire
mock_stream.return_value.__enter__.return_value = mock_resp
h.download(
"url:https://example.com/foo.json",
dest_root=tmp_path,
dataset_name="foo",
force=False,
)
assert (tmp_path / "foo" / "foo.json").read_bytes() == b"x" * 42 + b"y" * 42


def test_download_accepts_missing_content_length(tmp_path):
"""Chunked transfer-encoded responses often omit Content-Length. The
truncation check must degrade gracefully rather than reject every
Expand Down
Loading