"""Temporal Convolution model implemented in PyTorch."""
from torch import nn
from daart.models.base import BaseModel, get_activation_func_from_str
# to ignore imports for sphix-autoapidoc
__all__ = ['DilatedTCN']
[docs]class DilatedTCN(BaseModel):
"""Temporal Convolutional Model with dilated convolutions and no temporal downsampling.
Code adapted from: https://www.kaggle.com/ceshine/pytorch-temporal-convolutional-networks
"""
def __init__(self, hparams, type='encoder', in_size=None, hid_size=None, out_size=None):
super().__init__()
self.hparams = hparams
self.model = nn.Sequential()
if type == 'encoder':
in_size_ = hparams['input_size'] if in_size is None else in_size
hid_size_ = hparams['n_hid_units'] if hid_size is None else hid_size
out_size_ = hparams['n_hid_units'] if out_size is None else out_size
self.build_encoder(in_size=in_size_, hid_size=hid_size_, out_size=out_size_)
else:
in_size_ = hparams['n_hid_units'] if in_size is None else in_size
hid_size_ = hparams['n_hid_units'] if hid_size is None else hid_size
out_size_ = hparams['input_size'] if out_size is None else out_size
self.build_decoder(in_size=in_size_, hid_size=hid_size_, out_size=out_size_)
[docs] def build_encoder(self, in_size, hid_size, out_size):
"""Construct encoder model using hparams."""
global_layer_num = 0
for i_layer in range(self.hparams['n_hid_layers']):
dilation = 2 ** i_layer
in_size_ = in_size if i_layer == 0 else hid_size
hid_size_ = hid_size
if i_layer == (self.hparams['n_hid_layers'] - 1):
# final layer
out_size_ = out_size
else:
# intermediate layer
out_size_ = hid_size
# conv -> activation -> dropout (+ residual)
tcn_block = DilationBlock(
input_size=in_size_, int_size=hid_size_, output_size=out_size_,
kernel_size=self.hparams['n_lags'], stride=1, dilation=dilation,
activation=self.hparams['activation'], dropout=self.hparams.get('dropout', 0.2))
name = 'tcn_block_%02i' % global_layer_num
self.model.add_module(name, tcn_block)
# update layer info
global_layer_num += 1
return global_layer_num
[docs] def build_decoder(self, in_size, hid_size, out_size):
"""Construct the decoder using hparams."""
global_layer_num = 0
out_size_ = in_size # set "output size" of the layer that feeds into this module
for i_layer in range(self.hparams['n_hid_layers']):
dilation = 2 ** (self.hparams['n_hid_layers'] - i_layer - 1) # down by powers of 2
in_size_ = out_size_ # input is output size of previous block
hid_size_ = hid_size
if i_layer == (self.hparams['n_hid_layers'] - 1):
# final layer
# out_size = self.hparams['input_size']
# final_activation = 'linear'
# predictor_block = True
out_size_ = out_size
final_activation = self.hparams['activation']
predictor_block = False
else:
# intermediate layer
out_size_ = hid_size
final_activation = self.hparams['activation']
predictor_block = False
# conv -> activation -> dropout (+ residual)
tcn_block = DilationBlock(
input_size=in_size_, int_size=hid_size_, output_size=out_size_,
kernel_size=self.hparams['n_lags'], stride=1, dilation=dilation,
activation=self.hparams['activation'], final_activation=final_activation,
dropout=self.hparams.get('dropout', 0.2), predictor_block=predictor_block)
name = 'tcn_block_%02i' % global_layer_num
self.model.add_module(name, tcn_block)
# update layer info
global_layer_num += 1
# add final fully-connected layer
dense = nn.Conv1d(
in_channels=out_size,
out_channels=out_size,
kernel_size=1) # kernel_size=1 <=> dense, fully connected layer
self.model.add_module('final_dense_%02i' % global_layer_num, dense)
return global_layer_num
[docs] def forward(self, x, **kwargs):
"""Process input data.
Parameters
----------
x : torch.Tensor object
input data of shape (n_sequences, sequence_length, n_markers)
Returns
-------
torch.Tensor
shape (n_sequences, sequence_length, n) where n is the embedding dimension if an
encoder, or n_markers if a decoder/predictor
"""
# push data through encoder to get latent embedding
# x = B x T x N (e.g. B = 2, T = 500, N = 16)
# x.transpose(1, 2) -> x = B x N x T
# x = layer(x) -> x = B x M x T
# x.transpose(1, 2) -> x = B x T x M
return self.model(x.transpose(1, 2)).transpose(1, 2)
class DilationBlock(nn.Module):
"""Residual Temporal Block module for use with DilatedTCN class."""
def __init__(
self, input_size, int_size, output_size, kernel_size, stride=1, dilation=2,
activation='relu', dropout=0.2, final_activation=None, predictor_block=False):
super(DilationBlock, self).__init__()
self.conv0 = nn.utils.weight_norm(nn.Conv1d(
in_channels=input_size,
out_channels=int_size,
stride=stride,
dilation=dilation,
kernel_size=kernel_size * 2 + 1, # window around t
padding=kernel_size * dilation)) # same output
self.conv1 = nn.utils.weight_norm(nn.Conv1d(
in_channels=int_size,
out_channels=output_size,
stride=stride,
dilation=dilation,
kernel_size=kernel_size * 2 + 1, # window around t
padding=kernel_size * dilation)) # same output
# intermediate activations
self.activation = get_activation_func_from_str(activation)
# final activation
if final_activation is None:
final_activation = activation
self.final_activation = get_activation_func_from_str(final_activation)
# no Dropout1D in pytorch API, but Dropout2D does what what we want:
# takes an input of shape (N, C, L) and drops out entire features in the `C` dimension
self.dropout = nn.Dropout2d(dropout)
# build net
self.block = nn.Sequential()
# conv -> relu -> dropout block # 0
self.block.add_module('conv1d_layer_0', self.conv0)
self.block.add_module('%s_0' % activation, self.activation)
self.block.add_module('dropout_0', self.dropout)
# conv -> relu -> dropout block # 1
self.block.add_module('conv1d_layer_1', self.conv1)
if not predictor_block:
self.block.add_module('%s_1' % activation, self.activation)
self.block.add_module('dropout_1', self.dropout)
# for downsampling residual connection
if input_size != output_size:
self.downsample = nn.Conv1d(input_size, output_size, kernel_size=1)
else:
self.downsample = None
self.init_weights()
def __str__(self):
format_str = 'DilationBlock\n'
for i, module in enumerate(self.block):
format_str += ' {}: {}\n'.format(i, module)
format_str += ' {}: residual connection\n'.format(i + 1)
format_str += ' {}: {}\n'.format(i + 2, self.final_activation)
return format_str
def init_weights(self):
self.conv0.weight.data.normal_(0, 0.01)
self.conv1.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.block(x)
res = x if self.downsample is None else self.downsample(x)
return self.final_activation(out + res)