Module infinibatch.torch.data
Expand source code
import torch
from infinibatch.iterators import CheckpointableIterator
from infinibatch.datasets import chunked_dataset_iterator
from typing import Union, Iterable, Any
# @TODO: This has been tested once, but we have no regression test presently. I am worried tests will fail if Torch is not installed.
class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
"""
Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
PyTorch's DataLoader class.
"""
def __init__(self, source: CheckpointableIterator):
super().__init__()
self._source = source
def __iter__(self): # this is called in the forked clone
worker_info = torch.utils.data.get_worker_info()
assert worker_info is None or worker_info.num_workers == 1 # not supported since we can't get at the checkpoint for each worker
return iter(self._source)
# @TODO: This is currently untested, and may not work presently.
class IterableChunkedDataset(torch.utils.data.IterableDataset):
def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
super().__init__()
self.rank = rank
self.num_workers_per_rank = num_workers_per_rank
# instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
self.dataset._instance_rank = self.rank
else:
assert worker_info.num_workers == self.num_workers_per_rank
self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
return iter(self.dataset)
Classes
class IterableCheckpointedDataset (source: CheckpointableIterator)-
Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by PyTorch's DataLoader class.
Expand source code
class IterableCheckpointedDataset(torch.utils.data.IterableDataset): """ Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by PyTorch's DataLoader class. """ def __init__(self, source: CheckpointableIterator): super().__init__() self._source = source def __iter__(self): # this is called in the forked clone worker_info = torch.utils.data.get_worker_info() assert worker_info is None or worker_info.num_workers == 1 # not supported since we can't get at the checkpoint for each worker return iter(self._source)Ancestors
- torch.utils.data.dataset.IterableDataset
- torch.utils.data.dataset.Dataset
class IterableChunkedDataset (paths: Union[str, Iterable[str]], shuffle: bool = True, buffer_size: int = 1048576, transform=None, seed: int = None, world_size: int = 1, rank: int = 0, num_workers_per_rank: int = 1)-
An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:
__iter__, which would return an iterator of samples in this dataset.When a subclass is used with :class:
~torch.utils.data.DataLoader, each item in the dataset will be yielded from the :class:~torch.utils.data.DataLoaderiterator. When :attr:num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:~torch.utils.data.get_worker_info, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:__iter__method or the :class:~torch.utils.data.DataLoader's :attr:worker_init_fnoption to modify each copy's behavior.Example 1: splitting workload across all workers in :meth:
__iter__::>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6]Example 2: splitting workload across all workers using :attr:
worker_init_fn::>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]Expand source code
class IterableChunkedDataset(torch.utils.data.IterableDataset): def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1): super().__init__() self.rank = rank self.num_workers_per_rank = num_workers_per_rank # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__ self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank) def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: # single-process data loading self.dataset._instance_rank = self.rank else: assert worker_info.num_workers == self.num_workers_per_rank self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id return iter(self.dataset)Ancestors
- torch.utils.data.dataset.IterableDataset
- torch.utils.data.dataset.Dataset