Skip to content

Fix AttributeError when iterating IterableDatasetShard without set_epoch#4098

Open
vineethsaivs wants to merge 1 commit into
huggingface:mainfrom
vineethsaivs:iterable-dataset-shard-epoch-init
Open

Fix AttributeError when iterating IterableDatasetShard without set_epoch#4098
vineethsaivs wants to merge 1 commit into
huggingface:mainfrom
vineethsaivs:iterable-dataset-shard-epoch-init

Conversation

@vineethsaivs

Copy link
Copy Markdown

What does this PR do?

Iterating an IterableDatasetShard that wraps a dataset carrying a torch.Generator (and no set_epoch method) raises unless the caller happens to call set_epoch first:

class GeneratorIterableDataset(IterableDataset):
    def __init__(self):
        self.generator = torch.Generator()
    def __iter__(self):
        yield from torch.randperm(8, generator=self.generator).tolist()

shard = IterableDatasetShard(GeneratorIterableDataset(), batch_size=2)
list(shard)  # AttributeError: 'IterableDatasetShard' object has no attribute 'epoch'

Root cause: __init__ never initializes self.epoch; the only assignment lives in set_epoch. But __iter__'s first branch reads self.epoch to seed the generator for precisely this dataset shape (hasattr(dataset, "generator") and no set_epoch). Accelerate's own managed path calls set_epoch before iterating, which is why the crash only bites direct users of the class, a usage pattern the existing tests treat as supported (check_iterable_dataset_shards iterates shards directly; it only survives because its test dataset has no generator attribute).

Fix: initialize self.epoch = 0 in __init__, matching SeedableRandomSampler in the same module (and transformers' IterableDatasetShard). Direct iteration now seeds with the default epoch 0; set_epoch semantics are unchanged.

Test: test_iterable_dataset_shard_without_set_epoch reproduces the exact shape (generator-carrying dataset, no set_epoch call): fails before with the AttributeError at data_loader.py:346, passes after; also asserts the default seeding is deterministic. Full tests/test_data_loader.py: 25 passed, 13 skipped. ruff check / ruff format --check clean on both files.

Before submitting

IterableDatasetShard.__init__ never initialized self.epoch, but __iter__
reads it to seed the underlying dataset's generator whenever the dataset
has a torch.Generator attribute and no set_epoch method, which is exactly
the dataset shape that branch was written for. Iterating a fresh shard
without first calling set_epoch therefore raised
AttributeError: 'IterableDatasetShard' object has no attribute 'epoch'.

Initialize epoch to 0, matching SeedableRandomSampler in the same module,
so direct iteration seeds with the default epoch and set_epoch semantics
stay unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant