Gumbel-Softmax
Gumbel-softmax[1][2] is a method to differentiably sample from a categorical distribution.
It is available in PyTorch as torch.nn.functional.gumbel_softmax.
Background
Suppose we have some logits \(\displaystyle \{a_i\}_{i=0}^{K}\) (aka unnormalized log probabilities).
If we want convert this into a distribution to sample from, we can take the softmax.
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.
You can sample from a categorical distribution by using the Gumbel-max reparameterization trick:
- Sample K i.i.d. values from Gumbel(0, 1)
- Add these to our logits.
- Take the argmax.
The original advantage to this reparameterization trick is that you can precompute the i.i.d. values and then avoid expensive exponents when sampling at at runtime.
However, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.
Misc
- If \(x \sim Exponential(1)\), then \(-log(x) \sim Gumbel(0, 1)\)
- If \(x \sim Uniform(0,1)\), then \(-log(-log(x)) \sim Gumbel(0, 1)\)
Resources
References
<templatestyles src="Reflist/styles.css" />
- ↑ Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee
- ↑ Maddison, C. J., Mnih, A., & Teh, Y. W. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=S1jE5L5gl