reparameterize_gaussian

daart.models.base.reparameterize_gaussian(mu, logvar)[source]

Sample from N(mu, var)

Parameters:
  • mu (torch.Tensor) – vector of mean parameters

  • logvar (torch.Tensor) – vector of log variances; only mean field approximation is currently implemented

Returns:

sampled vector of shape (n_sequences, sequence_length, embedding_dim)

Return type:

torch.Tensor