DataGenerator¶
- class daart.data.DataGenerator(ids_list: List[str], signals_list: List[List[str]], transforms_list: List[list], paths_list: List[List[str | None]], device: str = 'cuda', as_numpy: bool = False, rng_seed: int = 0, trial_splits: str | dict | None = None, train_frac: float = 1.0, sequence_length: int = 500, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, sequence_pad: int = 0, input_type: str = 'markers')[source]¶
Bases:
objectDataset generator for serving pytorch models.
This class contains a list of SingleDataset generators. It handles shuffling and iterating over these datasets.
Methods Summary
next_batch(dtype)Return next batch of data.
reset_iterators(dtype)Reset iterators so that all data is available.
Methods Documentation
- next_batch(dtype: str) tuple[source]¶
Return next batch of data.
The data generator iterates randomly through datasets and trials. Once a dataset runs out of trials it is skipped.
- Parameters:
dtype (str) – ‘train’ | ‘val’ | ‘test’
- Returns:
sample (dict): data batch with keys given by signals input to class
dataset (int): dataset from which data batch is drawn
- Return type:
tuple