Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
9b25914
Merge branch 'main' of github.com:mosaicml/streaming
knighton Dec 25, 2023
7867b10
Move epoch_size arg.
knighton Dec 26, 2023
afe835a
Move allow_unsafe_types arg.
knighton Dec 26, 2023
010c613
Fix usage.
knighton Dec 26, 2023
2f4875b
Propagate allow_unsafe_types as a normal Stream argument.
knighton Dec 26, 2023
333605e
Explicit list the kwargs for Stream.apply_defaults().
knighton Dec 26, 2023
2d8a905
Tweak docstrings.
knighton Dec 26, 2023
d298479
Complete rewrite of local dir collision detection using regular files.
knighton Dec 26, 2023
eff80e6
Add psutil.
knighton Dec 26, 2023
8fb9dca
Fix.
knighton Dec 26, 2023
223f3ca
Fix.
knighton Dec 26, 2023
664cbc1
Fix.
knighton Dec 26, 2023
02416d7
Fix.
knighton Dec 26, 2023
b0b3b56
Fix.
knighton Dec 26, 2023
23505c2
Fix.
knighton Dec 26, 2023
67b936f
Fix.
knighton Dec 26, 2023
fa14130
Fix.
knighton Dec 26, 2023
c54887b
Fix.
knighton Dec 26, 2023
c214589
Fix.
knighton Dec 26, 2023
c0c82bd
Remove dist from StreamingDataset init.
knighton Dec 26, 2023
105ee16
Sleep first out of race paranoia.
knighton Dec 26, 2023
afafcd8
Fix.
knighton Dec 26, 2023
1012b21
Organize world, job/, and shmem/ into streaming/base/coord/.
knighton Dec 30, 2023
0eb3327
Fix.
knighton Dec 30, 2023
ea7c7ad
Fix.
knighton Dec 30, 2023
5267bbb
Fix.
knighton Dec 30, 2023
e8ad400
Merge branch 'main' into james/nodist2
knighton Dec 30, 2023
6581886
Merge branch 'james/nodist2' of github.com:mosaicml/streaming into ja…
knighton Dec 30, 2023
74646a2
Keep around the lazily initialized FileLock.
knighton Dec 30, 2023
6b3783f
What if it's the filelock?
knighton Dec 30, 2023
94d4136
Handle case where a process dies while holding a soft file lock.
knighton Dec 30, 2023
2d481af
Merge branch 'james/nodist2' of github.com:mosaicml/streaming into ja…
knighton Dec 30, 2023
f38e99f
Fix.
knighton Dec 30, 2023
9743cd6
Rewrite the homebrew soft file lock.
knighton Dec 30, 2023
b327858
Docstrings.
knighton Dec 30, 2023
0e47545
Switch all filelock.FileLock to streaming.base.coord.file.SoftFileLock.
knighton Dec 30, 2023
0658ab9
Fix.
knighton Dec 30, 2023
1d90d42
Rewrite StreamingDataset/SimulationDataset args handling to be rigorous.
knighton Dec 30, 2023
da9693f
MMap-based cross-process Array, Barrier, Buffer, Number.
knighton Dec 30, 2023
3cd5fc7
First attempt at replacing all SD shmem -> mmap.
knighton Dec 30, 2023
8b99425
Complete rewrite of all the mmap stuff
knighton Jan 2, 2024
a302c46
Merge branch 'main' into james/nodist2
knighton Jan 2, 2024
f60d2b8
Tweak.
knighton Jan 2, 2024
d88e7a2
Merge branch 'james/nodist2' of github.com:mosaicml/streaming into ja…
knighton Jan 2, 2024
85d4810
Fix.
knighton Jan 2, 2024
7d69d66
Fix.
knighton Jan 2, 2024
43e5a72
Fix.
knighton Jan 2, 2024
4fa5114
Fix.
knighton Jan 2, 2024
372e5eb
Fix.
knighton Jan 2, 2024
fd37096
Fix.
knighton Jan 2, 2024
f85bb6b
Add lock usage to wait_for_existence, wait_for_removal.
knighton Jan 3, 2024
e660c96
Generalize waiting, etc.
knighton Jan 3, 2024
61a5d42
Fix.
knighton Jan 3, 2024
507744e
Stop doing pytest in ten parts, as the file handle issue is now fixed…
knighton Jan 3, 2024
0e6d717
Fix (docstring).
knighton Jan 3, 2024
986bef8
Split file format functionality out of base.py into file.py
knighton Jan 3, 2024
82993ec
SD init barrier: cruft flag file -> proper MemMapBarrier.
knighton Jan 5, 2024
34efcf7
Misc.
knighton Jan 20, 2024
44df36e
Merge branch 'main' of github.com:mosaicml/streaming
knighton Jan 20, 2024
dbdd58e
Merge branch 'main' into james/nodist2
knighton Jan 20, 2024
15808f3
Update the new files' copyright years.
knighton Jan 20, 2024
32c3e5e
Benchmark separately.
knighton Jan 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,4 @@ jobs:
id: tests
run: |
set -ex
pytest --splits 10 --group 1 --cov-fail-under=10
pytest --splits 10 --group 2 --cov-fail-under=10
pytest --splits 10 --group 3 --cov-fail-under=10
pytest --splits 10 --group 4 --cov-fail-under=10
pytest --splits 10 --group 5 --cov-fail-under=10
pytest --splits 10 --group 6 --cov-fail-under=10
pytest --splits 10 --group 7 --cov-fail-under=10
pytest --splits 10 --group 8 --cov-fail-under=10
pytest --splits 10 --group 9 --cov-fail-under=10
pytest --splits 10 --group 10 --cov-fail-under=10
pytest --cov-fail-under 50
3 changes: 1 addition & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,13 @@ def _modules_to_rst() -> List[types.ModuleType]:
document_modules: List[types.Module] = [
streaming,
streaming.base.compression,
streaming.base.coord,
streaming.base.format,
streaming.base.hashing,
streaming.base.partition,
streaming.base.shared,
streaming.base.shuffle,
streaming.base.storage,
streaming.base.util,
streaming.base.world,
]
exclude_modules: List[types.Module] = [streaming.base, streaming._version]
for name in streaming.__dict__:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'azure-storage-blob>=12.0.0,<13',
'azure-storage-file-datalake>=12.11.0,<13',
'azure-identity>=1.13.0',
'psutil>=5.9.4',
]

extra_deps = {}
Expand Down
204 changes: 85 additions & 119 deletions simulation/core/sim_dataset.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion simulation/core/sim_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Contains info about the nodes, ranks, and workers of the run for simulation purposes."""

from streaming.base.world import World
from streaming.base.coord.world import World


class SimulationWorld(World):
Expand Down
30 changes: 24 additions & 6 deletions simulation/core/yaml_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,29 @@ def create_simulation_dataset(nodes: int, devices: int, workers: int, global_bat
sampling_granularity = train_dataset.get('sampling_granularity', 1)
batching_method = train_dataset.get('batching_method', 'random')

dataset = SimulationDataset(nodes, devices, workers, streams, remote, local, split,
download_retry, download_timeout, validate_hash, keep_zip,
epoch_size, predownload, cache_limit, partition_algo,
num_canonical_nodes, batch_size, shuffle, shuffle_algo,
shuffle_seed, shuffle_block_size, sampling_method,
sampling_granularity, batching_method)
dataset = SimulationDataset(nodes=nodes,
devices=devices,
workers=workers,
streams=streams,
remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size,
sampling_method=sampling_method,
sampling_granularity=sampling_granularity,
batching_method=batching_method)

return dataset
2 changes: 1 addition & 1 deletion streaming/base/batching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from streaming.base.batching.per_stream import generate_work_per_stream_batching
from streaming.base.batching.random import generate_work_random_batching
from streaming.base.batching.stratified import generate_work_stratified_batching
from streaming.base.world import World
from streaming.base.coord.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset
Expand Down
5 changes: 1 addition & 4 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import numpy as np
from numpy.typing import NDArray

from streaming.base.coord.world import World
from streaming.base.partition import get_partitions
from streaming.base.shuffle import get_shuffle
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset
Expand Down Expand Up @@ -63,9 +63,6 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
5 changes: 1 addition & 4 deletions streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import numpy as np
from numpy.typing import NDArray

from streaming.base.coord.world import World
from streaming.base.partition import get_partitions
from streaming.base.shuffle import get_shuffle
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset
Expand Down Expand Up @@ -58,9 +58,6 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)
Expand Down
5 changes: 1 addition & 4 deletions streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import numpy as np
from numpy.typing import NDArray

from streaming.base.coord.world import World
from streaming.base.partition import get_partitions
from streaming.base.shuffle import get_shuffle
from streaming.base.world import World

if TYPE_CHECKING:
from streaming.base.dataset import StreamingDataset
Expand Down Expand Up @@ -75,9 +75,6 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
14 changes: 14 additions & 0 deletions streaming/base/coord/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Coordination among ranks and workers."""

from streaming.base.coord.job import JobDirectory, JobRegistry
from streaming.base.coord.shmem import (SharedArray, SharedBarrier, SharedMemory, SharedScalar,
get_shm_prefix)
from streaming.base.coord.world import World

__all__ = [
'JobDirectory', 'JobRegistry', 'SharedArray', 'SharedBarrier', 'SharedMemory',
'get_shm_prefix', 'SharedScalar', 'World'
]
9 changes: 9 additions & 0 deletions streaming/base/coord/file/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Coordinating processes using files."""

from streaming.base.coord.file.lock import SoftFileLock
from streaming.base.coord.file.waiting import create_file, wait_for_creation, wait_for_deletion

__all__ = ['create_file', 'wait_for_creation', 'wait_for_deletion', 'SoftFileLock']
184 changes: 184 additions & 0 deletions streaming/base/coord/file/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Soft file locking via file open mode 'x'."""

import os
from types import TracebackType
from typing import Optional, Type, Union

from typing_extensions import Self

from streaming.base.coord.process import get_live_processes
from streaming.base.coord.waiting import wait

__all__ = ['SoftFileLock']


class SoftFileLock:
"""Soft file locking via file open mode 'x'.

Args:
filename (str): Path to lock.
timeout (float, optional): How long to wait in seconds before raising an exception.
Set to ``None`` to never time out. Defaults to ``30``.
tick (float): Check interval in seconds. Defaults to ``0.007``.
"""

def __init__(
self,
filename: str,
timeout: Optional[float] = 30,
tick: float = 0.007,
) -> None:
if not filename:
raise ValueError('Path to file lock is empty.')

if timeout is not None:
if timeout <= 0:
raise ValueError(
f'Timeout must be positive float seconds, but got: {timeout} sec.')

if tick <= 0:
raise ValueError(f'Tick must be positive float seconds, but got: {tick} sec.')

self.filename = filename
self.timeout = timeout
self.tick = tick

self._normalize(filename)

@classmethod
def _write(cls, filename: str, pid: int) -> None:
"""Write the locking process's pid.

Args:
filename (str): Path to lock.
"""
with open(filename, 'x') as file:
file.write(str(pid))

@classmethod
def _read(cls, filename: str) -> int:
"""Read the locking process's pid.

Args:
filename (str): Path to lock.
"""
with open(filename, 'r') as file:
return int(file.read())

@classmethod
def _normalize(cls, filename: str) -> None:
"""Ensure parent dirs exist and lock files held by dead processes do not exist.

Args:
filename (str): Path to lock.
"""
# Ensure the file's parent directory exists so we can write it in one shot.
dirname = os.path.dirname(filename)
if dirname:
os.makedirs(dirname, exist_ok=True)

# If no file, we don't need to do anything.
if not os.path.exists(filename):
return

# If we fail to open the file and parse the pid, bail out while deleting it.
try:
pid = cls._read(filename)
except:
os.remove(filename)
return

# If the pid is not among the living, delete the file.
if pid not in get_live_processes():
os.remove(filename)

@classmethod
def _get_timeout(
cls,
init_timeout: Optional[float],
timeout: Optional[Union[str, float]] = 'auto',
) -> Optional[float]:
"""Determine the timeout for a given acquire().

Args:
init_timeout (float, optional): Default timeout provided to init.
timeout (str | float, optional): Override timeout for just this method call.

Returns:
float, optional: Normalized timeout as positive float seconds or ``None`` to disable.
"""
if timeout is None:
# No timeout.
ret = timeout
elif isinstance(timeout, float):
# Override timeout.
if timeout <= 0:
raise ValueError(
f'Timeout must be positive float seconds, but got: {timeout} sec.')
ret = timeout
elif timeout == 'auto':
# Default timeout.
ret = init_timeout
else:
raise ValueError(f'Timeout must either be positive float seconds, ``None`` to ' +
f'disable timing out, or ``auto`` to use the default passed to ' +
f'init, but got: {timeout}.')
return ret

def acquire(
self,
timeout: Optional[Union[str, float]] = 'auto',
) -> None:
"""Acquire this lock.

Args:
timeout (str | float, optional): Override timeout for just this method call.
"""

def stop() -> bool:
try:
with open(self.filename, 'x') as out:
text = str(os.getpid())
out.write(text)
return True
except:
return False

norm_timeout = self._get_timeout(self.timeout, timeout)
wait(stop, norm_timeout, self.tick)

def release(self) -> None:
"""Release this lock."""
if os.path.isfile(self.filename):
os.remove(self.filename)
elif os.path.exists(self.filename):
raise ValueError(f'Path exists, but is not a file: {self.filename}.')
else:
raise ValueError(f'Path does not exist: {self.filename}.')

def __enter__(self) -> Self:
"""Enter context manager.

Returns:
Self: This lock.
"""
self.acquire()
return self

def __exit__(
self,
err_type: Optional[Type[BaseException]] = None,
err: Optional[BaseException] = None,
trace: Optional[TracebackType] = None,
) -> None:
"""Exit context manager.

Args:
err_type (Type[BaseException], optional): Exc type.
err (BaseException, optional): Exc.
trace (TracebackType, optional): Traceback.
"""
self.release()
Loading