Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ requires-python = ">=3.10"
dependencies = [
"tqdm",
"typing-extensions",
"backports.zstd; python_version < '3.14'",
]

[dependency-groups]
Expand Down
2 changes: 2 additions & 0 deletions src/pystow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ensure_open_sqlite_gz,
ensure_open_tarfile,
ensure_open_zip,
ensure_open_zstd,
ensure_pickle,
ensure_pickle_gz,
ensure_rdf,
Expand Down Expand Up @@ -80,6 +81,7 @@
"ensure_open_sqlite_gz",
"ensure_open_tarfile",
"ensure_open_zip",
"ensure_open_zstd",
"ensure_pickle",
"ensure_pickle_gz",
"ensure_rdf",
Expand Down
50 changes: 50 additions & 0 deletions src/pystow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io
import lzma
import sqlite3
import sys
import typing
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
Expand All @@ -18,6 +19,11 @@
from .impl import Module, VersionHint
from .utils.download import DownloadKwargs

if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd

if TYPE_CHECKING:
import bs4
import lxml.etree
Expand Down Expand Up @@ -50,6 +56,7 @@
"ensure_open_sqlite_gz",
"ensure_open_tarfile",
"ensure_open_zip",
"ensure_open_zstd",
"ensure_pickle",
"ensure_pickle_gz",
"ensure_rdf",
Expand Down Expand Up @@ -693,6 +700,49 @@ def ensure_open_lzma(
yield yv


@contextmanager
def ensure_open_zstd(
key: str,
*subkeys: str,
url: str,
name: str | None = None,
force: bool = False,
download_kwargs: DownloadKwargs | None = None,
mode: Literal["r", "rb", "w", "wb", "rt", "wt"] = "rt",
open_kwargs: Mapping[str, Any] | None = None,
) -> Generator[zstd.ZstdFile, None, None]:
"""Ensure a zstd-compressed file is downloaded and open a file inside it.

:param key: The name of the module. No funny characters. The envvar `<key>_HOME`
where key is uppercased is checked first before using the default home
directory.
:param subkeys: A sequence of additional strings to join. If none are given, returns
the directory for this module.
:param url: The URL to download.
:param name: Overrides the name of the file at the end of the URL, if given. Also
useful for URLs that don't have proper filenames with extensions.
:param force: Should the download be done again, even if the path already exists?
Defaults to false.
:param download_kwargs: Keyword arguments to pass through to
:func:`pystow.utils.download`.
:param mode: The read mode, passed to :func:`zstd.open`
:param open_kwargs: Additional keyword arguments passed to :func:`zstd.open`

:yields: An open file object
"""
_module = Module.from_key(key, ensure_exists=True)
with _module.ensure_open_zstd(
*subkeys,
url=url,
name=name,
force=force,
download_kwargs=download_kwargs,
mode=mode,
open_kwargs=open_kwargs,
) as yv:
yield yv


# docstr-coverage:excused `overload`
@overload
@contextmanager
Expand Down
41 changes: 41 additions & 0 deletions src/pystow/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import pickle
import sqlite3
import sys
import tarfile
import typing
from collections.abc import Callable, Generator, Mapping, Sequence
Expand Down Expand Up @@ -42,6 +43,11 @@
)
from .utils.download import DownloadKwargs

if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd

if TYPE_CHECKING:
import botocore.client
import bs4
Expand Down Expand Up @@ -734,6 +740,41 @@ def ensure_open_lzma(
with lzma.open(path, **open_kwargs) as file:
yield file

@contextmanager
def ensure_open_zstd(
self,
*subkeys: str,
url: str,
name: str | None = None,
force: bool = False,
download_kwargs: DownloadKwargs | None = None,
mode: Literal["r", "rb", "w", "wb", "rt", "wt"] = "rt",
open_kwargs: Mapping[str, Any] | None = None,
) -> Generator[zstd.ZstdFile, None, None]:
"""Ensure a zSTD-compressed file is downloaded and open a file inside it.

:param subkeys: A sequence of additional strings to join. If none are given,
returns the directory for this module.
:param url: The URL to download.
:param name: Overrides the name of the file at the end of the URL, if given.
Also useful for URLs that don't have proper filenames with extensions.
:param force: Should the download be done again, even if the path already
exists? Defaults to false.
:param download_kwargs: Keyword arguments to pass through to
:func:`pystow.utils.download`.
:param mode: The read mode, passed to :func:`zstd.open`
:param open_kwargs: Additional keyword arguments passed to :func:`zstd.open`

:yields: An open file object
"""
path = self.ensure(
*subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs
)
open_kwargs = {} if open_kwargs is None else dict(open_kwargs)
open_kwargs.setdefault("mode", mode)
with zstd.open(path, **open_kwargs) as file:
yield file

# docstr-coverage:excused `overload`
@overload
@contextmanager
Expand Down
22 changes: 21 additions & 1 deletion src/pystow/utils/safe_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from __future__ import annotations

import bz2
import contextlib
import csv
import gzip
import io
import json
import lzma
import sys
import typing
import urllib.request
import zipfile
Expand All @@ -27,6 +30,11 @@
ensure_sensible_newline,
)

if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd

__all__ = [
"is_url",
"open_inner_zipfile",
Expand Down Expand Up @@ -125,6 +133,15 @@ def safe_open( # noqa:C901
if path.suffix.endswith(".gz"):
with gzip.open(path, mode=mode, encoding=encoding, newline=newline) as file:
yield file # type:ignore
elif path.suffix.endswith(".bz2"):
with bz2.open(path, mode=mode, encoding=encoding, newline=newline) as file:
yield file # type:ignore
elif path.suffix.endswith(".xz"):
with lzma.open(path, mode=mode, encoding=encoding, newline=newline) as file:
yield file # type:ignore
elif path.suffix.endswith(".zst"):
with zstd.open(path, mode=mode, encoding=encoding, newline=newline) as file:
yield file # type:ignore
else:
with open(path, mode=mode, encoding=encoding, newline=newline) as file:
yield file # type:ignore
Expand All @@ -135,7 +152,10 @@ def safe_open( # noqa:C901
"must specify `text` representation when passing through a text file-like object"
)
yield path
elif isinstance(path, typing.BinaryIO | io.BufferedReader | gzip.GzipFile):

# io.BufferedIOBase covers the LZMA, BZ2, Gzip, and ZSTD file types
# as well as io.BufferedReader
elif isinstance(path, typing.BinaryIO | io.BufferedIOBase):
if representation != "binary":
raise ValueError(
"must specify `binary` representation when passing through "
Expand Down
Binary file added tests/resources/test.txt.bz2
Binary file not shown.
Binary file added tests/resources/test.txt.xz
Binary file not shown.
Binary file added tests/resources/test.txt.zst
Binary file not shown.
11 changes: 7 additions & 4 deletions tests/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,14 @@
TEST_TXT_CONTENT = "this is a test file\n"
TEST_TXT_MD5 = RESOURCES.joinpath("test.txt.md5")
TEST_TXT_GZ = RESOURCES.joinpath("test.txt.gz")
TEST_TXT_BZ2 = RESOURCES.joinpath("test.txt.bz2")
TEST_TXT_LZMA = RESOURCES.joinpath("test.txt.xz")
TEST_TXT_ZSTD = RESOURCES.joinpath("test.txt.zst")
TEST_TXT_VERBOSE_MD5 = RESOURCES.joinpath("test_verbose.txt.md5")
TEST_TXT_WRONG_MD5 = RESOURCES.joinpath("test_wrong.txt.md5")

TEST_PATHS = [TEST_TXT, TEST_TXT_GZ, TEST_TXT_BZ2, TEST_TXT_LZMA, TEST_TXT_ZSTD]


class _Session(requests.sessions.Session):
"""A mock session."""
Expand Down Expand Up @@ -401,7 +406,7 @@ def test_safe_open_exceptions(self) -> None:

def test_safe_open_binary(self) -> None:
"""Test safe open in binary mode."""
for path in [TEST_TXT, TEST_TXT_GZ]:
for path in TEST_PATHS:
with self.subTest(path=path):
with safe_open(path, representation="binary") as file:
self.assertEqual(
Expand All @@ -427,9 +432,7 @@ def test_safe_open_url_binary(self) -> None:

def test_safe_open_text(self) -> None:
"""Test safe open in text mode."""
for path, encoding, newline in itt.product(
[TEST_TXT, TEST_TXT_GZ], [None, "utf-8"], [None, "\n"]
):
for path, encoding, newline in itt.product(TEST_PATHS, [None, "utf-8"], [None, "\n"]):
with self.subTest(path=path, encoding=encoding, newline=newline):
with safe_open(
path, encoding=encoding, representation="text", newline=newline
Expand Down
Loading