Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/custom-mods/custom_mods/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_centralized_dataset():
"""Load test set and return dataloader."""
# Load entire test set
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
dataset = test_dataset.with_format("torch").with_transform(apply_transforms)
dataset = test_dataset.with_transform(apply_transforms)
return DataLoader(dataset, batch_size=128)


Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-pytorch/pytorchexample/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_centralized_dataset():
"""Load test set and return dataloader."""
# Load entire test set
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
dataset = test_dataset.with_format("torch").with_transform(apply_transforms)
dataset = test_dataset.with_transform(apply_transforms)
return DataLoader(dataset, batch_size=128)


Expand Down
12 changes: 8 additions & 4 deletions examples/whisper-federated-finetuning/centralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from torch.utils.data import DataLoader
from transformers import WhisperProcessor

from whisper_example.dataset import get_encoding_fn, prepare_silences_dataset
from whisper_example.dataset import (
get_encoding_fn,
prepare_silences_dataset,
with_torch_transform,
)
from whisper_example.model import (
construct_balanced_sampler,
eval_model,
Expand Down Expand Up @@ -80,13 +84,13 @@ def main():
sampler = construct_balanced_sampler(full_train_dataset)

# Prepare dataloaders
train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"])
train_dataset = with_torch_transform(full_train_dataset)
train_loader = DataLoader(
train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler
)
val_encoded = val_encoded.with_format("torch", columns=["data", "targets"])
val_encoded = with_torch_transform(val_encoded)
val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4)
test_dataset = test_encoded.with_format("torch", columns=["data", "targets"])
test_dataset = with_torch_transform(test_encoded)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4)

# Model to cuda, set criterion, classification layer to train and optimiser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flwr.clientapp import ClientApp
from torch.utils.data import DataLoader

from whisper_example.dataset import load_data
from whisper_example.dataset import load_data, with_torch_transform
from whisper_example.model import construct_balanced_sampler, get_model, train_one_epoch

torch.set_float32_matmul_precision(
Expand Down Expand Up @@ -58,13 +58,13 @@ def train(msg: Message, context: Context):
partition_id=partition_id,
remove_cols=context.run_config["remove-cols"],
)
trainset = partition.with_format("torch", columns=["data", "targets"])
torch.set_num_threads(og_threads)

# construct sampler in order to have balanced batches
sampler = None
if len(trainset) > batch_size:
sampler = construct_balanced_sampler(trainset)
if len(partition) > batch_size:
sampler = construct_balanced_sampler(partition)
trainset = with_torch_transform(partition)

# Construct dataloader
train_loader = DataLoader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import random

import torch
from datasets import Dataset, concatenate_datasets, load_from_disk
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import GroupedNaturalIdPartitioner
Expand Down Expand Up @@ -53,6 +54,20 @@ def load_data_from_disk(data_path):
return load_from_disk(data_path)


def _apply_torch_transform(batch):
"""Convert encoded columns to torch tensors."""
if "data" in batch:
batch["data"] = torch.as_tensor(batch["data"], dtype=torch.float32)
if "targets" in batch:
batch["targets"] = torch.as_tensor(batch["targets"], dtype=torch.long)
return batch


def with_torch_transform(dataset: Dataset) -> Dataset:
"""Return a dataset that lazily converts encoded columns to torch tensors."""
return dataset.with_transform(_apply_torch_transform, columns=["data", "targets"])


def get_encoding_fn(processor):
"""Return a function to use to pre-process/encode the SpeechCommands dataset.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader
from transformers import WhisperProcessor

from whisper_example.dataset import get_encoding_fn
from whisper_example.dataset import get_encoding_fn, with_torch_transform
from whisper_example.model import eval_model, get_model

# Create ServerApp
Expand Down Expand Up @@ -97,7 +97,7 @@ def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
encoded = val_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols)

torch.set_num_threads(og_threads)
val_encoded = encoded.with_format("torch", columns=["data", "targets"])
val_encoded = with_torch_transform(encoded)
val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4)

# Run global evaluation
Expand Down
Loading