Gumbel-Softmax: Difference between revisions
No edit summary |
|||
Line 1: | Line 1: | ||
Gumbel-softmax<ref name="jang2017gumbel"></ref> is a method to differentiably sample from a categorical distribution. | Gumbel-softmax<ref name="jang2017gumbel"></ref><ref name="madison2017concrete"></ref> is a method to differentiably sample from a categorical distribution. | ||
It is available in PyTorch as [https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html torch.nn.functional.gumbel_softmax]. | It is available in PyTorch as [https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html torch.nn.functional.gumbel_softmax]. | ||
Line 26: | Line 26: | ||
{{reflist|refs= | {{reflist|refs= | ||
<ref name="jang2017gumbel">Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee</ref> | <ref name="jang2017gumbel">Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee</ref> | ||
<ref name="madison2017concrete">Maddison, C. J., Mnih, A., & Teh, Y. W. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=S1jE5L5gl</ref> | |||
}} | }} |
Revision as of 19:04, 25 January 2023
Gumbel-softmax[1][2] is a method to differentiably sample from a categorical distribution.
It is available in PyTorch as torch.nn.functional.gumbel_softmax.
Background
Suppose we have some logits \(\displaystyle \{a_i\}_{i=0}^{K}\) (aka unnormalized log probabilities).
If we want convert this into a distribution to sample from, we can take the softmax.
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.
You can sample from a categorical distribution by using the Gumbel-max reparameterization trick:
- Sample K i.i.d. values from Gumbel(0, 1)
- Add these to our logits.
- Take the argmax.
The original advantage to this reparameterization trick is that you can precompute the i.i.d. values and then avoid expensive exponents when sampling at at runtime.
However, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.
Misc
- If \(x \sim Exponential(1)\), then \(-log(x) \sim Gumbel(0, 1)\)
- If \(x \sim Uniform(0,1)\), then \(-log(-log(x)) \sim Gumbel(0, 1)\)
Resources
References
<templatestyles src="Reflist/styles.css" />
- ↑ Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee
- ↑ Maddison, C. J., Mnih, A., & Teh, Y. W. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=S1jE5L5gl