Gumbel-Softmax: Difference between revisions

From David's Wiki
(Created page with "Gumbel-softmax<ref name="jang2017gumbel"></ref> is a method to differentiably sample from a categorical distribution. It is available in PyTorch as [https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html torch.nn.functional.gumbel_softmax]. ==Background== Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br> If we want convert this into a distribution, we can take the softmax.<br> However, there...")
 
Line 5: Line 5:
==Background==
==Background==
Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br>
Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br>
If we want convert this into a distribution, we can take the softmax.<br>
If we want convert this into a distribution to sample from, we can take the softmax.<br>
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.


Line 12: Line 12:
# Add these to our logits.
# Add these to our logits.
# Take the argmax.
# Take the argmax.
The main advantage to this reparameterization trick is that you can precompute the iid values and then avoid expensive exponents.
The original advantage to this reparameterization trick is that you can precompute the i.i.d. values and then avoid expensive exponents when sampling at at runtime.


Howver, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.
However, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.


==Misc==
==Misc==

Revision as of 18:35, 25 January 2023

Gumbel-softmax[1] is a method to differentiably sample from a categorical distribution.

It is available in PyTorch as torch.nn.functional.gumbel_softmax.

Background

Suppose we have some logits \(\displaystyle \{a_i\}_{i=0}^{K}\) (aka unnormalized log probabilities).
If we want convert this into a distribution to sample from, we can take the softmax.
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it.

You can sample from a categorical distribution by using the Gumbel-max reparameterization trick:

  1. Sample K i.i.d. values from Gumbel(0, 1)
  2. Add these to our logits.
  3. Take the argmax.

The original advantage to this reparameterization trick is that you can precompute the i.i.d. values and then avoid expensive exponents when sampling at at runtime.

However, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure.

Misc

  • If \(x \sim Exponential(1)\), then \(-log(x) \sim Gumbel(0, 1)\)
  • If \(x \sim Uniform(0,1)\), then \(-log(-log(x)) \sim Gumbel(0, 1)\)

Resources

References

  1. Jang, E., Gu, S., & Poole, B. (2017). Categorical Reparameterization with Gumbel-Softmax. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=rkE3y85ee