reparameterize_gaussian¶
- daart.models.base.reparameterize_gaussian(mu, logvar)[source]¶
Sample from N(mu, var)
- Parameters:
mu (torch.Tensor) – vector of mean parameters
logvar (torch.Tensor) – vector of log variances; only mean field approximation is currently implemented
- Returns:
sampled vector of shape (n_sequences, sequence_length, embedding_dim)
- Return type:
torch.Tensor