Debugging ML Models: Difference between revisions
No edit summary |
Tags: Mobile edit Mobile web edit |
||
| (15 intermediate revisions by the same user not shown) | |||
| Line 1: | Line 1: | ||
Notes on debugging ML models, primarilly CNNs. | Notes on debugging ML models, primarilly CNNs. | ||
Most of this is advice I've found online or gotten through mentors. | Most of this is advice I've found online or gotten through mentors or experience. | ||
==Debugging== | ==Debugging== | ||
| Line 6: | Line 6: | ||
** If it doesn't overfit, there may be an issue with your code. | ** If it doesn't overfit, there may be an issue with your code. | ||
** You can try increasing the capacity (e.g number of filters or number of nodes in FC) 2-4x. | ** You can try increasing the capacity (e.g number of filters or number of nodes in FC) 2-4x. | ||
*** For CNNs, these days, you can fit O(10 million) parameters on a single GPU. | |||
**: If the input is 3 channels, then the first conv layer should have more than 3 channels. | **: If the input is 3 channels, then the first conv layer should have more than 3 channels. | ||
** Check that your loss is implemented correctly and taken against the correct ground truth image. | ** Check that your loss is implemented correctly and taken against the correct ground truth image. | ||
* Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | * Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | ||
* Make sure there is no activation on the final layer. | * Make sure there is no activation on the final layer. | ||
* If the loss is unstable or increasing, drop the learning rate to <code>O(1e-3)</code> or <code>O(1e-4)</code>. | |||
* Try taking the loss closer to the output of the network. | |||
** If you apply some transformations \(f\) after the output, do \(loss = loss\_fn(f^{-1}(gt), output)\) instead of \(loss = loss\_fn(gt, f(output))\). | |||
** This shortens the paths the gradients need to flow through. | |||
** Note that this may change the per-pixel weights of the loss function. | |||
==Underfitting== | ==Underfitting== | ||
| Line 16: | Line 22: | ||
* Train for 4x as long until the training loss and validation loss both flatten. | * Train for 4x as long until the training loss and validation loss both flatten. | ||
* Increase or decrease the learning rate one magnitude. | * Increase or decrease the learning rate one magnitude. | ||
* Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations. | * Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations with more noise. | ||
* Try disabling any tricks you have like dropout. | * Try disabling any tricks you have like dropout. | ||
==Overfitting== | ==Overfitting== | ||
Overfitting occurs when your training | Overfitting occurs when your model begins learning attributes specific to your training data, causing your validation loss to increase. | ||
Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting. | Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting. | ||
Recently though, overfitting has become less of a concern with larger ML models. | Recently though, overfitting has become less of a concern with larger ML models. | ||
| Line 47: | Line 53: | ||
assert all_finite(my_tensor), "my_tensor has NaNs or Infs" | assert all_finite(my_tensor), "my_tensor has NaNs or Infs" | ||
# Or | |||
tf.debugging.assert_all_finite(my_tensor, "my_tensor has NaNs or Infs") | |||
</syntaxhighlight> | </syntaxhighlight> | ||
Typically, you | Typically, you get Infs and NaNs when there is an division by ~0 in the forward or backward pass. | ||
However it is also possible that the learning rate is too high or your model is broken. | However it is also possible that the learning rate is too high or your model is broken. | ||
I typically debug by: | I typically debug by: | ||
| Line 56: | Line 65: | ||
* Checking that the training data has no NaNs or Infs. | * Checking that the training data has no NaNs or Infs. | ||
* Checking that there are no divides anywhere in the code or that all divides are safe. | * Checking that there are no divides anywhere in the code or that all divides are safe. | ||
** See [https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan <code>tf.math.divide_no_nan</code>]. | |||
* Checking the gradients of trig functions in the code. | * Checking the gradients of trig functions in the code. | ||
| Line 68: | Line 78: | ||
* The derivatives of \(\arccsc(x)\) and \(\arcsec(x)\) are Inf near 0 and 1. | * The derivatives of \(\arccsc(x)\) and \(\arcsec(x)\) are Inf near 0 and 1. | ||
If you must use one of these functions near the unstable points, I suggest performing a linear approximation near 0 and 1. | If you must use one of these functions near the unstable points, I suggest performing a linear or taylor series approximation near 0 and 1. | ||
Other ways to mitigate NaNs and Infs are: | Other ways to mitigate NaNs and Infs are: | ||
| Line 76: | Line 86: | ||
** For Tensorflow see [https://www.tensorflow.org/api_docs/python/tf/clip_by_norm tf.clip_by_norm] and [https://www.tensorflow.org/api_docs/python/tf/clip_by_value tf.clip_by_value]. | ** For Tensorflow see [https://www.tensorflow.org/api_docs/python/tf/clip_by_norm tf.clip_by_norm] and [https://www.tensorflow.org/api_docs/python/tf/clip_by_value tf.clip_by_value]. | ||
* Using a safe divide which forces the denominator to have values with abs > EPS. | * Using a safe divide which forces the denominator to have values with abs > EPS. | ||
** Note that this can cutoff gradients. | |||
==Soft Operations== | |||
The idea of soft operations are to make sure that gradients flow through the entire network rather than one specific path. | |||
One example of this is softmax which allows you to apply gradients using a one-hot encoding. | |||
* Rather than regressing a real value <math>x</math> directly, output a probability distribution. | |||
** Output scores for <math>P(x=j)</math> for some fixed set of <math>j</math>, do softmax, and take the expected value. | |||
** Or output <math>\mu, \sigma</math> and normalize the loss based on <math>\sigma</math>. | |||