'Variational Autoencoder KL divergence loss explodes and the model returns nan

I'm training a Conv-VAE for MRI brain images (2D slices). the output of the model is sigmoid, and the loss function binary cross-entropy:

x = input, x_hat = output

rec_loss = nn.functional.binary_cross_entropy(x_hat.view(-1, 128 ** 2), x.view(-1, 128 ** 2),reduction='sum')

but my problem is actually with the KL divergence loss:

KL_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

at some point in the training, the KL divergence loss is insanely high (somewhere infinity)

enter image description here

and then I'm having the error that u can see down below which is probably cause the output is nan. any suggestions on how to avoid this exploding?



Solution 1:[1]

You could use the means as a reduction method in the BCE and in the KL divergence as well. KL_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
The KL divergence often has some spikes which can be of magnitude higher then the other values. Why this happens I dont know but it annoys me as well :)
Probably your model crashes quite early. If you log plot your KL divergence you will see that you can still have spikes later on but they are smaller since the whole KL divergence term gets smaller.
That is the KL-Divergence of one of my training runs with mean reduction, plotted in log scaling

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Error404