DataGenerator

class daart.data.DataGenerator(ids_list: List[str], signals_list: List[List[str]], transforms_list: List[list], paths_list: List[List[str | None]], device: str = 'cuda', as_numpy: bool = False, rng_seed: int = 0, trial_splits: str | dict | None = None, train_frac: float = 1.0, sequence_length: int = 500, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, sequence_pad: int = 0, input_type: str = 'markers')[source]

Bases: object

Dataset generator for serving pytorch models.

This class contains a list of SingleDataset generators. It handles shuffling and iterating over these datasets.

Methods Summary

next_batch(dtype)

Return next batch of data.

reset_iterators(dtype)

Reset iterators so that all data is available.

Methods Documentation

next_batch(dtype: str) tuple[source]

Return next batch of data.

The data generator iterates randomly through datasets and trials. Once a dataset runs out of trials it is skipped.

Parameters:

dtype (str) – ‘train’ | ‘val’ | ‘test’

Returns:

  • sample (dict): data batch with keys given by signals input to class

  • dataset (int): dataset from which data batch is drawn

Return type:

tuple

reset_iterators(dtype: str) None[source]

Reset iterators so that all data is available.

Parameters:

dtype (str) – ‘train’ | ‘val’ | ‘test’ | ‘all’