Gumbel-Softmax
Gumbel-softmax[1] is a method to differentiably sample from a categorical distribution.
It is available in PyTorch as torch.nn.functional.gumbel_softmax.
Background
Suppose we have some logits \(\displaystyle \{a_i\}_{i=0}^{K}\) (aka unnormalized log probabilities).
If we want convert this into a distribution, we can take the softmax.
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.
You can sample from a categorical distribution by using the Gumbel-max reparameterization trick:
- Sample K i.i.d. values from Gumbel(0, 1)
- Add these to our logits.
- Take the argmax.
The main advantage to this reparameterization trick is that you can precompute the iid values and then avoid expensive exponents.
Howver, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.
Misc
- If \(x \sim Exponential(1)\), then \(-log(x) \sim Gumbel(0, 1)\)
- If \(x \sim Uniform(0,1)\), then \(-log(-log(x)) \sim Gumbel(0, 1)\)
Resources
References
<templatestyles src="Reflist/styles.css" />
- ↑ Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee