daart Modules¶
daart.callbacks Module¶
Callback classes to control training.
Classes¶
Abstract base class for callbacks. |
|
|
Stop training when a monitored quantity has stopped improving. |
|
Linearly increase value in an hparam dict. |
|
Implement PseudoLabels algorithm. |
|
Implement uncertainty-aware pseudo-labels algorithm. |
daart.data Module¶
Classes for splitting and serving data to models.
The data generator classes contained in this module inherit from the
torch.utils.data.Dataset class. The user-facing class is the
DataGenerator, which can manage one or more datasets. Each dataset is composed
of trials, which are split into training, validation, and testing trials using the
split_trials(). The default data generator can handle the following data types:
markers: i.e. DLC/DGP markers
labels_strong: discrete behavioral labels
labels_weak: noisy discrete behavioral labels
Functions¶
|
Split trials into train/val/test blocks. |
|
Compute sequences of temporally contiguous data points. |
|
Compute padding needed to account for convolutions. |
|
Load markers from csv file assuming DLC format. |
|
Load markers from csv file assuming the following format. |
|
Load markers from hdf5 file assuming DLC format. |
|
Load labels from csv file assuming a standard format. |
|
Load labels from pkl file assuming a standard format. |
Classes¶
|
Dataset class for a single dataset. |
|
Dataset generator for serving pytorch models. |
daart.eval Module¶
Evaluation functions for the daart package.
Functions¶
|
Compute precision and recall for classifier. |
|
Compute intersection over union for two 1D arrays. |
|
Compute distribution of run lengths for an array with integer entries. |
|
Create training plots for each term in the objective function. |
|
Load metrics csv file and return as a pandas dataframe for easy plotting. |
daart.io Module¶
File IO for daart package.
Functions¶
|
Get all first-level subdirectories in a given path (no recursion). |
|
Construct experiment directory given base directory and list of experiment ids. |
|
Helper function to construct model directory from model param dict. |
|
Returns dict containing all params considered essential for defining a model of that type. |
|
Search testtube versions to find if experiment with the same hyperparameters has been fit. |
|
Read csv file that contains expt id info. |
|
Export list of expt ids to csv file. |
|
Export hyperparameter dictionary as a yaml file. |
|
Utility function for creating necessary dictories for a specified filename. |
daart.losses Module¶
Custom losses for PyTorch models.
Functions¶
|
Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar. |
daart.testtube Module¶
Test-tube helper functions for use in fitting scripts.
Functions¶
|
Nicely formatted hparams string. |
|
Create test-tube experiment for organizing model fits. |
|
Delete all (unnecessary) subdirectories in the model directory (created by test-tube) |
daart.train Module¶
Helper functions for model training.
Classes¶
|
Base class for logging loss metrics. |
|
daart.transforms Module¶
Tranform classes to process data.
Data generator objects can apply these transforms to data upon loading.
Classes¶
|
Composes several transforms together. |
Abstract base class for transforms. |
|
|
Shuffle blocks of contiguous discrete states within each trial. |
|
Turn a categorical vector into a one-hot vector. |
Compute motion energy across batch dimension. |
|
|
Place each channel (mostly) in [0, 1]. |
|
z-score channel activity. |
daart.utils Module¶
Utility functions for daart package.
Functions¶
|
Helper function to build a data generator from hparam dict. |
|
Helper function to build a list of callbacks from hparam dict. |