daart Models package

daart.models.base Module

Base models/modules in PyTorch.

Functions

reparameterize_gaussian(mu, logvar)

Sample from N(mu, var)

get_activation_func_from_str(activation_str)

Classes

BaseModel(*args, **kwargs)

Template for PyTorch models.

Segmenter(hparams)

General wrapper class for behavioral segmentation models.

Ensembler(models)

Ensemble of models.

daart.models.rnn Module

RNN models (LSTM/GRU) implemented in PyTorch.

Classes

RNN(hparams[, type, in_size, hid_size, out_size])

daart.models.tcn Module

Temporal Convolution model implemented in PyTorch.

Classes

DilatedTCN(hparams[, type, in_size, ...])

Temporal Convolutional Model with dilated convolutions and no temporal downsampling.

daart.models.temporalmlp Module

Temporal MLP model implemented in PyTorch.

Classes

TemporalMLP(hparams[, type, in_size, ...])

MLP network with initial 1D convolution layer.