"""Base models/modules in PyTorch."""
import numpy as np
from scipy.special import softmax as scipy_softmax
from scipy.stats import entropy
import torch
from sklearn.metrics import accuracy_score, r2_score
from torch import nn, save
from daart import losses
# to ignore imports for sphix-autoapidoc
__all__ = [
'reparameterize_gaussian',
'get_activation_func_from_str',
'BaseModel',
'Segmenter',
'Ensembler',
]
[docs]def reparameterize_gaussian(mu, logvar):
"""Sample from N(mu, var)
Parameters
----------
mu : torch.Tensor
vector of mean parameters
logvar : torch.Tensor
vector of log variances; only mean field approximation is currently implemented
Returns
-------
torch.Tensor
sampled vector of shape (n_sequences, sequence_length, embedding_dim)
"""
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
[docs]def get_activation_func_from_str(activation_str):
if activation_str == 'linear':
activation_func = None
elif activation_str == 'relu':
activation_func = nn.ReLU()
elif activation_str == 'lrelu':
activation_func = nn.LeakyReLU(0.05)
elif activation_str == 'sigmoid':
activation_func = nn.Sigmoid()
elif activation_str == 'tanh':
activation_func = nn.Tanh()
else:
raise ValueError('"%s" is an invalid activation function' % activation_str)
return activation_func
[docs]class BaseModel(nn.Module):
"""Template for PyTorch models."""
def __init__(self, *args, **kwargs):
super().__init__()
def __str__(self):
"""Pretty print model architecture."""
raise NotImplementedError
[docs] def build_model(self):
"""Build model from hparams."""
raise NotImplementedError
@staticmethod
def _build_linear(global_layer_num, name, in_size, out_size):
linear_layer = nn.Sequential()
# add layer (cross entropy loss handles activation)
layer = nn.Linear(in_features=in_size, out_features=out_size)
layer_name = str('dense(%s)_layer_%02i' % (name, global_layer_num))
linear_layer.add_module(layer_name, layer)
return linear_layer
@staticmethod
def _build_mlp(
global_layer_num, in_size, hid_size, out_size, n_hid_layers=1, activation='lrelu'):
mlp = nn.Sequential()
in_size_ = in_size
# loop over hidden layers (0 layers <-> linear model)
for i_layer in range(n_hid_layers + 1):
if i_layer == n_hid_layers:
out_size_ = out_size
else:
out_size_ = hid_size
# add layer
layer = nn.Linear(in_features=in_size_, out_features=out_size_)
name = str('dense_layer_%02i' % global_layer_num)
mlp.add_module(name, layer)
# add activation
if i_layer == n_hid_layers:
# no activation for final layer
activation_func = None
else:
activation_func = get_activation_func_from_str(activation)
if activation_func:
name = '%s_%02i' % (activation, global_layer_num)
mlp.add_module(name, activation_func)
# update layer info
global_layer_num += 1
in_size_ = out_size_
return mlp
[docs] def forward(self, *args, **kwargs):
"""Push data through model."""
raise NotImplementedError
[docs] def training_step(self, *args, **kwargs):
"""Compute loss."""
raise NotImplementedError
[docs] def save(self, filepath):
"""Save model parameters."""
save(self.state_dict(), filepath)
[docs] def get_parameters(self):
"""Get all model parameters that have gradient updates turned on."""
return filter(lambda p: p.requires_grad, self.parameters())
[docs] def load_parameters_from_file(self, filepath):
"""Load parameters from .pt file."""
self.load_state_dict(torch.load(filepath, map_location=lambda storage, loc: storage))
[docs]class Segmenter(BaseModel):
"""General wrapper class for behavioral segmentation models."""
def __init__(self, hparams):
"""
Parameters
----------
hparams : dict
- backbone (str): 'temporal-mlp' | 'dtcn' | 'lstm' | 'gru'
- rng_seed_model (int): random seed to control weight initialization
- input_size (int): number of input channels
- output_size (int): number of classes
- task_size (int): number of regression tasks
- batch_pad (int): padding needed to account for convolutions
- n_hid_layers (int): hidden layers of network architecture
- n_hid_units (int): hidden units per layer
- n_lags (int): number of lags in input data to use for temporal convolution
- activation (str): 'linear' | 'relu' | 'lrelu' | 'sigmoid' | 'tanh'
- classifier_type (str): 'multiclass' | 'binary' | 'multibinary'
- class_weights (array-like): weights on classes
- variational (bool): whether or not model is variational
- lambda_weak (float): hyperparam on weak label classification
- lambda_strong (float): hyperparam on srong label classification
- lambda_pred (float): hyperparam on next step prediction
- lambda_task (float): hyperparam on task regression
"""
super().__init__()
self.hparams = hparams
# model dict will contain some or all of the following components:
# - encoder: inputs -> latents
# - classifier: latents -> hand labels
# - classifier_weak: latents -> heuristic/pseudo labels
# - task_predictor: latents -> tasks
# - decoder: latents[t] -> inputs[t]
# - predictor: latents[t] -> inputs[t+1]
self.model = nn.ModuleDict()
self.build_model()
# label loss based on cross entropy; don't compute gradient when target = 0
classifier_type = hparams.get('classifier_type', 'multiclass')
if classifier_type == 'multiclass':
# multiple mutually exclusive classes, 0 is backgroud class
ignore_index = hparams.get('ignore_class', 0)
elif classifier_type == 'binary':
# single class
ignore_index = -100 # pytorch default
elif classmethod == 'multibinary':
# multiple non-mutually exclusive classes (each a binary classification)
raise NotImplementedError
else:
raise NotImplementedError("classifier type must be 'multiclass' or 'binary'")
weight = hparams.get('class_weights', None)
if weight is not None:
weight = torch.tensor(weight, dtype=torch.float32)
self.class_loss = nn.CrossEntropyLoss(
weight=weight, ignore_index=ignore_index, reduction='mean')
self.pred_loss = nn.MSELoss(reduction='mean')
self.task_loss = nn.MSELoss(reduction='mean')
def __str__(self):
"""Pretty print model architecture."""
format_str = '\n%s architecture\n' % self.hparams['backbone'].upper()
format_str += '------------------------\n'
format_str += 'Encoder:\n'
for i, module in enumerate(self.model['encoder'].model):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if self.hparams.get('variational', False):
format_str += 'Variational Layers:\n'
for latent in ['latent_mean', 'latent_logvar']:
for i, module in enumerate(self.model[latent]):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if 'decoder' in self.model:
format_str += 'Decoder:\n'
for i, module in enumerate(self.model['decoder'].model):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if 'predictor' in self.model:
format_str += 'Predictor:\n'
for i, module in enumerate(self.model['predictor'].model):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if 'classifier' in self.model:
format_str += 'Classifier:\n'
for i, module in enumerate(self.model['classifier']):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if 'classifier_weak' in self.model:
format_str += 'Classifier Weak:\n'
for i, module in enumerate(self.model['classifier_weak']):
format_str += str(' {}: {}\n'.format(i, module))
format_str += '\n'
if 'task_predictor' in self.model:
format_str += 'Task Predictor:\n'
for i, module in enumerate(self.model['task_predictor']):
format_str += str(' {}: {}\n'.format(i, module))
return format_str
[docs] def build_model(self):
"""Construct the model using hparams."""
# set random seeds for control over model initialization
rng_seed_model = self.hparams.get('rng_seed_model', 0)
torch.manual_seed(rng_seed_model)
np.random.seed(rng_seed_model)
# select backbone network
if self.hparams['backbone'].lower() == 'temporal-mlp':
from daart.models.temporalmlp import TemporalMLP as Module
elif self.hparams['backbone'].lower() == 'tcn':
raise NotImplementedError('deprecated; use dtcn instead')
elif self.hparams['backbone'].lower() == 'dtcn':
from daart.models.tcn import DilatedTCN as Module
elif self.hparams['backbone'].lower() in ['lstm', 'gru']:
from daart.models.rnn import RNN as Module
elif self.hparams['backbone'].lower() == 'tgm':
raise NotImplementedError
# from daart.models.tgm import TGM as Module
else:
raise ValueError('"%s" is not a valid backbone network' % self.hparams['backbone'])
global_layer_num = 0
# build encoder module
self.model['encoder'] = Module(self.hparams, type='encoder')
if self.hparams.get('variational', False):
self.hparams['kl_weight'] = 1 # weight in front of kl term; anneal this using callback
self.model['latent_mean'] = self._build_linear(
global_layer_num=len(self.model['encoder'].model), name='latent_mean',
in_size=self.hparams['n_hid_units'], out_size=self.hparams['n_hid_units'])
self.model['latent_logvar'] = self._build_linear(
global_layer_num=len(self.model['encoder'].model), name='latent_logvar',
in_size=self.hparams['n_hid_units'], out_size=self.hparams['n_hid_units'])
# build decoder module
if self.hparams.get('lambda_recon', 0) > 0:
self.model['decoder'] = Module(self.hparams, type='decoder')
# build predictor module
if self.hparams.get('lambda_pred', 0) > 0:
self.model['predictor'] = Module(self.hparams, type='decoder')
# classifier: single linear layer for hand labels
if self.hparams.get('lambda_strong', 0) > 0:
self.model['classifier'] = self._build_linear(
global_layer_num=global_layer_num, name='classification',
in_size=self.hparams['n_hid_units'], out_size=self.hparams['output_size'])
# classifier: single linear layer for heuristic labels
if self.hparams.get('lambda_weak', 0) > 0:
self.model['classifier_weak'] = self._build_linear(
global_layer_num=global_layer_num, name='classification',
in_size=self.hparams['n_hid_units'], out_size=self.hparams['output_size'])
# task regression: single linear layer
if self.hparams.get('lambda_task', 0) > 0:
self.model['task_predictor'] = self._build_mlp(
global_layer_num=global_layer_num, in_size=self.hparams['n_hid_units'],
hid_size=self.hparams['n_hid_units'], out_size=self.hparams['task_size'],
n_hid_layers=1)
[docs] def forward(self, x):
"""Process input data.
Parameters
----------
x : torch.Tensor
input data of shape (n_sequences, sequence_length, n_markers)
Returns
-------
dict of model outputs/internals as torch tensors
- 'labels' (torch.Tensor): model classification
shape of (n_sequences, sequence_length, n_classes)
- 'labels_weak' (torch.Tensor): model classification of weak/pseudo labels
shape of (n_sequences, sequence_length, n_classes)
- 'reconstruction' (torch.Tensor): input decoder prediction
shape of (n_sequences, sequence_length, n_markers)
- 'prediction' (torch.Tensor): one-step-ahead prediction
shape of (n_sequences, sequence_length, n_markers)
- 'task_prediction' (torch.Tensor): prediction of regression tasks
(n_sequences, sequence_length, n_tasks)
- 'embedding' (torch.Tensor): behavioral embedding used for classification/prediction
in non-variational models
shape of (n_sequences, sequence_length, embedding_dim)
- 'mean' (torch.Tensor): mean of appx posterior of latents in variational models
shape of (n_sequences, sequence_length, embedding_dim)
- 'logvar' (torch.Tensor): logvar of appx posterior of latents in variational models
shape of (n_sequences, sequence_length, embedding_dim)
- 'sample' (torch.Tensor): sample from appx posterior of latents in variational models
shape of (n_sequences, sequence_length, embedding_dim)
"""
# push data through encoder to get latent embedding
# x = B x T x N (e.g. B = 2, T = 500, N = 16)
x = self.model['encoder'](x)
if self.hparams.get('variational', False):
mean = self.model['latent_mean'](x)
logvar = self.model['latent_logvar'](x)
z = reparameterize_gaussian(mean, logvar)
else:
mean = x
logvar = None
z = x
# push embedding through classifiers to get hand labels
if self.hparams.get('lambda_strong', 0) > 0:
y = self.model['classifier'](z)
else:
y = None
# push embedding through linear layer to heuristic/pseudo labels
if self.hparams.get('lambda_weak', 0) > 0:
y_weak = self.model['classifier_weak'](z)
else:
y_weak = None
# push embedding through linear layer to get task predictions
if self.hparams.get('lambda_task', 0) > 0:
w = self.model['task_predictor'](z)
else:
w = None
# push embedding through decoder network to get data at current time point
if self.hparams.get('lambda_recon', 0) > 0:
xt = self.model['decoder'](z)
else:
xt = None
# push embedding through predictor network to get data at subsequent time points
if self.hparams.get('lambda_pred', 0) > 0:
xtp1 = self.model['predictor'](z)
else:
xtp1 = None
return {
'labels': y, # (n_sequences, sequence_length, n_classes)
'labels_weak': y_weak, # (n_sequences, sequence_length, n_classes)
'reconstruction': xt, # (n_sequences, sequence_length, n_markers)
'prediction': xtp1, # (n_sequences, sequence_length, n_markers)
'task_prediction': w, # (n_sequences, sequence_length, n_tasks)
'embedding': mean, # (n_sequences, sequence_length, embedding_dim)
'latent_mean': mean, # (n_sequences, sequence_length, embedding_dim)
'latent_logvar': logvar, # (n_sequences, sequence_length, embedding_dim)
'sample': z, # (n_sequences, sequence_length, embedding_dim)
}
[docs] def predict_labels(self, data_generator, return_scores=False, remove_pad=True, mode='eval'):
"""
Parameters
----------
data_generator : DataGenerator object
data generator to serve data batches
return_scores : bool
return scores before they've been passed through softmax
remove_pad : bool
remove batch padding from model outputs before returning
mode : str
'eval' | 'train'
Returns
-------
dict
- 'predictions' (list of lists): first list is over datasets; second list is over
batches in the dataset; each element is a numpy array of the label probability
distribution
- 'weak_labels' (list of lists): corresponding weak labels
- 'labels' (list of lists): corresponding labels
"""
if mode == 'eval':
self.eval()
elif mode == 'train':
self.train()
else:
raise NotImplementedError(f'select mode="eval" or mode="train", not mode="{mode}"')
pad = self.hparams.get('sequence_pad', 0)
softmax = nn.Softmax(dim=1)
# initialize containers
# softmax outputs
labels = [[] for _ in range(data_generator.n_datasets)]
# logits
scores = [[] for _ in range(data_generator.n_datasets)]
# latent representation
embedding = [[] for _ in range(data_generator.n_datasets)]
# predictions on regression task
task_predictions = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
labels[sess] = [np.array([]) for _ in range(dataset.n_sequences)]
scores[sess] = [np.array([]) for _ in range(dataset.n_sequences)]
embedding[sess] = [np.array([]) for _ in range(dataset.n_sequences)]
task_predictions[sess] = [np.array([]) for _ in range(dataset.n_sequences)]
# partially fill container (gap trials will be included as nans)
dtypes = ['train', 'val', 'test']
for dtype in dtypes:
data_generator.reset_iterators(dtype)
for i in range(data_generator.n_tot_batches[dtype]):
data, sess_list = data_generator.next_batch(dtype)
outputs_dict = self.forward(data['markers'])
# remove padding if necessary
if pad > 0 and remove_pad:
for key, val in outputs_dict.items():
outputs_dict[key] = val[:, pad:-pad] if val is not None else None
# loop over sequences in batch
for s, sess in enumerate(sess_list):
batch_idx = data['batch_idx'][s].item()
# push through log-softmax, since this is included in the loss and not model
labels[sess][batch_idx] = \
softmax(outputs_dict['labels'][s]).cpu().detach().numpy()
embedding[sess][batch_idx] = \
outputs_dict['embedding'][s].cpu().detach().numpy()
if return_scores:
scores[sess][batch_idx] = \
outputs_dict['labels'][s].cpu().detach().numpy()
if outputs_dict.get('task_prediction', None) is not None:
task_predictions[sess][batch_idx] = \
outputs_dict['task_prediction'][s].cpu().detach().numpy()
return {
'labels': labels,
'scores': scores,
'embedding': embedding,
'task_predictions': task_predictions,
}
[docs] def training_step(self, data, accumulate_grad=True, **kwargs):
"""Calculate negative log-likelihood loss for supervised models.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
data : dict
signals are of shape (n_sequences, sequence_length, n_channels)
accumulate_grad : bool, optional
accumulate gradient for training step
Returns
-------
dict
- 'loss' (float): total loss (negative log-like under specified noise dist)
- other loss terms depending on model hyperparameters
"""
# define hyperparams
lambda_weak = self.hparams.get('lambda_weak', 0)
lambda_strong = self.hparams.get('lambda_strong', 0)
lambda_pred = self.hparams.get('lambda_pred', 0)
lambda_task = self.hparams.get('lambda_task', 0)
kl_weight = self.hparams.get('kl_weight', 1)
# index padding for convolutions
pad = self.hparams.get('sequence_pad', 0)
# push data through model
markers_wpad = data['markers']
outputs_dict = self.forward(markers_wpad)
# remove padding from supplied data
if lambda_strong > 0:
if pad > 0:
labels_strong = data['labels_strong'][:, pad:-pad, ...]
else:
labels_strong = data['labels_strong']
# reshape to fit into class loss; needs to be (n_examples,)
labels_strong = torch.flatten(labels_strong)
else:
labels_strong = None
if lambda_weak > 0:
if pad > 0:
labels_weak = data['labels_weak'][:, pad:-pad, ...]
else:
labels_weak = data['labels_weak']
# reshape to fit into class loss; needs to be (n_examples,)
labels_weak = torch.flatten(labels_weak)
else:
labels_weak = None
if lambda_task > 0:
if pad > 0:
tasks = data['tasks'][:, pad:-pad, ...]
else:
tasks = data['tasks']
else:
tasks = None
# remove padding from model output
if pad > 0:
markers = markers_wpad[:, pad:-pad, ...]
# remove padding from model output
for key, val in outputs_dict.items():
outputs_dict[key] = val[:, pad:-pad, ...] if val is not None else None
else:
markers = markers_wpad
# initialize loss to zero
loss = 0
loss_dict = {}
# ------------------------------------
# compute loss on weak labels
# ------------------------------------
if lambda_weak > 0:
# reshape predictions to fit into class loss; needs to be (n_examples, n_classes)
labels_weak_reshape = torch.reshape(
outputs_dict['labels_weak'], (-1, outputs_dict['labels_weak'].shape[-1]))
# only compute loss where strong labels do not exist [indicated by a zero]
if labels_strong is not None:
idxs_ = labels_strong == 0
if torch.sum(idxs_) > 0:
loss_weak = self.class_loss(labels_weak_reshape[idxs_], labels_weak[idxs_])
else:
# if all timepoints are labeled, set weak loss to zero
loss_weak = torch.tensor([0.], device=labels_strong.device)
else:
loss_weak = self.class_loss(labels_weak_reshape, labels_weak)
loss += lambda_weak * loss_weak
loss_weak_val = loss_weak.item()
loss_dict['loss_weak'] = loss_weak_val
# compute fraction correct on weak labels
if 'labels' in outputs_dict.keys():
fc = accuracy_score(
labels_weak.cpu().detach().numpy().flatten(),
np.argmax(outputs_dict['labels'].cpu().detach().numpy(), axis=2).flatten(),
)
# log
loss_dict['fc'] = fc
# ------------------------------------
# compute loss on strong labels
# ------------------------------------
if lambda_strong > 0:
# reshape predictions to fit into class loss; needs to be (n_examples, n_classes)
labels_strong_reshape = torch.reshape(
outputs_dict['labels'], (-1, outputs_dict['labels'].shape[-1]))
loss_strong = self.class_loss(labels_strong_reshape, labels_strong)
loss += lambda_strong * loss_strong
loss_strong_val = loss_strong.item()
# log
loss_dict['loss_strong'] = loss_strong_val
# ------------------------------------
# compute loss on one-step predictions
# ------------------------------------
if lambda_pred > 0:
loss_pred = self.pred_loss(markers[:, 1:], outputs_dict['prediction'][:, :-1])
loss += lambda_pred * loss_pred
loss_pred_val = loss_pred.item()
# log
loss_dict['loss_pred'] = loss_pred_val
# ------------------------------------
# compute regression loss on tasks
# ------------------------------------
if lambda_task > 0:
loss_task = self.task_loss(tasks, outputs_dict['task_prediction'])
loss += lambda_task * loss_task
loss_task_val = loss_task.item()
r2 = r2_score(
tasks.cpu().detach().numpy().flatten(),
outputs_dict['task_prediction'].cpu().detach().numpy().flatten(),
)
# log
loss_dict['loss_task'] = loss_task_val
loss_dict['task_r2'] = r2
# ------------------------------------
# compute kl divergence on appx posterior
# ------------------------------------
if self.hparams.get('variational', False):
# multiply by 2 to take into account the fact that we're computing raw mse for decoding
# and prediction rather than (1 / 2\sigma^2) * MSE
loss_kl = 2.0 * losses.kl_div_to_std_normal(
outputs_dict['latent_mean'], outputs_dict['latent_logvar'])
loss += kl_weight * loss_kl
# log
loss_dict['kl_weight'] = kl_weight
loss_dict['loss_kl'] = loss_kl.item()
if accumulate_grad:
loss.backward()
# collect loss vals
loss_dict['loss'] = loss.item()
return loss_dict
[docs]class Ensembler(object):
"""Ensemble of models."""
def __init__(self, models):
self.models = models
self.n_models = len(models)
[docs] def predict_labels(self, data_generator, combine_before_softmax=False, weights=None):
"""Combine class predictions from multiple models by averaging before softmax.
Parameters
----------
data_generator : DataGenerator object
data generator to serve data batches
combine_before_softmax : bool, optional
True to combine logits across models before taking softmax; False to take softmax for
each model then combine probabilities
weights: array-like, str, or NoneType, optional
array-like: weight for each model
str: 'entropy': weight each model at each time point by inverse entropy of distribution
None: uniform weight for each model
Returns
-------
dict
- 'labels' (list of lists): corresponding labels
"""
# initialize container for labels
labels = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
labels[sess] = [np.array([]) for _ in range(dataset.n_sequences)]
# process data for each model
labels_all = []
for model in self.models:
outputs_curr = model.predict_labels(
data_generator, return_scores=combine_before_softmax)
if combine_before_softmax:
labels_all.append(outputs_curr['scores'])
else:
labels_all.append(outputs_curr['labels'])
# labels_all is a list of list of lists
# access: labels_all[idx_model][idx_dataset][idx_batch]
# ensemble prediction across models
for sess, labels_sess in enumerate(labels):
for batch, labels_batch in enumerate(labels_sess):
# labels_curr is of shape (n_models, sequence_len, n_classes)
labels_curr = np.vstack([lab[sess][batch][None, ...] for lab in labels_all])
# combine predictions across models
if weights is None:
# simple average across models
labels_curr = np.mean(labels_curr, axis=0)
elif isinstance(weights, str) and weights == 'entropy':
# weight each model at each time point by inverse entropy of distribution
# so that more confident models have a higher weight
# compute entropy across labels
ent = entropy(labels_curr, axis=-1)
# low entropy = high confidence, weight these more
w = 1.0 / ent
# normalize over models
w /= np.sum(w, axis=0) # shape of (n_models, sequence_len)
labels_curr = np.mean(labels_curr * w[..., None], axis=0)
elif isinstance(weights, (list, tuple, np.ndarray)):
# weight each model according to user-supplied weights
labels_curr = np.average(labels_curr, axis=0, weights=weights)
if combine_before_softmax:
labels[sess][batch] = scipy_softmax(labels_curr, axis=-1)
else:
labels[sess][batch] = labels_curr
return {'labels': labels}