Source code for neural_pipeline.data_producer.data_producer

import itertools
from random import shuffle

from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from abc import ABCMeta, abstractmethod


__all__ = ['AbstractDataset', 'DataProducer']


class AbstractDataset(metaclass=ABCMeta):
    @abstractmethod
    def __len__(self):
        pass

    @abstractmethod
    def __getitem__(self, item):
        pass


[docs]class DataProducer: """ Data Producer. Accumulate one or more datasets and pass it's data by batches for processing. This use PyTorch builtin :class:`DataLoader` for increase performance of data delivery. :param datasets: list of datasets. Every dataset might be iterable (contans methods ``__getitem__`` and ``__len__``) :param batch_size: size of output batch :param num_workers: number of processes, that load data from datasets and pass it for output """ def __init__(self, datasets: [AbstractDataset], batch_size: int = 1, num_workers: int = 0): self.__datasets = datasets self.__batch_size = batch_size self.__num_workers = num_workers self._shuffle_datasets_order = False self._glob_shuffle = False self._pin_memory = False self._collate_fn = default_collate self._drop_last = False self._need_pass_indices = False self._update_datasets_idx_space() self._indices = None def drop_last(self, need_drop: bool) -> 'DataProducer': self._drop_last = need_drop return self
[docs] def shuffle_datasets_order(self, is_need: bool) -> 'DataProducer': """ Is need to shuffle datasets order. Shuffling performs after every 0 index access :param is_need: is need :return: self object """ self._shuffle_datasets_order = is_need return self
[docs] def global_shuffle(self, is_need: bool) -> 'DataProducer': """ Is need global shuffling. If global shuffling enable - batches will compile from random indices of all datasets. In this case datasets order shuffling was ignoring :param is_need: is need global shuffling :return: self object """ self._glob_shuffle = is_need return self
[docs] def pin_memory(self, is_need: bool) -> 'DataProducer': """ Is need to pin memory on loading. Pinning memory was increase data loading performance (especially when data loads to GPU) but incompatible with swap :param is_need: is need :return: self object """ self._pin_memory = is_need return self
[docs] def pass_indices(self, need_pass: bool) -> 'DataProducer': """ Pass indices of data in every batch. By default disabled :param need_pass: is need to pass indices """ self._need_pass_indices = need_pass return self
[docs] def set_indices(self, indices: [str]) -> 'DataProducer': """ Set indices to :class:`DataProducer`. After that, :class:`DataProducer` start produce data only by indices :param indices: list of indices in format "<dataset_idx>_<data_idx>` like: ['0_0', '0_1', '1_0'] :return: self object """ self._indices = indices return self
[docs] def get_indices(self) -> [str] or None: """ Get current indices :return: list of current indices or None if method :meth:`set_indices` doesn't called """ return self._indices
def _is_passed_indices(self) -> bool: """ Internal method for know if :class:`DataProducer` passed indices :return: is passed """ return self._need_pass_indices
[docs] def get_data(self, dataset_idx: int, data_idx: int) -> object: """ Get single data by dataset idx and data_idx :param dataset_idx: index of dataset :param data_idx: index of data in this dataset :return: dataset output """ data = self.__datasets[dataset_idx][data_idx] if self._need_pass_indices: if not isinstance(data, dict): data = {'data': data} return dict(data, **{'data_idx': str(dataset_idx) + "_" + str(data_idx)}) return data
def set_collate_func(self, func: callable) -> 'DataProducer': self._collate_fn = func return self def __len__(self): if self._indices is None: return self.__overall_len else: return len(self._indices) def __getitem__(self, item): if item == 0 and self._indices is None and (not self._glob_shuffle) and self._shuffle_datasets_order: self._update_datasets_idx_space() if self._indices is None: dataset_idx, data_idx = 0, item for i in range(len(self.__datasets)): if item > self._datatsets_idx_space[i]: dataset_idx = i + 1 data_idx = item - self._datatsets_idx_space[i] - 1 else: dataset_idx, data_idx = self._indices[item].split('_') dataset_idx, data_idx = int(dataset_idx), int(data_idx) return self.get_data(dataset_idx, data_idx)
[docs] def get_loader(self, indices: [str] = None) -> DataLoader: """ Get PyTorch :class:`DataLoader` object, that aggregate :class:`DataProducer`. If ``indices`` is specified - DataLoader will output data only by this indices. In this case indices will not passed. :param indices: list of indices. Each item of list is a string in format '{}_{}'.format(dataset_idx, data_idx) :return: :class:`DataLoader` object """ if indices is not None: return self._get_loader_by_indices(indices) return DataLoader(self, batch_size=self.__batch_size, num_workers=self.__num_workers, shuffle=self._glob_shuffle, pin_memory=self._pin_memory, collate_fn=self._collate_fn, drop_last=self._drop_last)
def _get_loader_by_indices(self, indices: [str]) -> DataLoader: """ Get loader, that produce data only by specified indices :param indices: required indices :return: :class:`DataLoader` object """ return DataLoader(_ByIndices(self.__datasets, indices), batch_size=self.__batch_size, num_workers=self.__num_workers, shuffle=self._glob_shuffle, pin_memory=self._pin_memory, collate_fn=self._collate_fn, drop_last=self._drop_last) def _update_datasets_idx_space(self) -> None: """ Update idx space of datasets. Idx space used for correct mapping global idx to corresponding dataset data index """ if self._shuffle_datasets_order: shuffle(self.__datasets) datasets_len = [len(d) for d in self.__datasets] self.__overall_len = sum(datasets_len) self._datatsets_idx_space = [] cur_len = 0 for dataset_len in datasets_len: self._datatsets_idx_space.append(dataset_len + cur_len - 1) cur_len += dataset_len
class _ByIndices(DataProducer): def __init__(self, datasets: [AbstractDataset], indices: []): super().__init__(datasets) self.shuffle_datasets_order(False) self.indices = list(itertools.chain(*indices)) def __getitem__(self, item): dataset_idx, data_idx = self.indices[item].split('_') return self.get_data(int(dataset_idx), int(data_idx)) def __len__(self): return len(self.indices)