Debugging ML Models: Difference between revisions

From David's Wiki
No edit summary
Line 1: Line 1:
Notes on debugging ML models, primarilly CNNs.   
Notes on debugging ML models, primarilly CNNs.   
Most of this is advice I've found online or gotten through mentors.
Most of this is advice I've found online or gotten through mentors or experience.


==Debugging==
==Debugging==

Revision as of 17:11, 5 April 2021

Notes on debugging ML models, primarilly CNNs.
Most of this is advice I've found online or gotten through mentors or experience.

Debugging

  • Train on a single example and see if it overfits.
    • If it doesn't overfit, there may be an issue with your code.
    • You can try increasing the capacity (e.g number of filters or number of nodes in FC) 2-4x.
      • For CNNs, these days, you can fit O(10 million) parameters on a single GPU.
      If the input is 3 channels, then the first conv layer should have more than 3 channels.
    • Check that your loss is implemented correctly and taken against the correct ground truth image.
  • Dump all inputs and outputs into TensorBoard. You may have an unexpected input or output somewhere.
  • Make sure there is no activation on the final layer.
  • If the loss is unstable or increasing, drop the learning rate to O(1e-3) or O(1e-4).
  • Try taking the loss closer to the output of the network.
    • If you apply some transformations \(f\) after the output, do \(loss = loss\_fn(f^{-1}(gt), output)\) instead of \(loss = loss\_fn(gt, f(output))\).
    • This shortens the paths the gradients need to flow through.
    • Note that this may change the per-pixel weights of the loss function.

Underfitting

If it looks like it is underfitting (e.g. if the training output and validation output are both blurry), then you can try the following.

  • Train for 4x as long until the training loss and validation loss both flatten.
  • Increase or decrease the learning rate one magnitude.
  • Make sure the batch size is a multiple of 2. Try increasing it to get more stable gradient updates or decreasing it to get faster iterations with more noise.
  • Try disabling any tricks you have like dropout.

Overfitting

Overfitting occurs when your training loss is below your validation loss.
Historically this was a big concern for ML models and people relied heavily on regularization to address overfitting.
Recently though, overfitting has become less of a concern with larger ML models.

  • Increase the capacity and depth of the model.
  • Add more training data if you can.
  • Depending on your task, you can also try data augmentation (e.g. random crops, random rotations, random hue).

NaNs and Infs

You can get NaNs and Infs in either the forward pass or the backward pass.
To determine where, check the outputs of each layer and check the gradients.

In PyTorch, you can do:

assert torch.isfinite(my_tensor).all(), "my_tensor has NaNs or Infs"

In Tensorflow, you can do:

def all_finite(tensor):
  is_finite = tf.math.is_finite(tensor)
  all_finite = tf.math.reduce_all(is_finite)
  return all_finite.numpy()

assert all_finite(my_tensor), "my_tensor has NaNs or Infs"

# Or
tf.debugging.assert_all_finite(my_tensor, "my_tensor has NaNs or Infs")

Typically, you get Infs and NaNs when there is an division by ~0 in the forward or backward pass.
However it is also possible that the learning rate is too high or your model is broken.
I typically debug by:

  • Dropping the learning rate to something super small 1e-10
  • Determining whether the NaN/Inf is in the forward or backward pass.
  • Checking that the training data has no NaNs or Infs.
  • Checking that there are no divides anywhere in the code or that all divides are safe.
  • Checking the gradients of trig functions in the code.

\( \DeclareMathOperator{\arcsec}{arcsec} \DeclareMathOperator{\arccot}{arccot} \DeclareMathOperator{\arccsc}{arccsc} \) The following functions have unstable (e.g. Inf) gradients at certain points:

  • Derivative of \(\log(x)\) is \(1/x\) which is Inf near 0
  • Derivative of \(\arcsin(x)\) and \(\arccos(x)\) are \(\frac{\pm 1}{\sqrt{1-x^2}}\) which are Inf near 1.
  • The derivatives of \(\arccsc(x)\) and \(\arcsec(x)\) are Inf near 0 and 1.

If you must use one of these functions near the unstable points, I suggest performing a linear or taylor series approximation near 0 and 1.

Other ways to mitigate NaNs and Infs are:

Soft Operations

The idea of soft operations are to make sure that gradients flow through the entire network rather than one specific path.
One example of this is softmax which allows you to apply gradients using a one-hot encoding.

  • Rather than regressing a real value \(\displaystyle x\) directly, output a probability distribution.
    • Output scores for \(\displaystyle P(x=j)\) for some fixed set of \(\displaystyle j\), do softmax, and take the expected value.
    • Or output \(\displaystyle \mu, \sigma\) and normalize the loss based on \(\displaystyle \sigma\).