daart Modules

daart.callbacks Module

Callback classes to control training.

Classes

BaseCallback()

Abstract base class for callbacks.

EarlyStopping([patience, delta])

Stop training when a monitored quantity has stopped improving.

AnnealHparam(hparams, key, epoch_start, ...)

Linearly increase value in an hparam dict.

PseudoLabels([prob_threshold, epoch_start])

Implement PseudoLabels algorithm.

UPS([prob_threshold, variance_threshold, ...])

Implement uncertainty-aware pseudo-labels algorithm.

daart.data Module

Classes for splitting and serving data to models.

The data generator classes contained in this module inherit from the torch.utils.data.Dataset class. The user-facing class is the DataGenerator, which can manage one or more datasets. Each dataset is composed of trials, which are split into training, validation, and testing trials using the split_trials(). The default data generator can handle the following data types:

  • markers: i.e. DLC/DGP markers

  • labels_strong: discrete behavioral labels

  • labels_weak: noisy discrete behavioral labels

Functions

split_trials(n_trials[, rng_seed, train_tr, ...])

Split trials into train/val/test blocks.

compute_sequences(data, sequence_length[, ...])

Compute sequences of temporally contiguous data points.

compute_sequence_pad(hparams)

Compute padding needed to account for convolutions.

load_marker_csv(filepath)

Load markers from csv file assuming DLC format.

load_feature_csv(filepath)

Load markers from csv file assuming the following format.

load_marker_h5(filepath)

Load markers from hdf5 file assuming DLC format.

load_label_csv(filepath)

Load labels from csv file assuming a standard format.

load_label_pkl(filepath)

Load labels from pkl file assuming a standard format.

Classes

SingleDataset(id, signals, transforms, paths)

Dataset class for a single dataset.

DataGenerator(ids_list, signals_list, ...[, ...])

Dataset generator for serving pytorch models.

daart.eval Module

Evaluation functions for the daart package.

Functions

get_precision_recall(true_classes, pred_classes)

Compute precision and recall for classifier.

int_over_union(array1, array2)

Compute intersection over union for two 1D arrays.

run_lengths(array)

Compute distribution of run lengths for an array with integer entries.

plot_training_curves(metrics_file[, dtype, ...])

Create training plots for each term in the objective function.

load_metrics_csv_as_df(metric_file, metrics_list)

Load metrics csv file and return as a pandas dataframe for easy plotting.

daart.io Module

File IO for daart package.

Functions

get_subdirs(path)

Get all first-level subdirectories in a given path (no recursion).

get_expt_dir(base_dir, expt_ids)

Construct experiment directory given base directory and list of experiment ids.

get_model_dir(base_dir, model_params)

Helper function to construct model directory from model param dict.

get_model_params(hparams)

Returns dict containing all params considered essential for defining a model of that type.

find_experiment(hparams[, verbose, ...])

Search testtube versions to find if experiment with the same hyperparameters has been fit.

read_expt_info_from_csv(expt_file)

Read csv file that contains expt id info.

export_expt_info_to_csv(expt_dir, ids_list)

Export list of expt ids to csv file.

export_hparams(hparams[, filename])

Export hyperparameter dictionary as a yaml file.

make_dir_if_not_exists(save_file)

Utility function for creating necessary dictories for a specified filename.

daart.losses Module

Custom losses for PyTorch models.

Functions

kl_div_to_std_normal(mu, logvar)

Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.

daart.testtube Module

Test-tube helper functions for use in fitting scripts.

Functions

get_all_params()

print_hparams(hparams)

Nicely formatted hparams string.

create_tt_experiment(hparams)

Create test-tube experiment for organizing model fits.

clean_tt_dir(hparams)

Delete all (unnecessary) subdirectories in the model directory (created by test-tube)

daart.train Module

Helper functions for model training.

Classes

Logger([n_datasets, save_path])

Base class for logging loss metrics.

Trainer([learning_rate, l2_reg, min_epochs, ...])

daart.transforms Module

Tranform classes to process data.

Data generator objects can apply these transforms to data upon loading.

Classes

Compose(transforms)

Composes several transforms together.

Transform()

Abstract base class for transforms.

BlockShuffle(rng_seed)

Shuffle blocks of contiguous discrete states within each trial.

MakeOneHot([n_classes])

Turn a categorical vector into a one-hot vector.

MotionEnergy()

Compute motion energy across batch dimension.

Unitize()

Place each channel (mostly) in [0, 1].

ZScore()

z-score channel activity.

daart.utils Module

Utility functions for daart package.

Functions

build_data_generator(hparams)

Helper function to build a data generator from hparam dict.

collect_callbacks(hparams)

Helper function to build a list of callbacks from hparam dict.