Neural Networks gone wild! They can sample from discrete distributions now!


Training deep neural networks usually boils down to defining your model’s architecture and a loss function, and watching the gradients propagate.

However, sometimes it’s not that simple: some architectures incorporate a random component. The forward pass is no longer a deterministic function of the input and weights. The random component introduces stochasticity, by means of sampling from it.

When would that happen, you ask? Whenever we want to approximate an intractable sum or integral. Then, we can form a Monte Carlo estimate. A good example is the variational autoencoder. Basically, it’s an autoencoder on steroids: the encoder’s job is to learn a distribution over the latent space. The loss function contains an intractable expectation over that distribution, so we sample from it.

As with any architecture, the gradients need to propagate to the weights of the model. Some of the weights are responsible for transforming the input into the parameters of the distribution from which we sample. Here we face a problem: the gradients can’t propagate through random nodes! Hence, these weights won’t be updated.

One solution to the problem is the reparameterization trick: you substitute the sampled random variable with a deterministic parameterized transformation of a parameterless random variable.

If you don’t know this trick I highly encourage you to read about it. I’ll demonstrate it with the Gaussian case:

Let $Z \sim \mathcal{N}(\mu(X), \sigma^2(X))$. The parameters of the Gaussian are a function of the input $X$ – e.g. the output of stacked dense layers. When sampling realizations of $Z$, gradients won’t be able to propagate to the weights of the dense layers. We can substitute $Z$ with a different random variable $Z’ = \mu(X) + \sigma(X) \cdot \mathcal{E}$ where $\mathcal{E} \sim \mathcal{N}(0, 1)$. Now the sampling will be from $\mathcal{E}$, so the gradients won’t propagate through this path – which we don’t care about. However, through $\mu(X)$ and $\sigma(X)$ they will, since it’s a deterministic path.

For many types of continuous distributions you can do the reparameterization trick. But what do you do if you need the distribution to be over a discrete set of values?

In the following sections you’ll learn:

  • what the Gumbel distribution is
  • how it is used for sampling from a discrete distribution
  • how the weights that affect the distribution’s parameters can be trained
  • how to use all of that in a toy example (with code)

The Gumbel distribution has two parameters – $\mu$ and $\beta$. The standard Gumbel distribution, where $\mu$ and $\beta$ are 0 and 1 respectively, has PDF of $e^{-(x + e^{-x})}$.

Why should you care about this distribution? Consider the setting where you have a discrete random variable whose logits are $\{\alpha_i\}_{i=1}^k$. The logits are a function of the input and weights that need to be trained.

What I’m going to describe next is called the Gumbel-max trick. Using this trick, you can sample from the discrete distribution. The process is as follows:

  1. Sample i.i.d samples $\{z_i\}_{i=1}^k$ from the standard Gumbel distribution.
  2. Add the samples to the logits: $\{\alpha_i + z_i\}_{i=1}^k$.
  3. Take the index of the maximal value: $\text{argmax}_{i=1}^k\alpha_i + z_i$.

The result will be a random sample of your original distribution. You can read the proof here.

Great! So we were able to substitute our distribution with a deterministic transformation of a parameterless distribution! So if we plug it into our model, the gradients will be able to propagate to the weights of the logits, right?

Well, not so fast! Gradients can’t propagate through argmax…

Photo by Radu Florin on Unsplash

Using argmax is equivalent to using one hot vector where the entry corresponding to the maximal value is 1.

So instead of using a hard one hot vector, we can approximate it using a soft one – softmax.

The process is the same as the process described above, except now you apply softmax instead of argmax.

And voila! Gradients can propagate to the weights of the logits.

There’s one hyperparameter I didn’t tell you about (yet) – the temperature:

$\frac{\text{exp}((\log(\alpha_i)+z_i) \cdot \tau^{-1})}{\sum_{j=1}^k \text{exp}((\log(\alpha_j)+z_j) \cdot \tau^{-1})}$

By dividing by a temperature $\tau > 0$, we can control how close the approximation will be to argmax. When $\tau \to 0$ the entry corresponding to the maximal value will tend to 1, and the other entries will tend to 0. When $\tau \to \infty$ the result will tend to uniform. The smaller $\tau$ is the better the approximation gets. The problem with setting $\tau$ to small values is that the variance of the gradients will be too high. This will make it difficult for the training. A good practice is to start with big temperature and then anneal it towards small values.

You can read more about the Gumbel-softmax trick here and here.

Photo by Blake Connally on Unsplash

To show that the theory works in real life, I’ll use a toy problem. The data is a stream of numbers in the range 0 to 4. Each number has a different probability to come up in the stream. Your mission, should you choose to accept it, is to find out what the distribution over the 5 numbers is.

A simple solution would be to count, but we’re going to do something much cooler (and ridiculously too complicated for this task): we’ll train a GAN. The generator will generate numbers from a distribution which should converge to the real one.

Here’s an intuition of why it should work: let’s say the true probability associated with the number 0 is 0. The discriminator will learn that 0’s never come with the label REAL. Therefore, the generator will incur a big loss whenever it generates 0’s. This will encourage the generator to stop generating 0’s.


Discover more from reviewer4you.com

Subscribe to get the latest posts to your email.

We will be happy to hear your thoughts

Leave a reply

0
Your Cart is empty!

It looks like you haven't added any items to your cart yet.

Browse Products
Powered by Caddy

Discover more from reviewer4you.com

Subscribe now to keep reading and get access to the full archive.

Continue reading