importbisectimportwarningsimportmathfromtypingimport(Generic,Iterable,Iterator,List,Optional,Sequence,Tuple,TypeVar,Union)# No 'default_generator' in torch/__init__.pyifromtorchimportdefault_generator,randpermfromtorch._utilsimport_accumulatefrom...importGenerator,Tensor__all__=["Dataset","IterableDataset","TensorDataset","ConcatDataset","ChainDataset","Subset","random_split",]T_co=TypeVar('T_co',covariant=True)T=TypeVar('T')
[docs]classDataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """def__getitem__(self,index)->T_co:raiseNotImplementedErrordef__add__(self,other:'Dataset[T_co]')->'ConcatDataset[T_co]':returnConcatDataset([self,other])
# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.py
[docs]classIterableDataset(Dataset[T_co]):r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """def__iter__(self)->Iterator[T_co]:raiseNotImplementedErrordef__add__(self,other:Dataset[T_co]):returnChainDataset([self,other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
[docs]classTensorDataset(Dataset[Tuple[Tensor,...]]):r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """tensors:Tuple[Tensor,...]def__init__(self,*tensors:Tensor)->None:assertall(tensors[0].size(0)==tensor.size(0)fortensorintensors),"Size mismatch between tensors"self.tensors=tensorsdef__getitem__(self,index):returntuple(tensor[index]fortensorinself.tensors)def__len__(self):returnself.tensors[0].size(0)
[docs]classConcatDataset(Dataset[T_co]):r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """datasets:List[Dataset[T_co]]cumulative_sizes:List[int]@staticmethoddefcumsum(sequence):r,s=[],0foreinsequence:l=len(e)r.append(l+s)s+=lreturnrdef__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=list(datasets)assertlen(self.datasets)>0,'datasets should not be an empty iterable'# type: ignore[arg-type]fordinself.datasets:assertnotisinstance(d,IterableDataset),"ConcatDataset does not support IterableDataset"self.cumulative_sizes=self.cumsum(self.datasets)def__len__(self):returnself.cumulative_sizes[-1]def__getitem__(self,idx):ifidx<0:if-idx>len(self):raiseValueError("absolute value of index should not exceed dataset length")idx=len(self)+idxdataset_idx=bisect.bisect_right(self.cumulative_sizes,idx)ifdataset_idx==0:sample_idx=idxelse:sample_idx=idx-self.cumulative_sizes[dataset_idx-1]returnself.datasets[dataset_idx][sample_idx]@propertydefcummulative_sizes(self):warnings.warn("cummulative_sizes attribute is renamed to ""cumulative_sizes",DeprecationWarning,stacklevel=2)returnself.cumulative_sizes
[docs]classChainDataset(IterableDataset):r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """def__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=datasetsdef__iter__(self):fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"forxind:yieldxdef__len__(self):total=0fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"total+=len(d)# type: ignore[arg-type]returntotal
[docs]classSubset(Dataset[T_co]):r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """dataset:Dataset[T_co]indices:Sequence[int]def__init__(self,dataset:Dataset[T_co],indices:Sequence[int])->None:self.dataset=datasetself.indices=indicesdef__getitem__(self,idx):ifisinstance(idx,list):returnself.dataset[[self.indices[i]foriinidx]]returnself.dataset[self.indices[idx]]def__len__(self):returnlen(self.indices)
[docs]defrandom_split(dataset:Dataset[T],lengths:Sequence[Union[int,float]],generator:Optional[Generator]=default_generator)->List[Subset[T]]:r""" Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ifmath.isclose(sum(lengths),1)andsum(lengths)<=1:subset_lengths:List[int]=[]fori,fracinenumerate(lengths):iffrac<0orfrac>1:raiseValueError(f"Fraction at index {i} is not between 0 and 1")n_items_in_split=int(math.floor(len(dataset)*frac)# type: ignore[arg-type])subset_lengths.append(n_items_in_split)remainder=len(dataset)-sum(subset_lengths)# type: ignore[arg-type]# add 1 to all the lengths in round-robin fashion until the remainder is 0foriinrange(remainder):idx_to_add_at=i%len(subset_lengths)subset_lengths[idx_to_add_at]+=1lengths=subset_lengthsfori,lengthinenumerate(lengths):iflength==0:warnings.warn(f"Length of split at index {i} is 0. "f"This might result in an empty dataset.")# Cannot verify that dataset is Sizedifsum(lengths)!=len(dataset):# type: ignore[arg-type]raiseValueError("Sum of input lengths does not equal the length of the input dataset!")indices=randperm(sum(lengths),generator=generator).tolist()# type: ignore[call-overload]return[Subset(dataset,indices[offset-length:offset])foroffset,lengthinzip(_accumulate(lengths),lengths)]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.