Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions openff/utilities/_tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import threading

import pytest

Expand Down Expand Up @@ -90,9 +91,12 @@ def test_temporary_cd():

assert compare_paths(os.getcwd(), original_directory)

# Move to a temporary directory
with temporary_cd():
assert not compare_paths(os.getcwd(), original_directory)
# Move to a temporary directory: CWD is NOT changed in the no-argument form.
# The caller must use the yielded path for file operations.
with temporary_cd() as tmpdir:
assert compare_paths(os.getcwd(), original_directory)
assert os.path.isdir(tmpdir)
assert not compare_paths(tmpdir, original_directory)

assert compare_paths(os.getcwd(), original_directory)

Expand All @@ -108,6 +112,53 @@ def test_temporary_cd():
assert compare_paths(os.getcwd(), original_directory)


def test_temporary_cd_yields_path():
"""Tests that temporary_cd yields the absolute path of the directory."""

original_directory = os.getcwd()

# Auto temp dir: yielded path should be an absolute path different from original
with temporary_cd() as tmpdir:
assert os.path.isabs(tmpdir)
assert not compare_paths(tmpdir, original_directory)
# CWD is NOT changed in the no-argument form (thread-safe behaviour)
assert compare_paths(os.getcwd(), original_directory)

# Specific directory: yielded path should match the absolute path of the given dir
with temporary_cd(os.pardir) as tmpdir:
expected = os.path.abspath(os.path.join(original_directory, os.pardir))
assert compare_paths(tmpdir, expected)
assert compare_paths(os.getcwd(), expected)

# Empty string: yielded path should be the original directory
with temporary_cd("") as tmpdir:
assert compare_paths(tmpdir, original_directory)


def test_temporary_cd_thread_safety():
"""Tests that temporary_cd is thread-safe when using the yielded absolute path."""

errors: list[str] = []

def func(n: int) -> None:
with temporary_cd() as tmpdir:
file_path = os.path.join(tmpdir, "f.txt")
with open(file_path, "w") as f:
f.write(str(n))
with open(file_path) as f:
read_value = int(f.read())
if read_value != n:
errors.append(f"Expected {n} but found {read_value}")

threads = [threading.Thread(target=func, args=(i,)) for i in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()

assert not errors, f"Thread safety errors: {errors}"


def test_has_package():
assert has_package("os")
assert has_package("pytest")
Expand Down
48 changes: 29 additions & 19 deletions openff/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,39 +150,49 @@ def _is_executable(fpath: str) -> bool:


@contextmanager
def temporary_cd(directory_path: str | None = None) -> Generator[None, None, None]:
def temporary_cd(directory_path: str | None = None) -> Generator[str, None, None]:
"""Temporarily move the current working directory to the path
specified. If no path is given, a temporary directory will be
created, moved into, and then destroyed when the context manager
is closed.
created and its path yielded; the directory is destroyed when the
context manager is closed.

When ``directory_path`` is ``None`` (the default), ``os.chdir`` is
**not** called, making this form safe to use from multiple threads
concurrently. Use the yielded absolute path for all file
operations inside the block (e.g.
``open(os.path.join(tmpdir, "file"))``).

When an explicit ``directory_path`` is given, ``os.chdir`` is still
called for backward compatibility. Note that ``os.chdir`` is a
process-wide operation and is therefore **not** thread-safe.

Parameters
----------
directory_path: str, optional

Returns
-------
Yields
------
str
The absolute path of the directory.

"""

if directory_path is not None and len(directory_path) == 0:
yield
yield os.getcwd()
return

old_directory = os.getcwd()
if directory_path is None:
with TemporaryDirectory() as new_directory:
yield new_directory

try:
if directory_path is None:
with TemporaryDirectory() as new_directory:
os.chdir(new_directory)
yield

else:
os.chdir(directory_path)
yield

finally:
os.chdir(old_directory)
else:
abs_directory = os.path.abspath(directory_path)
old_directory = os.getcwd()
try:
os.chdir(abs_directory)
yield abs_directory
finally:
os.chdir(old_directory)


def get_data_dir_path(relative_path: str, package_name: str) -> str:
Expand Down