Debugging ML Models: Difference between revisions
No edit summary |
Tags: Mobile edit Mobile web edit |
||
(17 intermediate revisions by the same user not shown) | |||
Line 1: | Line 1: | ||
Notes on debugging ML models, primarilly CNNs. | Notes on debugging ML models, primarilly CNNs. | ||
Most of this is advice I've found online or gotten through mentors or experience. | |||
==Debugging== | ==Debugging== | ||
Line 5: | Line 6: | ||
** If it doesn't overfit, there may be an issue with your code. | ** If it doesn't overfit, there may be an issue with your code. | ||
** You can try increasing the capacity (e.g number of filters or number of nodes in FC) 2-4x. | ** You can try increasing the capacity (e.g number of filters or number of nodes in FC) 2-4x. | ||
*** For CNNs, these days, you can fit O(10 million) parameters on a single GPU. | |||
**: If the input is 3 channels, then the first conv layer should have more than 3 channels. | **: If the input is 3 channels, then the first conv layer should have more than 3 channels. | ||
** Check that your loss is implemented correctly and taken against the correct ground truth image. | ** Check that your loss is implemented correctly and taken against the correct ground truth image. | ||
* Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | * Dump all inputs and outputs into [[TensorBoard]]. You may have an unexpected input or output somewhere. | ||
* Make sure there is no activation on the final layer. | * Make sure there is no activation on the final layer. | ||
* If the loss is unstable or increasing, drop the learning rate to <code>O(1e-3)</code> or <code>O(1e-4)</code>. | |||
* Try taking the loss closer to the output of the network. | |||
** If you apply some transformations \(f\) after the output, do \(loss = loss\_fn(f^{-1}(gt), output)\) instead of \(loss = loss\_fn(gt, f(output))\). | |||
** This shortens the paths the gradients need to flow through. | |||
** Note that this may change the per-pixel weights of the loss function. | |||
==Underfitting== | ==Underfitting== | ||
Line 15: | Line 22: | ||
* Train for 4x as long until the training loss and validation loss both flatten. | * Train for 4x as long until the training loss and validation loss both flatten. | ||
* Increase or decrease the learning rate one magnitude. | * Increase or decrease the learning rate one magnitude. | ||
* Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations. | * Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations with more noise. | ||
* Try disabling any tricks you have like dropout. | * Try disabling any tricks you have like dropout. | ||
==Overfitting== | ==Overfitting== | ||
Overfitting occurs when your training | Overfitting occurs when your model begins learning attributes specific to your training data, causing your validation loss to increase. | ||
Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting. | Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting. | ||
Recently though, overfitting has become less of a concern with larger ML models. | Recently though, overfitting has become less of a concern with larger ML models. | ||
Line 28: | Line 35: | ||
** PyTorch has many augmentations in [https://pytorch.org/docs/stable/torchvision/transforms.html torchvision.transforms]. | ** PyTorch has many augmentations in [https://pytorch.org/docs/stable/torchvision/transforms.html torchvision.transforms]. | ||
** TF/Keras has augmentations as well in tf.image. See [https://www.tensorflow.org/tutorials/images/data_augmentation Tutorials: Data Augmentation]. | ** TF/Keras has augmentations as well in tf.image. See [https://www.tensorflow.org/tutorials/images/data_augmentation Tutorials: Data Augmentation]. | ||
==NaNs and Infs== | |||
You can get NaNs and Infs in either the forward pass or the backward pass. | |||
To determine where, check the outputs of each layer and check the gradients. | |||
In PyTorch, you can do: | |||
<syntaxhighlight lang="python"> | |||
assert torch.isfinite(my_tensor).all(), "my_tensor has NaNs or Infs" | |||
</syntaxhighlight> | |||
In Tensorflow, you can do: | |||
<syntaxhighlight lang="python"> | |||
def all_finite(tensor): | |||
is_finite = tf.math.is_finite(tensor) | |||
all_finite = tf.math.reduce_all(is_finite) | |||
return all_finite.numpy() | |||
assert all_finite(my_tensor), "my_tensor has NaNs or Infs" | |||
# Or | |||
tf.debugging.assert_all_finite(my_tensor, "my_tensor has NaNs or Infs") | |||
</syntaxhighlight> | |||
Typically, you get Infs and NaNs when there is an division by ~0 in the forward or backward pass. | |||
However it is also possible that the learning rate is too high or your model is broken. | |||
I typically debug by: | |||
* Dropping the learning rate to something super small <code>1e-10</code> | |||
* Determining whether the NaN/Inf is in the forward or backward pass. | |||
* Checking that the training data has no NaNs or Infs. | |||
* Checking that there are no divides anywhere in the code or that all divides are safe. | |||
** See [https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan <code>tf.math.divide_no_nan</code>]. | |||
* Checking the gradients of trig functions in the code. | |||
\( | |||
\DeclareMathOperator{\arcsec}{arcsec} | |||
\DeclareMathOperator{\arccot}{arccot} | |||
\DeclareMathOperator{\arccsc}{arccsc} | |||
\) | |||
The following functions have unstable (e.g. Inf) gradients at certain points: | |||
* Derivative of \(\log(x)\) is \(1/x\) which is Inf near 0 | |||
* Derivative of \(\arcsin(x)\) and \(\arccos(x)\) are \(\frac{\pm 1}{\sqrt{1-x^2}}\) which are Inf near 1. | |||
* The derivatives of \(\arccsc(x)\) and \(\arcsec(x)\) are Inf near 0 and 1. | |||
If you must use one of these functions near the unstable points, I suggest performing a linear or taylor series approximation near 0 and 1. | |||
Other ways to mitigate NaNs and Infs are: | |||
* Decrease the learning rate and increase the batch size. | |||
* Gradient clipping. | |||
** For PyTorch see [https://pytorch.org/docs/master/generated/torch.nn.utils.clip_grad_norm_.html torch.nn.utils.clip_grad_norm_] and [https://pytorch.org/docs/master/generated/torch.nn.utils.clip_grad_value_.html torch.nn.utils.clip_grad_value_]. | |||
** For Tensorflow see [https://www.tensorflow.org/api_docs/python/tf/clip_by_norm tf.clip_by_norm] and [https://www.tensorflow.org/api_docs/python/tf/clip_by_value tf.clip_by_value]. | |||
* Using a safe divide which forces the denominator to have values with abs > EPS. | |||
** Note that this can cutoff gradients. | |||
==Soft Operations== | |||
The idea of soft operations are to make sure that gradients flow through the entire network rather than one specific path. | |||
One example of this is softmax which allows you to apply gradients using a one-hot encoding. | |||
* Rather than regressing a real value <math>x</math> directly, output a probability distribution. | |||
** Output scores for <math>P(x=j)</math> for some fixed set of <math>j</math>, do softmax, and take the expected value. | |||
** Or output <math>\mu, \sigma</math> and normalize the loss based on <math>\sigma</math>. |