kl_div_to_std_normal¶
- daart.losses.kl_div_to_std_normal(mu: Tensor, logvar: Tensor) Tensor[source]¶
Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.
- Parameters:
mu (torch.Tensor) – mean parameter of shape (n_sequences, sequence_length, n_dims)
logvar (torch.Tensor) – log variance parameter of shape (n_sequences, sequence_length, n_dims)
- Returns:
KL divergence summed across dims, averaged across batch
- Return type:
torch.Tensor