Skip to content

Optimization #9754

Description

@signore662-beep

This is a optimization for tpu

import torch
from torch.utils.data import DataLoader, Dataset
import asyncio

class AsyncDataLoader(DataLoader):
"""An asynchronous data loader to load data in a non-blocking manner."""

def __init__(self, dataset, batch_size=1, num_workers=1, *args, **kwargs):
    super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
    self.loop = asyncio.get_event_loop()

async def async_load_batch(self, batch_indices):
    """Asynchronously load a batch of data."""
    return [self.dataset[i] for i in batch_indices]

def __iter__(self):
    """Return an iterator over the data."""
    for i in range(0, len(self.dataset), self.batch_size):
        batch_indices = list(range(i, min(i + self.batch_size, len(self.dataset))))
        yield self.loop.run_until_complete(self.async_load_batch(batch_indices))

Example Dataset for demonstration

class ExampleDataset(Dataset):
"""A simple dataset for demonstration purposes."""

def __init__(self, size):
    self.data = [torch.tensor([i]) for i in range(size)]

def __len__(self):
    return len(self.data)

def __getitem__(self, index):
    return self.data[index]

Tests

async def test_async_data_loader():
dataset = ExampleDataset(100)
async_loader = AsyncDataLoader(dataset, batch_size=10)

for batch in async_loader:
    assert len(batch) == 10, "Batch size does not match"
print("All tests passed!")

if name == "main":
asyncio.run(test_async_data_loader())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions