Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
)


# HF `datasets` IterableDatasets created via `.skip()`/`.take()` (and similar) forbid
# reshuffling their data sources between epochs; Accelerate sets a nonzero epoch on the
# dataset before each pass, which makes such datasets raise this at iteration time. Import
# it defensively (empty tuple when unavailable) so the dataloaders below can recover.
try:
from datasets.iterable_dataset import DataSourcesShufflingDisallowed
except ImportError:
DataSourcesShufflingDisallowed = ()


logger = get_logger(__name__)

# kwargs of the DataLoader in min version 2.0
Expand Down Expand Up @@ -583,6 +593,18 @@ def __iter__(self):
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
except DataSourcesShufflingDisallowed:
# The wrapped HF `datasets` IterableDataset forbids reshuffling its data sources
# between epochs (e.g. it was built via `.skip()`/`.take()`). Reset its epoch so it
# can still be iterated; per-epoch source reshuffling is unavailable for such datasets.
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(0)
dataloader_iter = self.base_dataloader.__iter__()
try:
current_batch = next(dataloader_iter)
except StopIteration:
self.end()
return
except StopIteration:
self.end()
return
Expand Down Expand Up @@ -881,7 +903,16 @@ def __iter__(self):
stop_iteration = False
self._stop_iteration = False
first_batch = None
next_batch, next_batch_info = self._fetch_batches(main_iterator)
try:
next_batch, next_batch_info = self._fetch_batches(main_iterator)
except DataSourcesShufflingDisallowed:
# The wrapped HF `datasets` IterableDataset forbids reshuffling its data sources
# between epochs (e.g. it was built via `.skip()`/`.take()`). Reset its epoch so it
# can still be iterated; per-epoch source reshuffling is unavailable for such datasets.
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(0)
main_iterator = self.base_dataloader.__iter__()
next_batch, next_batch_info = self._fetch_batches(main_iterator)
batch_index = 0
while not stop_iteration:
batch, batch_info = next_batch, next_batch_info
Expand Down
17 changes: 17 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,23 @@ def collate_fn(features):
assert isinstance(d["tensor"], torch.Tensor)
assert d["non_tensor"] == "non_tensor_value"

@require_datasets
def test_iterable_dataset_blocked_source_shuffling_multiple_epochs(self):
# Regression test for #4080: an HF `datasets` IterableDataset built via `.skip()`/`.take()`
# forbids reshuffling its data sources between epochs. Accelerate sets a nonzero epoch on the
# dataset before each pass, which previously raised `DataSourcesShufflingDisallowed` on the
# second epoch. The dataloader should reset the dataset epoch and keep iterating instead.
import datasets

expected = [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]]
for dispatch_batches in (True, False):
dataset = datasets.Dataset.from_dict({"a": list(range(20))}).to_iterable_dataset().skip(10)
dataloader = prepare_data_loader(
DataLoader(dataset, batch_size=2), dispatch_batches=dispatch_batches, put_on_device=True
)
for _ in range(2): # second epoch must not raise DataSourcesShufflingDisallowed
assert [batch["a"].tolist() for batch in dataloader] == expected

@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
def test_reproducibility(self, num_processes):
set_seed(21)
Expand Down