Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Here's a basic quickstart example of how to partition the MNIST dataset:

```
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioners import IidPartitioner
from flwr_datasets.partitioner import IidPartitioner

# The train split of the MNIST dataset will be partitioned into 100 partitions
partitioner = IidPartitioner(num_partitions=100)
Expand Down
11 changes: 11 additions & 0 deletions datasets/docs/source/how-to-use-with-local-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ for that, e.g.:
iid_partitioner_for_cifar = IidPartitioner(num_partitions=10)
iid_partitioner_for_cifar.dataset = cifar_dataset

``IidPartitioner`` preserves contiguous dataset order by default for backwards
compatibility. If your local dataset is sorted by label or another target column,
enable shuffling before sharding:

.. code-block:: python

from flwr_datasets.partitioner import IidPartitioner

iid_partitioner = IidPartitioner(num_partitions=10, shuffle=True, seed=42)
iid_partitioner.dataset = sorted_dataset


More Resources
--------------
Expand Down
28 changes: 25 additions & 3 deletions datasets/flwr_datasets/partitioner/iid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@


class IidPartitioner(Partitioner):
"""Partitioner creates each partition sampled randomly from the dataset.
"""Partitioner creates IID partitions from a dataset.

By default, partitions are contiguous shards of the dataset. Set ``shuffle=True``
to shuffle the dataset once before sharding. This is useful for local datasets
sorted by class or another target column.

Parameters
----------
num_partitions : int
The total number of partitions that the data will be divided into.
shuffle : bool
Whether to shuffle the dataset before sharding. The default is ``False``.
seed : Optional[int]
Seed used for dataset shuffling when ``shuffle`` is set to ``True``.

Examples
--------
Expand All @@ -37,11 +45,16 @@ class IidPartitioner(Partitioner):
>>> partition = fds.load_partition(0)
"""

def __init__(self, num_partitions: int) -> None:
def __init__(
self, num_partitions: int, shuffle: bool = False, seed: int | None = 42
) -> None:
super().__init__()
if num_partitions <= 0:
raise ValueError("The number of partitions must be greater than zero.")
self._num_partitions = num_partitions
self._shuffle = shuffle
self._seed = seed
self._shuffled_dataset: datasets.Dataset | None = None

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a single IID partition based on the partition index.
Expand All @@ -56,11 +69,20 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
dataset_partition : Dataset
single dataset partition
"""
return self.dataset.shard(
dataset = self._dataset_to_partition()
return dataset.shard(
num_shards=self._num_partitions, index=partition_id, contiguous=True
)

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
return self._num_partitions

def _dataset_to_partition(self) -> datasets.Dataset:
"""Return the dataset used for sharding."""
if not self._shuffle:
return self.dataset
if self._shuffled_dataset is None:
self._shuffled_dataset = self.dataset.shuffle(seed=self._seed)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we have two copies of the dataset?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Dataset.shuffle(...) returns another Hugging Face Dataset object with shuffled indices/cache metadata rather than eagerly duplicating all row data. So this keeps a second dataset object around, but it should not be a full in-memory copy of the underlying dataset. The cache here is intentional so repeated load_partition calls use the same shuffled order, especially when seed=None.

return self._shuffled_dataset
71 changes: 71 additions & 0 deletions datasets/flwr_datasets/partitioner/iid_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import unittest
from collections import Counter

from parameterized import parameterized

Expand Down Expand Up @@ -111,6 +112,76 @@ def test_load_partition_correct_data(
dataset[partition_size * partition_index + row_id]["features"],
)

def test_default_partitioning_preserves_contiguous_order(self) -> None:
"""Test that default partitioning preserves historical contiguous sharding."""
data = {"features": list(range(10)), "labels": [0] * 5 + [1] * 5}
dataset = Dataset.from_dict(data)
partitioner = IidPartitioner(num_partitions=2)
partitioner.dataset = dataset

partition = partitioner.load_partition(0)

self.assertEqual(partition["features"], list(range(5)))
self.assertEqual(Counter(partition["labels"]), Counter({0: 5}))

def test_shuffle_mixes_sorted_dataset(self) -> None:
"""Test that shuffling prevents sorted labels from becoming single-label shards."""
data = {"features": list(range(200)), "labels": [0] * 100 + [1] * 100}
dataset = Dataset.from_dict(data)
partitioner = IidPartitioner(num_partitions=2, shuffle=True, seed=42)
partitioner.dataset = dataset

first_partition = partitioner.load_partition(0)
second_partition = partitioner.load_partition(1)

self.assertGreater(Counter(first_partition["labels"])[0], 0)
self.assertGreater(Counter(first_partition["labels"])[1], 0)
self.assertGreater(Counter(second_partition["labels"])[0], 0)
self.assertGreater(Counter(second_partition["labels"])[1], 0)

def test_shuffle_with_same_seed_is_deterministic(self) -> None:
"""Test that the same seed produces the same shuffled partition."""
dataset = Dataset.from_dict(
{"features": list(range(100)), "labels": [idx % 2 for idx in range(100)]}
)
first_partitioner = IidPartitioner(num_partitions=5, shuffle=True, seed=42)
second_partitioner = IidPartitioner(num_partitions=5, shuffle=True, seed=42)
first_partitioner.dataset = dataset
second_partitioner.dataset = dataset

self.assertEqual(
first_partitioner.load_partition(2)["features"],
second_partitioner.load_partition(2)["features"],
)

def test_shuffle_with_different_seed_changes_order(self) -> None:
"""Test that different seeds produce different shuffled partitions."""
dataset = Dataset.from_dict(
{"features": list(range(100)), "labels": [idx % 2 for idx in range(100)]}
)
first_partitioner = IidPartitioner(num_partitions=5, shuffle=True, seed=42)
second_partitioner = IidPartitioner(num_partitions=5, shuffle=True, seed=43)
first_partitioner.dataset = dataset
second_partitioner.dataset = dataset

self.assertNotEqual(
first_partitioner.load_partition(2)["features"],
second_partitioner.load_partition(2)["features"],
)

def test_shuffle_with_no_seed_is_stable_after_first_load(self) -> None:
"""Test that seedless shuffling is cached within one partitioner instance."""
dataset = Dataset.from_dict(
{"features": list(range(100)), "labels": [idx % 2 for idx in range(100)]}
)
partitioner = IidPartitioner(num_partitions=5, shuffle=True, seed=None)
partitioner.dataset = dataset

first_load = partitioner.load_partition(2)["features"]
second_load = partitioner.load_partition(2)["features"]

self.assertEqual(first_load, second_load)

@parameterized.expand( # type: ignore
[
# num_partitions, num_rows
Expand Down
Loading