Source code for daart.models.temporalmlp

"""Temporal MLP model implemented in PyTorch."""

from torch import nn
from daart.models.base import BaseModel, get_activation_func_from_str

# to ignore imports for sphix-autoapidoc
__all__ = ['TemporalMLP']


[docs]class TemporalMLP(BaseModel): """MLP network with initial 1D convolution layer.""" def __init__(self, hparams, type='encoder', in_size=None, hid_size=None, out_size=None): super().__init__() self.hparams = hparams self.activation_func = get_activation_func_from_str(self.hparams['activation']) self.model = nn.ModuleList() if type == 'encoder': in_size_ = hparams['input_size'] if in_size is None else in_size hid_size_ = hparams['n_hid_units'] if hid_size is None else hid_size out_size_ = hparams['n_hid_units'] if out_size is None else out_size self.build_encoder(in_size=in_size_, hid_size=hid_size_, out_size=out_size_) else: in_size_ = hparams['n_hid_units'] if in_size is None else in_size hid_size_ = hparams['n_hid_units'] if hid_size is None else hid_size out_size_ = hparams['input_size'] if out_size is None else out_size self.build_decoder(in_size=in_size_, hid_size=hid_size_, out_size=out_size_)
[docs] def build_encoder(self, in_size, hid_size, out_size): """Construct the encoder using hparams.""" global_layer_num = 0 # ------------------------------------------------------------- # first layer is 1d conv for incorporating past/future activity # ------------------------------------------------------------- layer = nn.Conv1d( in_channels=in_size, out_channels=hid_size, kernel_size=self.hparams['n_lags'] * 2 + 1, # window around t padding=self.hparams['n_lags']) # same output name = str('conv1d_layer_%02i' % global_layer_num) self.model.add_module(name, layer) # add activation if self.hparams['n_hid_layers'] == 0: activation = None # cross entropy loss handles this else: activation = self.activation_func if activation: name = '%s_%02i' % (self.hparams['activation'], global_layer_num) self.model.add_module(name, activation) # update layer info global_layer_num += 1 # ------------------------------------------------------------- # remaining layers # ------------------------------------------------------------- # loop over hidden layers (0 layers <-> linear model) for i_layer in range(self.hparams['n_hid_layers']): if i_layer == self.hparams['n_hid_layers'] - 1: activation = None # cross entropy loss handles this out_size_ = out_size else: activation = self.activation_func out_size_ = hid_size # add layer layer = nn.Linear(in_features=hid_size, out_features=out_size_) name = str('dense_layer_%02i' % global_layer_num) self.model.add_module(name, layer) # add activation if activation: name = '%s_%02i' % (self.hparams['activation'], global_layer_num) self.model.add_module(name, activation) # update layer info global_layer_num += 1 return global_layer_num
[docs] def build_decoder(self, in_size, hid_size, out_size): """Construct the vanilla MLP decoder using hparams.""" global_layer_num = 0 in_size_ = in_size # loop over hidden layers (0 layers <-> linear model) for i_layer in range(self.hparams['n_hid_layers'] + 1): if i_layer == self.hparams['n_hid_layers']: out_size_ = out_size else: out_size_ = hid_size # add layer layer = nn.Linear(in_features=in_size_, out_features=out_size_) name = str('dense_layer_%02i' % global_layer_num) self.model.add_module(name, layer) # add activation if i_layer == self.hparams['n_hid_layers']: # no activation for final layer activation = None else: activation = self.activation_func if activation: name = '%s_%02i' % (self.hparams['activation'], global_layer_num) self.model.add_module(name, activation) # update layer info global_layer_num += 1 in_size_ = out_size_ return global_layer_num
[docs] def forward(self, x, **kwargs): """Process input data. Parameters ---------- x : torch.Tensor object input data of shape (n_sequences, sequence_length, n_markers) Returns ------- torch.Tensor shape (n_sequences, sequence_length, n) where n is the embedding dimension if an encoder, or n_markers if a decoder/predictor """ # push data through encoder to get latent embedding for name, layer in self.model.named_children(): if name == 'conv1d_layer_00': # conv1d layer input is batch x in_channels x time # conv1d layer output is batch x out_channels x time # x = B x T x N (e.g. B = 2, T = 500, N = 16) # x.transpose(1, 2) -> x = B x N x T # x = layer(x) -> x = B x M x T # x.transpose(1, 2) -> x = B x T x M x = layer(x.transpose(1, 2)).transpose(1, 2) else: x = layer(x) return x