Gumbel-Softmax: Difference between revisions
Created page with "Gumbel-softmax<ref name="jang2017gumbel"></ref> is a method to differentiably sample from a categorical distribution. It is available in PyTorch as [https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html torch.nn.functional.gumbel_softmax]. ==Background== Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br> If we want convert this into a distribution, we can take the softmax.<br> However, there..." |
|||
Line 5: | Line 5: | ||
==Background== | ==Background== | ||
Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br> | Suppose we have some logits <math>\{a_i\}_{i=0}^{K}</math> (aka unnormalized log probabilities).<br> | ||
If we want convert this into a distribution, we can take the softmax.<br> | If we want convert this into a distribution to sample from, we can take the softmax.<br> | ||
However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it. | However, there are scenarios where we may want to instead take a sample from the distribution and also backprop through it. | ||
Line 12: | Line 12: | ||
# Add these to our logits. | # Add these to our logits. | ||
# Take the argmax. | # Take the argmax. | ||
The | The original advantage to this reparameterization trick is that you can precompute the i.i.d. values and then avoid expensive exponents when sampling at at runtime. | ||
However, by replacing the argmax with softmax, you get Gumbel-softmax which allows you to differentiate through the sampling procedure. | |||
==Misc== | ==Misc== |