This is a optimization for tpu
import torch
from torch.utils.data import DataLoader, Dataset
import asyncio
class AsyncDataLoader(DataLoader):
"""An asynchronous data loader to load data in a non-blocking manner."""
def __init__(self, dataset, batch_size=1, num_workers=1, *args, **kwargs):
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
self.loop = asyncio.get_event_loop()
async def async_load_batch(self, batch_indices):
"""Asynchronously load a batch of data."""
return [self.dataset[i] for i in batch_indices]
def __iter__(self):
"""Return an iterator over the data."""
for i in range(0, len(self.dataset), self.batch_size):
batch_indices = list(range(i, min(i + self.batch_size, len(self.dataset))))
yield self.loop.run_until_complete(self.async_load_batch(batch_indices))
Example Dataset for demonstration
class ExampleDataset(Dataset):
"""A simple dataset for demonstration purposes."""
def __init__(self, size):
self.data = [torch.tensor([i]) for i in range(size)]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
Tests
async def test_async_data_loader():
dataset = ExampleDataset(100)
async_loader = AsyncDataLoader(dataset, batch_size=10)
for batch in async_loader:
assert len(batch) == 10, "Batch size does not match"
print("All tests passed!")
if name == "main":
asyncio.run(test_async_data_loader())
This is a optimization for tpu
import torch
from torch.utils.data import DataLoader, Dataset
import asyncio
class AsyncDataLoader(DataLoader):
"""An asynchronous data loader to load data in a non-blocking manner."""
Example Dataset for demonstration
class ExampleDataset(Dataset):
"""A simple dataset for demonstration purposes."""
Tests
async def test_async_data_loader():
dataset = ExampleDataset(100)
async_loader = AsyncDataLoader(dataset, batch_size=10)
if name == "main":
asyncio.run(test_async_data_loader())